bqsr第一阶段完成了,结果还有点错误,得调试一下

This commit is contained in:
zzh 2025-12-28 14:33:45 +08:00
parent 25f079b936
commit 146055fc01
37 changed files with 4860 additions and 52 deletions

2
.gitignore vendored
View File

@ -3,7 +3,7 @@
*.d *.d
/.vscode /.vscode
/build /build
/text /test
build.sh build.sh
run.sh run.sh
test.sh test.sh

View File

@ -4,6 +4,7 @@ set(EXECUTABLE_OUTPUT_PATH "${PROJECT_BINARY_DIR}/bin")
# source codes path # source codes path
aux_source_directory(${PROJECT_SOURCE_DIR}/src MAIN_SRC) aux_source_directory(${PROJECT_SOURCE_DIR}/src MAIN_SRC)
aux_source_directory(${PROJECT_SOURCE_DIR}/src/util UTIL_SRC) aux_source_directory(${PROJECT_SOURCE_DIR}/src/util UTIL_SRC)
aux_source_directory(${PROJECT_SOURCE_DIR}/src/util/math UTIL_MATH_SRC)
aux_source_directory(${PROJECT_SOURCE_DIR}/src/bqsr BQSR_SRC) aux_source_directory(${PROJECT_SOURCE_DIR}/src/bqsr BQSR_SRC)
set(KTHREAD_FILE ${PROJECT_SOURCE_DIR}/ext/klib/kthread.c) set(KTHREAD_FILE ${PROJECT_SOURCE_DIR}/ext/klib/kthread.c)
@ -20,7 +21,12 @@ link_directories("${PROJECT_SOURCE_DIR}/ext/htslib")
set(PG_NAME "fastbqsr") set(PG_NAME "fastbqsr")
# dependency files # dependency files
add_executable(${PG_NAME} ${MAIN_SRC} ${UTIL_SRC} ${BQSR_SRC} ${KTHREAD_FILE}) add_executable(${PG_NAME}
${MAIN_SRC}
${UTIL_SRC}
${UTIL_MATH_SRC}
${BQSR_SRC}
${KTHREAD_FILE})
# link htslib # link htslib
target_link_libraries(${PG_NAME} libhts.a) target_link_libraries(${PG_NAME} libhts.a)

View File

