找到并行的问题了,是kt_for的steal策略,会导致处理的数据的idx小于已经处理过的。保留调试信息,并行结果和串行一致了

This commit is contained in:
zzh 2025-12-30 12:48:59 +08:00
parent 84463ede19
commit d56d926b6e
9 changed files with 305 additions and 129 deletions

View File

@ -64,6 +64,35 @@ void kt_for(int n_threads, void (*func)(void*,long,int), void *data, long n)
} }
} }
static void* ktf_worker_no_steal(void* data) {
ktf_worker_t* w = (ktf_worker_t*)data;
long i;
for (;;) {
i = __sync_fetch_and_add(&w->i, w->t->n_threads);
if (i >= w->t->n)
break;
w->t->func(w->t->data, i, w - w->t->w);
}
pthread_exit(0);
}
void kt_for_no_steal(int n_threads, void (*func)(void*, long, int), void* data, long n) {
if (n_threads > 1) {
int i;
kt_for_t t;
pthread_t* tid;
t.func = func, t.data = data, t.n_threads = n_threads, t.n = n;
t.w = (ktf_worker_t*)alloca(n_threads * sizeof(ktf_worker_t));
tid = (pthread_t*)alloca(n_threads * sizeof(pthread_t));
for (i = 0; i < n_threads; ++i) t.w[i].t = &t, t.w[i].i = i;
for (i = 0; i < n_threads; ++i) pthread_create(&tid[i], 0, ktf_worker_no_steal, &t.w[i]);
for (i = 0; i < n_threads; ++i) pthread_join(tid[i], 0);
} else {
long j;
for (j = 0; j < n; ++j) func(data, j, 0);
}
}
/*************************** /***************************
* kt_for with thread pool * * kt_for with thread pool *
***************************/ ***************************/

View File

