gatk-3.8/public/VectorPairHMM/src/main/c++/pairhmm-template-kernel.cc

381 lines
15 KiB
C++
Raw Normal View History

Added vectorized PairHMM implementation by Mohammad and Mustafa into the Maven build of GATK. C++ code has PAPI calls for reading hardware counters Followed Khalid's suggestion for packing libVectorLoglessCaching into the jar file with Maven Native library part of git repo 1. Renamed directory structure from public/c++/VectorPairHMM to public/VectorPairHMM/src/main/c++ as per Khalid's suggestion 2. Use java.home in public/VectorPairHMM/pom.xml to pass environment variable JRE_HOME to the make process. This is needed because the Makefile needs to compile JNI code with the flag -I<JRE_HOME>/../include (among others). Assuming that the Maven build process uses a JDK (and not just a JRE), the variable java.home points to the JRE inside maven. 3. Dropped all pretense at cross-platform compatibility. Removed Mac profile from pom.xml for VectorPairHMM Moved JNI_README 1. Added the catch UnsatisfiedLinkError exception in PairHMMLikelihoodCalculationEngine.java to fall back to LOGLESS_CACHING in case the native library could not be loaded. Made VECTOR_LOGLESS_CACHING as the default implementation. 2. Updated the README with Mauricio's comments 3. baseline.cc is used within the library - if the machine supports neither AVX nor SSE4.1, the native library falls back to un-vectorized C++ in baseline.cc. 4. pairhmm-1-base.cc: This is not part of the library, but is being heavily used for debugging/profiling. Can I request that we keep it there for now? In the next release, we can delete it from the repository. 5. I agree with Mauricio about the ifdefs. I am sure you already know, but just to reassure you the debug code is not compiled into the library (because of the ifdefs) and will not affect performance. 1. Changed logger.info to logger.warn in PairHMMLikelihoodCalculationEngine.java 2. Committing the right set of files after rebase Added public license text to all C++ files Added license to Makefile Add package info to Sandbox.java
2014-02-26 13:44:20 +08:00
/*Copyright (c) 2012 The Broad Institute
*Permission is hereby granted, free of charge, to any person
*obtaining a copy of this software and associated documentation
*files (the "Software"), to deal in the Software without
*restriction, including without limitation the rights to use,
*copy, modify, merge, publish, distribute, sublicense, and/or sell
*copies of the Software, and to permit persons to whom the
*Software is furnished to do so, subject to the following
*conditions:
*The above copyright notice and this permission notice shall be
*included in all copies or substantial portions of the Software.
*THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
*EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES
*OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
*NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT
*HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY,
*WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
*FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR
*THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*/
2014-01-15 09:26:55 +08:00
#ifdef PRECISION
#include <stdint.h>
#include <assert.h>
#include <stdlib.h>
2014-01-27 03:36:06 +08:00
void CONCAT(CONCAT(precompute_masks_,SIMD_ENGINE), PRECISION)(const testcase& tc, int COLS, int numMaskVecs, MASK_TYPE (*maskArr)[NUM_DISTINCT_CHARS]) {
2014-01-27 03:36:06 +08:00
const int maskBitCnt = MAIN_TYPE_SIZE ;
2014-01-15 09:26:55 +08:00
2014-01-27 03:36:06 +08:00
for (int vi=0; vi < numMaskVecs; ++vi) {
for (int rs=0; rs < NUM_DISTINCT_CHARS; ++rs) {
maskArr[vi][rs] = 0 ;
}
maskArr[vi][AMBIG_CHAR] = MASK_ALL_ONES ;
2014-01-15 09:26:55 +08:00
}
2014-01-27 03:36:06 +08:00
for (int col=1; col < COLS; ++col) {
int mIndex = (col-1) / maskBitCnt ;
int mOffset = (col-1) % maskBitCnt ;
MASK_TYPE bitMask = ((MASK_TYPE)0x1) << (maskBitCnt-1-mOffset) ;
2014-01-15 09:26:55 +08:00
2014-01-27 03:36:06 +08:00
char hapChar = ConvertChar::get(tc.hap[col-1]);
2014-01-15 09:26:55 +08:00
2014-01-27 03:36:06 +08:00
if (hapChar == AMBIG_CHAR) {
for (int ci=0; ci < NUM_DISTINCT_CHARS; ++ci)
maskArr[mIndex][ci] |= bitMask ;
}
2014-01-15 09:26:55 +08:00
2014-01-27 03:36:06 +08:00
maskArr[mIndex][hapChar] |= bitMask ;
// bit corresponding to col 1 will be the MSB of the mask 0
// bit corresponding to col 2 will be the MSB-1 of the mask 0
// ...
// bit corresponding to col 32 will be the LSB of the mask 0
// bit corresponding to col 33 will be the MSB of the mask 1
// ...
}
2014-01-15 09:26:55 +08:00
}
2014-01-27 03:36:06 +08:00
void CONCAT(CONCAT(init_masks_for_row_,SIMD_ENGINE), PRECISION)(const testcase& tc, char* rsArr, MASK_TYPE* lastMaskShiftOut, int beginRowIndex, int numRowsToProcess) {
2014-01-15 09:26:55 +08:00
2014-01-27 03:36:06 +08:00
for (int ri=0; ri < numRowsToProcess; ++ri) {
rsArr[ri] = ConvertChar::get(tc.rs[ri+beginRowIndex-1]) ;
}
2014-01-15 09:26:55 +08:00
2014-01-27 03:36:06 +08:00
for (int ei=0; ei < AVX_LENGTH; ++ei) {
lastMaskShiftOut[ei] = 0 ;
}
2014-01-15 09:26:55 +08:00
}
2014-01-27 03:36:06 +08:00
#define SET_MASK_WORD(__dstMask, __srcMask, __lastShiftOut, __shiftBy, __maskBitCnt){ \
MASK_TYPE __bitMask = (((MASK_TYPE)0x1) << __shiftBy) - 1 ; \
MASK_TYPE __nextShiftOut = (__srcMask & __bitMask) << (__maskBitCnt - __shiftBy) ; \
__dstMask = (__srcMask >> __shiftBy) | __lastShiftOut ; \
__lastShiftOut = __nextShiftOut ; \
}
2014-01-15 09:26:55 +08:00
2014-01-27 03:36:06 +08:00
void CONCAT(CONCAT(update_masks_for_cols_,SIMD_ENGINE), PRECISION)(int maskIndex, BITMASK_VEC& bitMaskVec, MASK_TYPE (*maskArr) [NUM_DISTINCT_CHARS], char* rsArr, MASK_TYPE* lastMaskShiftOut, int maskBitCnt) {
2014-01-15 09:26:55 +08:00
2014-01-27 03:36:06 +08:00
for (int ei=0; ei < AVX_LENGTH/2; ++ei) {
SET_MASK_WORD(bitMaskVec.getLowEntry(ei), maskArr[maskIndex][rsArr[ei]],
lastMaskShiftOut[ei], ei, maskBitCnt) ;
2014-01-27 03:36:06 +08:00
int ei2 = ei + AVX_LENGTH/2 ; // the second entry index
SET_MASK_WORD(bitMaskVec.getHighEntry(ei), maskArr[maskIndex][rsArr[ei2]],
lastMaskShiftOut[ei2], ei2, maskBitCnt) ;
}
2014-01-15 09:26:55 +08:00
2014-01-27 03:36:06 +08:00
}
2014-01-15 09:26:55 +08:00
2014-01-27 03:36:06 +08:00
inline void CONCAT(CONCAT(computeDistVec,SIMD_ENGINE), PRECISION) (BITMASK_VEC& bitMaskVec, SIMD_TYPE& distm, SIMD_TYPE& _1_distm, SIMD_TYPE& distmChosen) {
2014-01-15 09:26:55 +08:00
2014-01-27 03:36:06 +08:00
distmChosen = VEC_BLENDV(distm, _1_distm, bitMaskVec.getCombinedMask()) ;
2014-01-15 09:26:55 +08:00
2014-01-27 03:36:06 +08:00
bitMaskVec.shift_left_1bit() ;
}
2014-01-15 09:26:55 +08:00
/*
* This function:
* 1- Intializes probability values p_MM, p_XX, P_YY, p_MX, p_GAPM and pack them into vectors (SSE or AVX)
* 2- Precompute parts of "distm" which only depeneds on a row number and pack it into vector
*/
2014-01-27 03:36:06 +08:00
template<class NUMBER> void CONCAT(CONCAT(initializeVectors,SIMD_ENGINE), PRECISION)(int ROWS, int COLS, NUMBER* shiftOutM, NUMBER *shiftOutX, NUMBER *shiftOutY, Context<NUMBER> ctx, testcase *tc, SIMD_TYPE *p_MM, SIMD_TYPE *p_GAPM, SIMD_TYPE *p_MX, SIMD_TYPE *p_XX, SIMD_TYPE *p_MY, SIMD_TYPE *p_YY, SIMD_TYPE *distm1D)
2014-01-15 09:26:55 +08:00
{
2014-01-27 03:36:06 +08:00
NUMBER zero = ctx._(0.0);
NUMBER init_Y = ctx.INITIAL_CONSTANT / (tc->haplen);
for (int s=0;s<ROWS+COLS+AVX_LENGTH;s++)
{
shiftOutM[s] = zero;
shiftOutX[s] = zero;
shiftOutY[s] = init_Y;
}
2014-01-15 09:26:55 +08:00
2014-01-27 03:36:06 +08:00
NUMBER *ptr_p_MM = (NUMBER *)p_MM;
NUMBER *ptr_p_XX = (NUMBER *)p_XX;
NUMBER *ptr_p_YY = (NUMBER *)p_YY;
NUMBER *ptr_p_MX = (NUMBER *)p_MX;
NUMBER *ptr_p_MY = (NUMBER *)p_MY;
NUMBER *ptr_p_GAPM = (NUMBER *)p_GAPM;
*ptr_p_MM = ctx._(0.0);
*ptr_p_XX = ctx._(0.0);
*ptr_p_YY = ctx._(0.0);
*ptr_p_MX = ctx._(0.0);
*ptr_p_MY = ctx._(0.0);
*ptr_p_GAPM = ctx._(0.0);
for (int r = 1; r < ROWS; r++)
{
int _i = tc->i[r-1] & 127;
int _d = tc->d[r-1] & 127;
int _c = tc->c[r-1] & 127;
Added vectorized PairHMM implementation by Mohammad and Mustafa into the Maven build of GATK. C++ code has PAPI calls for reading hardware counters Followed Khalid's suggestion for packing libVectorLoglessCaching into the jar file with Maven Native library part of git repo 1. Renamed directory structure from public/c++/VectorPairHMM to public/VectorPairHMM/src/main/c++ as per Khalid's suggestion 2. Use java.home in public/VectorPairHMM/pom.xml to pass environment variable JRE_HOME to the make process. This is needed because the Makefile needs to compile JNI code with the flag -I<JRE_HOME>/../include (among others). Assuming that the Maven build process uses a JDK (and not just a JRE), the variable java.home points to the JRE inside maven. 3. Dropped all pretense at cross-platform compatibility. Removed Mac profile from pom.xml for VectorPairHMM Moved JNI_README 1. Added the catch UnsatisfiedLinkError exception in PairHMMLikelihoodCalculationEngine.java to fall back to LOGLESS_CACHING in case the native library could not be loaded. Made VECTOR_LOGLESS_CACHING as the default implementation. 2. Updated the README with Mauricio's comments 3. baseline.cc is used within the library - if the machine supports neither AVX nor SSE4.1, the native library falls back to un-vectorized C++ in baseline.cc. 4. pairhmm-1-base.cc: This is not part of the library, but is being heavily used for debugging/profiling. Can I request that we keep it there for now? In the next release, we can delete it from the repository. 5. I agree with Mauricio about the ifdefs. I am sure you already know, but just to reassure you the debug code is not compiled into the library (because of the ifdefs) and will not affect performance. 1. Changed logger.info to logger.warn in PairHMMLikelihoodCalculationEngine.java 2. Committing the right set of files after rebase Added public license text to all C++ files Added license to Makefile Add package info to Sandbox.java
2014-02-26 13:44:20 +08:00
//*(ptr_p_MM+r-1) = ctx._(1.0) - ctx.ph2pr[(_i + _d) & 127];
SET_MATCH_TO_MATCH_PROB(*(ptr_p_MM+r-1), _i, _d);
2014-01-27 03:36:06 +08:00
*(ptr_p_GAPM+r-1) = ctx._(1.0) - ctx.ph2pr[_c];
*(ptr_p_MX+r-1) = ctx.ph2pr[_i];
*(ptr_p_XX+r-1) = ctx.ph2pr[_c];
*(ptr_p_MY+r-1) = ctx.ph2pr[_d];
*(ptr_p_YY+r-1) = ctx.ph2pr[_c];
}
2014-01-15 09:26:55 +08:00
2014-01-27 03:36:06 +08:00
NUMBER *ptr_distm1D = (NUMBER *)distm1D;
for (int r = 1; r < ROWS; r++)
{
int _q = tc->q[r-1] & 127;
ptr_distm1D[r-1] = ctx.ph2pr[_q];
}
2014-01-15 09:26:55 +08:00
}
/*
* This function handles pre-stripe computation:
* 1- Retrieve probaility vectors from memory
* 2- Initialize M, X, Y vectors with all 0's (for the first stripe) and shifting the last row from previous stripe for the rest
*/
2014-01-15 09:26:55 +08:00
2014-01-27 03:36:06 +08:00
template<class NUMBER> inline void CONCAT(CONCAT(stripeINITIALIZATION,SIMD_ENGINE), PRECISION)(
int stripeIdx, Context<NUMBER> ctx, testcase *tc, SIMD_TYPE &pGAPM, SIMD_TYPE &pMM, SIMD_TYPE &pMX, SIMD_TYPE &pXX, SIMD_TYPE &pMY, SIMD_TYPE &pYY,
SIMD_TYPE &rs, UNION_TYPE &rsN, SIMD_TYPE &distm, SIMD_TYPE &_1_distm, SIMD_TYPE *distm1D, SIMD_TYPE N_packed256, SIMD_TYPE *p_MM , SIMD_TYPE *p_GAPM ,
SIMD_TYPE *p_MX, SIMD_TYPE *p_XX , SIMD_TYPE *p_MY, SIMD_TYPE *p_YY, UNION_TYPE &M_t_2, UNION_TYPE &X_t_2, UNION_TYPE &M_t_1, UNION_TYPE &X_t_1,
UNION_TYPE &Y_t_2, UNION_TYPE &Y_t_1, UNION_TYPE &M_t_1_y, NUMBER* shiftOutX, NUMBER* shiftOutM)
2014-01-15 09:26:55 +08:00
{
2014-01-27 03:36:06 +08:00
int i = stripeIdx;
pGAPM = p_GAPM[i];
pMM = p_MM[i];
pMX = p_MX[i];
pXX = p_XX[i];
pMY = p_MY[i];
pYY = p_YY[i];
NUMBER zero = ctx._(0.0);
NUMBER init_Y = ctx.INITIAL_CONSTANT / (tc->haplen);
UNION_TYPE packed1; packed1.d = VEC_SET1_VAL(1.0);
UNION_TYPE packed3; packed3.d = VEC_SET1_VAL(3.0);
distm = distm1D[i];
_1_distm = VEC_SUB(packed1.d, distm);
distm = VEC_DIV(distm, packed3.d);
/* initialize M_t_2, M_t_1, X_t_2, X_t_1, Y_t_2, Y_t_1 */
M_t_2.d = VEC_SET1_VAL(zero);
X_t_2.d = VEC_SET1_VAL(zero);
if (i==0) {
M_t_1.d = VEC_SET1_VAL(zero);
X_t_1.d = VEC_SET1_VAL(zero);
Y_t_2.d = VEC_SET_LSE(init_Y);
Y_t_1.d = VEC_SET1_VAL(zero);
}
else {
X_t_1.d = VEC_SET_LSE(shiftOutX[AVX_LENGTH]);
M_t_1.d = VEC_SET_LSE(shiftOutM[AVX_LENGTH]);
Y_t_2.d = VEC_SET1_VAL(zero);
Y_t_1.d = VEC_SET1_VAL(zero);
}
M_t_1_y = M_t_1;
}
/*
* This function is the main compute kernel to compute M, X and Y
*/
2014-01-27 03:36:06 +08:00
inline void CONCAT(CONCAT(computeMXY,SIMD_ENGINE), PRECISION)(UNION_TYPE &M_t, UNION_TYPE &X_t, UNION_TYPE &Y_t, UNION_TYPE &M_t_y,
UNION_TYPE M_t_2, UNION_TYPE X_t_2, UNION_TYPE Y_t_2, UNION_TYPE M_t_1, UNION_TYPE X_t_1, UNION_TYPE M_t_1_y, UNION_TYPE Y_t_1,
SIMD_TYPE pMM, SIMD_TYPE pGAPM, SIMD_TYPE pMX, SIMD_TYPE pXX, SIMD_TYPE pMY, SIMD_TYPE pYY, SIMD_TYPE distmSel)
2014-01-15 09:26:55 +08:00
{
2014-01-27 03:36:06 +08:00
/* Compute M_t <= distm * (p_MM*M_t_2 + p_GAPM*X_t_2 + p_GAPM*Y_t_2) */
M_t.d = VEC_MUL(VEC_ADD(VEC_ADD(VEC_MUL(M_t_2.d, pMM), VEC_MUL(X_t_2.d, pGAPM)), VEC_MUL(Y_t_2.d, pGAPM)), distmSel);
//M_t.d = VEC_MUL( VEC_ADD(VEC_MUL(M_t_2.d, pMM), VEC_MUL(VEC_ADD(X_t_2.d, Y_t_2.d), pGAPM)), distmSel);
2014-01-27 03:36:06 +08:00
M_t_y = M_t;
2014-01-15 09:26:55 +08:00
2014-01-27 03:36:06 +08:00
/* Compute X_t */
X_t.d = VEC_ADD(VEC_MUL(M_t_1.d, pMX) , VEC_MUL(X_t_1.d, pXX));
2014-01-15 09:26:55 +08:00
2014-01-27 03:36:06 +08:00
/* Compute Y_t */
Y_t.d = VEC_ADD(VEC_MUL(M_t_1_y.d, pMY) , VEC_MUL(Y_t_1.d, pYY));
2014-01-15 09:26:55 +08:00
}
/*
* This is the main compute function. It operates on the matrix in s stripe manner.
* The stripe height is determined by the SIMD engine type.
* Stripe height: "AVX float": 8, "AVX double": 4, "SSE float": 4, "SSE double": 2
* For each stripe the operations are anti-diagonal based.
* Each anti-diagonal (M_t, Y_t, X_t) depends on the two previous anti-diagonals (M_t_2, X_t_2, Y_t_2, M_t_1, X_t_1, Y_t_1).
* Each stripe (except the fist one) depends on the last row of the previous stripe.
* The last stripe computation handles the addition of the last row of M and X, that's the reason for loop spliting.
*/
2014-01-27 03:36:06 +08:00
template<class NUMBER> NUMBER CONCAT(CONCAT(compute_full_prob_,SIMD_ENGINE), PRECISION) (testcase *tc, NUMBER *before_last_log = NULL)
2014-01-15 09:26:55 +08:00
{
2014-01-27 03:36:06 +08:00
int ROWS = tc->rslen + 1;
int COLS = tc->haplen + 1;
int MAVX_COUNT = (ROWS+AVX_LENGTH-1)/AVX_LENGTH;
/* Probaility arrays */
2014-01-27 03:36:06 +08:00
SIMD_TYPE p_MM [MAVX_COUNT], p_GAPM [MAVX_COUNT], p_MX [MAVX_COUNT];
SIMD_TYPE p_XX [MAVX_COUNT], p_MY [MAVX_COUNT], p_YY [MAVX_COUNT];
/* For distm precomputation */
2014-01-27 03:36:06 +08:00
SIMD_TYPE distm1D[MAVX_COUNT];
/* Carries the values from each stripe to the next stripe */
2014-01-27 03:36:06 +08:00
NUMBER shiftOutM[ROWS+COLS+AVX_LENGTH], shiftOutX[ROWS+COLS+AVX_LENGTH], shiftOutY[ROWS+COLS+AVX_LENGTH];
/* The vector to keep the anti-diagonals of M, X, Y*/
/* Current: M_t, X_t, Y_t */
/* Previous: M_t_1, X_t_1, Y_t_1 */
/* Previous to previous: M_t_2, X_t_2, Y_t_2 */
2014-01-27 03:36:06 +08:00
UNION_TYPE M_t, M_t_1, M_t_2, X_t, X_t_1, X_t_2, Y_t, Y_t_1, Y_t_2, M_t_y, M_t_1_y;
/* Probality vectors */
2014-01-27 03:36:06 +08:00
SIMD_TYPE pGAPM, pMM, pMX, pXX, pMY, pYY;
struct timeval start, end;
NUMBER result_avx2;
Context<NUMBER> ctx;
UNION_TYPE rs , rsN;
HAP_TYPE hap;
SIMD_TYPE distmSel, distmChosen ;
SIMD_TYPE distm, _1_distm;
int r, c;
NUMBER zero = ctx._(0.0);
UNION_TYPE packed1; packed1.d = VEC_SET1_VAL(1.0);
SIMD_TYPE N_packed256 = VEC_POPCVT_CHAR('N');
NUMBER init_Y = ctx.INITIAL_CONSTANT / (tc->haplen);
int remainingRows = (ROWS-1) % AVX_LENGTH;
int stripe_cnt = ((ROWS-1) / AVX_LENGTH) + (remainingRows!=0);
const int maskBitCnt = MAIN_TYPE_SIZE ;
const int numMaskVecs = (COLS+ROWS+maskBitCnt-1)/maskBitCnt ; // ceil function
/* Mask precomputation for distm*/
2014-01-27 03:36:06 +08:00
MASK_TYPE maskArr[numMaskVecs][NUM_DISTINCT_CHARS] ;
CONCAT(CONCAT(precompute_masks_,SIMD_ENGINE), PRECISION)(*tc, COLS, numMaskVecs, maskArr) ;
char rsArr[AVX_LENGTH] ;
MASK_TYPE lastMaskShiftOut[AVX_LENGTH] ;
/* Precompute initialization for probabilities and shift vector*/
2014-01-27 03:36:06 +08:00
CONCAT(CONCAT(initializeVectors,SIMD_ENGINE), PRECISION)<NUMBER>(ROWS, COLS, shiftOutM, shiftOutX, shiftOutY,
ctx, tc, p_MM, p_GAPM, p_MX, p_XX, p_MY, p_YY, distm1D);
for (int i=0;i<stripe_cnt-1;i++)
{
//STRIPE_INITIALIZATION
CONCAT(CONCAT(stripeINITIALIZATION,SIMD_ENGINE), PRECISION)(i, ctx, tc, pGAPM, pMM, pMX, pXX, pMY, pYY, rs.d, rsN, distm, _1_distm, distm1D, N_packed256, p_MM , p_GAPM ,
p_MX, p_XX , p_MY, p_YY, M_t_2, X_t_2, M_t_1, X_t_1, Y_t_2, Y_t_1, M_t_1_y, shiftOutX, shiftOutM);
CONCAT(CONCAT(init_masks_for_row_,SIMD_ENGINE), PRECISION)(*tc, rsArr, lastMaskShiftOut, i*AVX_LENGTH+1, AVX_LENGTH) ;
// Since there are no shift intrinsics in AVX, keep the masks in 2 SSE vectors
BITMASK_VEC bitMaskVec ;
for (int begin_d=1;begin_d<COLS+AVX_LENGTH;begin_d+=MAIN_TYPE_SIZE)
2014-01-15 09:26:55 +08:00
{
2014-01-27 03:36:06 +08:00
int numMaskBitsToProcess = std::min(MAIN_TYPE_SIZE, COLS+AVX_LENGTH-begin_d) ;
CONCAT(CONCAT(update_masks_for_cols_,SIMD_ENGINE), PRECISION)((begin_d-1)/MAIN_TYPE_SIZE, bitMaskVec, maskArr, rsArr, lastMaskShiftOut, maskBitCnt) ;
2014-01-27 03:36:06 +08:00
for (int mbi=0; mbi < numMaskBitsToProcess; ++mbi) {
CONCAT(CONCAT(computeDistVec,SIMD_ENGINE), PRECISION) (bitMaskVec, distm, _1_distm, distmChosen) ;
int ShiftIdx = begin_d + mbi + AVX_LENGTH;
2014-01-15 09:26:55 +08:00
2014-01-27 03:36:06 +08:00
CONCAT(CONCAT(computeMXY,SIMD_ENGINE), PRECISION)(M_t, X_t, Y_t, M_t_y, M_t_2, X_t_2, Y_t_2, M_t_1, X_t_1, M_t_1_y, Y_t_1,
pMM, pGAPM, pMX, pXX, pMY, pYY, distmChosen);
2014-01-27 03:36:06 +08:00
CONCAT(CONCAT(_vector_shift,SIMD_ENGINE), PRECISION)(M_t, shiftOutM[ShiftIdx], shiftOutM[begin_d+mbi]);
2014-01-15 09:26:55 +08:00
2014-01-27 03:36:06 +08:00
CONCAT(CONCAT(_vector_shift,SIMD_ENGINE), PRECISION)(X_t, shiftOutX[ShiftIdx], shiftOutX[begin_d+mbi]);
2014-01-15 09:26:55 +08:00
2014-01-27 03:36:06 +08:00
CONCAT(CONCAT(_vector_shift,SIMD_ENGINE), PRECISION)(Y_t_1, shiftOutY[ShiftIdx], shiftOutY[begin_d+mbi]);
2014-01-15 09:26:55 +08:00
2014-01-27 03:36:06 +08:00
M_t_2 = M_t_1; M_t_1 = M_t; X_t_2 = X_t_1; X_t_1 = X_t;
Y_t_2 = Y_t_1; Y_t_1 = Y_t; M_t_1_y = M_t_y;
}
}
}
2014-01-15 09:26:55 +08:00
2014-01-27 03:36:06 +08:00
int i = stripe_cnt-1;
{
//STRIPE_INITIALIZATION
CONCAT(CONCAT(stripeINITIALIZATION,SIMD_ENGINE), PRECISION)(i, ctx, tc, pGAPM, pMM, pMX, pXX, pMY, pYY, rs.d, rsN, distm, _1_distm, distm1D, N_packed256, p_MM , p_GAPM ,
p_MX, p_XX , p_MY, p_YY, M_t_2, X_t_2, M_t_1, X_t_1, Y_t_2, Y_t_1, M_t_1_y, shiftOutX, shiftOutM);
2014-01-15 09:26:55 +08:00
2014-01-27 03:36:06 +08:00
if (remainingRows==0) remainingRows=AVX_LENGTH;
CONCAT(CONCAT(init_masks_for_row_,SIMD_ENGINE), PRECISION)(*tc, rsArr, lastMaskShiftOut, i*AVX_LENGTH+1, remainingRows) ;
2014-01-15 09:26:55 +08:00
2014-01-27 03:36:06 +08:00
SIMD_TYPE sumM, sumX;
sumM = VEC_SET1_VAL(zero);
sumX = VEC_SET1_VAL(zero);
2014-01-15 09:26:55 +08:00
2014-01-27 03:36:06 +08:00
// Since there are no shift intrinsics in AVX, keep the masks in 2 SSE vectors
BITMASK_VEC bitMaskVec ;
2014-01-27 03:36:06 +08:00
for (int begin_d=1;begin_d<COLS+remainingRows-1;begin_d+=MAIN_TYPE_SIZE)
{
int numMaskBitsToProcess = std::min(MAIN_TYPE_SIZE, COLS+remainingRows-1-begin_d) ;
CONCAT(CONCAT(update_masks_for_cols_,SIMD_ENGINE),PRECISION)((begin_d-1)/MAIN_TYPE_SIZE, bitMaskVec, maskArr, rsArr, lastMaskShiftOut, maskBitCnt) ;
2014-01-27 03:36:06 +08:00
for (int mbi=0; mbi < numMaskBitsToProcess; ++mbi) {
2014-01-27 03:36:06 +08:00
CONCAT(CONCAT(computeDistVec,SIMD_ENGINE), PRECISION) (bitMaskVec, distm, _1_distm, distmChosen) ;
int ShiftIdx = begin_d + mbi +AVX_LENGTH;
2014-01-15 09:26:55 +08:00
2014-01-27 03:36:06 +08:00
CONCAT(CONCAT(computeMXY,SIMD_ENGINE), PRECISION)(M_t, X_t, Y_t, M_t_y, M_t_2, X_t_2, Y_t_2, M_t_1, X_t_1, M_t_1_y, Y_t_1,
pMM, pGAPM, pMX, pXX, pMY, pYY, distmChosen);
2014-01-15 09:26:55 +08:00
2014-01-27 03:36:06 +08:00
sumM = VEC_ADD(sumM, M_t.d);
CONCAT(CONCAT(_vector_shift_last,SIMD_ENGINE), PRECISION)(M_t, shiftOutM[ShiftIdx]);
2014-01-15 09:26:55 +08:00
2014-01-27 03:36:06 +08:00
sumX = VEC_ADD(sumX, X_t.d);
CONCAT(CONCAT(_vector_shift_last,SIMD_ENGINE), PRECISION)(X_t, shiftOutX[ShiftIdx]);
2014-01-15 09:26:55 +08:00
2014-01-27 03:36:06 +08:00
CONCAT(CONCAT(_vector_shift_last,SIMD_ENGINE), PRECISION)(Y_t_1, shiftOutY[ShiftIdx]);
2014-01-15 09:26:55 +08:00
2014-01-27 03:36:06 +08:00
M_t_2 = M_t_1; M_t_1 = M_t; X_t_2 = X_t_1; X_t_1 = X_t;
Y_t_2 = Y_t_1; Y_t_1 = Y_t; M_t_1_y = M_t_y;
2014-01-15 09:26:55 +08:00
2014-01-27 03:36:06 +08:00
}
}
UNION_TYPE sumMX;
sumMX.d = VEC_ADD(sumM, sumX);
result_avx2 = sumMX.f[remainingRows-1];
}
return result_avx2;
2014-01-15 09:26:55 +08:00
}
#endif