@ -56,6 +56,26 @@ struct BQSRArg {
// end of common parameters // end of common parameters
// We always use the same covariates. The field is retained for compatibility with GATK3 reports.
bool DO_NOT_USE_STANDARD_COVARIATES = false;
//It makes no sense to run BQSR without sites. so we remove this option.
bool RUN_WITHOUT_DBSNP = false;
// We don't support SOLID. The field is retained for compatibility with GATK3 reports.
string SOLID_RECAL_MODE = "SET_Q_ZERO";
string SOLID_NOCALL_STRATEGY = "THROW_EXCEPTION";
//@Hidden @Argument(fullName = "default-platform", optional = true,
// doc = "If a read has no platform then default to the provided String. Valid options are illumina, 454, and solid.") public String
string DEFAULT_PLATFORM = "";
// @Hidden @Argument(fullName = "force-platform", optional = true,
// doc = "If provided, the platform of EVERY read will be forced to be the provided String. Valid options are illumina, 454, and "solid.")
string FORCE_PLATFORM = "";
string existingRecalibrationReport = "";
/** /**
* The context covariate will use a context of this size to calculate its covariate value for base mismatches. Must be * The context covariate will use a context of this size to calculate its covariate value for base mismatches. Must be
* between 1 and 13 (inclusive). Note that higher values will increase runtime and required java heap size. * between 1 and 13 (inclusive). Note that higher values will increase runtime and required java heap size.

View File

@ -7,18 +7,18 @@ Copyright : All right reserved by ICT
Author : Zhang Zhonghai Author : Zhang Zhonghai
Date : 2023/10/23 Date : 2023/10/23
*/ */
#include <header.h>
#include <htslib/faidx.h> #include <htslib/faidx.h>
#include <htslib/kstring.h> #include <htslib/kstring.h>
#include <htslib/sam.h> #include <htslib/sam.h>
#include <htslib/synced_bcf_reader.h> #include <htslib/synced_bcf_reader.h>
#include <htslib/thread_pool.h> #include <htslib/thread_pool.h>
#include <header.h>
#include <spdlog/spdlog.h> #include <spdlog/spdlog.h>
#include <iomanip> #include <iomanip>
#include <numeric> #include <numeric>
#include <vector>
#include <queue> #include <queue>
#include <vector>
#include "baq.h" #include "baq.h"
#include "bqsr_args.h" #include "bqsr_args.h"
@ -28,10 +28,16 @@ Date : 2023/10/23
#include "dup_metrics.h" #include "dup_metrics.h"
#include "fastbqsr_version.h" #include "fastbqsr_version.h"
#include "read_name_parser.h" #include "read_name_parser.h"
#include "read_recal_info.h"
#include "recal_datum.h"
#include "recal_tables.h"
#include "recal_utils.h"
#include "util/interval.h" #include "util/interval.h"
#include "util/linear_index.h"
#include "util/profiling.h" #include "util/profiling.h"
#include "util/utils.h" #include "util/utils.h"
#include "util/linear_index.h" #include "util/math/math_utils.h"
#include "quant_info.h"
using std::deque; using std::deque;
@ -89,7 +95,8 @@ BQSRArg gBqsrArg; // bqsr arguments
samFile* gInBamFp; // input BAM file pointer samFile* gInBamFp; // input BAM file pointer
sam_hdr_t* gInBamHeader; // input BAM header sam_hdr_t* gInBamHeader; // input BAM header
vector<AuxVar> gAuxVars; // auxiliary variables保存一些文件数据等每个线程对应一个 vector<AuxVar> gAuxVars; // auxiliary variables保存一些文件数据等每个线程对应一个
RecalTables gRecalTables; // 记录bqsr所有的数据输出table结果
vector<EventTypeValue> gEventTypes; // 需要真正计算的eventtype
// 下面是需要删除或修改的变量 // 下面是需要删除或修改的变量
std::vector<ReadNameParser> gNameParsers; // read name parser std::vector<ReadNameParser> gNameParsers; // read name parser
@ -258,12 +265,12 @@ void calculateRefOffset(BamWrap *bw, SamData &ad) {
// 计算clip处理之后剩余的碱基 // 计算clip处理之后剩余的碱基
void calculateReadBases(BamWrap* bw, SamData& ad) { void calculateReadBases(BamWrap* bw, SamData& ad) {
ad.bases.resize(ad.read_len); ad.bases.resize(ad.read_len);
ad.quals.resize(ad.read_len); ad.base_quals.resize(ad.read_len);
uint8_t* seq = bam_get_seq(bw->b); uint8_t* seq = bam_get_seq(bw->b);
uint8_t* quals = bam_get_qual(bw->b); uint8_t* quals = bam_get_qual(bw->b);
for (int i = 0; i < ad.read_len; ++i) { for (int i = 0; i < ad.read_len; ++i) {
ad.bases[i] = cBaseToChar[bam_seqi(seq, i + ad.left_clip)]; ad.bases[i] = cBaseToChar[bam_seqi(seq, i + ad.left_clip)];
ad.quals[i] = quals[i]; ad.base_quals[i] = quals[i];
} }
} }
@ -402,6 +409,7 @@ int calculateIsSNPOrIndel(AuxVar& aux, BamWrap *bw, SamData &ad, vector<int> &is
for (int j = 0; j < len; ++j) { for (int j = 0; j < len; ++j) {
// 按位置将read和ref碱基进行比较不同则是snp注意read起始位置要加上left_clip // 按位置将read和ref碱基进行比较不同则是snp注意read起始位置要加上left_clip
int snpInt = cBaseToChar[bam_seqi(seq, readPos + ad.left_clip)] == refBases[refPos] ? 0 : 1; int snpInt = cBaseToChar[bam_seqi(seq, readPos + ad.left_clip)] == refBases[refPos] ? 0 : 1;
// if (snpInt > 0) { spdlog::info("snp {}, readpos: {}", snpInt, readPos); }
isSNP[readPos] = snpInt; isSNP[readPos] = snpInt;
nEvents += snpInt; nEvents += snpInt;
readPos++; readPos++;
@ -429,6 +437,7 @@ int calculateIsSNPOrIndel(AuxVar& aux, BamWrap *bw, SamData &ad, vector<int> &is
} }
} }
nEvents += std::accumulate(isIns.begin(), isIns.end(), 0) + std::accumulate(isDel.begin(), isDel.end(), 0); nEvents += std::accumulate(isIns.begin(), isIns.end(), 0) + std::accumulate(isDel.begin(), isDel.end(), 0);
// spdlog::info("nEvents: {}", nEvents);
//spdlog::info("SNPs: {}, Ins: {}, Del: {}, total events: {}", std::accumulate(isSNP.begin(), isSNP.end(), 0), //spdlog::info("SNPs: {}, Ins: {}, Del: {}, total events: {}", std::accumulate(isSNP.begin(), isSNP.end(), 0),
// std::accumulate(isIns.begin(), isIns.end(), 0), std::accumulate(isDel.begin(), isDel.end(), 0), nEvents); // std::accumulate(isIns.begin(), isIns.end(), 0), std::accumulate(isDel.begin(), isDel.end(), 0), nEvents);
@ -438,13 +447,13 @@ int calculateIsSNPOrIndel(AuxVar& aux, BamWrap *bw, SamData &ad, vector<int> &is
} }
// 简单计算baq数组就是全部赋值为'@' (64) // 简单计算baq数组就是全部赋值为'@' (64)
bool flatBAQArray(BamWrap* bw, SamData& ad, vector<int>& baqArray) { bool flatBAQArray(BamWrap* bw, SamData& ad, vector<uint8_t>& baqArray) {
baqArray.resize(ad.read_len, (int)'@'); baqArray.resize(ad.read_len, (uint8_t)'@');
return true; return true;
} }
// 计算真实的baq数组耗时更多好像enable-baq参数默认是关闭的那就先不实现这个了 // 计算真实的baq数组耗时更多好像enable-baq参数默认是关闭的那就先不实现这个了
bool calculateBAQArray(AuxVar& aux, BAQ& baq, BamWrap* bw, SamData& ad, vector<int>& baqArray) { bool calculateBAQArray(AuxVar& aux, BAQ& baq, BamWrap* bw, SamData& ad, vector<uint8_t>& baqArray) {
baqArray.resize(ad.read_len, 0); baqArray.resize(ad.read_len, 0);
return true; return true;
} }
@ -563,7 +572,8 @@ static void calculateAndStoreErrorsInBlock(int i, int blockStartIndex, vector<in
} }
// 应该是用来处理BAQ的把不等于特定BAQ分数的碱基作为一段数据统一处理 // 应该是用来处理BAQ的把不等于特定BAQ分数的碱基作为一段数据统一处理
void calculateFractionalErrorArray(vector<int>& errorArr, vector<int>& baqArr, vector<double>& fracErrs) { void calculateFractionalErrorArray(vector<int>& errorArr, vector<uint8_t>& baqArr, vector<double>& fracErrs) {
// for (auto val : errorArr) { if (val > 0) spdlog::info("snp err val: {}", val); }
fracErrs.resize(baqArr.size()); fracErrs.resize(baqArr.size());
// errorArray和baqArray必须长度相同 // errorArray和baqArray必须长度相同
const int BLOCK_START_UNSET = -1; const int BLOCK_START_UNSET = -1;
@ -592,6 +602,95 @@ void calculateFractionalErrorArray(vector<int>& errorArr, vector<int>& baqArr, v
} }
} }
/**
* Update the recalibration statistics using the information in recalInfo.
*
* Implementation detail: we only populate the quality score table and the optional tables.
* The read group table will be populated later by collapsing the quality score table.
*
* @param recalInfo data structure holding information about the recalibration values for a single read
*/
void updateRecalTablesForRead(ReadRecalInfo &info) {
SamData &read = info.read;
PerReadCovariateMatrix& readCovars = info.covariates;
Array3D<RecalDatum>& qualityScoreTable = nsgv::gRecalTables.qualityScoreTable;
Array4D<RecalDatum>& contextTable = nsgv::gRecalTables.contextTable;
Array4D<RecalDatum>& cycleTable = nsgv::gRecalTables.cycleTable;
int readLength = read.read_len;
for (int offset = 0; offset < readLength; ++offset) {
if (!info.skips[offset]) { // 不跳过当前位置
for (int idx = 0; idx < nsgv::gEventTypes.size(); ++idx) {
// 获取四个值readgroup / qualityscore / context / cycle
EventTypeValue& event = nsgv::gEventTypes[idx];
vector<int>& covariatesAtOffset = readCovars[event.index][offset];
uint8_t qual = info.getQual(event, offset);
double isError = info.getErrorFraction(event, offset);
int readGroup = covariatesAtOffset[ReadGroupCovariate::index];
int baseQuality = covariatesAtOffset[BaseQualityCovariate::index];
// 处理base quality score协变量
// RecalUtils::IncrementDatum3keys(qualityScoreTable, qual, isError, readGroup, baseQuality, event.index);
qualityScoreTable[readGroup][baseQuality][event.index].increment(1, isError, baseQuality);
auto &d = qualityScoreTable[readGroup][baseQuality][event.index];
// spdlog::info("isError {} : {}, mis {}, obs {}", isError, info.snp_errs[offset], d.numMismatches, d.numObservations);
// 处理context covariate
int contextCovariate = covariatesAtOffset[ContextCovariate::index];
if (contextCovariate >= 0)
contextTable[readGroup][baseQuality][contextCovariate][event.index].increment(1, isError, baseQuality);
// RecalUtils::IncrementDatum4keys(nsgv::gRecalTables.contextTable, qual, isError, readGroup, baseQuality, contextCovariate,
// event.index);
// 处理cycle covariate
int cycleCovariate = covariatesAtOffset[CycleCovariate::index];
if (cycleCovariate >= 0)
cycleTable[readGroup][baseQuality][cycleCovariate][event.index].increment(1, isError, baseQuality);
// RecalUtils::IncrementDatum4keys(nsgv::gRecalTables.cycleTable, qual, isError, readGroup, baseQuality, cycleCovariate, event.index);
}
}
}
}
// 数据总结
void collapseQualityScoreTableToReadGroupTable(Array2D<RecalDatum> &byReadGroupTable, Array3D<RecalDatum> &byQualTable) {
// 遍历quality table
for (int k1 = 0; k1 < byQualTable.data.size(); ++k1) {
for (int k2 = 0; k2 < byQualTable[k1].size(); ++k2) {
for (int k3 = 0; k3 < byQualTable[k1][k2].size(); ++k3) {
auto& qualDatum = byQualTable[k1][k2][k3];
if (qualDatum.numObservations > 0) {
int rgKey = k1;
int eventIndex = k3;
// spdlog::info("k1 {}, k2 {}, k3 {}, numMis {}", k1, k2, k3, qualDatum.numMismatches);
byReadGroupTable[rgKey][eventIndex].combine(qualDatum);
// spdlog::info("rg {} {}, k3 {}", byReadGroupTable[rgKey][eventIndex].numMismatches, byReadGroupTable[rgKey][eventIndex].reportedQuality, k3);
}
}
}
}
}
/**
* To replicate the results of BQSR whether or not we save tables to disk (which we need in Spark),
* we need to trim the numbers to a few decimal placed (that's what writing and reading does).
*/
void roundTableValues(RecalTables& rt) {
#define _round_val(val) \
do { \
if (val.numObservations > 0) { \
val.numMismatches = MathUtils::RoundToNDecimalPlaces(val.numMismatches, RecalUtils::NUMBER_ERRORS_DECIMAL_PLACES); \
val.reportedQuality = MathUtils::RoundToNDecimalPlaces(val.reportedQuality, RecalUtils::REPORTED_QUALITY_DECIMAL_PLACES); \
} \
} while (0)
_Foreach2D(rt.readGroupTable, val, { _round_val(val); });
_Foreach3D(rt.qualityScoreTable, val, { _round_val(val); });
_Foreach4D(rt.contextTable, val, { _round_val(val); });
_Foreach4D(rt.cycleTable, val, { _round_val(val); });
}
// 串行bqsr // 串行bqsr
int SerialBQSR() { int SerialBQSR() {
int round = 0; int round = 0;
@ -601,6 +700,9 @@ int SerialBQSR() {
int64_t readNumSum = 0; int64_t readNumSum = 0;
// 0. 初始化一些全局数据 // 0. 初始化一些全局数据
// BAQ baq{BAQ::DEFAULT_GOP}; // BAQ baq{BAQ::DEFAULT_GOP};
RecalDatum::StaticInit();
QualityUtils::StaticInit();
MathUtils::StaticInit();
// 1. 协变量数据相关初始化 // 1. 协变量数据相关初始化
PerReadCovariateMatrix readCovariates; PerReadCovariateMatrix readCovariates;
@ -608,6 +710,14 @@ int SerialBQSR() {
ContextCovariate::InitContextCovariate(nsgv::gBqsrArg); ContextCovariate::InitContextCovariate(nsgv::gBqsrArg);
CycleCovariate::InitCycleCovariate(nsgv::gBqsrArg); CycleCovariate::InitCycleCovariate(nsgv::gBqsrArg);
// 注意初始化顺序,这个必须在协变量初始化之后
nsgv::gRecalTables.init(nsgv::gInBamHeader->hrecs->nrg);
nsgv::gEventTypes.push_back(EventType::BASE_SUBSTITUTION);
if (nsgv::gBqsrArg.computeIndelBQSRTables) {
nsgv::gEventTypes.push_back(EventType::BASE_INSERTION);
nsgv::gEventTypes.push_back(EventType::BASE_DELETION);
}
// 2. 读取bam的read group // 2. 读取bam的read group
if (nsgv::gInBamHeader->hrecs->nrg == 0) { if (nsgv::gInBamHeader->hrecs->nrg == 0) {
spdlog::error("No RG tag found in the header!"); spdlog::error("No RG tag found in the header!");
@ -641,8 +751,10 @@ int SerialBQSR() {
// 2. 对质量分数长度跟碱基长度不匹配的read缺少的质量分数用默认值补齐先忽略后边有需要再处理 // 2. 对质量分数长度跟碱基长度不匹配的read缺少的质量分数用默认值补齐先忽略后边有需要再处理
// 3. 如果bam文件之前做过bqsrtag中包含OQoriginnal quality原始质量分数检查用户参数里是否指定用原始质量分数进行bqsr如果是则将质量分数替换为OQ否则忽略OQ先忽略 // 3. 如果bam文件之前做过bqsrtag中包含OQoriginnal quality原始质量分数检查用户参数里是否指定用原始质量分数进行bqsr如果是则将质量分数替换为OQ否则忽略OQ先忽略
// 4. 对read的两端进行检测去除hardclipadapter // 4. 对read的两端进行检测去除hardclipadapter
BamWrap *bw = bams[i]; // spdlog::info("bam idx: {}", i);
BamWrap* bw = bams[i];
SamData ad; SamData ad;
ad.bw = bw;
ad.read_len = BamWrap::BamEffectiveLength(bw->b); ad.read_len = BamWrap::BamEffectiveLength(bw->b);
ad.cigar_end = bw->b->core.n_cigar; ad.cigar_end = bw->b->core.n_cigar;
if (ad.read_len <= 0) continue; if (ad.read_len <= 0) continue;
@ -676,12 +788,16 @@ int SerialBQSR() {
vector<int> isIns(ad.read_len, 0); // 该位置是否是插入位置0不是1是 vector<int> isIns(ad.read_len, 0); // 该位置是否是插入位置0不是1是
vector<int> isDel(ad.read_len, 0); // 该位置是否是删除位置0不是1是 vector<int> isDel(ad.read_len, 0); // 该位置是否是删除位置0不是1是
const int nErrors = calculateIsSNPOrIndel(nsgv::gAuxVars[0], bw, ad, isSNP, isIns, isDel); const int nErrors = calculateIsSNPOrIndel(nsgv::gAuxVars[0], bw, ad, isSNP, isIns, isDel);
// spdlog::info("nErrors: {}", nErrors);
// for (auto val : isSNP) { if (val > 0) spdlog::info("snp val: {}", val); }
//exit(0);
// 7. 计算baqArray // 7. 计算baqArray
// BAQ = base alignment quality // BAQ = base alignment quality
// note for efficiency reasons we don't compute the BAQ array unless we actually have // note for efficiency reasons we don't compute the BAQ array unless we actually have
// some error to marginalize over. For ILMN data ~85% of reads have no error // some error to marginalize over. For ILMN data ~85% of reads have no error
vector<int> baqArray; vector<uint8_t> baqArray;
bool baqCalculated = false; bool baqCalculated = false;
if (nErrors == 0 || !nsgv::gBqsrArg.enableBAQ) { if (nErrors == 0 || !nsgv::gBqsrArg.enableBAQ) {
baqCalculated = flatBAQArray(bw, ad, baqArray); baqCalculated = flatBAQArray(bw, ad, baqArray);
@ -699,28 +815,49 @@ int SerialBQSR() {
int end_pos = bw->contig_end_pos(); int end_pos = bw->contig_end_pos();
//spdlog::info("adapter: {}, read: {}, {}, strand: {}", adapter_boundary, bw->contig_pos(), end_pos, //spdlog::info("adapter: {}, read: {}, {}, strand: {}", adapter_boundary, bw->contig_pos(), end_pos,
// bw->GetReadNegativeStrandFlag() ? "reverse" : "forward"); // bw->GetReadNegativeStrandFlag() ? "reverse" : "forward");
// for (auto val : isSNP) { if (val > 0) spdlog::info("snp err val-1: {}", val); }
// 9. 计算这条read需要跳过的位置 // 9. 计算这条read需要跳过的位置
vector<bool> skip(ad.read_len, 0); vector<bool> skips(ad.read_len, 0);
PROF_START(known_sites); PROF_START(known_sites);
calculateKnownSites(bw, ad, nsgv::gAuxVars[0].vcfArr, skip); calculateKnownSites(bw, ad, nsgv::gAuxVars[0].vcfArr, skips);
for (int ii = 0; ii < ad.read_len; ++ii) { for (int ii = 0; ii < ad.read_len; ++ii) {
skip[ii] = skips[ii] = skips[ii] || (ContextCovariate::baseIndexMap[ad.bases[ii]] == -1) ||
skip[ii] || (ContextCovariate::baseIndexMap[ad.bases[ii]] == -1) || ad.quals[ii] < nsgv::gBqsrArg.PRESERVE_QSCORES_LESS_THAN; ad.base_quals[ii] < nsgv::gBqsrArg.PRESERVE_QSCORES_LESS_THAN;
} }
PROF_END(gprof[GP_read_vcf], known_sites); PROF_END(gprof[GP_read_vcf], known_sites);
// 10. 根据BAQ进一步处理snpindel得到处理后的数据 // 10. 根据BAQ进一步处理snpindel得到处理后的数据
vector<double> snpErrors, insErrors, delErrors; vector<double> snpErrors, insErrors, delErrors;
// for (auto val : isSNP) { if (val > 0) spdlog::info("snp err val-2: {}", val); }
calculateFractionalErrorArray(isSNP, baqArray, snpErrors); calculateFractionalErrorArray(isSNP, baqArray, snpErrors);
calculateFractionalErrorArray(isIns, baqArray, insErrors); calculateFractionalErrorArray(isIns, baqArray, insErrors);
calculateFractionalErrorArray(isDel, baqArray, delErrors); calculateFractionalErrorArray(isDel, baqArray, delErrors);
// for (auto val : isSNP) { if (val > 0) spdlog::info("snp val: {}", val); }
//spdlog::info("snp errors size: {}, read len: {}", snpErrors.size(), ad.read_len);
//for (auto val : snpErrors) { if (val > 0) spdlog::info("snp err val: {}", val); }
// aggregate all of the info into our info object, and update the data // aggregate all of the info into our info object, and update the data
// 11. 合并之前计算的数据得到info并更新bqsr table数据 // 11. 合并之前计算的数据得到info并更新bqsr table数据
ReadRecalInfo info(ad, readCovariates, skips, snpErrors, insErrors, delErrors);
int m = 0;
// for (auto err : snpErrors) { if (isSNP[m] > 0 || err > 0) spdlog::info("snp err: {} : {}", isSNP[m++], err); }
//exit(0);
PROF_START(update_info);
updateRecalTablesForRead(info);
PROF_END(gprof[GP_update_info], update_info);
} }
// exit(0);
// 12. 创建总结数据
collapseQualityScoreTableToReadGroupTable(nsgv::gRecalTables.readGroupTable, nsgv::gRecalTables.qualityScoreTable);
roundTableValues(nsgv::gRecalTables);
// 13. 量化质量分数
QuantizationInfo quantInfo(nsgv::gRecalTables, nsgv::gBqsrArg.QUANTIZING_LEVELS);
// 14. 输出结果
RecalUtils::outputRecalibrationReport(nsgv::gBqsrArg, quantInfo, nsgv::gRecalTables);
#if 0 #if 0
// spdlog::info("region: {} - {}", bams[0]->global_softclip_start(), bams.back()->global_softclip_end()); // spdlog::info("region: {} - {}", bams[0]->global_softclip_start(), bams.back()->global_softclip_end());
// 1. 获取bams数组覆盖的region范围 // 1. 获取bams数组覆盖的region范围

View File

@ -4,6 +4,7 @@
EventTypeValue EventType::BASE_SUBSTITUTION = {0, 'M', "Base Substitution"}; EventTypeValue EventType::BASE_SUBSTITUTION = {0, 'M', "Base Substitution"};
EventTypeValue EventType::BASE_INSERTION = {1, 'I', "Base Insertion"}; EventTypeValue EventType::BASE_INSERTION = {1, 'I', "Base Insertion"};
EventTypeValue EventType::BASE_DELETION = {2, 'D', "Base Deletion"}; EventTypeValue EventType::BASE_DELETION = {2, 'D', "Base Deletion"};
vector<EventTypeValue> EventType::EVENTS = {BASE_SUBSTITUTION, BASE_INSERTION, BASE_DELETION};
// static变量 for ContextCovariate // static变量 for ContextCovariate
int ContextCovariate::mismatchesContextSize; int ContextCovariate::mismatchesContextSize;

View File

@ -25,7 +25,7 @@ using std::string;
using std::vector; using std::vector;
/** /**
* This is where we store the pre-read covariates, also indexed by (event type) and (read position). * This is where we store the per-read covariates, also indexed by (event type) and (read position).
* Thus the array has shape { event type } x { read position (aka cycle) } x { covariate }. * Thus the array has shape { event type } x { read position (aka cycle) } x { covariate }.
* For instance, { covariate } is by default 4-dimensional (read group, base quality, context, cycle). * For instance, { covariate } is by default 4-dimensional (read group, base quality, context, cycle).
*/ */
@ -36,6 +36,7 @@ struct EventTypeValue {
int index; // 在协变量数组中对应的索引 int index; // 在协变量数组中对应的索引
char representation; char representation;
string longRepresentation; string longRepresentation;
bool operator==(const EventTypeValue& a) const { return a.index == index; }
}; };
struct EventType { struct EventType {
@ -43,6 +44,7 @@ struct EventType {
static EventTypeValue BASE_SUBSTITUTION; static EventTypeValue BASE_SUBSTITUTION;
static EventTypeValue BASE_INSERTION; static EventTypeValue BASE_INSERTION;
static EventTypeValue BASE_DELETION; static EventTypeValue BASE_DELETION;
static vector<EventTypeValue> EVENTS;
}; };
// 协变量相关的工具类 // 协变量相关的工具类
@ -140,6 +142,17 @@ struct ContextCovariate {
baseIndexMap['t'] = 3; baseIndexMap['t'] = 3;
} }
static int MaximumKeyValue() {
int length = max(mismatchesContextSize, indelsContextSize);
int key = length;
int bitOffset = LENGTH_BITS;
for (int i = 0; i < length; ++i) {
key |= (3 << bitOffset);
bitOffset += 2;
}
return key;
}
static int CreateMask(int contextSize) { static int CreateMask(int contextSize) {
int mask = 0; int mask = 0;
// create 2*contextSize worth of bits // create 2*contextSize worth of bits
@ -161,6 +174,41 @@ struct ContextCovariate {
return isNegativeStrand ? (readLength - offset - 1) : offset; return isNegativeStrand ? (readLength - offset - 1) : offset;
} }
static char baseIndexToSimpleBase(const int baseIndex) {
switch (baseIndex) {
case 0:
return 'A';
case 1:
return 'C';
case 2:
return 'G';
case 3:
return 'T';
default:
return '.';
}
}
/**
* Converts a key into the dna string representation.
*
* @param key the key representing the dna sequence
* @return the dna sequence represented by the key
*/
static string ContextFromKey(const int key) {
int length = key & LENGTH_MASK; // the first bits represent the length (in bp) of the context
int mask = 48; // use the mask to pull out bases
int offset = LENGTH_BITS;
string dna;
for (int i = 0; i < length; i++) {
int baseIndex = (key & mask) >> offset;
dna.push_back(baseIndexToSimpleBase(baseIndex));
mask <<= 2; // move the mask over to the next 2 bits
offset += 2;
}
return dna;
}
// 获取去除低质量分数碱基之后的read碱基序列将低质量分数的碱基变成N // 获取去除低质量分数碱基之后的read碱基序列将低质量分数的碱基变成N
static void GetStrandedClippedBytes(BamWrap* bw, SamData& ad, string& clippedBases, uint8_t lowQTail); static void GetStrandedClippedBytes(BamWrap* bw, SamData& ad, string& clippedBases, uint8_t lowQTail);
// Creates a int representation of a given dna string. // Creates a int representation of a given dna string.
@ -180,6 +228,8 @@ struct CycleCovariate {
static void InitCycleCovariate(BQSRArg& p) { MAXIMUM_CYCLE_VALUE = p.MAXIMUM_CYCLE_VALUE; } static void InitCycleCovariate(BQSRArg& p) { MAXIMUM_CYCLE_VALUE = p.MAXIMUM_CYCLE_VALUE; }
static int MaximumKeyValue() { return (MAXIMUM_CYCLE_VALUE << 1) + 1; }
/** /**
* Encodes the cycle number as a key. * Encodes the cycle number as a key.
*/ */
@ -200,10 +250,27 @@ struct CycleCovariate {
result++; // negative cycles get the lower-most bit set result++; // negative cycles get the lower-most bit set
} }
return result; return result;
} }
/**
* Decodes the cycle number from the key.
*/
static int CycleFromKey(const int key) {
int cycle = key >> 1; // shift so we can remove the "sign" bit
if ((key & 1) != 0) { // is the last bit set?
cycle *= -1; // then the cycle is negative
}
return cycle;
}
// Computes the encoded value of CycleCovariate's key for the given position at the read. // Computes the encoded value of CycleCovariate's key for the given position at the read.
static int CycleKey(BamWrap* bw, SamData& ad, const int baseNumber, const bool indel, const int maxCycle); static int CycleKey(BamWrap* bw, SamData& ad, const int baseNumber, const bool indel, const int maxCycle);
static void RecordValues(BamWrap* bw, SamData& ad, sam_hdr_t* header, PerReadCovariateMatrix& values, bool recordIndelValues); static void RecordValues(BamWrap* bw, SamData& ad, sam_hdr_t* header, PerReadCovariateMatrix& values, bool recordIndelValues);
}; };
// 好像不需要
struct StandardCovariateList {
ReadGroupCovariate readGroupCovariate;
BaseQualityCovariate qualityScoreCovariate;
};

View File

@ -0,0 +1,194 @@
/*
Description:
Copyright : All right reserved by ICT
Author : Zhang Zhonghai
Date : 2025/12/24
*/
#pragma once
#include <vector>
#include <assert.h>
#include "spdlog/spdlog.h"
using std::vector;
template <class T>
struct Array2D {
vector<vector<T>> data;
Array2D() { }
Array2D(int dim1, int dim2) { init(dim1, dim2); }
void init(int dim1, int dim2) { data.resize(dim1); for (auto& v : data) v.resize(dim2); }
inline T& get(int k1, int k2) { return data[k1][k2]; }
// 根据关键字,在对应位置插入数据
inline void put(const T& value, int k1, int k2) { data[k1][k2] = value; }
inline vector<T>& operator[](size_t idx) { return data[idx]; }
inline const vector<T>& operator[](size_t idx) const { return data[idx]; }
#define _Foreach2D(array, valName, codes) \
for (auto& arr1 : array.data) { \
for (auto& valName : arr1) { \
codes; \
} \
}
#define _Foreach2DK(array, valName, codes) \
do { \
int k1 = 0; \
for (auto& arr1 : array.data) { \
int k2 = 0; \
for (auto& valName : arr1) { \
codes; \
++k2; \
} \
++k1; \
} \
} while (0)
};
template <class T>
struct Array3D {
vector<vector<vector<T>>> data;
Array3D() {}
Array3D(int dim1, int dim2, int dim3) { init(dim1, dim2, dim3); }
void init(int dim1, int dim2, int dim3) {
data.resize(dim1);
for (auto& v : data) v.resize(dim2);
for (auto& v1 : data)
for (auto& v2 : v1) v2.resize(dim3);
}
inline T& get(int k1, int k2, int k3) { return data[k1][k2][k3]; }
// 根据关键字,在对应位置插入数据
inline void put(const T& value, int k1, int k2, int k3) { data[k1][k2][k3] = value; }
inline vector<vector<T>>& operator[](size_t idx) { return data[idx]; }
#define _Foreach3D(array, valName, codes) \
for (auto& arr1 : array.data) { \
for (auto& arr2 : arr1) { \
for (auto& valName : arr2) { \
codes; \
} \
} \
}
#define _Foreach3DK(array, valName, codes) \
do { \
int k1 = 0; \
for (auto& arr1 : array.data) { \
int k2 = 0; \
for (auto& arr2 : arr1) { \
int k3 = 0; \
for (auto& valName : arr2) { \
codes; \
++k3; \
} \
++k2; \
} \
++k1; \
} \
} while (0)
};
template <class T>
struct Array4D {
vector<vector<vector<vector<T>>>> data;
Array4D() {}
Array4D(int dim1, int dim2, int dim3, int dim4) { init(dim1, dim2, dim3, dim4); }
void init(int dim1, int dim2, int dim3, int dim4) {
data.resize(dim1);
for (auto& v : data) v.resize(dim2);
for (auto& v1 : data)
for (auto& v2 : v1) v2.resize(dim3);
for (auto& v1 : data)
for (auto& v2 : v1)
for (auto& v3 : v2) v3.resize(dim4);
}
inline T& get(int k1, int k2, int k3, int k4) { return data[k1][k2][k3][k4]; }
// 根据关键字,在对应位置插入数据
inline void put(const T& value, int k1, int k2, int k3, int k4) { data[k1][k2][k3][k4] = value; }
inline vector<vector<vector<T>>>& operator[](size_t idx) { return data[idx]; }
#define _Foreach4D(array, valName, codes) \
for (auto& arr1 : array.data) { \
for (auto& arr2 : arr1) { \
for (auto& arr3 : arr2) { \
for (auto& valName : arr3) { \
codes; \
} \
} \
} \
}
#define _Foreach4DK(array, valName, codes) \
do { \
int k1 = 0; \
for (auto& arr1 : array.data) { \
int k2 = 0; \
for (auto& arr2 : arr1) { \
int k3 = 0; \
for (auto& arr3 : arr2) { \
int k4 = 0; \
for (auto& valName : arr3) { \
codes; \
++k4; \
} \
++k3; \
} \
++k2; \
} \
++k1; \
} \
} while (0)
};
// 类似一个tensor
template <class T>
struct NestedArray {
vector<T> data;
vector<int> dimensions;
vector<int> dim_offset;
NestedArray() { }
template <typename... Args>
NestedArray(Args... dims) {
init(dims...);
}
template <typename... Args>
void init(Args... dims) {
(dimensions.emplace_back(std::forward<Args>(dims)), ...);
// spdlog::info("dimensions: {}", dimensions.size());
// for (auto& val : dimensions) spdlog::info("dim: {}", val);
dim_offset.resize(dimensions.size(), 1);
for (int i = dimensions.size() - 2; i >= 0; --i) {
dim_offset[i] = dim_offset[i + 1] * dimensions[i + 1];
}
data.resize(dimensions[0] * dim_offset[0]);
// for (int i = 0; i < data.size(); ++i) data[i] = i;
// for (auto& val : dim_offset) spdlog::info("dim offset: {}", val);
}
// 根据索引位置获取数据
template <typename... Args>
T& get(Args... keys) {
vector<int> keyArr;
(keyArr.emplace_back(std::forward<Args>(keys)), ...);
assert(keyArr.size() == dimensions.size());
int idx = 0;
for (int i = 0; i < keyArr.size(); ++i) {
idx += keyArr[i] * dim_offset[i];
}
return data[idx];
}
// 根据关键字,在对应位置插入数据
template <typename... Args>
void put(T value, Args... keys) {
vector<int> keyArr;
(keyArr.emplace_back(std::forward<Args>(keys)), ...);
assert(keyArr.size() == dimensions.size());
int idx = 0;
for (int i = 0; i < keyArr.size(); ++i) {
idx += keyArr[i] * dim_offset[i];
}
data[idx] = value;
}
};

View File

@ -0,0 +1,315 @@
/*
Description:
Copyright : All right reserved by ICT
Author : Zhang Zhonghai
Date : 2025/12/26
*/
#pragma once
#include <set>
#include <vector>
#include <cstdint>
#include "qual_utils.h"
using std::set;
using std::vector;
/**
* A general algorithm for quantizing quality score distributions to use a specific number of levels
*
* Takes a histogram of quality scores and a desired number of levels and produces a
* map from original quality scores -> quantized quality scores.
*
* Note that this data structure is fairly heavy-weight, holding lots of debugging and
* calculation information. If you want to use it efficiently at scale with lots of
* read groups the right way to do this:
*
* Map<ReadGroup, List<Byte>> map
* for each read group rg:
* hist = getQualHist(rg)
* QualQuantizer qq = new QualQuantizer(hist, nLevels, minInterestingQual)
* map.set(rg, qq.getOriginalToQuantizedMap())
*
* This map would then be used to look up the appropriate original -> quantized
* quals for each read as it comes in.
*/
struct QualQuantizer {
/**
* Represents a contiguous interval of quality scores.
*
* qStart and qEnd are inclusive, so qStart = qEnd = 2 is the quality score bin of 2
*/
struct QualInterval {
int qStart, qEnd, fixedQual, level;
int64_t nObservations, nErrors;
set<QualInterval> subIntervals;
/** for debugging / visualization. When was this interval created? */
int mergeOrder;
void init(const int _qStart, const int _qEnd, const int64_t _nObservations, const int64_t _nErrors, const int _level, const int _fixedQual) {
qStart = _qStart;
qEnd = _qEnd;
nObservations = _nObservations;
nErrors = _nErrors;
level = _level;
fixedQual = _fixedQual;
}
QualInterval() {
qStart = -1;
qEnd = -1;
nObservations = -1;
nErrors = -1;
fixedQual = -1;
level = -1;
mergeOrder = 0;
}
QualInterval(const int qStart, const int qEnd, const int64_t nObservations, const int64_t nErrors, const int level) {
init(qStart, qEnd, nObservations, nErrors, level, -1);
}
QualInterval(const int qStart, const int qEnd, const int64_t nObservations, const int64_t nErrors, const int level,
const set<QualInterval>& _subIntervals) {
init(qStart, qEnd, nObservations, nErrors, level, -1);
subIntervals = _subIntervals;
}
QualInterval(const int qStart, const int qEnd, const int64_t nObservations, const int64_t nErrors, const int level, const int fixedQual) {
init(qStart, qEnd, nObservations, nErrors, level, fixedQual);
}
QualInterval(const int qStart, const int qEnd, const int64_t nObservations, const int64_t nErrors, const int level, const int fixedQual,
const set<QualInterval>& _subIntervals) {
init(qStart, qEnd, nObservations, nErrors, level, fixedQual);
subIntervals = _subIntervals;
}
/**
* @return Human readable name of this interval: e.g., 10-12
*/
string getName() const { return std::to_string(qStart) + "-" + std::to_string(qEnd); }
string toString() const {
return "QQ:" + getName();
}
/**
* @return true if this bin is using a fixed qual
*/
bool hasFixedQual() const { return fixedQual != -1; }
/**
* @return the error rate (in real space) of this interval, or 0 if there are no observations
*/
double getErrorRate() const {
if (hasFixedQual())
return QualityUtils::qualToErrorProb((uint8_t)fixedQual);
else if (nObservations == 0)
return 0.0;
else
return (nErrors + 1) / (1.0 * (nObservations + 1));
}
/**
* @return the QUAL of the error rate of this interval, or the fixed qual if this interval was created with a fixed qual.
*/
uint8_t getQual() const {
if (!hasFixedQual())
return QualityUtils::errorProbToQual(getErrorRate());
else
return (uint8_t)fixedQual;
}
int compareTo(const QualInterval& qi) const { return qStart < qi.qStart ? -1 : (qStart == qi.qStart ? 0 : 1); }
/**
* Create a interval representing the merge of this interval and toMerge
*
* Errors and observations are combined
* Subintervals updated in order of left to right (determined by qStart)
* Level is 1 + highest level of this and toMerge
* Order must be updated elsewhere
*
* @param toMerge
* @return newly created merged QualInterval
*/
QualInterval merge(const QualInterval& toMerge) const {
const QualInterval &left = this->compareTo(toMerge) < 0 ? *this : toMerge;
const QualInterval &right = this->compareTo(toMerge) < 0 ? toMerge : *this;
if (left.qEnd + 1 != right.qStart) {
// throw new GATKException("Attempting to merge non-contiguous intervals: left = " + left + " right = " + right);
std::cerr << "Attempting to merge non-contiguous intervals: left = " + left.toString() + " right = " + right.toString() << std::endl;
exit(1);
}
const int64_t nCombinedObs = left.nObservations + right.nObservations;
const int64_t nCombinedErr = left.nErrors + right.nErrors;
const int level = std::max(left.level, right.level) + 1;
set<QualInterval> subIntervals;
subIntervals.insert(left);
subIntervals.insert(right);
QualInterval merged(left.qStart, right.qEnd, nCombinedObs, nCombinedErr, level, subIntervals);
return merged;
}
double getPenalty(const int minInterestingQual) const { return calcPenalty(getErrorRate(), minInterestingQual); }
/**
* Calculate the penalty of this interval, given the overall error rate for the interval
*
* If the globalErrorRate is e, this value is:
*
* sum_i |log10(e_i) - log10(e)| * nObservations_i
*
* each the index i applies to all leaves of the tree accessible from this interval
* (found recursively from subIntervals as necessary)
*
* @param globalErrorRate overall error rate in real space against which we calculate the penalty
* @return the cost of approximating the bins in this interval with the globalErrorRate
*/
double calcPenalty(const double globalErrorRate, const int minInterestingQual) const {
if (globalErrorRate == 0.0) // there were no observations, so there's no penalty
return 0.0;
if (subIntervals.empty()) {
// this is leave node
if (this->qEnd <= minInterestingQual)
// It's free to merge up quality scores below the smallest interesting one
return 0;
else {
return (std::abs(std::log10(getErrorRate()) - std::log10(globalErrorRate))) * nObservations;
}
} else {
double sum = 0;
for (const QualInterval interval : subIntervals) sum += interval.calcPenalty(globalErrorRate, minInterestingQual);
return sum;
}
}
bool operator<(const QualInterval& o) const {
return qStart < o.qStart;
}
QualInterval& operator=(const QualInterval& o) {
if (this == &o) return *this;
init(o.qStart, o.qEnd, o.nObservations, o.nErrors, o.level, o.fixedQual);
mergeOrder = o.mergeOrder;
subIntervals.clear();
for (auto& val : o.subIntervals) subIntervals.insert(val);
return *this;
}
};
/**
* Inputs to the QualQuantizer
*/
const int nLevels, minInterestingQual;
vector<int64_t>& nObservationsPerQual;
QualQuantizer(vector<int64_t>& _nObservationsPerQual, const int _nLevels, const int _minInterestingQual)
: nObservationsPerQual(_nObservationsPerQual), nLevels(_nLevels), minInterestingQual(_minInterestingQual) {
quantize();
}
/** Sorted set of qual intervals.
*
* After quantize() this data structure contains only the top-level qual intervals
*/
set<QualInterval> quantizedIntervals;
/**
* Represents a contiguous interval of quality scores.
*
* qStart and qEnd are inclusive, so qStart = qEnd = 2 is the quality score bin of 2
*/
void getOriginalToQuantizedMap(vector<uint8_t>& quantMap) {
quantMap.resize(getNQualsInHistogram(), UINT8_MAX);
for (auto& interval : quantizedIntervals) {
for (int q = interval.qStart; q <= interval.qEnd; q++) {
quantMap[q] = interval.getQual();
}
}
// if (Collections.min(map) == Byte.MIN_VALUE) throw new GATKException("quantized quality score map contains an un-initialized value");
}
int getNQualsInHistogram() { return nObservationsPerQual.size(); }
/**
* Main method for computing the quantization intervals.
*
* Invoked in the constructor after all input variables are initialized. Walks
* over the inputs and builds the min. penalty forest of intervals with exactly nLevel
* root nodes. Finds this min. penalty forest via greedy search, so is not guarenteed
* to find the optimal combination.
*
* TODO: develop a smarter algorithm
*
* @return the forest of intervals with size == nLevels
*/
void quantize() {
// create intervals for each qual individually
auto& intervals = quantizedIntervals;
for (int qStart = 0; qStart < getNQualsInHistogram(); qStart++) {
const int64_t nObs = nObservationsPerQual.at(qStart);
const double errorRate = QualityUtils::qualToErrorProb((uint8_t)qStart);
const double nErrors = nObs * errorRate;
const QualInterval qi(qStart, qStart, nObs, (int)std::floor(nErrors), 0, (uint8_t)qStart);
intervals.insert(qi);
}
// greedy algorithm:
// while ( n intervals >= nLevels ):
// find intervals to merge with least penalty
// merge it
while (intervals.size() > nLevels) {
mergeLowestPenaltyIntervals(intervals);
}
}
/**
* Helper function that finds and merges together the lowest penalty pair of intervals
* @param intervals
*/
void mergeLowestPenaltyIntervals(set<QualInterval>& intervals) {
// setup the iterators
auto it1 = intervals.begin();
auto it1p = intervals.begin();
++it1p; // skip one
// walk over the pairs of left and right, keeping track of the pair with the lowest merge penalty
QualInterval minMerge;
// if (logger.isDebugEnabled()) logger.debug("mergeLowestPenaltyIntervals: " + intervals.size());
int lastMergeOrder = 0;
while (it1p != intervals.end()) {
const QualInterval& left = *it1;
const QualInterval& right = *it1p;
const QualInterval merged = left.merge(right);
lastMergeOrder = std::max(std::max(lastMergeOrder, left.mergeOrder), right.mergeOrder);
if (minMerge.qStart == -1 || (merged.getPenalty(minInterestingQual) < minMerge.getPenalty(minInterestingQual))) {
// if (logger.isDebugEnabled()) logger.debug(" Updating merge " + minMerge);
minMerge = merged; // merge two bins that when merged incur the lowest "penalty"
}
++it1;
++it1p;
}
// now actually go ahead and merge the minMerge pair
// if (logger.isDebugEnabled()) logger.debug(" => const min merge " + minMerge);
minMerge.mergeOrder = lastMergeOrder + 1;
// intervals.removeAll(minMerge.subIntervals);
for (auto &itr : minMerge.subIntervals) {
intervals.erase(itr);
}
intervals.insert(minMerge);
// if (logger.isDebugEnabled()) logger.debug("updated intervals: " + intervals);
}
};

View File

@ -0,0 +1,12 @@
#include "qual_utils.h"
const double QualityUtils::PHRED_TO_LOG_PROB_MULTIPLIER = -std::log(10) / 10.0;
const double QualityUtils::LOG_PROB_TO_PHRED_MULTIPLIER = 1 / PHRED_TO_LOG_PROB_MULTIPLIER;
const double QualityUtils::MIN_LOG10_SCALED_QUAL = std::log10(DBL_MIN);
const double QualityUtils::MIN_PHRED_SCALED_QUAL = -10.0 * MIN_LOG10_SCALED_QUAL;
double QualityUtils::qualToErrorProbCache[MAX_QUAL + 1];
double QualityUtils::qualToProbLog10Cache[MAX_QUAL + 1];

View File

@ -0,0 +1,406 @@
/*
Description:
Copyright : All right reserved by ICT
Author : Zhang Zhonghai
Date : 2025/12/25
*/
#pragma once
#include <cassert>
#include <cmath>
#include <cstdint>
#include <climits>
#include <vector>
#include "util/math/math_utils.h"
using std::vector;
struct QualityUtils {
static void StaticInit() {
for (int i = 0; i <= MAX_QUAL; i++) {
qualToErrorProbCache[i] = qualToErrorProb((double)i);
qualToProbLog10Cache[i] = std::log10(1.0 - qualToErrorProbCache[i]);
}
}
/**
* Maximum quality score that can be encoded in a SAM/BAM file
*/
static constexpr uint8_t MAX_SAM_QUAL_SCORE = 93;
/**
* bams containing quals above this value are extremely suspicious and we should warn the user
*/
static constexpr uint8_t MAX_REASONABLE_Q_SCORE = 60;
/**
* conversion factor from phred scaled quality to log error probability and vice versa
*/
static const double PHRED_TO_LOG_PROB_MULTIPLIER;
static const double LOG_PROB_TO_PHRED_MULTIPLIER;
static const double MIN_LOG10_SCALED_QUAL;
static const double MIN_PHRED_SCALED_QUAL;
/**
* The lowest quality score for a base that is considered reasonable for statistical analysis. This is
* because Q 6 => you stand a 25% of being right, which means all bases are equally likely
*/
static constexpr uint8_t MIN_USABLE_Q_SCORE = 6;
static constexpr int MAPPING_QUALITY_UNAVAILABLE = 255;
/**
* Maximum sense quality value.
*/
static constexpr int MAX_QUAL = 254;
/**
* Cached values for qual as uint8_t calculations so they are very fast
*/
static double qualToErrorProbCache[MAX_QUAL + 1];
static double qualToProbLog10Cache[MAX_QUAL + 1];
// ----------------------------------------------------------------------
//
// These are all functions to convert a phred-scaled quality score to a probability
//
// ----------------------------------------------------------------------
/**
* Convert a phred-scaled quality score to its probability of being true (Q30 => 0.999)
*
* This is the Phred-style conversion, *not* the Illumina-style conversion.
*
* Because the input is a discretized byte value, this function uses a cache so is very efficient
*
* WARNING -- because this function takes a byte for maxQual, you must be careful in converting
* integers to byte. The appropriate way to do this is ((byte)(myInt & 0xFF))
*
* @param qual a quality score (0-255)
* @return a probability (0.0-1.0)
*/
static double qualToProb(const uint8_t qual) { return 1.0 - qualToErrorProb(qual); }
/**
* Convert a phred-scaled quality score to its probability of being true (Q30 => 0.999)
*
* This is the Phred-style conversion, *not* the Illumina-style conversion.
*
* Because the input is a double value, this function must call Math.pow so can be quite expensive
*
* @param qual a phred-scaled quality score encoded as a double. Can be non-integer values (30.5)
* @return a probability (0.0-1.0)
*/
static double qualToProb(const double qual) {
// Utils.validateArg(qual >= 0.0, ()->"qual must be >= 0.0 but got " + qual);
return 1.0 - qualToErrorProb(qual);
}
/**
* Convert a phred-scaled quality score to its log10 probability of being true (Q30 => log10(0.999))
*
* This is the Phred-style conversion, *not* the Illumina-style conversion.
*
* WARNING -- because this function takes a byte for maxQual, you must be careful in converting
* integers to byte. The appropriate way to do this is ((byte)(myInt & 0xFF))
*
* @param qual a phred-scaled quality score encoded as a double. Can be non-integer values (30.5)
* @return a probability (0.0-1.0)
*/
static double qualToProbLog10(const uint8_t qual) {
return qualToProbLog10Cache[(int)qual & 0xff]; // Map: 127 -> 127; -128 -> 128; -1 -> 255; etc.
}
/**
* Convert a log-probability to a phred-scaled value ( log(0.001) => 30 )
*
* @param prob a log-probability
* @return a phred-scaled value, not necessarily integral or bounded by {@code MAX_QUAL}
*/
static double logProbToPhred(const double prob) { return prob * LOG_PROB_TO_PHRED_MULTIPLIER; }
/**
* Convert a phred-scaled quality score to its probability of being wrong (Q30 => 0.001)
*
* This is the Phred-style conversion, *not* the Illumina-style conversion.
*
* Because the input is a double value, this function must call Math.pow so can be quite expensive
*
* @param qual a phred-scaled quality score encoded as a double. Can be non-integer values (30.5)
* @return a probability (0.0-1.0)
*/
static double qualToErrorProb(const double qual) {
assert(qual >= 0.0);
return std::pow(10.0, qual / -10.0);
}
/**v
* Convert a phred-scaled quality score to its probability of being wrong (Q30 => 0.001)
*
* This is the Phred-style conversion, *not* the Illumina-style conversion.
*
* Because the input is a byte value, this function uses a cache so is very efficient
*
* WARNING -- because this function takes a byte for maxQual, you must be careful in converting
* integers to byte. The appropriate way to do this is ((byte)(myInt & 0xFF))
*
* @param qual a phred-scaled quality score encoded as a byte
* @return a probability (0.0-1.0)
*/
static double qualToErrorProb(const uint8_t qual) {
return qualToErrorProbCache[(int)qual & 0xff]; // Map: 127 -> 127; -128 -> 128; -1 -> 255; etc.
}
/**
* Convert a phred-scaled quality score to its log10 probability of being wrong (Q30 => log10(0.001))
*
* This is the Phred-style conversion, *not* the Illumina-style conversion.
*
* The calculation is extremely efficient
*
* WARNING -- because this function takes a byte for maxQual, you must be careful in converting
* integers to byte. The appropriate way to do this is ((byte)(myInt & 0xFF))
*
* @param qual a phred-scaled quality score encoded as a byte
* @return a probability (0.0-1.0)
*/
static double qualToErrorProbLog10(const uint8_t qual) { return qualToErrorProbLog10((double)(qual & 0xFF)); }
/**
* Convert a phred-scaled quality score to its log10 probability of being wrong (Q30 => log10(0.001))
*
* This is the Phred-style conversion, *not* the Illumina-style conversion.
*
* The calculation is extremely efficient
*
* @param qual a phred-scaled quality score encoded as a double
* @return log of probability (0.0-1.0)
*/
static double qualToErrorProbLog10(const double qual) {
// Utils.validateArg(qual >= 0.0, ()->"qual must be >= 0.0 but got " + qual);
return qual * -0.1;
}
// ----------------------------------------------------------------------
//
// Functions to convert a probability to a phred-scaled quality score
//
// ----------------------------------------------------------------------
/**
* Convert a probability of being wrong to a phred-scaled quality score (0.01 => 20).
*
* Note, this function caps the resulting quality score by the public static value MAX_SAM_QUAL_SCORE
* and by 1 at the low-end.
*
* @param errorRate a probability (0.0-1.0) of being wrong (i.e., 0.01 is 1% change of being wrong)
* @return a quality score (0-MAX_SAM_QUAL_SCORE)
*/
static uint8_t errorProbToQual(const double errorRate) { return errorProbToQual(errorRate, MAX_SAM_QUAL_SCORE); }
/**
* Convert a probability of being wrong to a phred-scaled quality score (0.01 => 20).
*
* Note, this function caps the resulting quality score by the public static value MIN_REASONABLE_ERROR
* and by 1 at the low-end.
*
* WARNING -- because this function takes a byte for maxQual, you must be careful in converting
* integers to byte. The appropriate way to do this is ((byte)(myInt & 0xFF))
*
* @param errorRate a probability (0.0-1.0) of being wrong (i.e., 0.01 is 1% change of being wrong)
* @return a quality score (0-maxQual)
*/
static uint8_t errorProbToQual(const double errorRate, const uint8_t maxQual) {
// Utils.validateArg(MathUtils.isValidProbability(errorRate), ()->"errorRate must be good probability but got " + errorRate);
const double d = std::round(-10.0 * std::log10(errorRate));
return boundQual((int)d, maxQual);
}
/**
* @see #errorProbToQual(double, byte) with proper conversion of maxQual integer to a byte
*/
static uint8_t errorProbToQual(const double prob, const int maxQual) {
// Utils.validateArg(maxQual >= 0 && maxQual <= 255, ()->"maxQual must be between 0-255 but got " + maxQual);
return errorProbToQual(prob, (uint8_t)(maxQual & 0xFF));
}
/**
* Convert a probability of being right to a phred-scaled quality score (0.99 => 20).
*
* Note, this function caps the resulting quality score by the public static value MAX_SAM_QUAL_SCORE
* and by 1 at the low-end.
*
* @param prob a probability (0.0-1.0) of being right
* @return a quality score (0-MAX_SAM_QUAL_SCORE)
*/
static uint8_t trueProbToQual(const double prob) { return trueProbToQual(prob, MAX_SAM_QUAL_SCORE); }
/**
* Convert a probability of being right to a phred-scaled quality score (0.99 => 20).
*
* Note, this function caps the resulting quality score by the min probability allowed (EPS).
* So for example, if prob is 1e-6, which would imply a Q-score of 60, and EPS is 1e-4,
* the result of this function is actually Q40.
*
* Note that the resulting quality score, regardless of EPS, is capped by MAX_SAM_QUAL_SCORE and
* bounded on the low-side by 1.
*
* WARNING -- because this function takes a byte for maxQual, you must be careful in converting
* integers to byte. The appropriate way to do this is ((byte)(myInt & 0xFF))
*
* @param trueProb a probability (0.0-1.0) of being right
* @param maxQual the maximum quality score we are allowed to emit here, regardless of the error rate
* @return a phred-scaled quality score (0-maxQualScore) as a byte
*/
static uint8_t trueProbToQual(const double trueProb, const uint8_t maxQual) {
// Utils.validateArg(MathUtils.isValidProbability(trueProb), ()->"trueProb must be good probability but got " + trueProb);
const double lp = std::round(-10.0 * MathUtils::log10OneMinusX(trueProb));
return boundQual((int)lp, maxQual);
}
/**
* @see #trueProbToQual(double, byte) with proper conversion of maxQual to a byte
*/
static uint8_t trueProbToQual(const double prob, const int maxQual) {
// Utils.validateArg(maxQual >= 0 && maxQual <= 255, ()->"maxQual must be between 0-255 but got " + maxQual);
return trueProbToQual(prob, (uint8_t)(maxQual & 0xFF));
}
/**
* Convert a probability of being right to a phred-scaled quality score of being wrong as a double
*
* This is a very generic method, that simply computes a phred-scaled double quality
* score given an error rate. It has the same precision as a normal double operation
*
* @param trueRate the probability of being right (0.0-1.0)
* @return a phred-scaled version of the error rate implied by trueRate
*/
static double phredScaleCorrectRate(const double trueRate) { return phredScaleLog10ErrorRate(MathUtils::log10OneMinusX(trueRate)); }
/**
* Convert a probability of being wrong to a phred-scaled quality score as a double
*
* This is a very generic method, that simply computes a phred-scaled double quality
* score given an error rate. It has the same precision as a normal double operation
*
* @param errorRate the probability of being wrong (0.0-1.0)
* @return a phred-scaled version of the error rate
*/
static double phredScaleErrorRate(const double errorRate) { return phredScaleLog10ErrorRate(std::log10(errorRate)); }
/**
* Convert a log10 probability of being wrong to a phred-scaled quality score as a double
*
* This is a very generic method, that simply computes a phred-scaled double quality
* score given an error rate. It has the same precision as a normal double operation
*
* @param errorRateLog10 the log10 probability of being wrong (0.0-1.0). Can be -Infinity, in which case
* the result is MIN_PHRED_SCALED_QUAL
* @return a phred-scaled version of the error rate
*/
static double phredScaleLog10ErrorRate(const double errorRateLog10) {
// Utils.validateArg(MathUtils.isValidLog10Probability(errorRateLog10), ()->"errorRateLog10 must be good probability but got " + errorRateLog10);
// abs is necessary for edge base with errorRateLog10 = 0 producing -0.0 doubles
return std::abs(-10.0 * std::max(errorRateLog10, MIN_LOG10_SCALED_QUAL));
}
/**
* Convert a log10 probability of being right to a phred-scaled quality score of being wrong as a double
*
* This is a very generic method, that simply computes a phred-scaled double quality
* score given an error rate. It has the same precision as a normal double operation
*
* @param trueRateLog10 the log10 probability of being right (0.0-1.0). Can be -Infinity to indicate
* that the result is impossible in which MIN_PHRED_SCALED_QUAL is returned
* @return a phred-scaled version of the error rate implied by trueRate
*/
static double phredScaleLog10CorrectRate(const double trueRateLog10) { return phredScaleCorrectRate(std::pow(10.0, trueRateLog10)); }
// ----------------------------------------------------------------------
//
// Routines to bound a quality score to a reasonable range
//
// ----------------------------------------------------------------------
/**
* Return a quality score that bounds qual by MAX_SAM_QUAL_SCORE and 1
*
* @param qual the uncapped quality score as an integer
* @return the bounded quality score
*/
static uint8_t boundQual(int qual) { return boundQual(qual, MAX_SAM_QUAL_SCORE); }
/**
* Return a quality score that bounds qual by maxQual and 1
*
* WARNING -- because this function takes a byte for maxQual, you must be careful in converting
* integers to byte. The appropriate way to do this is ((byte)(myInt & 0xFF))
*
* @param qual the uncapped quality score as an integer. Can be < 0 (which may indicate an error in the
* client code), which will be brought back to 1, but this isn't an error, as some
* routines may use this functionality (BaseRecalibrator, for example)
* @param maxQual the maximum quality score, must be less < 255
* @return the bounded quality score
*/
static uint8_t boundQual(const int qual, const uint8_t maxQual) { return (uint8_t)(std::max(std::min(qual, maxQual & 0xFF), 1) & 0xFF); }
/**
* Calculate the sum of phred scores.
* @param phreds the phred score values.
* @return the sum.
*/
static double phredSum(const vector<double>& phreds) {
switch (phreds.size()) {
case 0:
return DBL_MAX;
case 1:
return phreds[0];
case 2:
return phredSum(phreds[0], phreds[1]);
case 3:
return phredSum(phreds[0], phreds[1], phreds[2]);
default:
vector<double> log10Vals(phreds.size()); // todo: 优化加速
for (int i = 0; i < log10Vals.size(); i++) {
log10Vals[i] = phreds[i] * -0.1;
}
return -10 * MathUtils::log10SumLog10(log10Vals);
}
}
/**
* Calculate the sum of two phred scores.
* <p>
* As any sum, this operation is commutative.
* </p>
* @param a first phred score value.
* @param b second phred score value.
* @return the sum.
*/
static double phredSum(const double a, const double b) { return -10 * MathUtils::log10SumLog10(a * -.1, b * -.1); }
/**
* Calculate the sum of three phred scores.
* <p>
* As any sum, this operation is commutative.
* </p>
* @param a first phred score value.
* @param b second phred score value.
* @param c thrid phred score value.
* @return the sum.
*/
static double phredSum(const double a, const double b, const double c) { return -10 * MathUtils::log10SumLog10(a * -.1, b * -.1, c * -.1); }
};

View File

@ -0,0 +1,41 @@
/*
Description:
Copyright : All right reserved by ICT
Author : Zhang Zhonghai
Date : 2025/12/24
*/
#pragma once
#include <vector>
#include <stdint.h>
#include "recal_tables.h"
#include "qual_utils.h"
#include "util/math/math_utils.h"
#include "qual_quantizer.h"
using std::vector;
struct QuantizationInfo {
vector<uint8_t> quantizedQuals;
vector<int64_t> empiricalQualCounts; // 直方图
int quantizationLevels; // 量化等级
QuantizationInfo() {}
QuantizationInfo(RecalTables &recalTables, const int levels) {
quantizationLevels = levels;
empiricalQualCounts.resize(QualityUtils::MAX_SAM_QUAL_SCORE + 1, 0);
auto& qualTable = recalTables.qualityScoreTable;
_Foreach3D(qualTable, val, {
// convert the empirical quality to an integer ( it is already capped by MAX_QUAL )
const int empiricalQual = MathUtils::fastRound(val.getEmpiricalQuality());
empiricalQualCounts[empiricalQual] += val.getNumObservations(); // add the number of observations for every key
});
// quantizeQualityScores
QualQuantizer quantizer(empiricalQualCounts, levels, QualityUtils::MIN_USABLE_Q_SCORE);
quantizer.getOriginalToQuantizedMap(quantizedQuals);
}
};

View File

@ -0,0 +1,76 @@
/*
Description: sam
Copyright : All right reserved by ICT
Author : Zhang Zhonghai
Date : 2025/12/24
*/
#pragma once
#include "util/bam_wrap.h"
#include "covariate.h"
struct ReadRecalInfo {
SamData& read;
int length;
PerReadCovariateMatrix& covariates;
vector<bool>& skips;
FastArray<uint8_t>&base_quals, &ins_quals, &del_quals;
vector<double>&snp_errs, &ins_errs, &del_errs;
ReadRecalInfo(SamData& _read, PerReadCovariateMatrix& _covariates, vector<bool>& _skips, vector<double>& _snp_errs, vector<double>& _ins_errs,
vector<double>& _del_errs)
: read(_read),
covariates(_covariates),
skips(_skips),
base_quals(_read.base_quals),
ins_quals(_read.ins_quals),
del_quals(_read.del_quals),
snp_errs(_snp_errs),
ins_errs(_ins_errs),
del_errs(_del_errs) {
length = _read.read_len;
}
/**
* Get the qual score for event type at offset
*
* @param eventType the type of event we want the qual for
* @param offset the offset into this read for the qual
* @return a valid quality score for event at offset
*/
// 获取跟eventType对应的quality
uint8_t getQual(const EventTypeValue &event, const int offset) {
if (event == EventType::BASE_SUBSTITUTION) {
return base_quals[offset];
} else if (event == EventType::BASE_INSERTION) {
return ins_quals[offset];
} else {
return del_quals[offset];
}
}
/**
* Get the error fraction for event type at offset
*
* The error fraction is a value between 0 and 1 that indicates how much certainty we have
* in the error occurring at offset. A value of 1 means that the error definitely occurs at this
* site, a value of 0.0 means it definitely doesn't happen here. 0.5 means that half the weight
* of the error belongs here
*
* @param eventType the type of event we want the qual for
* @param offset the offset into this read for the qual
* @return a fractional weight for an error at this offset
*/
double getErrorFraction(const EventTypeValue& event, const int offset) {
if (event == EventType::BASE_SUBSTITUTION) {
return snp_errs[offset];
} else if (event == EventType::BASE_INSERTION) {
return ins_errs[offset];
} else {
return del_errs[offset];
}
}
};

View File

@ -0,0 +1,3 @@
#include "recal_datum.h"
double RecalDatum::logPriorCache[MAX_GATK_USABLE_Q_SCORE + 1];

View File

@ -0,0 +1,246 @@
/*
Description: bqsr
Copyright : All right reserved by ICT
Author : Zhang Zhonghai
Date : 2025/12/24
*/
#pragma once
#include <stdint.h>
#include <float.h>
#include "qual_utils.h"
#include "util/math/normal_dist.h"
#include "util/math/math_utils.h"
/**
* The container for the 4-tuple
*
* ( reported quality, empirical quality, num observations, num mismatches/errors )
*
* for a given set of covariates.
*
*/
struct RecalDatum {
static constexpr uint8_t MAX_RECALIBRATED_Q_SCORE = 93; // SAMUtils.MAX_PHRED_SCORE;
static constexpr int UNINITIALIZED_EMPIRICAL_QUALITY = -1;
static constexpr double MULTIPLIER = 100000.0; // See discussion in numMismatches about what the multiplier is.
/**
* used when calculating empirical qualities to avoid division by zero
*/
static constexpr int SMOOTHING_CONSTANT = 1;
static constexpr uint64_t MAX_NUMBER_OF_OBSERVATIONS = INT_MAX - 1;
/**
* Quals above this value should be capped down to this value (because they are too high)
* in the base quality score recalibrator
*/
static constexpr uint8_t MAX_GATK_USABLE_Q_SCORE = 40;
static double logPriorCache[MAX_GATK_USABLE_Q_SCORE + 1];
static void StaticInit() {
// normal distribution describing P(empiricalQuality - reportedQuality). Its mean is zero because a priori we expect
// no systematic bias in the reported quality score
const double mean = 0.0;
const double sigma = 0.5; // with these parameters, deltas can shift at most ~20 Q points
const NormalDistribution gaussian(mean, sigma);
for (int i = 0; i <= MAX_GATK_USABLE_Q_SCORE; i++) {
logPriorCache[i] = gaussian.logDensity(i);
}
}
/**
* Estimated reported quality score based on combined data's individual q-reporteds and number of observations.
* The estimating occurs when collapsing counts across different reported qualities.
*/
// 测序仪给出的原始质量分数
double reportedQuality = 0.0;
/**
* The empirical quality for datums that have been collapsed together (by read group and reported quality, for example).
*
* This variable was historically a double, but {@link #bayesianEstimateOfEmpiricalQuality} has always returned an integer qual score.
* Thus the type has been changed to integer in February 2025 to highlight this implementation detail. It does not change the output.
*/
// 计算出来的真实质量分数
int empiricalQuality = 0;
/**
* Number of bases seen in total
*/
// 这个字段也用来判断当前datum的有效性只有到numObservations > 0时这个datum才有效因为如果为0说明这个datum都没有出现过
uint64_t numObservations = 0;
/**
* Number of bases seen that didn't match the reference
* (actually sum of the error weights - so not necessarily a whole number)
* Stored with an internal multiplier to keep it closer to the floating-point sweet spot and avoid numerical error
* (see https://github.com/broadinstitute/gatk/wiki/Numerical-errors ).
* However, the value of the multiplier influences the results.
* For example, you get different results for 1000.0 and 10000.0
* See MathUtilsUnitTest.testAddDoubles for a demonstration.
* The value of the MULTIPLIER that we found to give consistent results insensitive to sorting is 10000.0;
*/
double numMismatches = 0.0;
RecalDatum() {}
RecalDatum(const uint64_t _numObservations, const double _numMismatches, const uint8_t _reportedQuality) {
numObservations = _numObservations;
numMismatches = _numMismatches * MULTIPLIER;
reportedQuality = _reportedQuality;
empiricalQuality = UNINITIALIZED_EMPIRICAL_QUALITY;
}
void increment(const uint64_t incObservations, const double incMismatches) {
numObservations += incObservations;
numMismatches += (incMismatches * MULTIPLIER); // the multiplier used to avoid underflow, or something like that.
empiricalQuality = UNINITIALIZED_EMPIRICAL_QUALITY;
}
void increment(const uint64_t incObservations, const double incMismatches, int baseQuality) {
numObservations += incObservations;
numMismatches += (incMismatches * MULTIPLIER); // the multiplier used to avoid underflow, or something like that.
reportedQuality = baseQuality;
empiricalQuality = UNINITIALIZED_EMPIRICAL_QUALITY;
}
/**
* Add in all of the data from other into this object, updating the reported quality from the expected
* error rate implied by the two reported qualities.
*
* For example (the only example?), this method is called when collapsing the counts across reported quality scores within
* the same read group.
*
* @param other RecalDatum to combine
*/
void combine(const RecalDatum& other) {
// this is the *expected* (or theoretical) number of errors given the reported qualities and the number of observations.
double expectedNumErrors = this->calcExpectedErrors() + other.calcExpectedErrors();
// increment the counts
increment(other.getNumObservations(), other.getNumMismatches());
// we use the theoretical count above to compute the "estimated" reported quality
// after combining two datums with different reported qualities.
reportedQuality = -10 * log10(expectedNumErrors / getNumObservations());
empiricalQuality = UNINITIALIZED_EMPIRICAL_QUALITY;
}
/**
* calculate the expected number of errors given the estimated Q reported and the number of observations
* in this datum.
*
* @return a positive (potentially fractional) estimate of the number of errors
*/
inline double calcExpectedErrors() const { return numObservations * QualityUtils::qualToErrorProb(reportedQuality); }
inline double getNumMismatches() const { return numMismatches / MULTIPLIER; }
inline uint64_t getNumObservations() const { return numObservations; }
inline double getReportedQuality() const { return reportedQuality; }
/**
* Computes the empirical quality of the datum, using the reported quality as the prior.
* @see #getEmpiricalQuality(double) below.
*/
double getEmpiricalQuality() { return getEmpiricalQuality(getReportedQuality()); }
/**
* Computes the empirical base quality (roughly (num errors)/(num observations)) from the counts stored in this datum.
*/
double getEmpiricalQuality(const double priorQualityScore) {
if (empiricalQuality == UNINITIALIZED_EMPIRICAL_QUALITY) {
calcEmpiricalQuality(priorQualityScore);
}
return empiricalQuality;
}
/**
* Calculate and cache the empirical quality score from mismatches and observations (expensive operation)
*/
void calcEmpiricalQuality(const double priorQualityScore) {
// smoothing is one error and one non-error observation
const uint64_t mismatches = (uint64_t)(getNumMismatches() + 0.5) + SMOOTHING_CONSTANT; // TODO: why add 0.5?
const uint64_t observations = getNumObservations() + SMOOTHING_CONSTANT + SMOOTHING_CONSTANT;
const int empiricalQual = bayesianEstimateOfEmpiricalQuality(observations, mismatches, priorQualityScore);
empiricalQuality = std::min(empiricalQual, (int)MAX_RECALIBRATED_Q_SCORE);
}
/**
* Compute the maximum a posteriori (MAP) estimate of the probability of sequencing error under the following model.
*
* Let
* X = number of sequencing errors,
* n = number of observations,
* theta = probability of sequencing error as a quality score,
* theta_rep = probability of sequencing error reported by the sequencing machine as a quality score.
*
* The prior and the likelihood are:
*
* P(theta|theta_rep) ~ Gaussian(theta - theta_rep| 0, 0.5) (Note this is done in log space)
* P(X|n, theta) ~ Binom(X|n,theta)
*
* Note the prior is equivalent to
*
* P(theta|theta_rep) ~ Gaussian(theta | theta_rep, 0.5)
*
* TODO: use beta prior to do away with the search.
*
* @param nObservations n in the model above.
* @param nErrors the observed number of sequencing errors.
* @param priorMeanQualityScore the prior quality score, often the reported quality score.
*
* @return phredScale quality score that maximizes the posterior probability.
*/
static int bayesianEstimateOfEmpiricalQuality(const uint64_t nObservations, const uint64_t nErrors, const double priorMeanQualityScore) {
const int numQualityScoreBins = (QualityUtils::MAX_REASONABLE_Q_SCORE + 1);
double logPosteriors[numQualityScoreBins];
for (int i = 0; i < numQualityScoreBins; ++i) {
logPosteriors[i] = getLogPrior(i, priorMeanQualityScore) + getLogBinomialLikelihood(i, nObservations, nErrors);
}
return MathUtils::maxElementIndex(logPosteriors, 0, numQualityScoreBins);
}
static double getLogPrior(const double qualityScore, const double priorQualityScore) {
const int difference = std::min(std::abs((int)(qualityScore - priorQualityScore)), (int)MAX_GATK_USABLE_Q_SCORE);
return logPriorCache[difference];
}
/**
* Given:
* - n, the number of observations,
* - k, the number of sequencing errors,
* - p, the probability of error, encoded as the quality score.
*
* Return the binomial probability Bin(k|n,p).
*
* The method handles the case when the counts of type long are higher than the maximum allowed integer value,
* Integer.MAX_VALUE = (2^31)-1 ~= 2*10^9, since the library we use for binomial probability expects integer input.
*
*/
static double getLogBinomialLikelihood(const double qualityScore, uint64_t nObservations, uint64_t nErrors) {
if (nObservations == 0)
return 0.0;
// the binomial code requires ints as input (because it does caching). This should theoretically be fine because
// there is plenty of precision in 2^31 observations, but we need to make sure that we don't have overflow
// before casting down to an int.
if (nObservations > MAX_NUMBER_OF_OBSERVATIONS) {
// we need to decrease nErrors by the same fraction that we are decreasing nObservations
const double fraction = (double)MAX_NUMBER_OF_OBSERVATIONS / (double)nObservations;
nErrors = std::round((double)nErrors * fraction);
nObservations = MAX_NUMBER_OF_OBSERVATIONS;
}
// this is just a straight binomial PDF
const double logLikelihood = MathUtils::logBinomialProbability((int)nObservations, (int)nErrors, QualityUtils::qualToErrorProb(qualityScore));
return (std::isinf(logLikelihood) || std::isnan(logLikelihood)) ? -DBL_MAX : logLikelihood;
}
};

View File

@ -0,0 +1,44 @@
/*
Description: bqsrbqsr table
Copyright : All right reserved by ICT
Author : Zhang Zhonghai
Date : 2025/12/24
*/
#pragma once
#include "covariate.h"
#include "nested_array.h"
#include "recal_datum.h"
struct RecalTables {
int qualDimension = 94; // MAX_SAM_QUAL_SCORE(93) + 1
int eventDimension = EventType::EVENT_SIZE;
int numReadGroups;
// These two tables are special
Array2D<RecalDatum> readGroupTable;
Array3D<RecalDatum> qualityScoreTable;
// additional tables
Array4D<RecalDatum> contextTable;
Array4D<RecalDatum> cycleTable;
// NestedArray<int> testArr;
RecalTables() {}
RecalTables(int _numReadGroups) { init(_numReadGroups); }
void init(int _numReadGroups) {
numReadGroups = _numReadGroups;
// 初始化readgroup和quality两个table
readGroupTable.init(numReadGroups, eventDimension);
qualityScoreTable.init(numReadGroups, qualDimension, eventDimension);
// 初始化context和cycle两个table
contextTable.init(numReadGroups, qualDimension, ContextCovariate::MaximumKeyValue() + 1, eventDimension);
cycleTable.init(numReadGroups, qualDimension, CycleCovariate::MaximumKeyValue() + 1, eventDimension);
}
};

View File

@ -0,0 +1,56 @@
/*
Description: bqsr
Copyright : All right reserved by ICT
Author : Zhang Zhonghai
Date : 2025/12/24
*/
#include "recal_utils.h"
/**
* Increments the RecalDatum at the specified position in the specified table, or put a new item there
* if there isn't already one.
*
* Note: we intentionally do not use varargs here to avoid the performance cost of allocating an array on every call. It showed on the profiler.
*
* @param table the table that holds/will hold our item
* @param qual qual for this event
* @param isError error value for this event
* @param key0, key1, key2 location in table of our item
*/
void RecalUtils::IncrementDatum3keys(Array3D<RecalDatum> table, uint8_t qual, double isError, int key0, int key1, int key2) {
RecalDatum &existingDatum = table.get(key0, key1, key2);
if (existingDatum.numObservations == 0) {
// No existing item, put a new one
table.put(RecalDatum(1, qual, isError), key0, key1, key2);
} else {
// Easy case: already an item here, so increment it
existingDatum.increment(1, isError);
}
}
/**
* Increments the RecalDatum at the specified position in the specified table, or put a new item there
* if there isn't already one.
*
* Note: we intentionally do not use varargs here to avoid the performance cost of allocating an array on every call. It showed on the profiler.
*
* @param table the table that holds/will hold our item
* @param qual qual for this event
* @param isError error value for this event
* @param key0, key1, key2, key3 location in table of our item
*/
void RecalUtils::IncrementDatum4keys(Array4D<RecalDatum> table, uint8_t qual, double isError, int key0, int key1, int key2, int key3) {
RecalDatum& existingDatum = table.get(key0, key1, key2, key3);
if (existingDatum.numObservations == 0) {
// No existing item, put a new one
table.put(RecalDatum(1, qual, isError), key0, key1, key2, key3);
} else {
// Easy case: already an item here, so increment it
existingDatum.increment(1, isError);
}
}

View File

@ -0,0 +1,207 @@
/*
Description: bqsr
Copyright : All right reserved by ICT
Author : Zhang Zhonghai
Date : 2025/12/24
*/
#pragma once
#include <spdlog/spdlog.h>
#include <cstdio>
#include "bqsr_args.h"
#include "nested_array.h"
#include "quant_info.h"
#include "recal_datum.h"
#include "recal_tables.h"
#include "util/report_table.h"
#include "util/utils.h"
#include "covariate.h"
struct RecalUtils {
static constexpr int EMPIRICAL_QUAL_DECIMAL_PLACES = 4;
static constexpr int REPORTED_QUALITY_DECIMAL_PLACES = 4;
static constexpr int NUMBER_ERRORS_DECIMAL_PLACES = 2;
// 根据每个read的key在recalTable中添加对应的数据
static void IncrementDatum3keys(Array3D<RecalDatum> table, uint8_t qual, double isError, int key0, int key1, int key2);
static void IncrementDatum4keys(Array4D<RecalDatum> table, uint8_t qual, double isError, int key0, int key1, int key2, int key3);
// 输出bqsr报告
static void outputRecalibrationReport(const BQSRArg& RAC, const QuantizationInfo& quantInfo, const RecalTables& recalTables) {
spdlog::info("output tables");
// 输出文件
FILE* fpout = fopen(RAC.OUTPUT_FILE.c_str(), "w");
if (!fpout) {
fprintf(stderr, "Failed to open file %s.\n", RAC.OUTPUT_FILE.c_str());
exit(1);
}
// 输出version信息
fprintf(fpout, "%s\n", REPORT_HEADER_VERSION);
// 输出参数信息
outputArgsTable(RAC, fpout);
// 输出量化质量分数信息
outputQuantTable(quantInfo, fpout);
// 输出协变量信息
outputCovariateTable(recalTables, fpout);
fclose(fpout);
}
// 输出运行时使用的参数信息
static void outputArgsTable(const BQSRArg & p, FILE * fpout) {
ReportTable table("Arguments", "Recalibration argument collection values used in this run");
table.addColumn({"Argument", "%s"});
table.addColumn({"Value", ""});
// 添加行数据
table.addRowData({"covariate", "ReadGroupCovariate,QualityScoreCovariate,ContextCovariate,CycleCovariate"});
table.addRowData({"no_standard_covs", ReportUtil::ToString(p.DO_NOT_USE_STANDARD_COVARIATES)});
table.addRowData({"run_without_dbsnp", ReportUtil::ToString(p.RUN_WITHOUT_DBSNP)});
table.addRowData({"solid_recal_mode", ReportUtil::ToString(p.SOLID_RECAL_MODE)});
table.addRowData({"solid_nocall_strategy", ReportUtil::ToString(p.SOLID_NOCALL_STRATEGY)});
table.addRowData({"mismatches_context_size", ReportUtil::ToString(p.MISMATCHES_CONTEXT_SIZE)});
table.addRowData({"indels_context_size", ReportUtil::ToString(p.INDELS_CONTEXT_SIZE)});
table.addRowData({"mismatches_default_quality", ReportUtil::ToString(p.MISMATCHES_DEFAULT_QUALITY)});
table.addRowData({"deletions_default_quality", ReportUtil::ToString(p.DELETIONS_DEFAULT_QUALITY)});
table.addRowData({"insertions_default_quality", ReportUtil::ToString(p.INSERTIONS_DEFAULT_QUALITY)});
table.addRowData({"maximum_cycle_value", ReportUtil::ToString(p.MAXIMUM_CYCLE_VALUE)});
table.addRowData({"low_quality_tail", ReportUtil::ToString(p.LOW_QUAL_TAIL)});
table.addRowData({"default_platform", ReportUtil::ToString(p.DEFAULT_PLATFORM)});
table.addRowData({"force_platform", ReportUtil::ToString(p.FORCE_PLATFORM)});
table.addRowData({"quantizing_levels", ReportUtil::ToString(p.QUANTIZING_LEVELS)});
table.addRowData({"recalibration_report", ReportUtil::ToString(p.existingRecalibrationReport)});
table.addRowData({"binary_tag_name", ReportUtil::ToString(p.BINARY_TAG_NAME)});
// 输出到文件
table.write(fpout);
}
// 输出量化质量分数table
static void outputQuantTable(const QuantizationInfo& q, FILE* fpout) {
ReportTable table("Quantized", "Quality quantization map");
table.addColumn({"QualityScore", "%d"});
table.addColumn({"Count", "%d"});
table.addColumn({"QuantizedScore", "%d"});
for (int qual = 0; qual <= QualityUtils::MAX_SAM_QUAL_SCORE; ++qual) {
table.addRowData({ReportUtil::ToString(qual), ReportUtil::ToString(q.empiricalQualCounts[qual]), ReportUtil::ToString(q.quantizedQuals[qual])});
}
table.write(fpout);
}
// 输出协变量table
static void outputCovariateTable(const RecalTables &r, FILE* fpout) {
// 1. read group covariates
outputReadGroupTable(r, fpout);
// 2. quality score covariates
outputQualityScoreTable(r, fpout);
// 3. context and cycle covariates
outputContextCycleTable(r, fpout);
}
// read group table
static void outputReadGroupTable(const RecalTables& r, FILE* fpout) {
ReportTable table("RecalTable0");
table.addColumn({"ReadGroup", "%s"});
table.addColumn({"EventType", "%s"});
table.addColumn({"EmpiricalQuality", "%.4f"});
table.addColumn({"EstimatedQReported", "%.4f"});
table.addColumn({"Observations", "%d"});
table.addColumn({"Errors", "%.2f"});
spdlog::info("rg0: {}, {}, {}, {}", r.readGroupTable[0][0].numObservations, r.readGroupTable[0][0].numMismatches,
r.readGroupTable[0][0].reportedQuality, r.readGroupTable[0][0].empiricalQuality);
_Foreach2DK(r.readGroupTable, datum, {
RecalDatum &dat = const_cast<RecalDatum&>(datum);
spdlog::info("errors: {}, {}", datum.numMismatches, dat.getNumMismatches());
spdlog::info("obs: {}, {}", datum.numObservations, dat.getNumObservations());
if (dat.getNumObservations() > 0) {
table.addRowData({
ReadGroupCovariate::IdToRg[k1],
ReportUtil::ToString(EventType::EVENTS[k2].representation),
ReportUtil::ToString(dat.getEmpiricalQuality(), 4),
ReportUtil::ToString(dat.getReportedQuality(), 4),
ReportUtil::ToString(dat.getNumObservations()),
ReportUtil::ToString(dat.getNumMismatches(), 2)
});
}
});
table.write(fpout);
}
// quality table
static void outputQualityScoreTable(const RecalTables& r, FILE* fpout) {
ReportTable table("RecalTable1");
table.addColumn({"ReadGroup", "%s"});
table.addColumn({"QualityScore", "%d"});
table.addColumn({"EventType", "%s"});
table.addColumn({"EmpiricalQuality", "%.4f"});
table.addColumn({"Observations", "%d"});
table.addColumn({"Errors", "%.2f"});
_Foreach3DK(r.qualityScoreTable, datum, {
RecalDatum &dat = const_cast<RecalDatum&>(datum);
if (dat.getNumObservations() > 0) {
table.addRowData({
ReadGroupCovariate::IdToRg[k1],
ReportUtil::ToString(k2),
ReportUtil::ToString(EventType::EVENTS[k3].representation),
ReportUtil::ToString(dat.getEmpiricalQuality(), 4),
ReportUtil::ToString(dat.getNumObservations()),
ReportUtil::ToString(dat.getNumMismatches(), 2)
});
}
});
table.write(fpout);
}
static void outputContextCycleTable(const RecalTables& r, FILE* fpout) {
ReportTable table("RecalTable2");
table.addColumn({"ReadGroup", "%s"});
table.addColumn({"QualityScore", "%d"});
table.addColumn({"CovariateValue", "%s"});
table.addColumn({"CovariateName", "%s"});
table.addColumn({"EventType", "%s"});
table.addColumn({"EmpiricalQuality", "%.4f"});
table.addColumn({"Observations", "%d"});
table.addColumn({"Errors", "%.2f"});
_Foreach4DK(r.contextTable, datum, {
RecalDatum &dat = const_cast<RecalDatum&>(datum);
if (dat.getNumObservations() > 0) {
table.addRowData({
ReadGroupCovariate::IdToRg[k1],
ReportUtil::ToString(k2),
ReportUtil::ToString(ContextCovariate::ContextFromKey(k3)),
"Context",
ReportUtil::ToString(EventType::EVENTS[k4].representation),
ReportUtil::ToString(dat.getEmpiricalQuality(), 4),
ReportUtil::ToString(dat.getNumObservations()),
ReportUtil::ToString(dat.getNumMismatches(), 2)
});
}
});
_Foreach4DK(r.cycleTable, datum, {
RecalDatum &dat = const_cast<RecalDatum&>(datum);
if (dat.getNumObservations() > 0) {
table.addRowData({
ReadGroupCovariate::IdToRg[k1],
ReportUtil::ToString(k2),
ReportUtil::ToString(CycleCovariate::CycleFromKey(k3)),
"Cycle",
ReportUtil::ToString(EventType::EVENTS[k4].representation),
ReportUtil::ToString(dat.getEmpiricalQuality(), 4),
ReportUtil::ToString(dat.getNumObservations()),
ReportUtil::ToString(dat.getNumMismatches(), 2)
});
}
});
table.write(fpout);
}
};

View File

@ -42,6 +42,12 @@ struct FastArray {
size_t idx; size_t idx;
void clear() { idx = 0; } void clear() { idx = 0; }
size_t size() { return idx; } size_t size() { return idx; }
bool empty() { return idx == 0; }
void reserve(size_t _size) { arr.reserve(_size); }
void resize(size_t _size) {
arr.resize(_size);
idx = _size;
}
void push_back(const T& val) { void push_back(const T& val) {
if (idx < arr.size()) { if (idx < arr.size()) {
arr[idx++] = val; arr[idx++] = val;
@ -50,6 +56,7 @@ struct FastArray {
idx++; idx++;
} }
} }
inline T& operator[](size_t pos) { return arr[pos]; }
struct iterator { struct iterator {
typename std::vector<T>::iterator it; typename std::vector<T>::iterator it;
iterator(typename std::vector<T>::iterator _it) : it(_it) {} iterator(typename std::vector<T>::iterator _it) : it(_it) {}
@ -65,6 +72,7 @@ struct FastArray {
}; };
// 对原始bam数据的补充比如对两端进行hardclip等 // 对原始bam数据的补充比如对两端进行hardclip等
class BamWrap;
struct SamData { struct SamData {
int read_len = 0; // read长度各种clip之后的长度 int read_len = 0; // read长度各种clip之后的长度
int cigar_start = 0; // cigar起始位置闭区间 int cigar_start = 0; // cigar起始位置闭区间
@ -77,10 +85,15 @@ struct SamData {
// 记录一下bqsr运算过程中用到的数据回头提前计算一下修正现在的复杂逻辑 // 记录一下bqsr运算过程中用到的数据回头提前计算一下修正现在的复杂逻辑
static constexpr int READ_INDEX_NOT_FOUND = -1; static constexpr int READ_INDEX_NOT_FOUND = -1;
string bases; // 处理之后的read的碱基
vector<uint8_t> quals; // 对应的质量分数 BamWrap* bw;
int64_t start_pos; // 因为soft clip都被切掉了这里的softstart应该就是切掉之后的匹配位点闭区间 int64_t start_pos; // 因为soft clip都被切掉了这里的softstart应该就是切掉之后的匹配位点闭区间
int64_t end_pos; // 同上,闭区间 int64_t end_pos; // 同上,闭区间
string bases; // 处理之后的read的碱基
FastArray<uint8_t> base_quals; // 对应的质量分数
FastArray<uint8_t> ins_quals; // insert质量分数, BI (大部分应该都没有)
FastArray<uint8_t> del_quals; // delete质量分数, BD (大部分应该都没有)
FastArray<Cigar> cigars; FastArray<Cigar> cigars;
int64_t& softStart() { return start_pos; } int64_t& softStart() { return start_pos; }
int64_t& softEnd() { return end_pos; } int64_t& softEnd() { return end_pos; }

View File

@ -0,0 +1,39 @@
/*
Description:
Copyright : All right reserved by ICT
Author : Zhang Zhonghai
Date : 2025/12/26
*/
#pragma once
#include <math.h>
#include <float.h>
#include <limits>
#include "saddle_pe.h"
struct BinomialDistribution {
/** The number of trials. */
int numberOfTrials;
/** The probability of success. */
double probabilityOfSuccess;
BinomialDistribution(int trials, double p) {
probabilityOfSuccess = p;
numberOfTrials = trials;
}
double logProbability(int x) {
if (numberOfTrials == 0) {
return (x == 0) ? 0. : -std::numeric_limits<double>::infinity();
}
double ret;
if (x < 0 || x > numberOfTrials) {
ret = -std::numeric_limits<double>::infinity();
} else {
ret = SaddlePointExpansion::logBinomialProbability(x, numberOfTrials, probabilityOfSuccess, 1.0 - probabilityOfSuccess);
}
return ret;
}
};

View File

@ -0,0 +1,162 @@
/*
Description: ContinuedFraction
Copyright : All right reserved by ICT
Author : Zhang Zhonghai
Date : 2025/12/25
*/
#pragma once
#include <math.h>
#include <functional>
#include <iostream>
#include "math_const.h"
#include "precision.h"
//using ContinuedFractionGetFunc = std::function<double(int, double)>();
// using ContinuedFractionGetFunc = double (*)(int n, double x);
struct ContinuedFraction {
/** Maximum allowed numerical error. */
static constexpr double DEFAULT_EPSILON = 10e-9;
// ContinuedFraction(ContinuedFractionGetFunc _getA, ContinuedFractionGetFunc _getB) {
// getA = _getA;
// getB = _getB;
// }
/**
* Access the n-th a coefficient of the continued fraction. Since a can be
* a function of the evaluation point, x, that is passed in as well.
* @param n the coefficient index to retrieve.
* @param x the evaluation point.
* @return the n-th a coefficient.
*/
//ContinuedFractionGetFunc getA;
virtual double getA(int n, double x) = 0;
/**
* Access the n-th b coefficient of the continued fraction. Since b can be
* a function of the evaluation point, x, that is passed in as well.
* @param n the coefficient index to retrieve.
* @param x the evaluation point.
* @return the n-th b coefficient.
*/
//ContinuedFractionGetFunc getB;
virtual double getB(int n, double x) = 0;
/**
* Evaluates the continued fraction at the value x.
* @param x the evaluation point.
* @return the value of the continued fraction evaluated at x.
* @throws ConvergenceException if the algorithm fails to converge.
*/
double evaluate(double x) { return evaluate(x, DEFAULT_EPSILON, MathConst::INT_MAX_VALUE); }
/**
* Evaluates the continued fraction at the value x.
* @param x the evaluation point.
* @param epsilon maximum error allowed.
* @return the value of the continued fraction evaluated at x.
* @throws ConvergenceException if the algorithm fails to converge.
*/
double evaluate(double x, double epsilon) { return evaluate(x, epsilon, MathConst::INT_MAX_VALUE); }
/**
* Evaluates the continued fraction at the value x.
* @param x the evaluation point.
* @param maxIterations maximum number of convergents
* @return the value of the continued fraction evaluated at x.
* @throws ConvergenceException if the algorithm fails to converge.
* @throws MaxCountExceededException if maximal number of iterations is reached
*/
double evaluate(double x, int maxIterations) {
return evaluate(x, DEFAULT_EPSILON, maxIterations);
}
/**
* Evaluates the continued fraction at the value x.
* <p>
* The implementation of this method is based on the modified Lentz algorithm as described
* on page 18 ff. in:
* <ul>
* <li>
* I. J. Thompson, A. R. Barnett. "Coulomb and Bessel Functions of Complex Arguments and Order."
* <a target="_blank" href="http://www.fresco.org.uk/papers/Thompson-JCP64p490.pdf">
* http://www.fresco.org.uk/papers/Thompson-JCP64p490.pdf</a>
* </li>
* </ul>
* <b>Note:</b> the implementation uses the terms a<sub>i</sub> and b<sub>i</sub> as defined in
* <a href="http://mathworld.wolfram.com/ContinuedFraction.html">Continued Fraction @ MathWorld</a>.
* </p>
*
* @param x the evaluation point.
* @param epsilon maximum error allowed.
* @param maxIterations maximum number of convergents
* @return the value of the continued fraction evaluated at x.
* @throws ConvergenceException if the algorithm fails to converge.
* @throws MaxCountExceededException if maximal number of iterations is reached
*/
double evaluate(double x, double epsilon, int maxIterations) {
const double small = 1e-50;
double hPrev = getA(0, x);
// use the value of small as epsilon criteria for zero checks
if (Precision::equals(hPrev, 0.0, small)) {
hPrev = small;
}
int n = 1;
double dPrev = 0.0;
double cPrev = hPrev;
double hN = hPrev;
while (n < maxIterations) {
const double a = getA(n, x);
const double b = getB(n, x);
double dN = a + b * dPrev;
if (Precision::equals(dN, 0.0, small)) {
dN = small;
}
double cN = a + b / cPrev;
if (Precision::equals(cN, 0.0, small)) {
cN = small;
}
dN = 1 / dN;
const double deltaN = cN * dN;
hN = hPrev * deltaN;
if (std::isinf(hN)) {
std::cerr << "ConvergenceException INFINITY" << std::endl;
exit(1);
// throw new ConvergenceException(LocalizedFormats.CONTINUED_FRACTION_INFINITY_DIVERGENCE, x);
}
if (std::isnan(hN)) {
std::cerr << "ConvergenceException NAN" << std::endl;
exit(1);
// throw new ConvergenceException(LocalizedFormats.CONTINUED_FRACTION_NAN_DIVERGENCE, x);
}
if (std::abs(deltaN - 1.0) < epsilon) {
break;
}
dPrev = dN;
cPrev = cN;
hPrev = hN;
n++;
}
if (n >= maxIterations) {
std::cerr << "MaxCountExceededException max iterations" << std::endl;
exit(1);
// throw new MaxCountExceededException(LocalizedFormats.NON_CONVERGENT_CONTINUED_FRACTION, maxIterations, x);
}
return hN;
}
};

View File

@ -0,0 +1,11 @@
#include "gamma.h"
const double Gamma::LANCZOS[] = {
0.99999999999999709182, 57.156235665862923517, -59.597960355475491248, 14.136097974741747174, -0.49191381609762019978,
.33994649984811888699e-4, .46523628927048575665e-4, -.98374475304879564677e-4, .15808870322491248884e-3, -.21026444172410488319e-3,
.21743961811521264320e-3, -.16431810653676389022e-3, .84418223983852743293e-4, -.26190838401581408670e-4, .36899182659531622704e-5,
};
const int Gamma::LANCZOS_LEN = sizeof(LANCZOS) / sizeof(LANCZOS[0]);
const double Gamma::HALF_LOG_2_PI = 0.5 * std::log(TWO_PI);

View File

@ -0,0 +1,644 @@
/*
Description: Gamma
Copyright : All right reserved by ICT
Author : Zhang Zhonghai
Date : 2025/12/25
*/
#pragma once
#include <assert.h>
#include <math.h>
#include <limits>
#include <stdint.h>
#include <limits.h>
#include <iostream>
#include "math_const.h"
#include "continued_frac.h"
struct Gamma {
/**
* <a href="http://en.wikipedia.org/wiki/Euler-Mascheroni_constant">Euler-Mascheroni constant</a>
* @since 2.0
*/
static constexpr double GAMMA = 0.577215664901532860606512090082;
static constexpr double TWO_PI = 2 * (105414357.0 / 33554432.0 + 1.984187159361080883e-9);
/**
* The value of the {@code g} constant in the Lanczos approximation, see
* {@link #lanczos(double)}.
* @since 3.1
*/
static constexpr double LANCZOS_G = 607.0 / 128.0;
/** Maximum allowed numerical error. */
static constexpr double DEFAULT_EPSILON = 10e-15;
/** Lanczos coefficients */
static const double LANCZOS[];
static const int LANCZOS_LEN;
/** Avoid repeated computation of log of 2 PI in logGamma */
static const double HALF_LOG_2_PI;
/** The constant value of &radic;(2&pi;). */
static constexpr double SQRT_TWO_PI = 2.506628274631000502;
// limits for switching algorithm in digamma
/** C limit. */
static constexpr double C_LIMIT = 49;
/** S limit. */
static constexpr double S_LIMIT = 1e-5;
/*
* Constants for the computation of double invGamma1pm1(double).
* Copied from DGAM1 in the NSWC library.
*/
/** The constant {@code A0} defined in {@code DGAM1}. */
static constexpr double INV_GAMMA1P_M1_A0 = .611609510448141581788E-08;
/** The constant {@code A1} defined in {@code DGAM1}. */
static constexpr double INV_GAMMA1P_M1_A1 = .624730830116465516210E-08;
/** The constant {@code B1} defined in {@code DGAM1}. */
static constexpr double INV_GAMMA1P_M1_B1 = .203610414066806987300E+00;
/** The constant {@code B2} defined in {@code DGAM1}. */
static constexpr double INV_GAMMA1P_M1_B2 = .266205348428949217746E-01;
/** The constant {@code B3} defined in {@code DGAM1}. */
static constexpr double INV_GAMMA1P_M1_B3 = .493944979382446875238E-03;
/** The constant {@code B4} defined in {@code DGAM1}. */
static constexpr double INV_GAMMA1P_M1_B4 = -.851419432440314906588E-05;
/** The constant {@code B5} defined in {@code DGAM1}. */
static constexpr double INV_GAMMA1P_M1_B5 = -.643045481779353022248E-05;
/** The constant {@code B6} defined in {@code DGAM1}. */
static constexpr double INV_GAMMA1P_M1_B6 = .992641840672773722196E-06;
/** The constant {@code B7} defined in {@code DGAM1}. */
static constexpr double INV_GAMMA1P_M1_B7 = -.607761895722825260739E-07;
/** The constant {@code B8} defined in {@code DGAM1}. */
static constexpr double INV_GAMMA1P_M1_B8 = .195755836614639731882E-09;
/** The constant {@code P0} defined in {@code DGAM1}. */
static constexpr double INV_GAMMA1P_M1_P0 = .6116095104481415817861E-08;
/** The constant {@code P1} defined in {@code DGAM1}. */
static constexpr double INV_GAMMA1P_M1_P1 = .6871674113067198736152E-08;
/** The constant {@code P2} defined in {@code DGAM1}. */
static constexpr double INV_GAMMA1P_M1_P2 = .6820161668496170657918E-09;
/** The constant {@code P3} defined in {@code DGAM1}. */
static constexpr double INV_GAMMA1P_M1_P3 = .4686843322948848031080E-10;
/** The constant {@code P4} defined in {@code DGAM1}. */
static constexpr double INV_GAMMA1P_M1_P4 = .1572833027710446286995E-11;
/** The constant {@code P5} defined in {@code DGAM1}. */
static constexpr double INV_GAMMA1P_M1_P5 = -.1249441572276366213222E-12;
/** The constant {@code P6} defined in {@code DGAM1}. */
static constexpr double INV_GAMMA1P_M1_P6 = .4343529937408594255178E-14;
/** The constant {@code Q1} defined in {@code DGAM1}. */
static constexpr double INV_GAMMA1P_M1_Q1 = .3056961078365221025009E+00;
/** The constant {@code Q2} defined in {@code DGAM1}. */
static constexpr double INV_GAMMA1P_M1_Q2 = .5464213086042296536016E-01;
/** The constant {@code Q3} defined in {@code DGAM1}. */
static constexpr double INV_GAMMA1P_M1_Q3 = .4956830093825887312020E-02;
/** The constant {@code Q4} defined in {@code DGAM1}. */
static constexpr double INV_GAMMA1P_M1_Q4 = .2692369466186361192876E-03;
/** The constant {@code C} defined in {@code DGAM1}. */
static constexpr double INV_GAMMA1P_M1_C = -.422784335098467139393487909917598E+00;
/** The constant {@code C0} defined in {@code DGAM1}. */
static constexpr double INV_GAMMA1P_M1_C0 = .577215664901532860606512090082402E+00;
/** The constant {@code C1} defined in {@code DGAM1}. */
static constexpr double INV_GAMMA1P_M1_C1 = -.655878071520253881077019515145390E+00;
/** The constant {@code C2} defined in {@code DGAM1}. */
static constexpr double INV_GAMMA1P_M1_C2 = -.420026350340952355290039348754298E-01;
/** The constant {@code C3} defined in {@code DGAM1}. */
static constexpr double INV_GAMMA1P_M1_C3 = .166538611382291489501700795102105E+00;
/** The constant {@code C4} defined in {@code DGAM1}. */
static constexpr double INV_GAMMA1P_M1_C4 = -.421977345555443367482083012891874E-01;
/** The constant {@code C5} defined in {@code DGAM1}. */
static constexpr double INV_GAMMA1P_M1_C5 = -.962197152787697356211492167234820E-02;
/** The constant {@code C6} defined in {@code DGAM1}. */
static constexpr double INV_GAMMA1P_M1_C6 = .721894324666309954239501034044657E-02;
/** The constant {@code C7} defined in {@code DGAM1}. */
static constexpr double INV_GAMMA1P_M1_C7 = -.116516759185906511211397108401839E-02;
/** The constant {@code C8} defined in {@code DGAM1}. */
static constexpr double INV_GAMMA1P_M1_C8 = -.215241674114950972815729963053648E-03;
/** The constant {@code C9} defined in {@code DGAM1}. */
static constexpr double INV_GAMMA1P_M1_C9 = .128050282388116186153198626328164E-03;
/** The constant {@code C10} defined in {@code DGAM1}. */
static constexpr double INV_GAMMA1P_M1_C10 = -.201348547807882386556893914210218E-04;
/** The constant {@code C11} defined in {@code DGAM1}. */
static constexpr double INV_GAMMA1P_M1_C11 = -.125049348214267065734535947383309E-05;
/** The constant {@code C12} defined in {@code DGAM1}. */
static constexpr double INV_GAMMA1P_M1_C12 = .113302723198169588237412962033074E-05;
/** The constant {@code C13} defined in {@code DGAM1}. */
static constexpr double INV_GAMMA1P_M1_C13 = -.205633841697760710345015413002057E-06;
/**
* Default constructor. Prohibit instantiation.
*/
private:
Gamma() {}
public:
/**
* <p>
* Returns the value of log&nbsp;&Gamma;(x) for x&nbsp;&gt;&nbsp;0.
* </p>
* <p>
* For x &le; 8, the implementation is based on the double precision
* implementation in the <em>NSWC Library of Mathematics Subroutines</em>,
* {@code DGAMLN}. For x &gt; 8, the implementation is based on
* </p>
* <ul>
* <li><a href="http://mathworld.wolfram.com/GammaFunction.html">Gamma
* Function</a>, equation (28).</li>
* <li><a href="http://mathworld.wolfram.com/LanczosApproximation.html">
* Lanczos Approximation</a>, equations (1) through (5).</li>
* <li><a href="http://my.fit.edu/~gabdo/gamma.txt">Paul Godfrey, A note on
* the computation of the convergent Lanczos complex Gamma
* approximation</a></li>
* </ul>
*
* @param x Argument.
* @return the value of {@code log(Gamma(x))}, {@code Double.NaN} if
* {@code x <= 0.0}.
*/
static double logGamma(double x) {
double ret;
if (std::isnan(x) || (x <= 0.0)) {
ret = MathConst::DOUBLE_NAN;
} else if (x < 0.5) {
return logGamma1p(x) - std::log(x);
} else if (x <= 2.5) {
return logGamma1p((x - 0.5) - 0.5);
} else if (x <= 8.0) {
const int n = (int)std::floor(x - 1.5);
double prod = 1.0;
for (int i = 1; i <= n; i++) {
prod *= x - i;
}
return logGamma1p(x - (n + 1)) + std::log(prod);
} else {
double sum = lanczos(x);
double tmp = x + LANCZOS_G + .5;
ret = ((x + .5) * std::log(tmp)) - tmp + HALF_LOG_2_PI + std::log(sum / x);
}
return ret;
}
/**
* Returns the regularized gamma function P(a, x).
*
* @param a Parameter.
* @param x Value.
* @return the regularized gamma function P(a, x).
* @throws MaxCountExceededException if the algorithm fails to converge.
*/
static double regularizedGammaP(double a, double x) { return regularizedGammaP(a, x, DEFAULT_EPSILON, INT_MAX); }
/**
* Returns the regularized gamma function P(a, x).
*
* The implementation of this method is based on:
* <ul>
* <li>
* <a href="http://mathworld.wolfram.com/RegularizedGammaFunction.html">
* Regularized Gamma Function</a>, equation (1)
* </li>
* <li>
* <a href="http://mathworld.wolfram.com/IncompleteGammaFunction.html">
* Incomplete Gamma Function</a>, equation (4).
* </li>
* <li>
* <a href="http://mathworld.wolfram.com/ConfluentHypergeometricFunctionoftheFirstKind.html">
* Confluent Hypergeometric Function of the First Kind</a>, equation (1).
* </li>
* </ul>
*
* @param a the a parameter.
* @param x the value.
* @param epsilon When the absolute value of the nth item in the
* series is less than epsilon the approximation ceases to calculate
* further elements in the series.
* @param maxIterations Maximum number of "iterations" to complete.
* @return the regularized gamma function P(a, x)
* @throws MaxCountExceededException if the algorithm fails to converge.
*/
static double regularizedGammaP(double a, double x, double epsilon, int maxIterations) {
double ret;
if (std::isnan(a) || std::isnan(x) || (a <= 0.0) || (x < 0.0)) {
ret = std::numeric_limits<double>::quiet_NaN();
} else if (x == 0.0) {
ret = 0.0;
} else if (x >= a + 1) {
// use regularizedGammaQ because it should converge faster in this
// case.
ret = 1.0 - regularizedGammaQ(a, x, epsilon, maxIterations);
} else {
// calculate series
double n = 0.0; // current element index
double an = 1.0 / a; // n-th element in the series
double sum = an; // partial sum
while (std::abs(an / sum) > epsilon && n < maxIterations && sum < std::numeric_limits<double>::infinity()) {
// compute next element in the series
n += 1.0;
an *= x / (a + n);
// update partial sum
sum += an;
}
if (n >= maxIterations) {
std::cerr << "Reach max iterations!" << std::endl;
exit(1);
} else if (std::isinf(sum)) {
ret = 1.0;
} else {
ret = std::exp(-x + (a * std::log(x)) - logGamma(a)) * sum;
}
}
return ret;
}
/**
* Returns the regularized gamma function Q(a, x) = 1 - P(a, x).
*
* @param a the a parameter.
* @param x the value.
* @return the regularized gamma function Q(a, x)
* @throws MaxCountExceededException if the algorithm fails to converge.
*/
static double regularizedGammaQ(double a, double x) { return regularizedGammaQ(a, x, DEFAULT_EPSILON, INT_MAX); }
struct ContinuedFractionInGamma : ContinuedFraction {
double a;
ContinuedFractionInGamma(double _a) : a(_a) {}
double getA(int n, double x) { return ((2.0 * n) + 1.0) - a + x; };
double getB(int n, double x) { return n * (a - n); };
};
/**
* Returns the regularized gamma function Q(a, x) = 1 - P(a, x).
*
* The implementation of this method is based on:
* <ul>
* <li>
* <a href="http://mathworld.wolfram.com/RegularizedGammaFunction.html">
* Regularized Gamma Function</a>, equation (1).
* </li>
* <li>
* <a href="http://functions.wolfram.com/GammaBetaErf/GammaRegularized/10/0003/">
* Regularized incomplete gamma function: Continued fraction representations
* (formula 06.08.10.0003)</a>
* </li>
* </ul>
*
* @param a the a parameter.
* @param x the value.
* @param epsilon When the absolute value of the nth item in the
* series is less than epsilon the approximation ceases to calculate
* further elements in the series.
* @param maxIterations Maximum number of "iterations" to complete.
* @return the regularized gamma function P(a, x)
* @throws MaxCountExceededException if the algorithm fails to converge.
*/
static double regularizedGammaQ(const double a, double x, double epsilon, int maxIterations) {
double ret;
if (std::isnan(a) || std::isnan(x) || (a <= 0.0) || (x < 0.0)) {
ret = MathConst::DOUBLE_NAN;
} else if (x == 0.0) {
ret = 1.0;
} else if (x < a + 1.0) {
// use regularizedGammaP because it should converge faster in this
// case.
ret = 1.0 - regularizedGammaP(a, x, epsilon, maxIterations);
} else {
// create continued fraction
ContinuedFractionInGamma cf(a);
ret = 1.0 / cf.evaluate(x, epsilon, maxIterations);
ret = std::exp(-x + (a * std::log(x)) - logGamma(a)) * ret;
}
return ret;
}
/**
* <p>Computes the digamma function of x.</p>
*
* <p>This is an independently written implementation of the algorithm described in
* Jose Bernardo, Algorithm AS 103: Psi (Digamma) Function, Applied Statistics, 1976.</p>
*
* <p>Some of the constants have been changed to increase accuracy at the moderate expense
* of run-time. The result should be accurate to within 10^-8 absolute tolerance for
* x >= 10^-5 and within 10^-8 relative tolerance for x > 0.</p>
*
* <p>Performance for large negative values of x will be quite expensive (proportional to
* |x|). Accuracy for negative values of x should be about 10^-8 absolute for results
* less than 10^5 and 10^-8 relative for results larger than that.</p>
*
* @param x Argument.
* @return digamma(x) to within 10-8 relative or absolute error whichever is smaller.
* @see <a href="http://en.wikipedia.org/wiki/Digamma_function">Digamma</a>
* @see <a href="http://www.uv.es/~bernardo/1976AppStatist.pdf">Bernardo&apos;s original article </a>
* @since 2.0
*/
static double digamma(double x) {
if (x > 0 && x <= S_LIMIT) {
// use method 5 from Bernardo AS103
// accurate to O(x)
return -GAMMA - 1 / x;
}
if (x >= C_LIMIT) {
// use method 4 (accurate to O(1/x^8)
double inv = 1 / (x * x);
// 1 1 1 1
// log(x) - --- - ------ + ------- - -------
// 2 x 12 x^2 120 x^4 252 x^6
return std::log(x) - 0.5 / x - inv * ((1.0 / 12) + inv * (1.0 / 120 - inv / 252));
}
return digamma(x + 1) - 1 / x;
}
/**
* Computes the trigamma function of x.
* This function is derived by taking the derivative of the implementation
* of digamma.
*
* @param x Argument.
* @return trigamma(x) to within 10-8 relative or absolute error whichever is smaller
* @see <a href="http://en.wikipedia.org/wiki/Trigamma_function">Trigamma</a>
* @see Gamma#digamma(double)
* @since 2.0
*/
static double trigamma(double x) {
if (x > 0 && x <= S_LIMIT) {
return 1 / (x * x);
}
if (x >= C_LIMIT) {
double inv = 1 / (x * x);
// 1 1 1 1 1
// - + ---- + ---- - ----- + -----
// x 2 3 5 7
// 2 x 6 x 30 x 42 x
return 1 / x + inv / 2 + inv / x * (1.0 / 6 - inv * (1.0 / 30 + inv / 42));
}
return trigamma(x + 1) + 1 / (x * x);
}
/**
* <p>
* Returns the Lanczos approximation used to compute the gamma function.
* The Lanczos approximation is related to the Gamma function by the
* following equation
* <center>
* {@code gamma(x) = sqrt(2 * pi) / x * (x + g + 0.5) ^ (x + 0.5)
* * exp(-x - g - 0.5) * lanczos(x)},
* </center>
* where {@code g} is the Lanczos constant.
* </p>
*
* @param x Argument.
* @return The Lanczos approximation.
* @see <a href="http://mathworld.wolfram.com/LanczosApproximation.html">Lanczos Approximation</a>
* equations (1) through (5), and Paul Godfrey's
* <a href="http://my.fit.edu/~gabdo/gamma.txt">Note on the computation
* of the convergent Lanczos complex Gamma approximation</a>
* @since 3.1
*/
static double lanczos(const double x) {
double sum = 0.0;
for (int i = LANCZOS_LEN - 1; i > 0; --i) {
sum += LANCZOS[i] / (x + i);
}
return sum + LANCZOS[0];
}
/**
* Returns the value of 1 / &Gamma;(1 + x) - 1 for -0&#46;5 &le; x &le;
* 1&#46;5. This implementation is based on the double precision
* implementation in the <em>NSWC Library of Mathematics Subroutines</em>,
* {@code DGAM1}.
*
* @param x Argument.
* @return The value of {@code 1.0 / Gamma(1.0 + x) - 1.0}.
* @throws NumberIsTooSmallException if {@code x < -0.5}
* @throws NumberIsTooLargeException if {@code x > 1.5}
* @since 3.1
*/
static double invGamma1pm1(const double x) {
// if (x < -0.5) {
// throw new NumberIsTooSmallException(x, -0.5, true);
// }
// if (x > 1.5) {
// throw new NumberIsTooLargeException(x, 1.5, true);
// }
assert(x >= -0.5 && x <= 1.5);
double ret;
double t = x <= 0.5 ? x : (x - 0.5) - 0.5;
if (t < 0.0) {
double a = INV_GAMMA1P_M1_A0 + t * INV_GAMMA1P_M1_A1;
double b = INV_GAMMA1P_M1_B8;
b = INV_GAMMA1P_M1_B7 + t * b;
b = INV_GAMMA1P_M1_B6 + t * b;
b = INV_GAMMA1P_M1_B5 + t * b;
b = INV_GAMMA1P_M1_B4 + t * b;
b = INV_GAMMA1P_M1_B3 + t * b;
b = INV_GAMMA1P_M1_B2 + t * b;
b = INV_GAMMA1P_M1_B1 + t * b;
b = 1.0 + t * b;
double c = INV_GAMMA1P_M1_C13 + t * (a / b);
c = INV_GAMMA1P_M1_C12 + t * c;
c = INV_GAMMA1P_M1_C11 + t * c;
c = INV_GAMMA1P_M1_C10 + t * c;
c = INV_GAMMA1P_M1_C9 + t * c;
c = INV_GAMMA1P_M1_C8 + t * c;
c = INV_GAMMA1P_M1_C7 + t * c;
c = INV_GAMMA1P_M1_C6 + t * c;
c = INV_GAMMA1P_M1_C5 + t * c;
c = INV_GAMMA1P_M1_C4 + t * c;
c = INV_GAMMA1P_M1_C3 + t * c;
c = INV_GAMMA1P_M1_C2 + t * c;
c = INV_GAMMA1P_M1_C1 + t * c;
c = INV_GAMMA1P_M1_C + t * c;
if (x > 0.5) {
ret = t * c / x;
} else {
ret = x * ((c + 0.5) + 0.5);
}
} else {
double p = INV_GAMMA1P_M1_P6;
p = INV_GAMMA1P_M1_P5 + t * p;
p = INV_GAMMA1P_M1_P4 + t * p;
p = INV_GAMMA1P_M1_P3 + t * p;
p = INV_GAMMA1P_M1_P2 + t * p;
p = INV_GAMMA1P_M1_P1 + t * p;
p = INV_GAMMA1P_M1_P0 + t * p;
double q = INV_GAMMA1P_M1_Q4;
q = INV_GAMMA1P_M1_Q3 + t * q;
q = INV_GAMMA1P_M1_Q2 + t * q;
q = INV_GAMMA1P_M1_Q1 + t * q;
q = 1.0 + t * q;
double c = INV_GAMMA1P_M1_C13 + (p / q) * t;
c = INV_GAMMA1P_M1_C12 + t * c;
c = INV_GAMMA1P_M1_C11 + t * c;
c = INV_GAMMA1P_M1_C10 + t * c;
c = INV_GAMMA1P_M1_C9 + t * c;
c = INV_GAMMA1P_M1_C8 + t * c;
c = INV_GAMMA1P_M1_C7 + t * c;
c = INV_GAMMA1P_M1_C6 + t * c;
c = INV_GAMMA1P_M1_C5 + t * c;
c = INV_GAMMA1P_M1_C4 + t * c;
c = INV_GAMMA1P_M1_C3 + t * c;
c = INV_GAMMA1P_M1_C2 + t * c;
c = INV_GAMMA1P_M1_C1 + t * c;
c = INV_GAMMA1P_M1_C0 + t * c;
if (x > 0.5) {
ret = (t / x) * ((c - 0.5) - 0.5);
} else {
ret = x * c;
}
}
return ret;
}
/**
* Returns the value of log &Gamma;(1 + x) for -0&#46;5 &le; x &le; 1&#46;5.
* This implementation is based on the double precision implementation in
* the <em>NSWC Library of Mathematics Subroutines</em>, {@code DGMLN1}.
*
* @param x Argument.
* @return The value of {@code log(Gamma(1 + x))}.
* @throws NumberIsTooSmallException if {@code x < -0.5}.
* @throws NumberIsTooLargeException if {@code x > 1.5}.
* @since 3.1
*/
static double logGamma1p(const double x) {
// if (x < -0.5) {
// throw new NumberIsTooSmallException(x, -0.5, true);
// }
// if (x > 1.5) {
// throw new NumberIsTooLargeException(x, 1.5, true);
// }
assert(x >= -0.5 && x <= 1.5);
return -std::log1p(invGamma1pm1(x));
}
/**
* Returns the value of Γ(x). Based on the <em>NSWC Library of
* Mathematics Subroutines</em> double precision implementation,
* {@code DGAMMA}.
*
* @param x Argument.
* @return the value of {@code Gamma(x)}.
* @since 3.1
*/
static double gamma(const double x) {
if ((x == std::rint(x)) && (x <= 0.0)) {
return std::numeric_limits<double>::quiet_NaN();
}
double ret;
const double absX = std::abs(x);
if (absX <= 20.0) {
if (x >= 1.0) {
/*
* From the recurrence relation
* Gamma(x) = (x - 1) * ... * (x - n) * Gamma(x - n),
* then
* Gamma(t) = 1 / [1 + invGamma1pm1(t - 1)],
* where t = x - n. This means that t must satisfy
* -0.5 <= t - 1 <= 1.5.
*/
double prod = 1.0;
double t = x;
while (t > 2.5) {
t -= 1.0;
prod *= t;
}
ret = prod / (1.0 + invGamma1pm1(t - 1.0));
} else {
/*
* From the recurrence relation
* Gamma(x) = Gamma(x + n + 1) / [x * (x + 1) * ... * (x + n)]
* then
* Gamma(x + n + 1) = 1 / [1 + invGamma1pm1(x + n)],
* which requires -0.5 <= x + n <= 1.5.
*/
double prod = x;
double t = x;
while (t < -0.5) {
t += 1.0;
prod *= t;
}
ret = 1.0 / (prod * (1.0 + invGamma1pm1(t)));
}
} else {
const double y = absX + LANCZOS_G + 0.5;
const double gammaAbs = SQRT_TWO_PI / x * std::pow(y, absX + 0.5) * std::exp(-y) * lanczos(absX);
if (x > 0.0) {
ret = gammaAbs;
} else {
/*
* From the reflection formula
* Gamma(x) * Gamma(1 - x) * sin(pi * x) = pi,
* and the recurrence relation
* Gamma(1 - x) = -x * Gamma(-x),
* it is found
* Gamma(x) = -pi / [x * sin(pi * x) * Gamma(-x)].
*/
ret = -MathConst::PI / (x * std::sin(MathConst::PI * x) * gammaAbs);
}
}
return ret;
}
};

View File

@ -0,0 +1,26 @@
/*
Description: pie
Copyright : All right reserved by ICT
Author : Zhang Zhonghai
Date : 2025/12/26
*/
#pragma once
#include <limits>
#include <limits.h>
struct MathConst {
static constexpr double PI = 105414357.0 / 33554432.0 + 1.984187159361080883e-9;
static constexpr double TWO_PI = 2 * (105414357.0 / 33554432.0 + 1.984187159361080883e-9);
static constexpr double DOUBLE_NAN = std::numeric_limits<double>::quiet_NaN();
static constexpr double DOUBLE_NEGATIVE_INFINITY = -std::numeric_limits<double>::infinity();
static constexpr double DOUBLE_POSITIVE_INFINITY = std::numeric_limits<double>::infinity();
static constexpr int INT_MAX_VALUE = INT_MAX;
};

View File

@ -0,0 +1,41 @@
/*
Description:
Copyright : All right reserved by ICT
Author : Zhang Zhonghai
Date : 2025/12/26
*/
#pragma once
#include <limits>
#include <cstdint>
#include <cstdlib>
#include <cstring>
struct MathFunc {
// 1.1 使用memcpy最安全无未定义行为
static int64_t doubleToRawLongBits(double value) {
int64_t result;
std::memcpy(&result, &value, sizeof(result));
return result;
}
static double longBitsToDouble(int64_t bits) {
double result;
std::memcpy(&result, &bits, sizeof(result));
return result;
}
static int floatToRawIntBits(float value) {
int result;
std::memcpy(&result, &value, sizeof(result));
return result;
}
static float intBitsToFloat(int bits) {
float result;
std::memcpy(&result, &bits, sizeof(result));
return result;
}
};

View File

@ -0,0 +1,195 @@
#pragma once
#include <bit>
#include <cfloat>
#include <cmath>
#include <cstdint>
#include <cstring>
#include <iomanip>
#include <iostream>
#include <limits>
#include <type_traits>
// 2.1 基础模板实现
template <typename T>
class MathUlp {
static_assert(std::is_floating_point_v<T>, "MathUlp only works with floating-point types");
private:
using UIntType = std::conditional_t<std::is_same_v<T, float>, uint32_t, std::conditional_t<std::is_same_v<T, double>, uint64_t, void> >;
static constexpr int BITS = sizeof(T) * 8;
static constexpr int MANTISSA_BITS = std::numeric_limits<T>::digits - 1;
static constexpr int EXPONENT_BITS = BITS - MANTISSA_BITS - 1;
static constexpr UIntType SIGN_BIT = UIntType(1) << (BITS - 1);
static constexpr UIntType EXPONENT_MASK = ((UIntType(1) << EXPONENT_BITS) - 1) << MANTISSA_BITS;
static constexpr UIntType MANTISSA_MASK = (UIntType(1) << MANTISSA_BITS) - 1;
static constexpr int EXPONENT_BIAS = (1 << (EXPONENT_BITS - 1)) - 1;
static constexpr int MAX_EXPONENT = (1 << EXPONENT_BITS) - 1;
// 按位访问浮点数
static UIntType get_bits(T value) {
UIntType bits;
std::memcpy(&bits, &value, sizeof(T));
return bits;
}
static T from_bits(UIntType bits) {
T value;
std::memcpy(&value, &bits, sizeof(T));
return value;
}
// 获取指数部分
static int get_exponent(T value) {
UIntType bits = get_bits(value);
UIntType exp_bits = (bits >> MANTISSA_BITS) & ((UIntType(1) << EXPONENT_BITS) - 1);
return static_cast<int>(exp_bits) - EXPONENT_BIAS;
}
// 获取尾数部分
static UIntType get_mantissa(T value) {
UIntType bits = get_bits(value);
return bits & MANTISSA_MASK;
}
// 判断是否为规格化数
static bool is_normal(T value) {
if (value == 0 || std::isinf(value) || std::isnan(value))
return false;
int exp = get_exponent(value);
return exp > -EXPONENT_BIAS && exp <= EXPONENT_BIAS;
}
// 判断是否为非规格化数(次正规数)
static bool is_subnormal(T value) {
if (value == 0 || std::isinf(value) || std::isnan(value))
return false;
int exp = get_exponent(value);
return exp == -EXPONENT_BIAS && get_mantissa(value) != 0;
}
public:
// 主函数计算ulp
static T ulp(T x) {
// 处理特殊值
if (std::isnan(x)) {
return std::numeric_limits<T>::quiet_NaN();
}
if (std::isinf(x)) {
return std::numeric_limits<T>::infinity();
}
// 处理0
if (x == 0.0 || x == -0.0) {
return std::numeric_limits<T>::min(); // 最小正规格化数
}
// 获取x的绝对值
T abs_x = std::fabs(x);
// 对于非规格化数ulp是固定的
if (is_subnormal(abs_x)) {
// 非规格化数的ulp是最小正非规格化数
return std::numeric_limits<T>::denorm_min();
}
// 对于规格化数
int exp = get_exponent(abs_x);
// ulp = 2^(exp - mantissa_bits)
// 对于float: ulp = 2^(exp - 23)
// 对于double: ulp = 2^(exp - 52)
if constexpr (std::is_same_v<T, float>) {
// 使用ldexp计算2的幂次
return std::ldexp(1.0f, exp - 23);
} else if constexpr (std::is_same_v<T, double>) {
return std::ldexp(1.0, exp - 52);
} else if constexpr (std::is_same_v<T, long double>) {
// long double的实现可能因平台而异
constexpr int LD_MANTISSA_BITS = std::numeric_limits<long double>::digits - 1;
return std::ldexp(1.0L, exp - LD_MANTISSA_BITS);
}
}
// 使用nextafter的实现更简单但可能有性能开销
static T ulp_simple(T x) {
// 处理特殊值
if (std::isnan(x)) {
return std::numeric_limits<T>::quiet_NaN();
}
if (std::isinf(x)) {
return std::numeric_limits<T>::infinity();
}
// 处理0
if (x == 0.0 || x == -0.0) {
return std::numeric_limits<T>::min();
}
// 使用nextafter计算ulp
T next = std::nextafter(std::fabs(x), std::numeric_limits<T>::infinity());
T prev = std::nextafter(std::fabs(x), -std::numeric_limits<T>::infinity());
// 取最小的正差值
T diff1 = next - std::fabs(x);
T diff2 = std::fabs(x) - prev;
// 返回较小的正差值
if (diff1 > 0 && diff2 > 0) {
return std::min(diff1, diff2);
} else if (diff1 > 0) {
return diff1;
} else {
return diff2;
}
}
// 获取浮点数的内部表示信息
static void print_float_info(T value) {
std::cout << "Value: " << std::setprecision(15) << value << "\n";
std::cout << "Hex bits: ";
if constexpr (std::is_same_v<T, float>) {
uint32_t bits;
std::memcpy(&bits, &value, sizeof(float));
std::cout << std::hex << bits << std::dec;
std::cout << "\nSign: " << ((bits >> 31) & 1);
std::cout << "\nExponent: " << ((bits >> 23) & 0xFF);
std::cout << " (bias-adjusted: " << ((bits >> 23) & 0xFF) - 127 << ")";
std::cout << "\nMantissa: " << (bits & 0x7FFFFF);
} else {
uint64_t bits;
std::memcpy(&bits, &value, sizeof(double));
std::cout << std::hex << bits << std::dec;
std::cout << "\nSign: " << ((bits >> 63) & 1);
std::cout << "\nExponent: " << ((bits >> 52) & 0x7FF);
std::cout << " (bias-adjusted: " << ((bits >> 52) & 0x7FF) - 1023 << ")";
std::cout << "\nMantissa: " << (bits & 0xFFFFFFFFFFFFF);
}
std::cout << "\nIs normal: " << std::boolalpha << is_normal(value);
std::cout << "\nIs subnormal: " << is_subnormal(value);
std::cout << "\nULP (our): " << ulp(value);
std::cout << "\nULP (simple): " << ulp_simple(value);
std::cout << "\n---\n";
}
};
// 2.2 用户友好接口
template <typename T>
inline T math_ulp(T x) {
return MathUlp<T>::ulp(x);
}
template <typename T>
inline T math_ulp_simple(T x) {
return MathUlp<T>::ulp_simple(x);
}
// 2.3 类型别名
using Ulp32 = MathUlp<float>;
using Ulp64 = MathUlp<double>;
using Ulp128 = MathUlp<long double>;

View File

@ -0,0 +1,138 @@
/*
Description:
Copyright : All right reserved by ICT
Author : Zhang Zhonghai
Date : 2025/12/25
*/
#pragma once
#include <cassert>
#include <climits>
#include <cmath>
#include <vector>
#include "binomial_dist.h"
#include "math_const.h"
#include "math_ulp.h"
using std::vector;
struct MathUtils {
// 初始化math库里所有需要静态初始化的类静态变量
static void StaticInit() {
}
/**
* Rounds the double to the given number of decimal places.
* For example, rounding 3.1415926 to 3 places would give 3.142.
* The requirement is that it works exactly as writing a number down with string.format and reading back in.
*/
static double RoundToNDecimalPlaces(const double in, const int n) {
assert(n > 0);
const double mult = std::pow(10, n);
return std::round((in + Ulp64::ulp(in)) * mult) / mult;
}
/**
* binomial Probability(int, int, double) with log applied to result
*/
static double logBinomialProbability(const int n, const int k, const double p) {
BinomialDistribution binomial(n, p);
return binomial.logProbability(k);
}
template <typename ArrayType>
static int maxElementIndex(const ArrayType& array, const int start, const int endIndex) {
int maxI = start;
for (int i = (start + 1); i < endIndex; i++) {
if (array[i] > array[maxI])
maxI = i;
}
return maxI;
}
static int fastRound(double d) { return (d > 0.0) ? (int)(d + 0.5) : (int)(d - 0.5); }
/**
* Compute in a numerically correct way the quantity log10(1-x)
*
* Uses the approximation log10(1-x) = log10(1/x - 1) + log10(x) to avoid very quick underflow
* in 1-x when x is very small
*
* log10(1-x) = log10( x * (1-x) / x )
* = log10(x) + log10( (1-x) / x)
* = log10(x) + log10(1/x - 1)
*
* @param x a positive double value between 0.0 and 1.0
* @return an estimate of log10(1-x)
*/
static double log10OneMinusX(const double x) {
if (x == 1.0)
return MathConst::DOUBLE_NEGATIVE_INFINITY;
else if (x == 0.0)
return 0.0;
else {
const double d = std::log10(1 / x - 1) + std::log10(x);
return std::isinf(d) || d > 0.0 ? 0.0 : d;
}
}
static double log10SumLog10(const vector<double>& log10Values, const int start) {
return log10SumLog10(log10Values, start, log10Values.size());
}
static double log10SumLog10(const vector<double>& log10Values) { return log10SumLog10(log10Values, 0); }
static double log10SumLog10(const vector<double>& log10Values, const int start, const int finish) {
// Utils.nonNull(log10Values);
if (start >= finish) {
return MathConst::DOUBLE_NEGATIVE_INFINITY;
}
const int maxElementIndex = MathUtils::maxElementIndex(log10Values, start, finish);
const double maxValue = log10Values[maxElementIndex];
if (maxValue == MathConst::DOUBLE_NEGATIVE_INFINITY) {
return maxValue;
}
double sum = 1.0;
for (int i = start; i < finish; i++) {
const double curVal = log10Values[i];
if (i == maxElementIndex || curVal == MathConst::DOUBLE_NEGATIVE_INFINITY) {
continue;
} else {
const double scaled_val = curVal - maxValue;
sum += std::pow(10.0, scaled_val);
}
}
if (std::isinf(sum) || sum == MathConst::DOUBLE_POSITIVE_INFINITY) {
// throw new IllegalArgumentException("log10 p: Values must be non-infinite and non-NAN");
}
return maxValue + (sum != 1.0 ? std::log10(sum) : 0.0);
}
static double log10SumLog10(const double a, const double b) {
return a > b ? a + std::log10(1 + std::pow(10.0, b - a)) : b + std::log10(1 + std::pow(10.0, a - b));
}
/**
* Do the log-sum trick for three double values.
* @param a
* @param b
* @param c
* @return the sum... perhaps NaN or infinity if it applies.
*/
static double log10SumLog10(const double a, const double b, const double c) {
if (a >= b && a >= c) {
return a + std::log10(1 + std::pow(10.0, b - a) + std::pow(10.0, c - a));
} else if (b >= c) {
return b + std::log10(1 + std::pow(10.0, a - b) + std::pow(10.0, c - b));
} else {
return c + std::log10(1 + std::pow(10.0, a - c) + std::pow(10.0, b - c));
}
}
};

View File

@ -0,0 +1,48 @@
/*
Description:
Copyright : All right reserved by ICT
Author : Zhang Zhonghai
Date : 2025/12/26
*/
#pragma once
#include <math.h>
struct NormalDistribution {
/**
* Default inverse cumulative probability accuracy.
* @since 2.1
*/
static constexpr double DEFAULT_INVERSE_ABSOLUTE_ACCURACY = 1e-9;
/** &radic;(2) */
static constexpr double SQRT2 = 1.414213562373095;
static constexpr double PI = 105414357.0 / 33554432.0 + 1.984187159361080883e-9;
/** Mean of this distribution. */
double mean;
/** Standard deviation of this distribution. */
double standardDeviation;
/** The value of {@code log(sd) + 0.5*log(2*pi)} stored for faster computation. */
double logStandardDeviationPlusHalfLog2Pi;
/** Inverse cumulative probability accuracy. */
double solverAbsoluteAccuracy;
NormalDistribution(double mean, double sd) {
this->mean = mean;
standardDeviation = sd;
logStandardDeviationPlusHalfLog2Pi = std::log(sd) + 0.5 * std::log(2 * PI);
solverAbsoluteAccuracy = DEFAULT_INVERSE_ABSOLUTE_ACCURACY;
}
double logDensity(double x) const {
const double x0 = x - mean;
const double x1 = x0 / standardDeviation;
return -0.5 * x1 * x1 - logStandardDeviationPlusHalfLog2Pi;
}
};

View File

@ -0,0 +1,13 @@
#include "precision.h"
const double Precision::EPSILON = MathFunc::longBitsToDouble((EXPONENT_OFFSET - 53l) << 52);
const double Precision::SAFE_MIN = MathFunc::longBitsToDouble((EXPONENT_OFFSET - 1022l) << 52);
const int64_t Precision::POSITIVE_ZERO_DOUBLE_BITS = MathFunc::doubleToRawLongBits(+0.0);
const int64_t Precision::NEGATIVE_ZERO_DOUBLE_BITS = MathFunc::doubleToRawLongBits(-0.0);
const int Precision::POSITIVE_ZERO_FLOAT_BITS = MathFunc::floatToRawIntBits(+0.0f);
const int Precision::NEGATIVE_ZERO_FLOAT_BITS = MathFunc::floatToRawIntBits(-0.0f);

View File

@ -0,0 +1,122 @@
/*
Description:
Copyright : All right reserved by ICT
Author : Zhang Zhonghai
Date : 2025/12/26
*/
#pragma once
#include <cmath>
#include <cstdint>
#include "math_func.h"
struct Precision {
/**
* <p>
* Largest double-precision floating-point number such that
* {@code 1 + EPSILON} is numerically equal to 1. This value is an upper
* bound on the relative error due to rounding real numbers to double
* precision floating-point numbers.
* </p>
* <p>
* In IEEE 754 arithmetic, this is 2<sup>-53</sup>.
* </p>
*
* @see <a href="http://en.wikipedia.org/wiki/Machine_epsilon">Machine epsilon</a>
*/
static const double EPSILON;
/**
* Safe minimum, such that {@code 1 / SAFE_MIN} does not overflow.
* <br/>
* In IEEE 754 arithmetic, this is also the smallest normalized
* number 2<sup>-1022</sup>.
*/
static const double SAFE_MIN;
/** Exponent offset in IEEE754 representation. */
static constexpr int64_t EXPONENT_OFFSET = 1023l;
/** Offset to order signed double numbers lexicographically. */
static constexpr int64_t SGN_MASK = 0x8000000000000000L;
/** Offset to order signed double numbers lexicographically. */
static constexpr int SGN_MASK_FLOAT = 0x80000000;
/** Positive zero. */
static constexpr double POSITIVE_ZERO = 0.0;
/** Positive zero bits. */
static const int64_t POSITIVE_ZERO_DOUBLE_BITS;
/** Negative zero bits. */
static const int64_t NEGATIVE_ZERO_DOUBLE_BITS;
/** Positive zero bits. */
static const int POSITIVE_ZERO_FLOAT_BITS;
/** Negative zero bits. */
static const int NEGATIVE_ZERO_FLOAT_BITS;
/**
* Returns {@code true} if there is no double value strictly between the
* arguments or the difference between them is within the range of allowed
* error (inclusive).
*
* @param x First value.
* @param y Second value.
* @param eps Amount of allowed absolute error.
* @return {@code true} if the values are two adjacent floating point
* numbers or they are within range of each other.
*/
static bool equals(double x, double y, double eps) { return equals(x, y, 1) || std::abs(y - x) <= eps; }
/**
* Returns true if both arguments are equal or within the range of allowed
* error (inclusive).
* <p>
* Two float numbers are considered equal if there are {@code (maxUlps - 1)}
* (or fewer) floating point numbers between them, i.e. two adjacent
* floating point numbers are considered equal.
* </p>
* <p>
* Adapted from <a
* href="http://randomascii.wordpress.com/2012/02/25/comparing-floating-point-numbers-2012-edition/">
* Bruce Dawson</a>
* </p>
*
* @param x first value
* @param y second value
* @param maxUlps {@code (maxUlps - 1)} is the number of floating point
* values between {@code x} and {@code y}.
* @return {@code true} if there are fewer than {@code maxUlps} floating
* point values between {@code x} and {@code y}.
*/
static bool equals(const double x, const double y, const int maxUlps) {
const int64_t xInt = MathFunc::doubleToRawLongBits(x);
const int64_t yInt = MathFunc::doubleToRawLongBits(y);
bool isEqual;
if (((xInt ^ yInt) & SGN_MASK) == 0l) {
// number have same sign, there is no risk of overflow
isEqual = std::abs(xInt - yInt) <= maxUlps;
} else {
// number have opposite signs, take care of overflow
int64_t deltaPlus;
int64_t deltaMinus;
if (xInt < yInt) {
deltaPlus = yInt - POSITIVE_ZERO_DOUBLE_BITS;
deltaMinus = xInt - NEGATIVE_ZERO_DOUBLE_BITS;
} else {
deltaPlus = xInt - POSITIVE_ZERO_DOUBLE_BITS;
deltaMinus = yInt - NEGATIVE_ZERO_DOUBLE_BITS;
}
if (deltaPlus > maxUlps) {
isEqual = false;
} else {
isEqual = deltaMinus <= (maxUlps - deltaPlus);
}
}
return isEqual && !std::isnan(x) && !std::isnan(y);
}
};

View File

@ -0,0 +1,37 @@
#include "saddle_pe.h"
const double SaddlePointExpansion::HALF_LOG_2_PI = 0.5 * std::log(TWO_PI);
const double SaddlePointExpansion::EXACT_STIRLING_ERRORS[31] = {
0.0, /* 0.0 */
0.1534264097200273452913848, /* 0.5 */
0.0810614667953272582196702, /* 1.0 */
0.0548141210519176538961390, /* 1.5 */
0.0413406959554092940938221, /* 2.0 */
0.03316287351993628748511048, /* 2.5 */
0.02767792568499833914878929, /* 3.0 */
0.02374616365629749597132920, /* 3.5 */
0.02079067210376509311152277, /* 4.0 */
0.01848845053267318523077934, /* 4.5 */
0.01664469118982119216319487, /* 5.0 */
0.01513497322191737887351255, /* 5.5 */
0.01387612882307074799874573, /* 6.0 */
0.01281046524292022692424986, /* 6.5 */
0.01189670994589177009505572, /* 7.0 */
0.01110455975820691732662991, /* 7.5 */
0.010411265261972096497478567, /* 8.0 */
0.009799416126158803298389475, /* 8.5 */
0.009255462182712732917728637, /* 9.0 */
0.008768700134139385462952823, /* 9.5 */
0.008330563433362871256469318, /* 10.0 */
0.007934114564314020547248100, /* 10.5 */
0.007573675487951840794972024, /* 11.0 */
0.007244554301320383179543912, /* 11.5 */
0.006942840107209529865664152, /* 12.0 */
0.006665247032707682442354394, /* 12.5 */
0.006408994188004207068439631, /* 13.0 */
0.006171712263039457647532867, /* 13.5 */
0.005951370112758847735624416, /* 14.0 */
0.005746216513010115682023589, /* 14.5 */
0.005554733551962801371038690 /* 15.0 */
};

View File

@ -0,0 +1,128 @@
/*
Description: saddle point expansion
Copyright : All right reserved by ICT
Author : Zhang Zhonghai
Date : 2025/12/26
*/
#pragma once
#include <math.h>
#include <limits>
#include "gamma.h"
struct SaddlePointExpansion {
static constexpr double TWO_PI = 2 * (105414357.0 / 33554432.0 + 1.984187159361080883e-9);
/** 1/2 * log(2 &#960;). */
static const double HALF_LOG_2_PI;
/** exact Stirling expansion error for certain values. */
static const double EXACT_STIRLING_ERRORS[31];
/**
* A part of the deviance portion of the saddle point approximation.
* <p>
* References:
* <ol>
* <li>Catherine Loader (2000). "Fast and Accurate Computation of Binomial
* Probabilities.". <a target="_blank"
* href="http://www.herine.net/stat/papers/dbinom.pdf">
* http://www.herine.net/stat/papers/dbinom.pdf</a></li>
* </ol>
* </p>
*
* @param x the x value.
* @param mu the average.
* @return a part of the deviance.
*/
static double getDeviancePart(double x, double mu) {
double ret;
if (std::abs(x - mu) < 0.1 * (x + mu)) {
double d = x - mu;
double v = d / (x + mu);
double s1 = v * d;
double s = std::numeric_limits<double>::quiet_NaN();
double ej = 2.0 * x * v;
v *= v;
int j = 1;
while (s1 != s) {
s = s1;
ej *= v;
s1 = s + ej / ((j * 2) + 1);
++j;
}
ret = s1;
} else {
ret = x * std::log(x / mu) + mu - x;
}
return ret;
}
/**
* Compute the error of Stirling's series at the given value.
* <p>
* References:
* <ol>
* <li>Eric W. Weisstein. "Stirling's Series." From MathWorld--A Wolfram Web
* Resource. <a target="_blank"
* href="http://mathworld.wolfram.com/StirlingsSeries.html">
* http://mathworld.wolfram.com/StirlingsSeries.html</a></li>
* </ol>
* </p>
*
* @param z the value.
* @return the Striling's series error.
*/
static double getStirlingError(double z) {
double ret;
if (z < 15.0) {
double z2 = 2.0 * z;
if (std::floor(z2) == z2) {
ret = EXACT_STIRLING_ERRORS[(int)z2];
} else {
ret = Gamma::logGamma(z + 1.0) - (z + 0.5) * std::log(z) + z - HALF_LOG_2_PI;
}
} else {
double z2 = z * z;
ret = (0.083333333333333333333 -
(0.00277777777777777777778 -
(0.00079365079365079365079365 - (0.000595238095238095238095238 - 0.0008417508417508417508417508 / z2) / z2) / z2) /
z2) /
z;
}
return ret;
}
/**
* Compute the logarithm of the PMF for a binomial distribution
* using the saddle point expansion.
*
* @param x the value at which the probability is evaluated.
* @param n the number of trials.
* @param p the probability of success.
* @param q the probability of failure (1 - p).
* @return log(p(x)).
*/
static double logBinomialProbability(int x, int n, double p, double q) {
double ret;
if (x == 0) {
if (p < 0.1) {
ret = -getDeviancePart(n, n * q) - n * p;
} else {
ret = n * std::log(q);
}
} else if (x == n) {
if (q < 0.1) {
ret = -getDeviancePart(n, n * p) - n * q;
} else {
ret = n * std::log(p);
}
} else {
ret = getStirlingError(n) - getStirlingError(x) - getStirlingError(n - x) - getDeviancePart(x, n * p) - getDeviancePart(n - x, n * q);
double f = (TWO_PI * x * (n - x)) / n;
ret = -0.5 * std::log(f) + ret;
}
return ret;
}
};

View File

@ -57,6 +57,7 @@ int DisplayProfiling(int nthread) {
PRINT_GP(read_ref); PRINT_GP(read_ref);
PRINT_GP(read_vcf); PRINT_GP(read_vcf);
PRINT_GP(covariate); PRINT_GP(covariate);
PRINT_GP(update_info);
//PRINT_GP(markdup); //PRINT_GP(markdup);
//PRINT_GP(intersect); //PRINT_GP(intersect);
// PRINT_GP(merge_result); // PRINT_GP(merge_result);

View File

@ -43,6 +43,7 @@ enum {
GP_covariate, GP_covariate,
GP_read_ref, GP_read_ref,
GP_read_vcf, GP_read_vcf,
GP_update_info,
GP_gen_wait, GP_gen_wait,
GP_sort_wait, GP_sort_wait,
GP_markdup_wait, GP_markdup_wait,

View File

@ -0,0 +1,8 @@
#include "report_table.h"
const string ReportTable::COULD_NOT_READ_HEADER = "Could not read the header of this file -- ";
const string ReportTable::COULD_NOT_READ_COLUMN_NAMES = "Could not read the column names of this file -- ";
const string ReportTable::COULD_NOT_READ_DATA_LINE = "Could not read a data line of this table -- ";
const string ReportTable::COULD_NOT_READ_EMPTY_LINE = "Could not read the last empty line of this table -- ";
const string ReportTable::OLD_GATK_TABLE_VERSION = "We no longer support older versions of the GATK Tables";
const char* ReportTable::TABLE_PREFIX = "GATKTable";

View File

@ -0,0 +1,253 @@
/*
Description:
Copyright : All right reserved by ICT
Author : Zhang Zhonghai
Date : 2025/12/27
*/
#pragma once
#include <string>
#include <vector>
#include <cstdio>
#include <cassert>
#include <algorithm>
#include <sstream>
#include <iomanip>
#include <spdlog/spdlog.h>
using std::string;
using std::stringstream;
using std::vector;
#define REPORT_HEADER_VERSION "#:GATKReport.v1.1:5"
struct ReportUtil {
static string ToString(const bool val) { return val ? "true" : "false"; }
static string ToString(const char val) {
string s = "";
s += val;
// spdlog::info("char: {}, str: {}", val, s);
return s;
}
static string ToString(const string& val) { return val == "" ? "null" : val; }
static string ToString(const double val, int precise) {
stringstream ss;
ss << std::fixed << std::setprecision(precise) << val;
return ss.str();
}
template<typename T>
static string ToString(const T val) { return std::to_string(val); }
};
struct ReportTable {
// 静态数据
static const string COULD_NOT_READ_HEADER;
static const string COULD_NOT_READ_COLUMN_NAMES;
static const string COULD_NOT_READ_DATA_LINE;
static const string COULD_NOT_READ_EMPTY_LINE;
static const string OLD_GATK_TABLE_VERSION;
static const char* TABLE_PREFIX;
// 数据表排序方式
// 好像目前都是按照SORT_BY_COLUMN排序的这个排序是以此按照每一个cloumn内的所有行数据进行排序从左到右目的还是排列行的数据
// SORT_BY_ROW应该就是按照row id排序的
enum struct Sorting { SORT_BY_ROW, SORT_BY_COLUMN, DO_NOT_SORT };
// 数据类型
enum struct DataType {
/**
* The null type should not be used.
*/
NullType,
/**
* The default value when a format string is not present
*/
Unknown,
/**
* Used for boolean values. Will display as true or false in the table.
*/
Boolean,
/**
* Used for char values. Will display as a char so use printable values!
*/
Character,
/**
* Used for float and double values. Will output a decimal with format %.8f unless otherwise specified.
*/
Decimal,
/**
* Used for int, byte, short, and long values. Will display the full number by default.
*/
Integer,
/**
* Used for string values. Displays the string itself.
*/
String
};
// 列类型
struct Column {
// 只有column内容全是数字时(或者naninfnull这些),可以右对齐,剩下的字符或者字符串只能左对齐
enum struct Alignment {LEFT, RIGHT};
string columnName;
string format;
DataType dataType;
Alignment alignment = Alignment::RIGHT;
int maxWidth = 0;
Column() {}
Column(const string& name) : Column(name, "%s") {}
Column(const string& name, const string& _format) : Column(name, _format, DataType::String) {}
Column(const string& name, const string& _format, const DataType _dataType) : Column(name, _format, _dataType, Alignment::RIGHT) {}
Column(const string& name, const string& _format, const DataType _dataType, const Alignment align) { init(name, _format, _dataType, align); }
void init(const string& name, const string &_format, const DataType _dataType, const Alignment align) {
columnName = name;
format = _format;
dataType = _dataType;
alignment = align;
maxWidth = name.size();
updateFormat(*this); // format和datatype要一致
}
void updateMaxWidth(const string &value) {
if (!value.empty()) {
maxWidth = std::max(maxWidth, (int)value.size());
}
}
// 根据col的format更新一下对齐等信息
static void updateFormat(Column &col) {
if (col.format == "") {
col.dataType = DataType::Unknown;
col.format = "%s";
} else if (col.format.find('s') != string::npos) {
col.dataType = DataType::String;
} else if (col.format.find('d') != string::npos) {
col.dataType = DataType::Integer;
} else if (col.format.find('f') != string::npos) {
col.dataType = DataType::Decimal;
}
// spdlog::info("type: {}, align {}", col.dataType == DataType::Integer, col.alignment == Alignment::LEFT);
// 只有数字可以右对齐
if (col.dataType != DataType::Decimal && col.dataType != DataType::Integer)
col.alignment = Column::Alignment::LEFT;
}
string getNameFormat() { return "%-" + std::to_string(maxWidth) + "s";}
string getValueFormat() {
string align = alignment == Alignment::LEFT ? "-" : "";
return "%" + align + std::to_string(maxWidth) + "s"; // 因为所有值都转换成string类型了
}
};
using RowData = vector<string>; // 每一行的数据
string tableName; // 数据表名称
string tableDescription; // 数据表描述
Sorting sortingWay; // 排序方式
vector<Column> columnInfo; // 包含数据表的所有列信息
vector<RowData> underlyingData; // 保存数据表所有数据
vector<size_t> idxArr; // 保存data的索引用来排序
ReportTable() {}
ReportTable(const string& _tableName) : ReportTable(_tableName, "") {}
ReportTable(const string& _tableName, const string& _tableDescription) : ReportTable(_tableName, _tableDescription, Sorting::SORT_BY_COLUMN) {}
ReportTable(const string& _tableName, const string& _tableDescription, const Sorting sorting) { init(_tableName, _tableDescription, sorting); }
// 初始化
void init(const string& _tableName, const string& _tableDescription, const Sorting sorting) {
tableName = _tableName;
tableDescription = _tableDescription;
sortingWay = sorting;
}
// 添加column
void addColumn(Column col) {
// spdlog::info("col align: {}", col.alignment == Column::Alignment::RIGHT);
Column::updateFormat(col);
columnInfo.push_back(col);
}
void addRowData(const RowData &dat) {
assert(dat.size() == columnInfo.size());
idxArr.push_back(underlyingData.size());
underlyingData.push_back(dat);
for (int i = 0; i < dat.size(); ++i) {
columnInfo[i].updateMaxWidth(dat[i]);
}
}
void sort() {
std::sort(idxArr.begin(), idxArr.end(), [&](size_t a, size_t b) {
auto& r1 = underlyingData[a];
auto& r2 = underlyingData[b];
for (int i = 0; i < r1.size(); ++i) {
if (r1[i] != r2[i]) {
if (columnInfo[i].format.find('d') != string::npos) // 或者用datatype来判断
return std::stoll(r1[i]) < std::stoll(r2[i]);
else if (columnInfo[i].format.find('f') != string::npos) {
return std::stod(r1[i]) < std::stod(r2[i]);
} else {
return r1[i] < r2[i];
}
}
}
return a < b;
});
}
// 输出
void write(FILE *fpout) {
// 表头,列数:行数:列格式
fprintf(fpout, "#:%s:%d:%d:", TABLE_PREFIX, (int)columnInfo.size(), (int)underlyingData.size());
for (auto& col : columnInfo) fprintf(fpout, "%s:", col.format.c_str());
fprintf(fpout, ";\n");
// 表的说明
fprintf(fpout, "#:%s:%s:%s\n", TABLE_PREFIX, tableName.c_str(), tableDescription.c_str());
// 列名称
bool needPadding = false;
for (auto& col : columnInfo) {
if (needPadding)
fprintf(fpout, " ");
needPadding = true;
fprintf(fpout, col.getNameFormat().c_str(), col.columnName.c_str());
}
fprintf(fpout, "\n");
// 表的所有行数据
// 先排个序
sort();
for (auto idx : idxArr) {
auto &row = underlyingData[idx];
bool needPadding = false;
for (int i = 0; i < row.size(); ++i) {
if (needPadding)
fprintf(fpout, " ");
needPadding = true;
fprintf(fpout, columnInfo[i].getValueFormat().c_str(), row[i].c_str());
}
fprintf(fpout, "\n");
}
fprintf(fpout, "\n");
}
};
/*
// 输出参数相关的数据
struct ArgReportTable : ReportTable {
using ReportTable::ReportTable;
};
// 输出quantized quality相关的数据
struct QuantReportTable : ReportTable {
};
// 输出协变量相关的数据
struct CovariateReportTable : ReportTable {
};
*/

View File

@ -1,4 +1,11 @@
#pragma once #pragma once
#include <string>
using std::string;
// Report an error and exit -1 // Report an error and exit -1
void error(const char* format, ...); void error(const char* format, ...);
struct Utils {
};

File diff suppressed because it is too large Load Diff