@ -6,6 +6,7 @@ extern "C" {
#endif #endif
void kt_for(int n_threads, void (*func)(void*,long,int), void *data, long n); void kt_for(int n_threads, void (*func)(void*,long,int), void *data, long n);
void kt_for_no_steal(int n_threads, void (*func)(void*, long, int), void* data, long n);
void kt_pipeline(int n_threads, void *(*func)(void*, int, void*), void *shared_data, int n_steps); void kt_pipeline(int n_threads, void *(*func)(void*, int, void*), void *shared_data, int n_steps);
void *kt_forpool_init(int n_threads); void *kt_forpool_init(int n_threads);

View File

@ -28,7 +28,7 @@ struct AuxVar {
//const static int REF_CONTEXT_PAD = 3; // 需要做一些填充 //const static int REF_CONTEXT_PAD = 3; // 需要做一些填充
//const static int REFERENCE_HALF_WINDOW_LENGTH = 150; // 需要额外多取出一些ref序列防止边界效应 //const static int REFERENCE_HALF_WINDOW_LENGTH = 150; // 需要额外多取出一些ref序列防止边界效应
static constexpr int BAM_BLOCK_NUM = 1000; // 每个线程每次处理1k个bam记录 static constexpr int BAM_BLOCK_NUM = 1; // 每个线程每次处理1k个bam记录
static int64_t processedReads; static int64_t processedReads;
sam_hdr_t* header = nullptr; // bam header sam_hdr_t* header = nullptr; // bam header
@ -36,6 +36,7 @@ struct AuxVar {
char* ref_seq = nullptr; // reference sequence char* ref_seq = nullptr; // reference sequence
int ref_len = 0; // reference sequence length int ref_len = 0; // reference sequence length
int offset = 0; // 在要求的ref序列两边多余取出的碱基数量 int offset = 0; // 在要求的ref序列两边多余取出的碱基数量
int64_t threadProcessedReads = 0; // 该线程已经处理的reads数量
BamArray *bamArr = nullptr; // bam数据数组 BamArray *bamArr = nullptr; // bam数据数组

View File

@ -28,7 +28,7 @@ struct BQSRArg {
int NUM_THREADS = 1; int NUM_THREADS = 1;
size_t MAX_MEM = ((size_t)1) << 30; // // 1G size_t MAX_MEM = ((size_t)1) << 30; // 1G
bool DUPLEX_IO = true; // bool DUPLEX_IO = true; //

View File

@ -110,22 +110,48 @@ void roundTableValues(RecalTables& rt) {
_Foreach4D(rt.cycleTable, val, { _round_val(val); }); _Foreach4D(rt.cycleTable, val, { _round_val(val); });
} }
// 打印recal tables用于调试
static void printRecalTables(const RecalTables& rt) {
_Foreach2D(rt.readGroupTable, val, {
if (val.numObservations > 0) {
fprintf(gf[0], "%ld %f %f\n", val.numObservations, val.numMismatches, val.reportedQuality);
}
});
_Foreach3D(rt.qualityScoreTable, val, {
if (val.numObservations > 0) {
fprintf(gf[1], "%ld %f %f\n", val.numObservations, val.numMismatches, val.reportedQuality);
}
});
_Foreach4D(rt.contextTable, val, {
if (val.numObservations > 0) {
fprintf(gf[2], "%ld %f %f\n", val.numObservations, val.numMismatches, val.reportedQuality);
}
});
_Foreach4D(rt.cycleTable, val, {
if (val.numObservations > 0) {
fprintf(gf[3], "%ld %f %f\n", val.numObservations, val.numMismatches, val.reportedQuality);
}
});
}
// 串行bqsr // 串行bqsr
int SerialBQSR(AuxVar &aux) { int SerialBQSR(AuxVar &aux1) {
BamBufType inBamBuf(nsgv::gBqsrArg.DUPLEX_IO); BamBufType inBamBuf(nsgv::gBqsrArg.DUPLEX_IO);
inBamBuf.Init(nsgv::gInBamFp, nsgv::gInBamHeader, nsgv::gBqsrArg.MAX_MEM, bqsrReadFilterOut); inBamBuf.Init(nsgv::gInBamFp, nsgv::gInBamHeader, nsgv::gBqsrArg.MAX_MEM, bqsrReadFilterOut);
int64_t readNumSum = 0; int64_t readNumSum = 0;
int round = 0; int round = 0;
PerReadCovariateMatrix readCovariates;
CovariateUtils::InitPerReadCovMat(readCovariates);
RecalTables& recalTables = aux.recalTables;
SamData& sd = aux.sd; // PerReadCovariateMatrix &readCovariates = aux.readCovariates;
StableArray<int>&isSNP = aux.isSNP, &isIns = aux.isIns, &isDel = aux.isDel; // 该位置是否是SNP, indel位置0不是1是 // RecalTables& recalTables = aux.recalTables;
StableArray<uint8_t> &baqArray = aux.baqArray; // SamData& sd = aux.sd;
StableArray<double> &snpErrors = aux.snpErrors, &insErrors = aux.insErrors, &delErrors = aux.delErrors; // StableArray<int>&isSNP = aux.isSNP, &isIns = aux.isIns, &isDel = aux.isDel; // 该位置是否是SNP, indel位置0不是1是
StableArray<uint8_t> &skips = aux.skips; // 该位置是否是已知位点 // StableArray<uint8_t> &baqArray = aux.baqArray;
// StableArray<double> &snpErrors = aux.snpErrors, &insErrors = aux.insErrors, &delErrors = aux.delErrors;
// StableArray<uint8_t> &skips = aux.skips; // 该位置是否是已知位点
int numProcessed = 0;
int numthreads = 2;
int BLOCK_NUM = AuxVar::BAM_BLOCK_NUM;
while (true) { while (true) {
++round; ++round;
// 一. 读取bam数据 // 一. 读取bam数据
@ -138,121 +164,175 @@ int SerialBQSR(AuxVar &aux) {
auto bams = inBamBuf.GetBamArr(); auto bams = inBamBuf.GetBamArr();
spdlog::info("{} reads processed in {} round", readNum, round); spdlog::info("{} reads processed in {} round", readNum, round);
int numBLocks = (bams.size() + BLOCK_NUM - 1) / BLOCK_NUM;
int blocksPerThread = (numBLocks + numthreads - 1) / numthreads;
int spanBlocks = numthreads * BLOCK_NUM;
// 二. 遍历每个bamread记录进行处理 // 二. 遍历每个bamread记录进行处理
for (int i = 0; i < bams.size(); ++i) { for (int j = 0; j < numthreads; ++j) { // 模拟多线程
// 1. 对每个read需要检查cigar是否合法即没有两个连续的相同的cigar而且需要将首尾的deletion处理掉目前看好像没啥影响我们忽略这一步 AuxVar &aux = nsgv::gAuxVars[j];
// 2. 对质量分数长度跟碱基长度不匹配的read缺少的质量分数用默认值补齐先忽略后边有需要再处理 PerReadCovariateMatrix& readCovariates = aux.readCovariates;
// 3. 如果bam文件之前做过bqsrtag中包含OQoriginnal RecalTables& recalTables = aux.recalTables;
// quality原始质量分数检查用户参数里是否指定用原始质量分数进行bqsr如果是则将质量分数替换为OQ否则忽略OQ先忽略 spdlog::info("bam SamData& sd = aux.sd;
// idx: {}", i); StableArray<int>&isSNP = aux.isSNP, &isIns = aux.isIns, &isDel = aux.isDel; // 该位置是否是SNP, indel位置0不是1是
BamWrap* bw = bams[i]; StableArray<uint8_t>& baqArray = aux.baqArray;
sd.init(); StableArray<double>&snpErrors = aux.snpErrors, &insErrors = aux.insErrors, &delErrors = aux.delErrors;
sd.parseBasic(bw); StableArray<uint8_t>& skips = aux.skips; // 该位置是否是已知位点
sd.rid = i + readNumSum; for (int k = 0; k < blocksPerThread; ++k)
if (sd.read_len <= 0) for (int m = 0; m < BLOCK_NUM; ++m) {
continue; int i = j * BLOCK_NUM + k * spanBlocks + m;
if (i >= bams.size()) break;
++numProcessed;
// for (int i = 0; i < bams.size(); ++i) {
// if (i % 100 == 0)
// spdlog::info("Processing read idx: {}", i);
// 1.
// 对每个read需要检查cigar是否合法即没有两个连续的相同的cigar而且需要将首尾的deletion处理掉目前看好像没啥影响我们忽略这一步
// 2. 对质量分数长度跟碱基长度不匹配的read缺少的质量分数用默认值补齐先忽略后边有需要再处理
// 3. 如果bam文件之前做过bqsrtag中包含OQoriginnal
// quality原始质量分数检查用户参数里是否指定用原始质量分数进行bqsr如果是则将质量分数替换为OQ否则忽略OQ先忽略
// spdlog::info("bam idx: {}", i);
BamWrap* bw = bams[i];
sd.init();
sd.parseBasic(bw);
sd.rid = i + readNumSum;
if (sd.read_len <= 0)
continue;
PROF_START(clip_read); PROF_START(clip_read);
// 4. 对read的两端进行检测去除hardclipadapter // 4. 对read的两端进行检测去除hardclipadapter
ReadTransformer::hardClipAdaptorSequence(bw, sd); ReadTransformer::hardClipAdaptorSequence(bw, sd);
if (sd.read_len <= 0) if (sd.read_len <= 0)
continue; continue;
// 5. 然后再去除softclip部分 // 5. 然后再去除softclip部分
ReadTransformer::hardClipSoftClippedBases(bw, sd); ReadTransformer::hardClipSoftClippedBases(bw, sd);
if (sd.read_len <= 0) if (sd.read_len <= 0)
continue; continue;
// 应用所有的变换计算samdata的相关信息 // 应用所有的变换计算samdata的相关信息
sd.applyTransformations(); sd.applyTransformations();
PROF_END(gprof[GP_clip_read], clip_read); PROF_END(gprof[GP_clip_read], clip_read);
// 6. 更新每个read的platform信息好像没啥用暂时忽略 // 6. 更新每个read的platform信息好像没啥用暂时忽略
const int nErrors = RecalFuncs::calculateIsSNPOrIndel(aux, sd, isSNP, isIns, isDel); const int nErrors = RecalFuncs::calculateIsSNPOrIndel(aux, sd, isSNP, isIns, isDel);
/*fprintf(gf[0], "%d\t", sd.read_len); /*fprintf(gf[0], "%d\t", sd.read_len);
for (int ii = 0; ii < sd.read_len; ++ii) fprintf(gf[0], "%d ", isSNP[ii]); for (int ii = 0; ii < sd.read_len; ++ii) fprintf(gf[0], "%d ", isSNP[ii]);
fprintf(gf[0], "\n"); fprintf(gf[0], "\n");
fprintf(gf[1], "%d\t", sd.read_len); fprintf(gf[1], "%d\t", sd.read_len);
for (int ii = 0; ii < sd.read_len; ++ii) fprintf(gf[1], "%d ", isIns[ii]); for (int ii = 0; ii < sd.read_len; ++ii) fprintf(gf[1], "%d ", isIns[ii]);
fprintf(gf[1], "\n"); fprintf(gf[1], "\n");
fprintf(gf[2], "%d\t", sd.read_len); fprintf(gf[2], "%d\t", sd.read_len);
for (int ii = 0; ii < sd.read_len; ++ii) fprintf(gf[2], "%d ", isDel[ii]); for (int ii = 0; ii < sd.read_len; ++ii) fprintf(gf[2], "%d ", isDel[ii]);
fprintf(gf[2], "\n"); fprintf(gf[2], "\n");
*/ */
// 7. 计算baqArray // 7. 计算baqArray
// BAQ = base alignment quality // BAQ = base alignment quality
// note for efficiency reasons we don't compute the BAQ array unless we actually have // note for efficiency reasons we don't compute the BAQ array unless we actually have
// some error to marginalize over. For ILMN data ~85% of reads have no error // some error to marginalize over. For ILMN data ~85% of reads have no error
// vector<uint8_t> baqArray; // vector<uint8_t> baqArray;
bool baqCalculated = false; bool baqCalculated = false;
if (nErrors == 0 || !nsgv::gBqsrArg.enableBAQ) { if (nErrors == 0 || !nsgv::gBqsrArg.enableBAQ) {
baqCalculated = BAQ::flatBAQArray(sd, baqArray); baqCalculated = BAQ::flatBAQArray(sd, baqArray);
} else { } else {
// baqCalculated = calculateBAQArray(nsgv::gAuxVars[0], baq, sd, baqArray); // baqCalculated = calculateBAQArray(nsgv::gAuxVars[0], baq, sd, baqArray);
} }
if (!baqCalculated) if (!baqCalculated)
continue; continue;
// 到这里基本的数据都准备好了后续就是进行bqsr的统计了 // 到这里基本的数据都准备好了后续就是进行bqsr的统计了
// 8. 计算这条read对应的协变量 // 8. 计算这条read对应的协变量
PROF_START(covariate); PROF_START(covariate);
CovariateUtils::ComputeCovariates(sd, aux.header, readCovariates, true); CovariateUtils::ComputeCovariates(sd, aux.header, readCovariates, true);
PROF_END(gprof[GP_covariate], covariate); PROF_END(gprof[GP_covariate], covariate);
// fprintf(gf[3], "%ld %ld %d %ld\n", sd.rid, readCovariates.size(), sd.read_len, readCovariates[0][0].size()); // fprintf(gf[3], "%ld %ld %d %ld\n", sd.rid, readCovariates.size(), sd.read_len, readCovariates[0][0].size());
// for (auto &arr1 : readCovariates) { // for (auto &arr1 : readCovariates) {
// for (size_t si = 0; si < sd.read_len; ++si) { // for (size_t si = 0; si < sd.read_len; ++si) {
// for (auto &val : arr1[si]) { // for (auto &val : arr1[si]) {
// fprintf(gf[3], "%d ", val); // fprintf(gf[3], "%d ", val);
// } // }
// } // }
// } // }
// fprintf(gf[3], "\n"); // fprintf(gf[3], "\n");
// fprintf(gf[3], "%ld %d\n", sd.rid, sd.read_len); // fprintf(gf[3], "%ld %d\n", sd.rid, sd.read_len);
// { // {
// auto& arr1 = readCovariates[0]; // auto& arr1 = readCovariates[0];
// { // {
// for (int pos = 0; pos < sd.read_len; ++pos) { // for (int pos = 0; pos < sd.read_len; ++pos) {
// fprintf(gf[3], "%d %d\n", pos, arr1[pos][2]); // fprintf(gf[3], "%d %d\n", pos, arr1[pos][2]);
// } // }
// } // }
// } // }
// fprintf(gf[3], "\n"); // fprintf(gf[3], "\n");
// 9. 计算这条read需要跳过的位置 // 9. 计算这条read需要跳过的位置
PROF_START(read_vcf); PROF_START(read_vcf);
RecalFuncs::calculateKnownSites(sd, aux.vcfArr, aux.header, skips); RecalFuncs::calculateKnownSites(sd, aux.vcfArr, aux.header, skips);
for (int ii = 0; ii < sd.read_len; ++ii) { for (int ii = 0; ii < sd.read_len; ++ii) {
skips[ii] = skips[ii] || (ContextCovariate::baseIndexMap[sd.bases[ii]] == -1) || skips[ii] = skips[ii] || (ContextCovariate::baseIndexMap[sd.bases[ii]] == -1) ||
sd.base_quals[ii] < nsgv::gBqsrArg.PRESERVE_QSCORES_LESS_THAN; sd.base_quals[ii] < nsgv::gBqsrArg.PRESERVE_QSCORES_LESS_THAN;
} }
PROF_GP_END(read_vcf); PROF_GP_END(read_vcf);
// fprintf(gf[0], "%ld %d\t", sd.rid, sd.read_len); #if 1
// for (int ii = 0; ii < sd.read_len; ++ii) fprintf(gf[0], "%d ", skips[ii] ? 1 : 0); int fidx = 0;
// fprintf(gf[0], "\n"); if (sd.rid % 2 == 0) fidx = 0;
else fidx = 1;
fprintf(gf[fidx], "%ld %d\t", sd.rid, sd.read_len);
for (int ii = 0; ii < sd.read_len; ++ii) fprintf(gf[fidx], "%d ", skips[ii] ? 1 : 0);
fprintf(gf[fidx], "\n");
#endif
// fprintf(gf[0], "%ld %d\t", sd.rid, sd.read_len);
// for (int ii = 0; ii < sd.read_len; ++ii) fprintf(gf[0], "%d ", skips[ii] ? 1 : 0);
// fprintf(gf[0], "\n");
// 10. 根据BAQ进一步处理snpindel得到处理后的数据 // 10. 根据BAQ进一步处理snpindel得到处理后的数据
PROF_START(frac_err); PROF_START(frac_err);
RecalFuncs::calculateFractionalErrorArray(isSNP, baqArray, snpErrors); RecalFuncs::calculateFractionalErrorArray(isSNP, baqArray, snpErrors);
RecalFuncs::calculateFractionalErrorArray(isIns, baqArray, insErrors); RecalFuncs::calculateFractionalErrorArray(isIns, baqArray, insErrors);
RecalFuncs::calculateFractionalErrorArray(isDel, baqArray, delErrors); RecalFuncs::calculateFractionalErrorArray(isDel, baqArray, delErrors);
PROF_GP_END(frac_err); PROF_GP_END(frac_err);
// aggregate all of the info into our info object, and update the data // aggregate all of the info into our info object, and update the data
// 11. 合并之前计算的数据得到info并更新bqsr table数据 // 11. 合并之前计算的数据得到info并更新bqsr table数据
ReadRecalInfo info(sd, readCovariates, skips, snpErrors, insErrors, delErrors); ReadRecalInfo info(sd, readCovariates, skips, snpErrors, insErrors, delErrors);
PROF_START(update_info); PROF_START(update_info);
RecalUtils::updateRecalTablesForRead(info, recalTables); RecalUtils::updateRecalTablesForRead(info, recalTables);
PROF_END(gprof[GP_update_info], update_info); PROF_END(gprof[GP_update_info], update_info);
}
} }
readNumSum += readNum; readNumSum += readNum;
inBamBuf.ClearAll(); // inBamBuf.ClearAll(); //
} }
#if 0
printRecalTables(recalTables);
#endif
spdlog::info("read count: {}", readNumSum); spdlog::info("read count: {}", readNumSum);
spdlog::info("processed count: {}", numProcessed);
auto& auxArr = nsgv::gAuxVars;
RecalTables& recalTables = auxArr[0].recalTables;
for (int i = 0; i < numthreads; ++i) spdlog::info("thread {} processed reads {}.", i, auxArr[i].threadProcessedReads);
for (int i = 1; i < numthreads; ++i) {
auxArr[0].threadProcessedReads += auxArr[i].threadProcessedReads;
_Foreach3DK(auxArr[i].recalTables.qualityScoreTable, qualDatum, {
if (qualDatum.numObservations > 0) {
recalTables.qualityScoreTable(k1, k2, k3).increment(qualDatum);
}
});
_Foreach4DK(auxArr[i].recalTables.contextTable, contextDatum, {
if (contextDatum.numObservations > 0) {
recalTables.contextTable(k1, k2, k3, k4).increment(contextDatum);
}
});
_Foreach4DK(auxArr[i].recalTables.cycleTable, cycleDatum, {
if (cycleDatum.numObservations > 0) {
recalTables.cycleTable(k1, k2, k3, k4).increment(cycleDatum);
}
});
}
// 12. 创建总结数据 // 12. 创建总结数据
collapseQualityScoreTableToReadGroupTable(recalTables.readGroupTable, recalTables.qualityScoreTable); collapseQualityScoreTableToReadGroupTable(recalTables.readGroupTable, recalTables.qualityScoreTable);
roundTableValues(recalTables); roundTableValues(recalTables);
@ -266,8 +346,9 @@ int SerialBQSR(AuxVar &aux) {
return 0; return 0;
} }
// 多线程处理bam数据, tmd是乱序的
static void thread_worker(void* data, long idx, int tid) { static void thread_worker(void* data, long idx, int tid) {
AuxVar& aux = ((AuxVar*)data)[tid]; AuxVar& aux = (*(vector<AuxVar>*)data)[tid];
auto& readCovariates = aux.readCovariates; auto& readCovariates = aux.readCovariates;
RecalTables& recalTables = aux.recalTables; RecalTables& recalTables = aux.recalTables;
SamData& sd = aux.sd; SamData& sd = aux.sd;
@ -276,9 +357,18 @@ static void thread_worker(void* data, long idx, int tid) {
StableArray<double>&snpErrors = aux.snpErrors, &insErrors = aux.insErrors, &delErrors = aux.delErrors; StableArray<double>&snpErrors = aux.snpErrors, &insErrors = aux.insErrors, &delErrors = aux.delErrors;
StableArray<uint8_t>& skips = aux.skips; // 该位置是否是已知位点 StableArray<uint8_t>& skips = aux.skips; // 该位置是否是已知位点
auto &bams = *aux.bamArr; auto &bams = *aux.bamArr;
// for (auto& vcf : aux.vcfArr) vcf.knownSites.clear();
#if 1
int startIdx = idx * aux.BAM_BLOCK_NUM; int startIdx = idx * aux.BAM_BLOCK_NUM;
int stopIdx = std::min((size_t)(idx + 1) * aux.BAM_BLOCK_NUM, bams.size()); int stopIdx = std::min((size_t)(idx + 1) * aux.BAM_BLOCK_NUM, bams.size());
#else
int blockReadNums = (bams.size() + nsgv::gAuxVars.size() - 1) / nsgv::gAuxVars.size();
int startIdx = idx * blockReadNums;
int stopIdx = std::min((size_t)(idx + 1) * blockReadNums, bams.size());
#endif
aux.threadProcessedReads += stopIdx - startIdx;
for (int i = startIdx; i < stopIdx; ++i) { for (int i = startIdx; i < stopIdx; ++i) {
// spdlog::info("Thread {} processing read idx: {}", tid, i);
BamWrap* bw = bams[i]; BamWrap* bw = bams[i];
sd.init(); sd.init();
sd.parseBasic(bw); sd.parseBasic(bw);
@ -314,6 +404,14 @@ static void thread_worker(void* data, long idx, int tid) {
skips[ii] || (ContextCovariate::baseIndexMap[sd.bases[ii]] == -1) || sd.base_quals[ii] < nsgv::gBqsrArg.PRESERVE_QSCORES_LESS_THAN; skips[ii] || (ContextCovariate::baseIndexMap[sd.bases[ii]] == -1) || sd.base_quals[ii] < nsgv::gBqsrArg.PRESERVE_QSCORES_LESS_THAN;
} }
// PROF_GP_END(read_vcf); // PROF_GP_END(read_vcf);
#if 1
int fidx = 0 + 2 * tid;
//if (sd.rid % 2 == 0) fidx = 0 + 2 * tid;
//else fidx = 1 + 2 * tid;
fprintf(gf[fidx], "%ld %d\t", sd.rid, sd.read_len);
for (int ii = 0; ii < sd.read_len; ++ii) fprintf(gf[fidx], "%d ", skips[ii] ? 1 : 0);
fprintf(gf[fidx], "\n");
#endif
// PROF_START(frac_err); // PROF_START(frac_err);
RecalFuncs::calculateFractionalErrorArray(isSNP, baqArray, snpErrors); RecalFuncs::calculateFractionalErrorArray(isSNP, baqArray, snpErrors);
@ -348,24 +446,54 @@ int ParallelBQSR(vector<AuxVar>& auxArr) {
}); });
spdlog::info("{} reads processed in {} round", readNum, round); spdlog::info("{} reads processed in {} round", readNum, round);
kt_for(nsgv::gBqsrArg.NUM_THREADS, thread_worker, auxArr.data(), (readNum + AuxVar::BAM_BLOCK_NUM - 1) / AuxVar::BAM_BLOCK_NUM); #if 1
kt_for_no_steal(auxArr.size(), thread_worker, &auxArr, (readNum + AuxVar::BAM_BLOCK_NUM - 1) / AuxVar::BAM_BLOCK_NUM);
#else
kt_for(auxArr.size(), thread_worker, &auxArr, auxArr.size());
#endif
readNumSum += readNum; readNumSum += readNum;
AuxVar::processedReads += readNum; AuxVar::processedReads += readNum;
inBamBuf.ClearAll(); // inBamBuf.ClearAll(); //
} }
spdlog::info("read count: {}", readNumSum); spdlog::info("read count: {}", readNumSum);
// // 12. 创建总结数据
// collapseQualityScoreTableToReadGroupTable(recalTables.readGroupTable, recalTables.qualityScoreTable); // 合并各个线程的结果
// roundTableValues(recalTables); RecalTables& recalTables = auxArr[0].recalTables;
// for (int i = 0; i < auxArr.size(); ++i)
// // 13. 量化质量分数 spdlog::info("thread {} processed reads {}.", i, auxArr[i].threadProcessedReads);
// QuantizationInfo quantInfo(recalTables, nsgv::gBqsrArg.QUANTIZING_LEVELS); for (int i = 1; i < auxArr.size(); ++i) {
// auxArr[0].threadProcessedReads += auxArr[i].threadProcessedReads;
// // 14. 输出结果 _Foreach3DK(auxArr[i].recalTables.qualityScoreTable, qualDatum, {
// RecalUtils::outputRecalibrationReport(nsgv::gBqsrArg, quantInfo, recalTables); if (qualDatum.numObservations > 0) {
// recalTables.qualityScoreTable(k1, k2, k3).increment(qualDatum);
}
});
_Foreach4DK(auxArr[i].recalTables.contextTable, contextDatum, {
if (contextDatum.numObservations > 0) {
recalTables.contextTable(k1, k2, k3, k4).increment(contextDatum);
}
});
_Foreach4DK(auxArr[i].recalTables.cycleTable, cycleDatum, {
if (cycleDatum.numObservations > 0) {
recalTables.cycleTable(k1, k2, k3, k4).increment(cycleDatum);
}
});
}
spdlog::info("All processed reads {}.", auxArr[0].threadProcessedReads);
// printRecalTables(recalTables);
// 创建总结数据
collapseQualityScoreTableToReadGroupTable(recalTables.readGroupTable, recalTables.qualityScoreTable);
roundTableValues(recalTables);
// 量化质量分数
QuantizationInfo quantInfo(recalTables, nsgv::gBqsrArg.QUANTIZING_LEVELS);
// 输出结果
RecalUtils::outputRecalibrationReport(nsgv::gBqsrArg, quantInfo, recalTables);
return 0; return 0;
} }
@ -455,8 +583,10 @@ int BaseRecalibrator() {
PROF_START(whole_process); PROF_START(whole_process);
globalInit(); globalInit();
// ret = SerialBQSR(nsgv::gAuxVars[0]); // 串行处理数据生成recal table //if (nsgv::gBqsrArg.NUM_THREADS == 1)
ret = ParallelBQSR(nsgv::gAuxVars); // 串行处理数据生成recal table // ret = SerialBQSR(nsgv::gAuxVars[0]); // 串行处理数据生成recal table
//else
ret = ParallelBQSR(nsgv::gAuxVars); // 并行处理数据生成recal table
globalDestroy(); globalDestroy();
sam_close(nsgv::gInBamFp); sam_close(nsgv::gInBamFp);
PROF_END(gprof[GP_whole_process], whole_process); PROF_END(gprof[GP_whole_process], whole_process);

View File

@ -110,6 +110,13 @@ struct RecalDatum {
empiricalQuality = UNINITIALIZED_EMPIRICAL_QUALITY; empiricalQuality = UNINITIALIZED_EMPIRICAL_QUALITY;
} }
inline void increment(const RecalDatum& other) {
numObservations += other.numObservations;
numMismatches += other.numMismatches;
reportedQuality = other.reportedQuality;
empiricalQuality = UNINITIALIZED_EMPIRICAL_QUALITY;
}
/** /**
* Add in all of the data from other into this object, updating the reported quality from the expected * Add in all of the data from other into this object, updating the reported quality from the expected
* error rate implied by the two reported qualities. * error rate implied by the two reported qualities.

View File

@ -110,9 +110,11 @@ struct RecalFuncs {
int64_t startPos = bw->start_pos(); // 闭区间 int64_t startPos = bw->start_pos(); // 闭区间
int64_t endPos = bw->end_pos(); // 闭区间 int64_t endPos = bw->end_pos(); // 闭区间
knownSites.resize_fill(sd.read_len, 0); knownSites.resize_fill(sd.read_len, 0);
// return;
// update vcfs // update vcfs
for (auto& vcf : vcfs) { for (auto& vcf : vcfs) {
#if 1
// 清理旧的interval // 清理旧的interval
while (!vcf.knownSites.empty()) { while (!vcf.knownSites.empty()) {
auto& intv = vcf.knownSites.front(); auto& intv = vcf.knownSites.front();
@ -122,9 +124,11 @@ struct RecalFuncs {
else else
break; break;
} }
if (!vcf.knownSites.empty() && vcf.knownSites.back().left > endPos) // #endif
if (!vcf.knownSites.empty() && vcf.knownSites.back().left > endPos) // 此时vcf的区域包含bam不需要读取
continue; continue;
#endif
vcf.knownSites.clear();
// 读取新的interval // 读取新的interval
int64_t fpos, flen; int64_t fpos, flen;
endPos = std::max(startPos + MAX_SITES_INTERVAL, endPos); endPos = std::max(startPos + MAX_SITES_INTERVAL, endPos);
@ -151,7 +155,7 @@ struct RecalFuncs {
tid = sam_hdr_name2tid(samHdr, stid.c_str()); tid = sam_hdr_name2tid(samHdr, stid.c_str());
int64_t varStart = BamWrap::bam_global_pos(tid, pos); int64_t varStart = BamWrap::bam_global_pos(tid, pos);
Interval varIntv(varStart, varStart + ref.size() - 1); Interval varIntv(varStart, varStart + ref.size() - 1);
if (readIntv.overlaps(varIntv)) { if (varIntv.right >= readIntv.left) {
vcf.knownSites.push_back(Interval(tid, pos - 1, pos - 1 + ref.size() - 1)); // 闭区间 vcf.knownSites.push_back(Interval(tid, pos - 1, pos - 1 + ref.size() - 1)); // 闭区间
} }
get_line_from_buf(buf, flen, &cur, &line); get_line_from_buf(buf, flen, &cur, &line);

View File

@ -118,6 +118,10 @@ int main_BaseRecalibrator(int argc, char *argv[]) {
nsgv::gBqsrArg.INPUT_FILE = program.get("--input"); nsgv::gBqsrArg.INPUT_FILE = program.get("--input");
nsgv::gBqsrArg.OUTPUT_FILE = program.get("--output"); nsgv::gBqsrArg.OUTPUT_FILE = program.get("--output");
nsgv::gBqsrArg.NUM_THREADS = program.get<int>("--num-threads"); nsgv::gBqsrArg.NUM_THREADS = program.get<int>("--num-threads");
if (nsgv::gBqsrArg.NUM_THREADS < 1) {
spdlog::error("num-threads must be positive.");
exit(1);
}
nsgv::gBqsrArg.CREATE_INDEX = program.get<bool>("--create-index"); nsgv::gBqsrArg.CREATE_INDEX = program.get<bool>("--create-index");
nsgv::gBqsrArg.REFERENCE_FILE = program.get<string>("--reference"); nsgv::gBqsrArg.REFERENCE_FILE = program.get<string>("--reference");
nsgv::gBqsrArg.KNOWN_SITES_VCFS = program.get<std::vector<string>>("--known-sites"); nsgv::gBqsrArg.KNOWN_SITES_VCFS = program.get<std::vector<string>>("--known-sites");

View File

@ -29,7 +29,7 @@ typedef std::vector<Interval> IntervalArr;
*/ */
struct Interval { struct Interval {
// const常量 // const常量
const static int CONTIG_SHIFT = 40; const static int CONTIG_SHIFT = BamWrap::MAX_CONTIG_LEN_SHIFT;
const static uint64_t POS_MASK = (1LL << CONTIG_SHIFT) - 1; const static uint64_t POS_MASK = (1LL << CONTIG_SHIFT) - 1;
// 类变量 // 类变量