bqsr第一阶段完成了,结果还有点错误,得调试一下
This commit is contained in:
parent
25f079b936
commit
146055fc01
|
|
@ -3,7 +3,7 @@
|
||||||
*.d
|
*.d
|
||||||
/.vscode
|
/.vscode
|
||||||
/build
|
/build
|
||||||
/text
|
/test
|
||||||
build.sh
|
build.sh
|
||||||
run.sh
|
run.sh
|
||||||
test.sh
|
test.sh
|
||||||
|
|
|
||||||
|
|
@ -4,6 +4,7 @@ set(EXECUTABLE_OUTPUT_PATH "${PROJECT_BINARY_DIR}/bin")
|
||||||
# source codes path
|
# source codes path
|
||||||
aux_source_directory(${PROJECT_SOURCE_DIR}/src MAIN_SRC)
|
aux_source_directory(${PROJECT_SOURCE_DIR}/src MAIN_SRC)
|
||||||
aux_source_directory(${PROJECT_SOURCE_DIR}/src/util UTIL_SRC)
|
aux_source_directory(${PROJECT_SOURCE_DIR}/src/util UTIL_SRC)
|
||||||
|
aux_source_directory(${PROJECT_SOURCE_DIR}/src/util/math UTIL_MATH_SRC)
|
||||||
aux_source_directory(${PROJECT_SOURCE_DIR}/src/bqsr BQSR_SRC)
|
aux_source_directory(${PROJECT_SOURCE_DIR}/src/bqsr BQSR_SRC)
|
||||||
|
|
||||||
set(KTHREAD_FILE ${PROJECT_SOURCE_DIR}/ext/klib/kthread.c)
|
set(KTHREAD_FILE ${PROJECT_SOURCE_DIR}/ext/klib/kthread.c)
|
||||||
|
|
@ -20,7 +21,12 @@ link_directories("${PROJECT_SOURCE_DIR}/ext/htslib")
|
||||||
set(PG_NAME "fastbqsr")
|
set(PG_NAME "fastbqsr")
|
||||||
|
|
||||||
# dependency files
|
# dependency files
|
||||||
add_executable(${PG_NAME} ${MAIN_SRC} ${UTIL_SRC} ${BQSR_SRC} ${KTHREAD_FILE})
|
add_executable(${PG_NAME}
|
||||||
|
${MAIN_SRC}
|
||||||
|
${UTIL_SRC}
|
||||||
|
${UTIL_MATH_SRC}
|
||||||
|
${BQSR_SRC}
|
||||||
|
${KTHREAD_FILE})
|
||||||
|
|
||||||
# link htslib
|
# link htslib
|
||||||
target_link_libraries(${PG_NAME} libhts.a)
|
target_link_libraries(${PG_NAME} libhts.a)
|
||||||
|
|
|
||||||
|
|
@ -56,6 +56,26 @@ struct BQSRArg {
|
||||||
|
|
||||||
// end of common parameters
|
// end of common parameters
|
||||||
|
|
||||||
|
// We always use the same covariates. The field is retained for compatibility with GATK3 reports.
|
||||||
|
bool DO_NOT_USE_STANDARD_COVARIATES = false;
|
||||||
|
|
||||||
|
//It makes no sense to run BQSR without sites. so we remove this option.
|
||||||
|
bool RUN_WITHOUT_DBSNP = false;
|
||||||
|
|
||||||
|
// We don't support SOLID. The field is retained for compatibility with GATK3 reports.
|
||||||
|
string SOLID_RECAL_MODE = "SET_Q_ZERO";
|
||||||
|
string SOLID_NOCALL_STRATEGY = "THROW_EXCEPTION";
|
||||||
|
|
||||||
|
//@Hidden @Argument(fullName = "default-platform", optional = true,
|
||||||
|
// doc = "If a read has no platform then default to the provided String. Valid options are illumina, 454, and solid.") public String
|
||||||
|
string DEFAULT_PLATFORM = "";
|
||||||
|
|
||||||
|
// @Hidden @Argument(fullName = "force-platform", optional = true,
|
||||||
|
// doc = "If provided, the platform of EVERY read will be forced to be the provided String. Valid options are illumina, 454, and "solid.")
|
||||||
|
string FORCE_PLATFORM = "";
|
||||||
|
|
||||||
|
string existingRecalibrationReport = "";
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* The context covariate will use a context of this size to calculate its covariate value for base mismatches. Must be
|
* The context covariate will use a context of this size to calculate its covariate value for base mismatches. Must be
|
||||||
* between 1 and 13 (inclusive). Note that higher values will increase runtime and required java heap size.
|
* between 1 and 13 (inclusive). Note that higher values will increase runtime and required java heap size.
|
||||||
|
|
|
||||||
|
|
@ -7,18 +7,18 @@ Copyright : All right reserved by ICT
|
||||||
Author : Zhang Zhonghai
|
Author : Zhang Zhonghai
|
||||||
Date : 2023/10/23
|
Date : 2023/10/23
|
||||||
*/
|
*/
|
||||||
|
#include <header.h>
|
||||||
#include <htslib/faidx.h>
|
#include <htslib/faidx.h>
|
||||||
#include <htslib/kstring.h>
|
#include <htslib/kstring.h>
|
||||||
#include <htslib/sam.h>
|
#include <htslib/sam.h>
|
||||||
#include <htslib/synced_bcf_reader.h>
|
#include <htslib/synced_bcf_reader.h>
|
||||||
#include <htslib/thread_pool.h>
|
#include <htslib/thread_pool.h>
|
||||||
#include <header.h>
|
|
||||||
#include <spdlog/spdlog.h>
|
#include <spdlog/spdlog.h>
|
||||||
|
|
||||||
#include <iomanip>
|
#include <iomanip>
|
||||||
#include <numeric>
|
#include <numeric>
|
||||||
#include <vector>
|
|
||||||
#include <queue>
|
#include <queue>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
#include "baq.h"
|
#include "baq.h"
|
||||||
#include "bqsr_args.h"
|
#include "bqsr_args.h"
|
||||||
|
|
@ -28,10 +28,16 @@ Date : 2023/10/23
|
||||||
#include "dup_metrics.h"
|
#include "dup_metrics.h"
|
||||||
#include "fastbqsr_version.h"
|
#include "fastbqsr_version.h"
|
||||||
#include "read_name_parser.h"
|
#include "read_name_parser.h"
|
||||||
|
#include "read_recal_info.h"
|
||||||
|
#include "recal_datum.h"
|
||||||
|
#include "recal_tables.h"
|
||||||
|
#include "recal_utils.h"
|
||||||
#include "util/interval.h"
|
#include "util/interval.h"
|
||||||
|
#include "util/linear_index.h"
|
||||||
#include "util/profiling.h"
|
#include "util/profiling.h"
|
||||||
#include "util/utils.h"
|
#include "util/utils.h"
|
||||||
#include "util/linear_index.h"
|
#include "util/math/math_utils.h"
|
||||||
|
#include "quant_info.h"
|
||||||
|
|
||||||
using std::deque;
|
using std::deque;
|
||||||
|
|
||||||
|
|
@ -89,7 +95,8 @@ BQSRArg gBqsrArg; // bqsr arguments
|
||||||
samFile* gInBamFp; // input BAM file pointer
|
samFile* gInBamFp; // input BAM file pointer
|
||||||
sam_hdr_t* gInBamHeader; // input BAM header
|
sam_hdr_t* gInBamHeader; // input BAM header
|
||||||
vector<AuxVar> gAuxVars; // auxiliary variables,保存一些文件,数据等,每个线程对应一个
|
vector<AuxVar> gAuxVars; // auxiliary variables,保存一些文件,数据等,每个线程对应一个
|
||||||
|
RecalTables gRecalTables; // 记录bqsr所有的数据,输出table结果
|
||||||
|
vector<EventTypeValue> gEventTypes; // 需要真正计算的eventtype
|
||||||
|
|
||||||
// 下面是需要删除或修改的变量
|
// 下面是需要删除或修改的变量
|
||||||
std::vector<ReadNameParser> gNameParsers; // read name parser
|
std::vector<ReadNameParser> gNameParsers; // read name parser
|
||||||
|
|
@ -258,12 +265,12 @@ void calculateRefOffset(BamWrap *bw, SamData &ad) {
|
||||||
// 计算clip处理之后,剩余的碱基
|
// 计算clip处理之后,剩余的碱基
|
||||||
void calculateReadBases(BamWrap* bw, SamData& ad) {
|
void calculateReadBases(BamWrap* bw, SamData& ad) {
|
||||||
ad.bases.resize(ad.read_len);
|
ad.bases.resize(ad.read_len);
|
||||||
ad.quals.resize(ad.read_len);
|
ad.base_quals.resize(ad.read_len);
|
||||||
uint8_t* seq = bam_get_seq(bw->b);
|
uint8_t* seq = bam_get_seq(bw->b);
|
||||||
uint8_t* quals = bam_get_qual(bw->b);
|
uint8_t* quals = bam_get_qual(bw->b);
|
||||||
for (int i = 0; i < ad.read_len; ++i) {
|
for (int i = 0; i < ad.read_len; ++i) {
|
||||||
ad.bases[i] = cBaseToChar[bam_seqi(seq, i + ad.left_clip)];
|
ad.bases[i] = cBaseToChar[bam_seqi(seq, i + ad.left_clip)];
|
||||||
ad.quals[i] = quals[i];
|
ad.base_quals[i] = quals[i];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -402,6 +409,7 @@ int calculateIsSNPOrIndel(AuxVar& aux, BamWrap *bw, SamData &ad, vector<int> &is
|
||||||
for (int j = 0; j < len; ++j) {
|
for (int j = 0; j < len; ++j) {
|
||||||
// 按位置将read和ref碱基进行比较,不同则是snp,注意read起始位置要加上left_clip
|
// 按位置将read和ref碱基进行比较,不同则是snp,注意read起始位置要加上left_clip
|
||||||
int snpInt = cBaseToChar[bam_seqi(seq, readPos + ad.left_clip)] == refBases[refPos] ? 0 : 1;
|
int snpInt = cBaseToChar[bam_seqi(seq, readPos + ad.left_clip)] == refBases[refPos] ? 0 : 1;
|
||||||
|
// if (snpInt > 0) { spdlog::info("snp {}, readpos: {}", snpInt, readPos); }
|
||||||
isSNP[readPos] = snpInt;
|
isSNP[readPos] = snpInt;
|
||||||
nEvents += snpInt;
|
nEvents += snpInt;
|
||||||
readPos++;
|
readPos++;
|
||||||
|
|
@ -429,6 +437,7 @@ int calculateIsSNPOrIndel(AuxVar& aux, BamWrap *bw, SamData &ad, vector<int> &is
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
nEvents += std::accumulate(isIns.begin(), isIns.end(), 0) + std::accumulate(isDel.begin(), isDel.end(), 0);
|
nEvents += std::accumulate(isIns.begin(), isIns.end(), 0) + std::accumulate(isDel.begin(), isDel.end(), 0);
|
||||||
|
// spdlog::info("nEvents: {}", nEvents);
|
||||||
|
|
||||||
//spdlog::info("SNPs: {}, Ins: {}, Del: {}, total events: {}", std::accumulate(isSNP.begin(), isSNP.end(), 0),
|
//spdlog::info("SNPs: {}, Ins: {}, Del: {}, total events: {}", std::accumulate(isSNP.begin(), isSNP.end(), 0),
|
||||||
// std::accumulate(isIns.begin(), isIns.end(), 0), std::accumulate(isDel.begin(), isDel.end(), 0), nEvents);
|
// std::accumulate(isIns.begin(), isIns.end(), 0), std::accumulate(isDel.begin(), isDel.end(), 0), nEvents);
|
||||||
|
|
@ -438,13 +447,13 @@ int calculateIsSNPOrIndel(AuxVar& aux, BamWrap *bw, SamData &ad, vector<int> &is
|
||||||
}
|
}
|
||||||
|
|
||||||
// 简单计算baq数组,就是全部赋值为'@' (64)
|
// 简单计算baq数组,就是全部赋值为'@' (64)
|
||||||
bool flatBAQArray(BamWrap* bw, SamData& ad, vector<int>& baqArray) {
|
bool flatBAQArray(BamWrap* bw, SamData& ad, vector<uint8_t>& baqArray) {
|
||||||
baqArray.resize(ad.read_len, (int)'@');
|
baqArray.resize(ad.read_len, (uint8_t)'@');
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
// 计算真实的baq数组,耗时更多,好像enable-baq参数默认是关闭的,那就先不实现这个了
|
// 计算真实的baq数组,耗时更多,好像enable-baq参数默认是关闭的,那就先不实现这个了
|
||||||
bool calculateBAQArray(AuxVar& aux, BAQ& baq, BamWrap* bw, SamData& ad, vector<int>& baqArray) {
|
bool calculateBAQArray(AuxVar& aux, BAQ& baq, BamWrap* bw, SamData& ad, vector<uint8_t>& baqArray) {
|
||||||
baqArray.resize(ad.read_len, 0);
|
baqArray.resize(ad.read_len, 0);
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
@ -563,7 +572,8 @@ static void calculateAndStoreErrorsInBlock(int i, int blockStartIndex, vector<in
|
||||||
}
|
}
|
||||||
|
|
||||||
// 应该是用来处理BAQ的,把不等于特定BAQ分数的碱基作为一段数据统一处理
|
// 应该是用来处理BAQ的,把不等于特定BAQ分数的碱基作为一段数据统一处理
|
||||||
void calculateFractionalErrorArray(vector<int>& errorArr, vector<int>& baqArr, vector<double>& fracErrs) {
|
void calculateFractionalErrorArray(vector<int>& errorArr, vector<uint8_t>& baqArr, vector<double>& fracErrs) {
|
||||||
|
// for (auto val : errorArr) { if (val > 0) spdlog::info("snp err val: {}", val); }
|
||||||
fracErrs.resize(baqArr.size());
|
fracErrs.resize(baqArr.size());
|
||||||
// errorArray和baqArray必须长度相同
|
// errorArray和baqArray必须长度相同
|
||||||
const int BLOCK_START_UNSET = -1;
|
const int BLOCK_START_UNSET = -1;
|
||||||
|
|
@ -592,6 +602,95 @@ void calculateFractionalErrorArray(vector<int>& errorArr, vector<int>& baqArr, v
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Update the recalibration statistics using the information in recalInfo.
|
||||||
|
*
|
||||||
|
* Implementation detail: we only populate the quality score table and the optional tables.
|
||||||
|
* The read group table will be populated later by collapsing the quality score table.
|
||||||
|
*
|
||||||
|
* @param recalInfo data structure holding information about the recalibration values for a single read
|
||||||
|
*/
|
||||||
|
void updateRecalTablesForRead(ReadRecalInfo &info) {
|
||||||
|
SamData &read = info.read;
|
||||||
|
PerReadCovariateMatrix& readCovars = info.covariates;
|
||||||
|
Array3D<RecalDatum>& qualityScoreTable = nsgv::gRecalTables.qualityScoreTable;
|
||||||
|
Array4D<RecalDatum>& contextTable = nsgv::gRecalTables.contextTable;
|
||||||
|
Array4D<RecalDatum>& cycleTable = nsgv::gRecalTables.cycleTable;
|
||||||
|
|
||||||
|
int readLength = read.read_len;
|
||||||
|
for (int offset = 0; offset < readLength; ++offset) {
|
||||||
|
if (!info.skips[offset]) { // 不跳过当前位置
|
||||||
|
for (int idx = 0; idx < nsgv::gEventTypes.size(); ++idx) {
|
||||||
|
// 获取四个值,readgroup / qualityscore / context / cycle
|
||||||
|
EventTypeValue& event = nsgv::gEventTypes[idx];
|
||||||
|
vector<int>& covariatesAtOffset = readCovars[event.index][offset];
|
||||||
|
uint8_t qual = info.getQual(event, offset);
|
||||||
|
double isError = info.getErrorFraction(event, offset);
|
||||||
|
|
||||||
|
int readGroup = covariatesAtOffset[ReadGroupCovariate::index];
|
||||||
|
int baseQuality = covariatesAtOffset[BaseQualityCovariate::index];
|
||||||
|
|
||||||
|
// 处理base quality score协变量
|
||||||
|
// RecalUtils::IncrementDatum3keys(qualityScoreTable, qual, isError, readGroup, baseQuality, event.index);
|
||||||
|
qualityScoreTable[readGroup][baseQuality][event.index].increment(1, isError, baseQuality);
|
||||||
|
|
||||||
|
auto &d = qualityScoreTable[readGroup][baseQuality][event.index];
|
||||||
|
// spdlog::info("isError {} : {}, mis {}, obs {}", isError, info.snp_errs[offset], d.numMismatches, d.numObservations);
|
||||||
|
|
||||||
|
// 处理context covariate
|
||||||
|
int contextCovariate = covariatesAtOffset[ContextCovariate::index];
|
||||||
|
if (contextCovariate >= 0)
|
||||||
|
contextTable[readGroup][baseQuality][contextCovariate][event.index].increment(1, isError, baseQuality);
|
||||||
|
// RecalUtils::IncrementDatum4keys(nsgv::gRecalTables.contextTable, qual, isError, readGroup, baseQuality, contextCovariate,
|
||||||
|
// event.index);
|
||||||
|
// 处理cycle covariate
|
||||||
|
int cycleCovariate = covariatesAtOffset[CycleCovariate::index];
|
||||||
|
if (cycleCovariate >= 0)
|
||||||
|
cycleTable[readGroup][baseQuality][cycleCovariate][event.index].increment(1, isError, baseQuality);
|
||||||
|
// RecalUtils::IncrementDatum4keys(nsgv::gRecalTables.cycleTable, qual, isError, readGroup, baseQuality, cycleCovariate, event.index);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 数据总结
|
||||||
|
void collapseQualityScoreTableToReadGroupTable(Array2D<RecalDatum> &byReadGroupTable, Array3D<RecalDatum> &byQualTable) {
|
||||||
|
// 遍历quality table
|
||||||
|
for (int k1 = 0; k1 < byQualTable.data.size(); ++k1) {
|
||||||
|
for (int k2 = 0; k2 < byQualTable[k1].size(); ++k2) {
|
||||||
|
for (int k3 = 0; k3 < byQualTable[k1][k2].size(); ++k3) {
|
||||||
|
auto& qualDatum = byQualTable[k1][k2][k3];
|
||||||
|
if (qualDatum.numObservations > 0) {
|
||||||
|
int rgKey = k1;
|
||||||
|
int eventIndex = k3;
|
||||||
|
// spdlog::info("k1 {}, k2 {}, k3 {}, numMis {}", k1, k2, k3, qualDatum.numMismatches);
|
||||||
|
byReadGroupTable[rgKey][eventIndex].combine(qualDatum);
|
||||||
|
// spdlog::info("rg {} {}, k3 {}", byReadGroupTable[rgKey][eventIndex].numMismatches, byReadGroupTable[rgKey][eventIndex].reportedQuality, k3);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* To replicate the results of BQSR whether or not we save tables to disk (which we need in Spark),
|
||||||
|
* we need to trim the numbers to a few decimal placed (that's what writing and reading does).
|
||||||
|
*/
|
||||||
|
void roundTableValues(RecalTables& rt) {
|
||||||
|
#define _round_val(val) \
|
||||||
|
do { \
|
||||||
|
if (val.numObservations > 0) { \
|
||||||
|
val.numMismatches = MathUtils::RoundToNDecimalPlaces(val.numMismatches, RecalUtils::NUMBER_ERRORS_DECIMAL_PLACES); \
|
||||||
|
val.reportedQuality = MathUtils::RoundToNDecimalPlaces(val.reportedQuality, RecalUtils::REPORTED_QUALITY_DECIMAL_PLACES); \
|
||||||
|
} \
|
||||||
|
} while (0)
|
||||||
|
|
||||||
|
_Foreach2D(rt.readGroupTable, val, { _round_val(val); });
|
||||||
|
_Foreach3D(rt.qualityScoreTable, val, { _round_val(val); });
|
||||||
|
_Foreach4D(rt.contextTable, val, { _round_val(val); });
|
||||||
|
_Foreach4D(rt.cycleTable, val, { _round_val(val); });
|
||||||
|
}
|
||||||
|
|
||||||
// 串行bqsr
|
// 串行bqsr
|
||||||
int SerialBQSR() {
|
int SerialBQSR() {
|
||||||
int round = 0;
|
int round = 0;
|
||||||
|
|
@ -601,6 +700,9 @@ int SerialBQSR() {
|
||||||
int64_t readNumSum = 0;
|
int64_t readNumSum = 0;
|
||||||
// 0. 初始化一些全局数据
|
// 0. 初始化一些全局数据
|
||||||
// BAQ baq{BAQ::DEFAULT_GOP};
|
// BAQ baq{BAQ::DEFAULT_GOP};
|
||||||
|
RecalDatum::StaticInit();
|
||||||
|
QualityUtils::StaticInit();
|
||||||
|
MathUtils::StaticInit();
|
||||||
|
|
||||||
// 1. 协变量数据相关初始化
|
// 1. 协变量数据相关初始化
|
||||||
PerReadCovariateMatrix readCovariates;
|
PerReadCovariateMatrix readCovariates;
|
||||||
|
|
@ -608,6 +710,14 @@ int SerialBQSR() {
|
||||||
ContextCovariate::InitContextCovariate(nsgv::gBqsrArg);
|
ContextCovariate::InitContextCovariate(nsgv::gBqsrArg);
|
||||||
CycleCovariate::InitCycleCovariate(nsgv::gBqsrArg);
|
CycleCovariate::InitCycleCovariate(nsgv::gBqsrArg);
|
||||||
|
|
||||||
|
// 注意初始化顺序,这个必须在协变量初始化之后
|
||||||
|
nsgv::gRecalTables.init(nsgv::gInBamHeader->hrecs->nrg);
|
||||||
|
nsgv::gEventTypes.push_back(EventType::BASE_SUBSTITUTION);
|
||||||
|
if (nsgv::gBqsrArg.computeIndelBQSRTables) {
|
||||||
|
nsgv::gEventTypes.push_back(EventType::BASE_INSERTION);
|
||||||
|
nsgv::gEventTypes.push_back(EventType::BASE_DELETION);
|
||||||
|
}
|
||||||
|
|
||||||
// 2. 读取bam的read group
|
// 2. 读取bam的read group
|
||||||
if (nsgv::gInBamHeader->hrecs->nrg == 0) {
|
if (nsgv::gInBamHeader->hrecs->nrg == 0) {
|
||||||
spdlog::error("No RG tag found in the header!");
|
spdlog::error("No RG tag found in the header!");
|
||||||
|
|
@ -641,8 +751,10 @@ int SerialBQSR() {
|
||||||
// 2. 对质量分数长度跟碱基长度不匹配的read,缺少的质量分数用默认值补齐,先忽略,后边有需要再处理
|
// 2. 对质量分数长度跟碱基长度不匹配的read,缺少的质量分数用默认值补齐,先忽略,后边有需要再处理
|
||||||
// 3. 如果bam文件之前做过bqsr,tag中包含OQ(originnal quality,原始质量分数),检查用户参数里是否指定用原始质量分数进行bqsr,如果是则将质量分数替换为OQ,否则忽略OQ,先忽略
|
// 3. 如果bam文件之前做过bqsr,tag中包含OQ(originnal quality,原始质量分数),检查用户参数里是否指定用原始质量分数进行bqsr,如果是则将质量分数替换为OQ,否则忽略OQ,先忽略
|
||||||
// 4. 对read的两端进行检测,去除(hardclip)adapter
|
// 4. 对read的两端进行检测,去除(hardclip)adapter
|
||||||
BamWrap *bw = bams[i];
|
// spdlog::info("bam idx: {}", i);
|
||||||
|
BamWrap* bw = bams[i];
|
||||||
SamData ad;
|
SamData ad;
|
||||||
|
ad.bw = bw;
|
||||||
ad.read_len = BamWrap::BamEffectiveLength(bw->b);
|
ad.read_len = BamWrap::BamEffectiveLength(bw->b);
|
||||||
ad.cigar_end = bw->b->core.n_cigar;
|
ad.cigar_end = bw->b->core.n_cigar;
|
||||||
if (ad.read_len <= 0) continue;
|
if (ad.read_len <= 0) continue;
|
||||||
|
|
@ -676,12 +788,16 @@ int SerialBQSR() {
|
||||||
vector<int> isIns(ad.read_len, 0); // 该位置是否是插入位置,0不是,1是
|
vector<int> isIns(ad.read_len, 0); // 该位置是否是插入位置,0不是,1是
|
||||||
vector<int> isDel(ad.read_len, 0); // 该位置是否是删除位置,0不是,1是
|
vector<int> isDel(ad.read_len, 0); // 该位置是否是删除位置,0不是,1是
|
||||||
const int nErrors = calculateIsSNPOrIndel(nsgv::gAuxVars[0], bw, ad, isSNP, isIns, isDel);
|
const int nErrors = calculateIsSNPOrIndel(nsgv::gAuxVars[0], bw, ad, isSNP, isIns, isDel);
|
||||||
|
// spdlog::info("nErrors: {}", nErrors);
|
||||||
|
// for (auto val : isSNP) { if (val > 0) spdlog::info("snp val: {}", val); }
|
||||||
|
|
||||||
|
//exit(0);
|
||||||
|
|
||||||
// 7. 计算baqArray
|
// 7. 计算baqArray
|
||||||
// BAQ = base alignment quality
|
// BAQ = base alignment quality
|
||||||
// note for efficiency reasons we don't compute the BAQ array unless we actually have
|
// note for efficiency reasons we don't compute the BAQ array unless we actually have
|
||||||
// some error to marginalize over. For ILMN data ~85% of reads have no error
|
// some error to marginalize over. For ILMN data ~85% of reads have no error
|
||||||
vector<int> baqArray;
|
vector<uint8_t> baqArray;
|
||||||
bool baqCalculated = false;
|
bool baqCalculated = false;
|
||||||
if (nErrors == 0 || !nsgv::gBqsrArg.enableBAQ) {
|
if (nErrors == 0 || !nsgv::gBqsrArg.enableBAQ) {
|
||||||
baqCalculated = flatBAQArray(bw, ad, baqArray);
|
baqCalculated = flatBAQArray(bw, ad, baqArray);
|
||||||
|
|
@ -699,28 +815,49 @@ int SerialBQSR() {
|
||||||
int end_pos = bw->contig_end_pos();
|
int end_pos = bw->contig_end_pos();
|
||||||
//spdlog::info("adapter: {}, read: {}, {}, strand: {}", adapter_boundary, bw->contig_pos(), end_pos,
|
//spdlog::info("adapter: {}, read: {}, {}, strand: {}", adapter_boundary, bw->contig_pos(), end_pos,
|
||||||
// bw->GetReadNegativeStrandFlag() ? "reverse" : "forward");
|
// bw->GetReadNegativeStrandFlag() ? "reverse" : "forward");
|
||||||
|
// for (auto val : isSNP) { if (val > 0) spdlog::info("snp err val-1: {}", val); }
|
||||||
// 9. 计算这条read需要跳过的位置
|
// 9. 计算这条read需要跳过的位置
|
||||||
vector<bool> skip(ad.read_len, 0);
|
vector<bool> skips(ad.read_len, 0);
|
||||||
PROF_START(known_sites);
|
PROF_START(known_sites);
|
||||||
calculateKnownSites(bw, ad, nsgv::gAuxVars[0].vcfArr, skip);
|
calculateKnownSites(bw, ad, nsgv::gAuxVars[0].vcfArr, skips);
|
||||||
for (int ii = 0; ii < ad.read_len; ++ii) {
|
for (int ii = 0; ii < ad.read_len; ++ii) {
|
||||||
skip[ii] =
|
skips[ii] = skips[ii] || (ContextCovariate::baseIndexMap[ad.bases[ii]] == -1) ||
|
||||||
skip[ii] || (ContextCovariate::baseIndexMap[ad.bases[ii]] == -1) || ad.quals[ii] < nsgv::gBqsrArg.PRESERVE_QSCORES_LESS_THAN;
|
ad.base_quals[ii] < nsgv::gBqsrArg.PRESERVE_QSCORES_LESS_THAN;
|
||||||
}
|
}
|
||||||
PROF_END(gprof[GP_read_vcf], known_sites);
|
PROF_END(gprof[GP_read_vcf], known_sites);
|
||||||
|
|
||||||
// 10. 根据BAQ进一步处理snp,indel,得到处理后的数据
|
// 10. 根据BAQ进一步处理snp,indel,得到处理后的数据
|
||||||
vector<double> snpErrors, insErrors, delErrors;
|
vector<double> snpErrors, insErrors, delErrors;
|
||||||
|
// for (auto val : isSNP) { if (val > 0) spdlog::info("snp err val-2: {}", val); }
|
||||||
calculateFractionalErrorArray(isSNP, baqArray, snpErrors);
|
calculateFractionalErrorArray(isSNP, baqArray, snpErrors);
|
||||||
calculateFractionalErrorArray(isIns, baqArray, insErrors);
|
calculateFractionalErrorArray(isIns, baqArray, insErrors);
|
||||||
calculateFractionalErrorArray(isDel, baqArray, delErrors);
|
calculateFractionalErrorArray(isDel, baqArray, delErrors);
|
||||||
|
|
||||||
|
// for (auto val : isSNP) { if (val > 0) spdlog::info("snp val: {}", val); }
|
||||||
|
//spdlog::info("snp errors size: {}, read len: {}", snpErrors.size(), ad.read_len);
|
||||||
|
//for (auto val : snpErrors) { if (val > 0) spdlog::info("snp err val: {}", val); }
|
||||||
|
|
||||||
// aggregate all of the info into our info object, and update the data
|
// aggregate all of the info into our info object, and update the data
|
||||||
// 11. 合并之前计算的数据,得到info,并更新bqsr table数据
|
// 11. 合并之前计算的数据,得到info,并更新bqsr table数据
|
||||||
|
ReadRecalInfo info(ad, readCovariates, skips, snpErrors, insErrors, delErrors);
|
||||||
|
int m = 0;
|
||||||
|
// for (auto err : snpErrors) { if (isSNP[m] > 0 || err > 0) spdlog::info("snp err: {} : {}", isSNP[m++], err); }
|
||||||
|
//exit(0);
|
||||||
|
PROF_START(update_info);
|
||||||
|
updateRecalTablesForRead(info);
|
||||||
|
PROF_END(gprof[GP_update_info], update_info);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// exit(0);
|
||||||
|
// 12. 创建总结数据
|
||||||
|
collapseQualityScoreTableToReadGroupTable(nsgv::gRecalTables.readGroupTable, nsgv::gRecalTables.qualityScoreTable);
|
||||||
|
roundTableValues(nsgv::gRecalTables);
|
||||||
|
|
||||||
|
// 13. 量化质量分数
|
||||||
|
QuantizationInfo quantInfo(nsgv::gRecalTables, nsgv::gBqsrArg.QUANTIZING_LEVELS);
|
||||||
|
|
||||||
|
// 14. 输出结果
|
||||||
|
RecalUtils::outputRecalibrationReport(nsgv::gBqsrArg, quantInfo, nsgv::gRecalTables);
|
||||||
#if 0
|
#if 0
|
||||||
// spdlog::info("region: {} - {}", bams[0]->global_softclip_start(), bams.back()->global_softclip_end());
|
// spdlog::info("region: {} - {}", bams[0]->global_softclip_start(), bams.back()->global_softclip_end());
|
||||||
// 1. 获取bams数组覆盖的region范围
|
// 1. 获取bams数组覆盖的region范围
|
||||||
|
|
|
||||||
|
|
@ -4,6 +4,7 @@
|
||||||
EventTypeValue EventType::BASE_SUBSTITUTION = {0, 'M', "Base Substitution"};
|
EventTypeValue EventType::BASE_SUBSTITUTION = {0, 'M', "Base Substitution"};
|
||||||
EventTypeValue EventType::BASE_INSERTION = {1, 'I', "Base Insertion"};
|
EventTypeValue EventType::BASE_INSERTION = {1, 'I', "Base Insertion"};
|
||||||
EventTypeValue EventType::BASE_DELETION = {2, 'D', "Base Deletion"};
|
EventTypeValue EventType::BASE_DELETION = {2, 'D', "Base Deletion"};
|
||||||
|
vector<EventTypeValue> EventType::EVENTS = {BASE_SUBSTITUTION, BASE_INSERTION, BASE_DELETION};
|
||||||
|
|
||||||
// static变量 for ContextCovariate
|
// static变量 for ContextCovariate
|
||||||
int ContextCovariate::mismatchesContextSize;
|
int ContextCovariate::mismatchesContextSize;
|
||||||
|
|
|
||||||
|
|
@ -25,7 +25,7 @@ using std::string;
|
||||||
using std::vector;
|
using std::vector;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* This is where we store the pre-read covariates, also indexed by (event type) and (read position).
|
* This is where we store the per-read covariates, also indexed by (event type) and (read position).
|
||||||
* Thus the array has shape { event type } x { read position (aka cycle) } x { covariate }.
|
* Thus the array has shape { event type } x { read position (aka cycle) } x { covariate }.
|
||||||
* For instance, { covariate } is by default 4-dimensional (read group, base quality, context, cycle).
|
* For instance, { covariate } is by default 4-dimensional (read group, base quality, context, cycle).
|
||||||
*/
|
*/
|
||||||
|
|
@ -36,6 +36,7 @@ struct EventTypeValue {
|
||||||
int index; // 在协变量数组中对应的索引
|
int index; // 在协变量数组中对应的索引
|
||||||
char representation;
|
char representation;
|
||||||
string longRepresentation;
|
string longRepresentation;
|
||||||
|
bool operator==(const EventTypeValue& a) const { return a.index == index; }
|
||||||
};
|
};
|
||||||
|
|
||||||
struct EventType {
|
struct EventType {
|
||||||
|
|
@ -43,6 +44,7 @@ struct EventType {
|
||||||
static EventTypeValue BASE_SUBSTITUTION;
|
static EventTypeValue BASE_SUBSTITUTION;
|
||||||
static EventTypeValue BASE_INSERTION;
|
static EventTypeValue BASE_INSERTION;
|
||||||
static EventTypeValue BASE_DELETION;
|
static EventTypeValue BASE_DELETION;
|
||||||
|
static vector<EventTypeValue> EVENTS;
|
||||||
};
|
};
|
||||||
|
|
||||||
// 协变量相关的工具类
|
// 协变量相关的工具类
|
||||||
|
|
@ -140,6 +142,17 @@ struct ContextCovariate {
|
||||||
baseIndexMap['t'] = 3;
|
baseIndexMap['t'] = 3;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static int MaximumKeyValue() {
|
||||||
|
int length = max(mismatchesContextSize, indelsContextSize);
|
||||||
|
int key = length;
|
||||||
|
int bitOffset = LENGTH_BITS;
|
||||||
|
for (int i = 0; i < length; ++i) {
|
||||||
|
key |= (3 << bitOffset);
|
||||||
|
bitOffset += 2;
|
||||||
|
}
|
||||||
|
return key;
|
||||||
|
}
|
||||||
|
|
||||||
static int CreateMask(int contextSize) {
|
static int CreateMask(int contextSize) {
|
||||||
int mask = 0;
|
int mask = 0;
|
||||||
// create 2*contextSize worth of bits
|
// create 2*contextSize worth of bits
|
||||||
|
|
@ -161,6 +174,41 @@ struct ContextCovariate {
|
||||||
return isNegativeStrand ? (readLength - offset - 1) : offset;
|
return isNegativeStrand ? (readLength - offset - 1) : offset;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static char baseIndexToSimpleBase(const int baseIndex) {
|
||||||
|
switch (baseIndex) {
|
||||||
|
case 0:
|
||||||
|
return 'A';
|
||||||
|
case 1:
|
||||||
|
return 'C';
|
||||||
|
case 2:
|
||||||
|
return 'G';
|
||||||
|
case 3:
|
||||||
|
return 'T';
|
||||||
|
default:
|
||||||
|
return '.';
|
||||||
|
}
|
||||||
|
}
|
||||||
|
/**
|
||||||
|
* Converts a key into the dna string representation.
|
||||||
|
*
|
||||||
|
* @param key the key representing the dna sequence
|
||||||
|
* @return the dna sequence represented by the key
|
||||||
|
*/
|
||||||
|
static string ContextFromKey(const int key) {
|
||||||
|
int length = key & LENGTH_MASK; // the first bits represent the length (in bp) of the context
|
||||||
|
int mask = 48; // use the mask to pull out bases
|
||||||
|
int offset = LENGTH_BITS;
|
||||||
|
|
||||||
|
string dna;
|
||||||
|
for (int i = 0; i < length; i++) {
|
||||||
|
int baseIndex = (key & mask) >> offset;
|
||||||
|
dna.push_back(baseIndexToSimpleBase(baseIndex));
|
||||||
|
mask <<= 2; // move the mask over to the next 2 bits
|
||||||
|
offset += 2;
|
||||||
|
}
|
||||||
|
return dna;
|
||||||
|
}
|
||||||
|
|
||||||
// 获取去除低质量分数碱基之后的read碱基序列(将低质量分数的碱基变成N)
|
// 获取去除低质量分数碱基之后的read碱基序列(将低质量分数的碱基变成N)
|
||||||
static void GetStrandedClippedBytes(BamWrap* bw, SamData& ad, string& clippedBases, uint8_t lowQTail);
|
static void GetStrandedClippedBytes(BamWrap* bw, SamData& ad, string& clippedBases, uint8_t lowQTail);
|
||||||
// Creates a int representation of a given dna string.
|
// Creates a int representation of a given dna string.
|
||||||
|
|
@ -180,6 +228,8 @@ struct CycleCovariate {
|
||||||
|
|
||||||
static void InitCycleCovariate(BQSRArg& p) { MAXIMUM_CYCLE_VALUE = p.MAXIMUM_CYCLE_VALUE; }
|
static void InitCycleCovariate(BQSRArg& p) { MAXIMUM_CYCLE_VALUE = p.MAXIMUM_CYCLE_VALUE; }
|
||||||
|
|
||||||
|
static int MaximumKeyValue() { return (MAXIMUM_CYCLE_VALUE << 1) + 1; }
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Encodes the cycle number as a key.
|
* Encodes the cycle number as a key.
|
||||||
*/
|
*/
|
||||||
|
|
@ -200,10 +250,27 @@ struct CycleCovariate {
|
||||||
result++; // negative cycles get the lower-most bit set
|
result++; // negative cycles get the lower-most bit set
|
||||||
}
|
}
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Decodes the cycle number from the key.
|
||||||
|
*/
|
||||||
|
static int CycleFromKey(const int key) {
|
||||||
|
int cycle = key >> 1; // shift so we can remove the "sign" bit
|
||||||
|
if ((key & 1) != 0) { // is the last bit set?
|
||||||
|
cycle *= -1; // then the cycle is negative
|
||||||
|
}
|
||||||
|
return cycle;
|
||||||
|
}
|
||||||
|
|
||||||
// Computes the encoded value of CycleCovariate's key for the given position at the read.
|
// Computes the encoded value of CycleCovariate's key for the given position at the read.
|
||||||
static int CycleKey(BamWrap* bw, SamData& ad, const int baseNumber, const bool indel, const int maxCycle);
|
static int CycleKey(BamWrap* bw, SamData& ad, const int baseNumber, const bool indel, const int maxCycle);
|
||||||
|
|
||||||
static void RecordValues(BamWrap* bw, SamData& ad, sam_hdr_t* header, PerReadCovariateMatrix& values, bool recordIndelValues);
|
static void RecordValues(BamWrap* bw, SamData& ad, sam_hdr_t* header, PerReadCovariateMatrix& values, bool recordIndelValues);
|
||||||
|
};
|
||||||
|
|
||||||
|
// 好像不需要
|
||||||
|
struct StandardCovariateList {
|
||||||
|
ReadGroupCovariate readGroupCovariate;
|
||||||
|
BaseQualityCovariate qualityScoreCovariate;
|
||||||
};
|
};
|
||||||
|
|
@ -0,0 +1,194 @@
|
||||||
|
/*
|
||||||
|
Description: 多维度嵌套的数组
|
||||||
|
|
||||||
|
Copyright : All right reserved by ICT
|
||||||
|
|
||||||
|
Author : Zhang Zhonghai
|
||||||
|
Date : 2025/12/24
|
||||||
|
*/
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <vector>
|
||||||
|
#include <assert.h>
|
||||||
|
#include "spdlog/spdlog.h"
|
||||||
|
|
||||||
|
using std::vector;
|
||||||
|
|
||||||
|
template <class T>
|
||||||
|
struct Array2D {
|
||||||
|
vector<vector<T>> data;
|
||||||
|
Array2D() { }
|
||||||
|
Array2D(int dim1, int dim2) { init(dim1, dim2); }
|
||||||
|
void init(int dim1, int dim2) { data.resize(dim1); for (auto& v : data) v.resize(dim2); }
|
||||||
|
inline T& get(int k1, int k2) { return data[k1][k2]; }
|
||||||
|
// 根据关键字,在对应位置插入数据
|
||||||
|
inline void put(const T& value, int k1, int k2) { data[k1][k2] = value; }
|
||||||
|
inline vector<T>& operator[](size_t idx) { return data[idx]; }
|
||||||
|
inline const vector<T>& operator[](size_t idx) const { return data[idx]; }
|
||||||
|
#define _Foreach2D(array, valName, codes) \
|
||||||
|
for (auto& arr1 : array.data) { \
|
||||||
|
for (auto& valName : arr1) { \
|
||||||
|
codes; \
|
||||||
|
} \
|
||||||
|
}
|
||||||
|
#define _Foreach2DK(array, valName, codes) \
|
||||||
|
do { \
|
||||||
|
int k1 = 0; \
|
||||||
|
for (auto& arr1 : array.data) { \
|
||||||
|
int k2 = 0; \
|
||||||
|
for (auto& valName : arr1) { \
|
||||||
|
codes; \
|
||||||
|
++k2; \
|
||||||
|
} \
|
||||||
|
++k1; \
|
||||||
|
} \
|
||||||
|
} while (0)
|
||||||
|
};
|
||||||
|
|
||||||
|
template <class T>
|
||||||
|
struct Array3D {
|
||||||
|
vector<vector<vector<T>>> data;
|
||||||
|
Array3D() {}
|
||||||
|
Array3D(int dim1, int dim2, int dim3) { init(dim1, dim2, dim3); }
|
||||||
|
void init(int dim1, int dim2, int dim3) {
|
||||||
|
data.resize(dim1);
|
||||||
|
for (auto& v : data) v.resize(dim2);
|
||||||
|
for (auto& v1 : data)
|
||||||
|
for (auto& v2 : v1) v2.resize(dim3);
|
||||||
|
}
|
||||||
|
inline T& get(int k1, int k2, int k3) { return data[k1][k2][k3]; }
|
||||||
|
// 根据关键字,在对应位置插入数据
|
||||||
|
inline void put(const T& value, int k1, int k2, int k3) { data[k1][k2][k3] = value; }
|
||||||
|
inline vector<vector<T>>& operator[](size_t idx) { return data[idx]; }
|
||||||
|
#define _Foreach3D(array, valName, codes) \
|
||||||
|
for (auto& arr1 : array.data) { \
|
||||||
|
for (auto& arr2 : arr1) { \
|
||||||
|
for (auto& valName : arr2) { \
|
||||||
|
codes; \
|
||||||
|
} \
|
||||||
|
} \
|
||||||
|
}
|
||||||
|
|
||||||
|
#define _Foreach3DK(array, valName, codes) \
|
||||||
|
do { \
|
||||||
|
int k1 = 0; \
|
||||||
|
for (auto& arr1 : array.data) { \
|
||||||
|
int k2 = 0; \
|
||||||
|
for (auto& arr2 : arr1) { \
|
||||||
|
int k3 = 0; \
|
||||||
|
for (auto& valName : arr2) { \
|
||||||
|
codes; \
|
||||||
|
++k3; \
|
||||||
|
} \
|
||||||
|
++k2; \
|
||||||
|
} \
|
||||||
|
++k1; \
|
||||||
|
} \
|
||||||
|
} while (0)
|
||||||
|
};
|
||||||
|
|
||||||
|
template <class T>
|
||||||
|
struct Array4D {
|
||||||
|
vector<vector<vector<vector<T>>>> data;
|
||||||
|
Array4D() {}
|
||||||
|
Array4D(int dim1, int dim2, int dim3, int dim4) { init(dim1, dim2, dim3, dim4); }
|
||||||
|
void init(int dim1, int dim2, int dim3, int dim4) {
|
||||||
|
data.resize(dim1);
|
||||||
|
for (auto& v : data) v.resize(dim2);
|
||||||
|
for (auto& v1 : data)
|
||||||
|
for (auto& v2 : v1) v2.resize(dim3);
|
||||||
|
for (auto& v1 : data)
|
||||||
|
for (auto& v2 : v1)
|
||||||
|
for (auto& v3 : v2) v3.resize(dim4);
|
||||||
|
}
|
||||||
|
inline T& get(int k1, int k2, int k3, int k4) { return data[k1][k2][k3][k4]; }
|
||||||
|
// 根据关键字,在对应位置插入数据
|
||||||
|
inline void put(const T& value, int k1, int k2, int k3, int k4) { data[k1][k2][k3][k4] = value; }
|
||||||
|
inline vector<vector<vector<T>>>& operator[](size_t idx) { return data[idx]; }
|
||||||
|
#define _Foreach4D(array, valName, codes) \
|
||||||
|
for (auto& arr1 : array.data) { \
|
||||||
|
for (auto& arr2 : arr1) { \
|
||||||
|
for (auto& arr3 : arr2) { \
|
||||||
|
for (auto& valName : arr3) { \
|
||||||
|
codes; \
|
||||||
|
} \
|
||||||
|
} \
|
||||||
|
} \
|
||||||
|
}
|
||||||
|
|
||||||
|
#define _Foreach4DK(array, valName, codes) \
|
||||||
|
do { \
|
||||||
|
int k1 = 0; \
|
||||||
|
for (auto& arr1 : array.data) { \
|
||||||
|
int k2 = 0; \
|
||||||
|
for (auto& arr2 : arr1) { \
|
||||||
|
int k3 = 0; \
|
||||||
|
for (auto& arr3 : arr2) { \
|
||||||
|
int k4 = 0; \
|
||||||
|
for (auto& valName : arr3) { \
|
||||||
|
codes; \
|
||||||
|
++k4; \
|
||||||
|
} \
|
||||||
|
++k3; \
|
||||||
|
} \
|
||||||
|
++k2; \
|
||||||
|
} \
|
||||||
|
++k1; \
|
||||||
|
} \
|
||||||
|
} while (0)
|
||||||
|
};
|
||||||
|
|
||||||
|
// 类似一个tensor
|
||||||
|
template <class T>
|
||||||
|
struct NestedArray {
|
||||||
|
vector<T> data;
|
||||||
|
vector<int> dimensions;
|
||||||
|
vector<int> dim_offset;
|
||||||
|
|
||||||
|
NestedArray() { }
|
||||||
|
|
||||||
|
template <typename... Args>
|
||||||
|
NestedArray(Args... dims) {
|
||||||
|
init(dims...);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename... Args>
|
||||||
|
void init(Args... dims) {
|
||||||
|
(dimensions.emplace_back(std::forward<Args>(dims)), ...);
|
||||||
|
// spdlog::info("dimensions: {}", dimensions.size());
|
||||||
|
// for (auto& val : dimensions) spdlog::info("dim: {}", val);
|
||||||
|
dim_offset.resize(dimensions.size(), 1);
|
||||||
|
for (int i = dimensions.size() - 2; i >= 0; --i) {
|
||||||
|
dim_offset[i] = dim_offset[i + 1] * dimensions[i + 1];
|
||||||
|
}
|
||||||
|
data.resize(dimensions[0] * dim_offset[0]);
|
||||||
|
// for (int i = 0; i < data.size(); ++i) data[i] = i;
|
||||||
|
// for (auto& val : dim_offset) spdlog::info("dim offset: {}", val);
|
||||||
|
}
|
||||||
|
|
||||||
|
// 根据索引位置获取数据
|
||||||
|
template <typename... Args>
|
||||||
|
T& get(Args... keys) {
|
||||||
|
vector<int> keyArr;
|
||||||
|
(keyArr.emplace_back(std::forward<Args>(keys)), ...);
|
||||||
|
assert(keyArr.size() == dimensions.size());
|
||||||
|
int idx = 0;
|
||||||
|
for (int i = 0; i < keyArr.size(); ++i) {
|
||||||
|
idx += keyArr[i] * dim_offset[i];
|
||||||
|
}
|
||||||
|
return data[idx];
|
||||||
|
}
|
||||||
|
|
||||||
|
// 根据关键字,在对应位置插入数据
|
||||||
|
template <typename... Args>
|
||||||
|
void put(T value, Args... keys) {
|
||||||
|
vector<int> keyArr;
|
||||||
|
(keyArr.emplace_back(std::forward<Args>(keys)), ...);
|
||||||
|
assert(keyArr.size() == dimensions.size());
|
||||||
|
int idx = 0;
|
||||||
|
for (int i = 0; i < keyArr.size(); ++i) {
|
||||||
|
idx += keyArr[i] * dim_offset[i];
|
||||||
|
}
|
||||||
|
data[idx] = value;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
@ -0,0 +1,315 @@
|
||||||
|
/*
|
||||||
|
Description: 量化质量分数
|
||||||
|
|
||||||
|
Copyright : All right reserved by ICT
|
||||||
|
|
||||||
|
Author : Zhang Zhonghai
|
||||||
|
Date : 2025/12/26
|
||||||
|
*/
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <set>
|
||||||
|
#include <vector>
|
||||||
|
#include <cstdint>
|
||||||
|
#include "qual_utils.h"
|
||||||
|
|
||||||
|
using std::set;
|
||||||
|
using std::vector;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A general algorithm for quantizing quality score distributions to use a specific number of levels
|
||||||
|
*
|
||||||
|
* Takes a histogram of quality scores and a desired number of levels and produces a
|
||||||
|
* map from original quality scores -> quantized quality scores.
|
||||||
|
*
|
||||||
|
* Note that this data structure is fairly heavy-weight, holding lots of debugging and
|
||||||
|
* calculation information. If you want to use it efficiently at scale with lots of
|
||||||
|
* read groups the right way to do this:
|
||||||
|
*
|
||||||
|
* Map<ReadGroup, List<Byte>> map
|
||||||
|
* for each read group rg:
|
||||||
|
* hist = getQualHist(rg)
|
||||||
|
* QualQuantizer qq = new QualQuantizer(hist, nLevels, minInterestingQual)
|
||||||
|
* map.set(rg, qq.getOriginalToQuantizedMap())
|
||||||
|
*
|
||||||
|
* This map would then be used to look up the appropriate original -> quantized
|
||||||
|
* quals for each read as it comes in.
|
||||||
|
*/
|
||||||
|
|
||||||
|
struct QualQuantizer {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Represents a contiguous interval of quality scores.
|
||||||
|
*
|
||||||
|
* qStart and qEnd are inclusive, so qStart = qEnd = 2 is the quality score bin of 2
|
||||||
|
*/
|
||||||
|
struct QualInterval {
|
||||||
|
int qStart, qEnd, fixedQual, level;
|
||||||
|
int64_t nObservations, nErrors;
|
||||||
|
set<QualInterval> subIntervals;
|
||||||
|
|
||||||
|
/** for debugging / visualization. When was this interval created? */
|
||||||
|
int mergeOrder;
|
||||||
|
|
||||||
|
void init(const int _qStart, const int _qEnd, const int64_t _nObservations, const int64_t _nErrors, const int _level, const int _fixedQual) {
|
||||||
|
qStart = _qStart;
|
||||||
|
qEnd = _qEnd;
|
||||||
|
nObservations = _nObservations;
|
||||||
|
nErrors = _nErrors;
|
||||||
|
level = _level;
|
||||||
|
fixedQual = _fixedQual;
|
||||||
|
}
|
||||||
|
|
||||||
|
QualInterval() {
|
||||||
|
qStart = -1;
|
||||||
|
qEnd = -1;
|
||||||
|
nObservations = -1;
|
||||||
|
nErrors = -1;
|
||||||
|
fixedQual = -1;
|
||||||
|
level = -1;
|
||||||
|
mergeOrder = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
QualInterval(const int qStart, const int qEnd, const int64_t nObservations, const int64_t nErrors, const int level) {
|
||||||
|
init(qStart, qEnd, nObservations, nErrors, level, -1);
|
||||||
|
}
|
||||||
|
|
||||||
|
QualInterval(const int qStart, const int qEnd, const int64_t nObservations, const int64_t nErrors, const int level,
|
||||||
|
const set<QualInterval>& _subIntervals) {
|
||||||
|
init(qStart, qEnd, nObservations, nErrors, level, -1);
|
||||||
|
subIntervals = _subIntervals;
|
||||||
|
}
|
||||||
|
|
||||||
|
QualInterval(const int qStart, const int qEnd, const int64_t nObservations, const int64_t nErrors, const int level, const int fixedQual) {
|
||||||
|
init(qStart, qEnd, nObservations, nErrors, level, fixedQual);
|
||||||
|
}
|
||||||
|
|
||||||
|
QualInterval(const int qStart, const int qEnd, const int64_t nObservations, const int64_t nErrors, const int level, const int fixedQual,
|
||||||
|
const set<QualInterval>& _subIntervals) {
|
||||||
|
init(qStart, qEnd, nObservations, nErrors, level, fixedQual);
|
||||||
|
subIntervals = _subIntervals;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @return Human readable name of this interval: e.g., 10-12
|
||||||
|
*/
|
||||||
|
string getName() const { return std::to_string(qStart) + "-" + std::to_string(qEnd); }
|
||||||
|
|
||||||
|
string toString() const {
|
||||||
|
return "QQ:" + getName();
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @return true if this bin is using a fixed qual
|
||||||
|
*/
|
||||||
|
bool hasFixedQual() const { return fixedQual != -1; }
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @return the error rate (in real space) of this interval, or 0 if there are no observations
|
||||||
|
*/
|
||||||
|
double getErrorRate() const {
|
||||||
|
if (hasFixedQual())
|
||||||
|
return QualityUtils::qualToErrorProb((uint8_t)fixedQual);
|
||||||
|
else if (nObservations == 0)
|
||||||
|
return 0.0;
|
||||||
|
else
|
||||||
|
return (nErrors + 1) / (1.0 * (nObservations + 1));
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @return the QUAL of the error rate of this interval, or the fixed qual if this interval was created with a fixed qual.
|
||||||
|
*/
|
||||||
|
uint8_t getQual() const {
|
||||||
|
if (!hasFixedQual())
|
||||||
|
return QualityUtils::errorProbToQual(getErrorRate());
|
||||||
|
else
|
||||||
|
return (uint8_t)fixedQual;
|
||||||
|
}
|
||||||
|
|
||||||
|
int compareTo(const QualInterval& qi) const { return qStart < qi.qStart ? -1 : (qStart == qi.qStart ? 0 : 1); }
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Create a interval representing the merge of this interval and toMerge
|
||||||
|
*
|
||||||
|
* Errors and observations are combined
|
||||||
|
* Subintervals updated in order of left to right (determined by qStart)
|
||||||
|
* Level is 1 + highest level of this and toMerge
|
||||||
|
* Order must be updated elsewhere
|
||||||
|
*
|
||||||
|
* @param toMerge
|
||||||
|
* @return newly created merged QualInterval
|
||||||
|
*/
|
||||||
|
QualInterval merge(const QualInterval& toMerge) const {
|
||||||
|
const QualInterval &left = this->compareTo(toMerge) < 0 ? *this : toMerge;
|
||||||
|
const QualInterval &right = this->compareTo(toMerge) < 0 ? toMerge : *this;
|
||||||
|
|
||||||
|
if (left.qEnd + 1 != right.qStart) {
|
||||||
|
// throw new GATKException("Attempting to merge non-contiguous intervals: left = " + left + " right = " + right);
|
||||||
|
std::cerr << "Attempting to merge non-contiguous intervals: left = " + left.toString() + " right = " + right.toString() << std::endl;
|
||||||
|
exit(1);
|
||||||
|
}
|
||||||
|
const int64_t nCombinedObs = left.nObservations + right.nObservations;
|
||||||
|
const int64_t nCombinedErr = left.nErrors + right.nErrors;
|
||||||
|
|
||||||
|
const int level = std::max(left.level, right.level) + 1;
|
||||||
|
set<QualInterval> subIntervals;
|
||||||
|
subIntervals.insert(left);
|
||||||
|
subIntervals.insert(right);
|
||||||
|
QualInterval merged(left.qStart, right.qEnd, nCombinedObs, nCombinedErr, level, subIntervals);
|
||||||
|
|
||||||
|
return merged;
|
||||||
|
}
|
||||||
|
|
||||||
|
double getPenalty(const int minInterestingQual) const { return calcPenalty(getErrorRate(), minInterestingQual); }
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Calculate the penalty of this interval, given the overall error rate for the interval
|
||||||
|
*
|
||||||
|
* If the globalErrorRate is e, this value is:
|
||||||
|
*
|
||||||
|
* sum_i |log10(e_i) - log10(e)| * nObservations_i
|
||||||
|
*
|
||||||
|
* each the index i applies to all leaves of the tree accessible from this interval
|
||||||
|
* (found recursively from subIntervals as necessary)
|
||||||
|
*
|
||||||
|
* @param globalErrorRate overall error rate in real space against which we calculate the penalty
|
||||||
|
* @return the cost of approximating the bins in this interval with the globalErrorRate
|
||||||
|
*/
|
||||||
|
double calcPenalty(const double globalErrorRate, const int minInterestingQual) const {
|
||||||
|
if (globalErrorRate == 0.0) // there were no observations, so there's no penalty
|
||||||
|
return 0.0;
|
||||||
|
|
||||||
|
if (subIntervals.empty()) {
|
||||||
|
// this is leave node
|
||||||
|
if (this->qEnd <= minInterestingQual)
|
||||||
|
// It's free to merge up quality scores below the smallest interesting one
|
||||||
|
return 0;
|
||||||
|
else {
|
||||||
|
return (std::abs(std::log10(getErrorRate()) - std::log10(globalErrorRate))) * nObservations;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
double sum = 0;
|
||||||
|
for (const QualInterval interval : subIntervals) sum += interval.calcPenalty(globalErrorRate, minInterestingQual);
|
||||||
|
return sum;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
bool operator<(const QualInterval& o) const {
|
||||||
|
return qStart < o.qStart;
|
||||||
|
}
|
||||||
|
|
||||||
|
QualInterval& operator=(const QualInterval& o) {
|
||||||
|
if (this == &o) return *this;
|
||||||
|
init(o.qStart, o.qEnd, o.nObservations, o.nErrors, o.level, o.fixedQual);
|
||||||
|
mergeOrder = o.mergeOrder;
|
||||||
|
subIntervals.clear();
|
||||||
|
for (auto& val : o.subIntervals) subIntervals.insert(val);
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Inputs to the QualQuantizer
|
||||||
|
*/
|
||||||
|
const int nLevels, minInterestingQual;
|
||||||
|
vector<int64_t>& nObservationsPerQual;
|
||||||
|
|
||||||
|
QualQuantizer(vector<int64_t>& _nObservationsPerQual, const int _nLevels, const int _minInterestingQual)
|
||||||
|
: nObservationsPerQual(_nObservationsPerQual), nLevels(_nLevels), minInterestingQual(_minInterestingQual) {
|
||||||
|
quantize();
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Sorted set of qual intervals.
|
||||||
|
*
|
||||||
|
* After quantize() this data structure contains only the top-level qual intervals
|
||||||
|
*/
|
||||||
|
set<QualInterval> quantizedIntervals;
|
||||||
|
/**
|
||||||
|
* Represents a contiguous interval of quality scores.
|
||||||
|
*
|
||||||
|
* qStart and qEnd are inclusive, so qStart = qEnd = 2 is the quality score bin of 2
|
||||||
|
*/
|
||||||
|
|
||||||
|
void getOriginalToQuantizedMap(vector<uint8_t>& quantMap) {
|
||||||
|
quantMap.resize(getNQualsInHistogram(), UINT8_MAX);
|
||||||
|
|
||||||
|
for (auto& interval : quantizedIntervals) {
|
||||||
|
for (int q = interval.qStart; q <= interval.qEnd; q++) {
|
||||||
|
quantMap[q] = interval.getQual();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// if (Collections.min(map) == Byte.MIN_VALUE) throw new GATKException("quantized quality score map contains an un-initialized value");
|
||||||
|
}
|
||||||
|
|
||||||
|
int getNQualsInHistogram() { return nObservationsPerQual.size(); }
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Main method for computing the quantization intervals.
|
||||||
|
*
|
||||||
|
* Invoked in the constructor after all input variables are initialized. Walks
|
||||||
|
* over the inputs and builds the min. penalty forest of intervals with exactly nLevel
|
||||||
|
* root nodes. Finds this min. penalty forest via greedy search, so is not guarenteed
|
||||||
|
* to find the optimal combination.
|
||||||
|
*
|
||||||
|
* TODO: develop a smarter algorithm
|
||||||
|
*
|
||||||
|
* @return the forest of intervals with size == nLevels
|
||||||
|
*/
|
||||||
|
void quantize() {
|
||||||
|
// create intervals for each qual individually
|
||||||
|
auto& intervals = quantizedIntervals;
|
||||||
|
for (int qStart = 0; qStart < getNQualsInHistogram(); qStart++) {
|
||||||
|
const int64_t nObs = nObservationsPerQual.at(qStart);
|
||||||
|
const double errorRate = QualityUtils::qualToErrorProb((uint8_t)qStart);
|
||||||
|
const double nErrors = nObs * errorRate;
|
||||||
|
const QualInterval qi(qStart, qStart, nObs, (int)std::floor(nErrors), 0, (uint8_t)qStart);
|
||||||
|
intervals.insert(qi);
|
||||||
|
}
|
||||||
|
|
||||||
|
// greedy algorithm:
|
||||||
|
// while ( n intervals >= nLevels ):
|
||||||
|
// find intervals to merge with least penalty
|
||||||
|
// merge it
|
||||||
|
while (intervals.size() > nLevels) {
|
||||||
|
mergeLowestPenaltyIntervals(intervals);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Helper function that finds and merges together the lowest penalty pair of intervals
|
||||||
|
* @param intervals
|
||||||
|
*/
|
||||||
|
void mergeLowestPenaltyIntervals(set<QualInterval>& intervals) {
|
||||||
|
// setup the iterators
|
||||||
|
auto it1 = intervals.begin();
|
||||||
|
auto it1p = intervals.begin();
|
||||||
|
++it1p; // skip one
|
||||||
|
|
||||||
|
// walk over the pairs of left and right, keeping track of the pair with the lowest merge penalty
|
||||||
|
QualInterval minMerge;
|
||||||
|
// if (logger.isDebugEnabled()) logger.debug("mergeLowestPenaltyIntervals: " + intervals.size());
|
||||||
|
int lastMergeOrder = 0;
|
||||||
|
while (it1p != intervals.end()) {
|
||||||
|
const QualInterval& left = *it1;
|
||||||
|
const QualInterval& right = *it1p;
|
||||||
|
const QualInterval merged = left.merge(right);
|
||||||
|
lastMergeOrder = std::max(std::max(lastMergeOrder, left.mergeOrder), right.mergeOrder);
|
||||||
|
if (minMerge.qStart == -1 || (merged.getPenalty(minInterestingQual) < minMerge.getPenalty(minInterestingQual))) {
|
||||||
|
// if (logger.isDebugEnabled()) logger.debug(" Updating merge " + minMerge);
|
||||||
|
minMerge = merged; // merge two bins that when merged incur the lowest "penalty"
|
||||||
|
}
|
||||||
|
++it1;
|
||||||
|
++it1p;
|
||||||
|
}
|
||||||
|
// now actually go ahead and merge the minMerge pair
|
||||||
|
// if (logger.isDebugEnabled()) logger.debug(" => const min merge " + minMerge);
|
||||||
|
minMerge.mergeOrder = lastMergeOrder + 1;
|
||||||
|
// intervals.removeAll(minMerge.subIntervals);
|
||||||
|
for (auto &itr : minMerge.subIntervals) {
|
||||||
|
intervals.erase(itr);
|
||||||
|
}
|
||||||
|
intervals.insert(minMerge);
|
||||||
|
// if (logger.isDebugEnabled()) logger.debug("updated intervals: " + intervals);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
};
|
||||||
|
|
@ -0,0 +1,12 @@
|
||||||
|
#include "qual_utils.h"
|
||||||
|
|
||||||
|
const double QualityUtils::PHRED_TO_LOG_PROB_MULTIPLIER = -std::log(10) / 10.0;
|
||||||
|
|
||||||
|
const double QualityUtils::LOG_PROB_TO_PHRED_MULTIPLIER = 1 / PHRED_TO_LOG_PROB_MULTIPLIER;
|
||||||
|
|
||||||
|
const double QualityUtils::MIN_LOG10_SCALED_QUAL = std::log10(DBL_MIN);
|
||||||
|
|
||||||
|
const double QualityUtils::MIN_PHRED_SCALED_QUAL = -10.0 * MIN_LOG10_SCALED_QUAL;
|
||||||
|
|
||||||
|
double QualityUtils::qualToErrorProbCache[MAX_QUAL + 1];
|
||||||
|
double QualityUtils::qualToProbLog10Cache[MAX_QUAL + 1];
|
||||||
|
|
@ -0,0 +1,406 @@
|
||||||
|
/*
|
||||||
|
Description: 质量分数相关的工具函数
|
||||||
|
|
||||||
|
Copyright : All right reserved by ICT
|
||||||
|
|
||||||
|
Author : Zhang Zhonghai
|
||||||
|
Date : 2025/12/25
|
||||||
|
*/
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <cassert>
|
||||||
|
#include <cmath>
|
||||||
|
#include <cstdint>
|
||||||
|
#include <climits>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "util/math/math_utils.h"
|
||||||
|
|
||||||
|
using std::vector;
|
||||||
|
|
||||||
|
struct QualityUtils {
|
||||||
|
|
||||||
|
static void StaticInit() {
|
||||||
|
for (int i = 0; i <= MAX_QUAL; i++) {
|
||||||
|
qualToErrorProbCache[i] = qualToErrorProb((double)i);
|
||||||
|
qualToProbLog10Cache[i] = std::log10(1.0 - qualToErrorProbCache[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Maximum quality score that can be encoded in a SAM/BAM file
|
||||||
|
*/
|
||||||
|
|
||||||
|
static constexpr uint8_t MAX_SAM_QUAL_SCORE = 93;
|
||||||
|
/**
|
||||||
|
* bams containing quals above this value are extremely suspicious and we should warn the user
|
||||||
|
*/
|
||||||
|
static constexpr uint8_t MAX_REASONABLE_Q_SCORE = 60;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* conversion factor from phred scaled quality to log error probability and vice versa
|
||||||
|
*/
|
||||||
|
|
||||||
|
static const double PHRED_TO_LOG_PROB_MULTIPLIER;
|
||||||
|
|
||||||
|
static const double LOG_PROB_TO_PHRED_MULTIPLIER;
|
||||||
|
|
||||||
|
static const double MIN_LOG10_SCALED_QUAL;
|
||||||
|
|
||||||
|
static const double MIN_PHRED_SCALED_QUAL;
|
||||||
|
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The lowest quality score for a base that is considered reasonable for statistical analysis. This is
|
||||||
|
* because Q 6 => you stand a 25% of being right, which means all bases are equally likely
|
||||||
|
*/
|
||||||
|
|
||||||
|
static constexpr uint8_t MIN_USABLE_Q_SCORE = 6;
|
||||||
|
|
||||||
|
static constexpr int MAPPING_QUALITY_UNAVAILABLE = 255;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Maximum sense quality value.
|
||||||
|
*/
|
||||||
|
|
||||||
|
static constexpr int MAX_QUAL = 254;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Cached values for qual as uint8_t calculations so they are very fast
|
||||||
|
*/
|
||||||
|
static double qualToErrorProbCache[MAX_QUAL + 1];
|
||||||
|
static double qualToProbLog10Cache[MAX_QUAL + 1];
|
||||||
|
|
||||||
|
// ----------------------------------------------------------------------
|
||||||
|
//
|
||||||
|
// These are all functions to convert a phred-scaled quality score to a probability
|
||||||
|
//
|
||||||
|
// ----------------------------------------------------------------------
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Convert a phred-scaled quality score to its probability of being true (Q30 => 0.999)
|
||||||
|
*
|
||||||
|
* This is the Phred-style conversion, *not* the Illumina-style conversion.
|
||||||
|
*
|
||||||
|
* Because the input is a discretized byte value, this function uses a cache so is very efficient
|
||||||
|
*
|
||||||
|
* WARNING -- because this function takes a byte for maxQual, you must be careful in converting
|
||||||
|
* integers to byte. The appropriate way to do this is ((byte)(myInt & 0xFF))
|
||||||
|
*
|
||||||
|
* @param qual a quality score (0-255)
|
||||||
|
* @return a probability (0.0-1.0)
|
||||||
|
*/
|
||||||
|
static double qualToProb(const uint8_t qual) { return 1.0 - qualToErrorProb(qual); }
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Convert a phred-scaled quality score to its probability of being true (Q30 => 0.999)
|
||||||
|
*
|
||||||
|
* This is the Phred-style conversion, *not* the Illumina-style conversion.
|
||||||
|
*
|
||||||
|
* Because the input is a double value, this function must call Math.pow so can be quite expensive
|
||||||
|
*
|
||||||
|
* @param qual a phred-scaled quality score encoded as a double. Can be non-integer values (30.5)
|
||||||
|
* @return a probability (0.0-1.0)
|
||||||
|
*/
|
||||||
|
static double qualToProb(const double qual) {
|
||||||
|
// Utils.validateArg(qual >= 0.0, ()->"qual must be >= 0.0 but got " + qual);
|
||||||
|
return 1.0 - qualToErrorProb(qual);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Convert a phred-scaled quality score to its log10 probability of being true (Q30 => log10(0.999))
|
||||||
|
*
|
||||||
|
* This is the Phred-style conversion, *not* the Illumina-style conversion.
|
||||||
|
*
|
||||||
|
* WARNING -- because this function takes a byte for maxQual, you must be careful in converting
|
||||||
|
* integers to byte. The appropriate way to do this is ((byte)(myInt & 0xFF))
|
||||||
|
*
|
||||||
|
* @param qual a phred-scaled quality score encoded as a double. Can be non-integer values (30.5)
|
||||||
|
* @return a probability (0.0-1.0)
|
||||||
|
*/
|
||||||
|
static double qualToProbLog10(const uint8_t qual) {
|
||||||
|
return qualToProbLog10Cache[(int)qual & 0xff]; // Map: 127 -> 127; -128 -> 128; -1 -> 255; etc.
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Convert a log-probability to a phred-scaled value ( log(0.001) => 30 )
|
||||||
|
*
|
||||||
|
* @param prob a log-probability
|
||||||
|
* @return a phred-scaled value, not necessarily integral or bounded by {@code MAX_QUAL}
|
||||||
|
*/
|
||||||
|
static double logProbToPhred(const double prob) { return prob * LOG_PROB_TO_PHRED_MULTIPLIER; }
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Convert a phred-scaled quality score to its probability of being wrong (Q30 => 0.001)
|
||||||
|
*
|
||||||
|
* This is the Phred-style conversion, *not* the Illumina-style conversion.
|
||||||
|
*
|
||||||
|
* Because the input is a double value, this function must call Math.pow so can be quite expensive
|
||||||
|
*
|
||||||
|
* @param qual a phred-scaled quality score encoded as a double. Can be non-integer values (30.5)
|
||||||
|
* @return a probability (0.0-1.0)
|
||||||
|
*/
|
||||||
|
|
||||||
|
static double qualToErrorProb(const double qual) {
|
||||||
|
assert(qual >= 0.0);
|
||||||
|
return std::pow(10.0, qual / -10.0);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**v
|
||||||
|
* Convert a phred-scaled quality score to its probability of being wrong (Q30 => 0.001)
|
||||||
|
*
|
||||||
|
* This is the Phred-style conversion, *not* the Illumina-style conversion.
|
||||||
|
*
|
||||||
|
* Because the input is a byte value, this function uses a cache so is very efficient
|
||||||
|
*
|
||||||
|
* WARNING -- because this function takes a byte for maxQual, you must be careful in converting
|
||||||
|
* integers to byte. The appropriate way to do this is ((byte)(myInt & 0xFF))
|
||||||
|
*
|
||||||
|
* @param qual a phred-scaled quality score encoded as a byte
|
||||||
|
* @return a probability (0.0-1.0)
|
||||||
|
*/
|
||||||
|
static double qualToErrorProb(const uint8_t qual) {
|
||||||
|
return qualToErrorProbCache[(int)qual & 0xff]; // Map: 127 -> 127; -128 -> 128; -1 -> 255; etc.
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Convert a phred-scaled quality score to its log10 probability of being wrong (Q30 => log10(0.001))
|
||||||
|
*
|
||||||
|
* This is the Phred-style conversion, *not* the Illumina-style conversion.
|
||||||
|
*
|
||||||
|
* The calculation is extremely efficient
|
||||||
|
*
|
||||||
|
* WARNING -- because this function takes a byte for maxQual, you must be careful in converting
|
||||||
|
* integers to byte. The appropriate way to do this is ((byte)(myInt & 0xFF))
|
||||||
|
*
|
||||||
|
* @param qual a phred-scaled quality score encoded as a byte
|
||||||
|
* @return a probability (0.0-1.0)
|
||||||
|
*/
|
||||||
|
static double qualToErrorProbLog10(const uint8_t qual) { return qualToErrorProbLog10((double)(qual & 0xFF)); }
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Convert a phred-scaled quality score to its log10 probability of being wrong (Q30 => log10(0.001))
|
||||||
|
*
|
||||||
|
* This is the Phred-style conversion, *not* the Illumina-style conversion.
|
||||||
|
*
|
||||||
|
* The calculation is extremely efficient
|
||||||
|
*
|
||||||
|
* @param qual a phred-scaled quality score encoded as a double
|
||||||
|
* @return log of probability (0.0-1.0)
|
||||||
|
*/
|
||||||
|
static double qualToErrorProbLog10(const double qual) {
|
||||||
|
// Utils.validateArg(qual >= 0.0, ()->"qual must be >= 0.0 but got " + qual);
|
||||||
|
return qual * -0.1;
|
||||||
|
}
|
||||||
|
|
||||||
|
// ----------------------------------------------------------------------
|
||||||
|
//
|
||||||
|
// Functions to convert a probability to a phred-scaled quality score
|
||||||
|
//
|
||||||
|
// ----------------------------------------------------------------------
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Convert a probability of being wrong to a phred-scaled quality score (0.01 => 20).
|
||||||
|
*
|
||||||
|
* Note, this function caps the resulting quality score by the public static value MAX_SAM_QUAL_SCORE
|
||||||
|
* and by 1 at the low-end.
|
||||||
|
*
|
||||||
|
* @param errorRate a probability (0.0-1.0) of being wrong (i.e., 0.01 is 1% change of being wrong)
|
||||||
|
* @return a quality score (0-MAX_SAM_QUAL_SCORE)
|
||||||
|
*/
|
||||||
|
static uint8_t errorProbToQual(const double errorRate) { return errorProbToQual(errorRate, MAX_SAM_QUAL_SCORE); }
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Convert a probability of being wrong to a phred-scaled quality score (0.01 => 20).
|
||||||
|
*
|
||||||
|
* Note, this function caps the resulting quality score by the public static value MIN_REASONABLE_ERROR
|
||||||
|
* and by 1 at the low-end.
|
||||||
|
*
|
||||||
|
* WARNING -- because this function takes a byte for maxQual, you must be careful in converting
|
||||||
|
* integers to byte. The appropriate way to do this is ((byte)(myInt & 0xFF))
|
||||||
|
*
|
||||||
|
* @param errorRate a probability (0.0-1.0) of being wrong (i.e., 0.01 is 1% change of being wrong)
|
||||||
|
* @return a quality score (0-maxQual)
|
||||||
|
*/
|
||||||
|
static uint8_t errorProbToQual(const double errorRate, const uint8_t maxQual) {
|
||||||
|
// Utils.validateArg(MathUtils.isValidProbability(errorRate), ()->"errorRate must be good probability but got " + errorRate);
|
||||||
|
const double d = std::round(-10.0 * std::log10(errorRate));
|
||||||
|
return boundQual((int)d, maxQual);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @see #errorProbToQual(double, byte) with proper conversion of maxQual integer to a byte
|
||||||
|
*/
|
||||||
|
static uint8_t errorProbToQual(const double prob, const int maxQual) {
|
||||||
|
// Utils.validateArg(maxQual >= 0 && maxQual <= 255, ()->"maxQual must be between 0-255 but got " + maxQual);
|
||||||
|
return errorProbToQual(prob, (uint8_t)(maxQual & 0xFF));
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Convert a probability of being right to a phred-scaled quality score (0.99 => 20).
|
||||||
|
*
|
||||||
|
* Note, this function caps the resulting quality score by the public static value MAX_SAM_QUAL_SCORE
|
||||||
|
* and by 1 at the low-end.
|
||||||
|
*
|
||||||
|
* @param prob a probability (0.0-1.0) of being right
|
||||||
|
* @return a quality score (0-MAX_SAM_QUAL_SCORE)
|
||||||
|
*/
|
||||||
|
static uint8_t trueProbToQual(const double prob) { return trueProbToQual(prob, MAX_SAM_QUAL_SCORE); }
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Convert a probability of being right to a phred-scaled quality score (0.99 => 20).
|
||||||
|
*
|
||||||
|
* Note, this function caps the resulting quality score by the min probability allowed (EPS).
|
||||||
|
* So for example, if prob is 1e-6, which would imply a Q-score of 60, and EPS is 1e-4,
|
||||||
|
* the result of this function is actually Q40.
|
||||||
|
*
|
||||||
|
* Note that the resulting quality score, regardless of EPS, is capped by MAX_SAM_QUAL_SCORE and
|
||||||
|
* bounded on the low-side by 1.
|
||||||
|
*
|
||||||
|
* WARNING -- because this function takes a byte for maxQual, you must be careful in converting
|
||||||
|
* integers to byte. The appropriate way to do this is ((byte)(myInt & 0xFF))
|
||||||
|
*
|
||||||
|
* @param trueProb a probability (0.0-1.0) of being right
|
||||||
|
* @param maxQual the maximum quality score we are allowed to emit here, regardless of the error rate
|
||||||
|
* @return a phred-scaled quality score (0-maxQualScore) as a byte
|
||||||
|
*/
|
||||||
|
static uint8_t trueProbToQual(const double trueProb, const uint8_t maxQual) {
|
||||||
|
// Utils.validateArg(MathUtils.isValidProbability(trueProb), ()->"trueProb must be good probability but got " + trueProb);
|
||||||
|
const double lp = std::round(-10.0 * MathUtils::log10OneMinusX(trueProb));
|
||||||
|
return boundQual((int)lp, maxQual);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @see #trueProbToQual(double, byte) with proper conversion of maxQual to a byte
|
||||||
|
*/
|
||||||
|
static uint8_t trueProbToQual(const double prob, const int maxQual) {
|
||||||
|
// Utils.validateArg(maxQual >= 0 && maxQual <= 255, ()->"maxQual must be between 0-255 but got " + maxQual);
|
||||||
|
return trueProbToQual(prob, (uint8_t)(maxQual & 0xFF));
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Convert a probability of being right to a phred-scaled quality score of being wrong as a double
|
||||||
|
*
|
||||||
|
* This is a very generic method, that simply computes a phred-scaled double quality
|
||||||
|
* score given an error rate. It has the same precision as a normal double operation
|
||||||
|
*
|
||||||
|
* @param trueRate the probability of being right (0.0-1.0)
|
||||||
|
* @return a phred-scaled version of the error rate implied by trueRate
|
||||||
|
*/
|
||||||
|
static double phredScaleCorrectRate(const double trueRate) { return phredScaleLog10ErrorRate(MathUtils::log10OneMinusX(trueRate)); }
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Convert a probability of being wrong to a phred-scaled quality score as a double
|
||||||
|
*
|
||||||
|
* This is a very generic method, that simply computes a phred-scaled double quality
|
||||||
|
* score given an error rate. It has the same precision as a normal double operation
|
||||||
|
*
|
||||||
|
* @param errorRate the probability of being wrong (0.0-1.0)
|
||||||
|
* @return a phred-scaled version of the error rate
|
||||||
|
*/
|
||||||
|
static double phredScaleErrorRate(const double errorRate) { return phredScaleLog10ErrorRate(std::log10(errorRate)); }
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Convert a log10 probability of being wrong to a phred-scaled quality score as a double
|
||||||
|
*
|
||||||
|
* This is a very generic method, that simply computes a phred-scaled double quality
|
||||||
|
* score given an error rate. It has the same precision as a normal double operation
|
||||||
|
*
|
||||||
|
* @param errorRateLog10 the log10 probability of being wrong (0.0-1.0). Can be -Infinity, in which case
|
||||||
|
* the result is MIN_PHRED_SCALED_QUAL
|
||||||
|
* @return a phred-scaled version of the error rate
|
||||||
|
*/
|
||||||
|
static double phredScaleLog10ErrorRate(const double errorRateLog10) {
|
||||||
|
// Utils.validateArg(MathUtils.isValidLog10Probability(errorRateLog10), ()->"errorRateLog10 must be good probability but got " + errorRateLog10);
|
||||||
|
|
||||||
|
// abs is necessary for edge base with errorRateLog10 = 0 producing -0.0 doubles
|
||||||
|
return std::abs(-10.0 * std::max(errorRateLog10, MIN_LOG10_SCALED_QUAL));
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Convert a log10 probability of being right to a phred-scaled quality score of being wrong as a double
|
||||||
|
*
|
||||||
|
* This is a very generic method, that simply computes a phred-scaled double quality
|
||||||
|
* score given an error rate. It has the same precision as a normal double operation
|
||||||
|
*
|
||||||
|
* @param trueRateLog10 the log10 probability of being right (0.0-1.0). Can be -Infinity to indicate
|
||||||
|
* that the result is impossible in which MIN_PHRED_SCALED_QUAL is returned
|
||||||
|
* @return a phred-scaled version of the error rate implied by trueRate
|
||||||
|
*/
|
||||||
|
static double phredScaleLog10CorrectRate(const double trueRateLog10) { return phredScaleCorrectRate(std::pow(10.0, trueRateLog10)); }
|
||||||
|
|
||||||
|
// ----------------------------------------------------------------------
|
||||||
|
//
|
||||||
|
// Routines to bound a quality score to a reasonable range
|
||||||
|
//
|
||||||
|
// ----------------------------------------------------------------------
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Return a quality score that bounds qual by MAX_SAM_QUAL_SCORE and 1
|
||||||
|
*
|
||||||
|
* @param qual the uncapped quality score as an integer
|
||||||
|
* @return the bounded quality score
|
||||||
|
*/
|
||||||
|
static uint8_t boundQual(int qual) { return boundQual(qual, MAX_SAM_QUAL_SCORE); }
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Return a quality score that bounds qual by maxQual and 1
|
||||||
|
*
|
||||||
|
* WARNING -- because this function takes a byte for maxQual, you must be careful in converting
|
||||||
|
* integers to byte. The appropriate way to do this is ((byte)(myInt & 0xFF))
|
||||||
|
*
|
||||||
|
* @param qual the uncapped quality score as an integer. Can be < 0 (which may indicate an error in the
|
||||||
|
* client code), which will be brought back to 1, but this isn't an error, as some
|
||||||
|
* routines may use this functionality (BaseRecalibrator, for example)
|
||||||
|
* @param maxQual the maximum quality score, must be less < 255
|
||||||
|
* @return the bounded quality score
|
||||||
|
*/
|
||||||
|
static uint8_t boundQual(const int qual, const uint8_t maxQual) { return (uint8_t)(std::max(std::min(qual, maxQual & 0xFF), 1) & 0xFF); }
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Calculate the sum of phred scores.
|
||||||
|
* @param phreds the phred score values.
|
||||||
|
* @return the sum.
|
||||||
|
*/
|
||||||
|
static double phredSum(const vector<double>& phreds) {
|
||||||
|
switch (phreds.size()) {
|
||||||
|
case 0:
|
||||||
|
return DBL_MAX;
|
||||||
|
case 1:
|
||||||
|
return phreds[0];
|
||||||
|
case 2:
|
||||||
|
return phredSum(phreds[0], phreds[1]);
|
||||||
|
case 3:
|
||||||
|
return phredSum(phreds[0], phreds[1], phreds[2]);
|
||||||
|
default:
|
||||||
|
vector<double> log10Vals(phreds.size()); // todo: 优化加速
|
||||||
|
for (int i = 0; i < log10Vals.size(); i++) {
|
||||||
|
log10Vals[i] = phreds[i] * -0.1;
|
||||||
|
}
|
||||||
|
return -10 * MathUtils::log10SumLog10(log10Vals);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Calculate the sum of two phred scores.
|
||||||
|
* <p>
|
||||||
|
* As any sum, this operation is commutative.
|
||||||
|
* </p>
|
||||||
|
* @param a first phred score value.
|
||||||
|
* @param b second phred score value.
|
||||||
|
* @return the sum.
|
||||||
|
*/
|
||||||
|
static double phredSum(const double a, const double b) { return -10 * MathUtils::log10SumLog10(a * -.1, b * -.1); }
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Calculate the sum of three phred scores.
|
||||||
|
* <p>
|
||||||
|
* As any sum, this operation is commutative.
|
||||||
|
* </p>
|
||||||
|
* @param a first phred score value.
|
||||||
|
* @param b second phred score value.
|
||||||
|
* @param c thrid phred score value.
|
||||||
|
* @return the sum.
|
||||||
|
*/
|
||||||
|
static double phredSum(const double a, const double b, const double c) { return -10 * MathUtils::log10SumLog10(a * -.1, b * -.1, c * -.1); }
|
||||||
|
};
|
||||||
|
|
@ -0,0 +1,41 @@
|
||||||
|
/*
|
||||||
|
Description: 做一些量化处理
|
||||||
|
|
||||||
|
Copyright : All right reserved by ICT
|
||||||
|
|
||||||
|
Author : Zhang Zhonghai
|
||||||
|
Date : 2025/12/24
|
||||||
|
*/
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <vector>
|
||||||
|
#include <stdint.h>
|
||||||
|
#include "recal_tables.h"
|
||||||
|
#include "qual_utils.h"
|
||||||
|
#include "util/math/math_utils.h"
|
||||||
|
#include "qual_quantizer.h"
|
||||||
|
|
||||||
|
using std::vector;
|
||||||
|
|
||||||
|
struct QuantizationInfo {
|
||||||
|
vector<uint8_t> quantizedQuals;
|
||||||
|
vector<int64_t> empiricalQualCounts; // 直方图
|
||||||
|
int quantizationLevels; // 量化等级
|
||||||
|
|
||||||
|
QuantizationInfo() {}
|
||||||
|
|
||||||
|
QuantizationInfo(RecalTables &recalTables, const int levels) {
|
||||||
|
quantizationLevels = levels;
|
||||||
|
empiricalQualCounts.resize(QualityUtils::MAX_SAM_QUAL_SCORE + 1, 0);
|
||||||
|
auto& qualTable = recalTables.qualityScoreTable;
|
||||||
|
_Foreach3D(qualTable, val, {
|
||||||
|
// convert the empirical quality to an integer ( it is already capped by MAX_QUAL )
|
||||||
|
const int empiricalQual = MathUtils::fastRound(val.getEmpiricalQuality());
|
||||||
|
empiricalQualCounts[empiricalQual] += val.getNumObservations(); // add the number of observations for every key
|
||||||
|
});
|
||||||
|
|
||||||
|
// quantizeQualityScores
|
||||||
|
QualQuantizer quantizer(empiricalQualCounts, levels, QualityUtils::MIN_USABLE_Q_SCORE);
|
||||||
|
quantizer.getOriginalToQuantizedMap(quantizedQuals);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
@ -0,0 +1,76 @@
|
||||||
|
/*
|
||||||
|
Description: 单个sam记录需要保存的数据
|
||||||
|
|
||||||
|
Copyright : All right reserved by ICT
|
||||||
|
|
||||||
|
Author : Zhang Zhonghai
|
||||||
|
Date : 2025/12/24
|
||||||
|
*/
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "util/bam_wrap.h"
|
||||||
|
#include "covariate.h"
|
||||||
|
|
||||||
|
struct ReadRecalInfo {
|
||||||
|
SamData& read;
|
||||||
|
int length;
|
||||||
|
PerReadCovariateMatrix& covariates;
|
||||||
|
vector<bool>& skips;
|
||||||
|
FastArray<uint8_t>&base_quals, &ins_quals, &del_quals;
|
||||||
|
vector<double>&snp_errs, &ins_errs, &del_errs;
|
||||||
|
|
||||||
|
ReadRecalInfo(SamData& _read, PerReadCovariateMatrix& _covariates, vector<bool>& _skips, vector<double>& _snp_errs, vector<double>& _ins_errs,
|
||||||
|
vector<double>& _del_errs)
|
||||||
|
: read(_read),
|
||||||
|
covariates(_covariates),
|
||||||
|
skips(_skips),
|
||||||
|
base_quals(_read.base_quals),
|
||||||
|
ins_quals(_read.ins_quals),
|
||||||
|
del_quals(_read.del_quals),
|
||||||
|
snp_errs(_snp_errs),
|
||||||
|
ins_errs(_ins_errs),
|
||||||
|
del_errs(_del_errs) {
|
||||||
|
length = _read.read_len;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get the qual score for event type at offset
|
||||||
|
*
|
||||||
|
* @param eventType the type of event we want the qual for
|
||||||
|
* @param offset the offset into this read for the qual
|
||||||
|
* @return a valid quality score for event at offset
|
||||||
|
*/
|
||||||
|
// 获取跟eventType对应的quality
|
||||||
|
uint8_t getQual(const EventTypeValue &event, const int offset) {
|
||||||
|
if (event == EventType::BASE_SUBSTITUTION) {
|
||||||
|
return base_quals[offset];
|
||||||
|
} else if (event == EventType::BASE_INSERTION) {
|
||||||
|
return ins_quals[offset];
|
||||||
|
} else {
|
||||||
|
return del_quals[offset];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get the error fraction for event type at offset
|
||||||
|
*
|
||||||
|
* The error fraction is a value between 0 and 1 that indicates how much certainty we have
|
||||||
|
* in the error occurring at offset. A value of 1 means that the error definitely occurs at this
|
||||||
|
* site, a value of 0.0 means it definitely doesn't happen here. 0.5 means that half the weight
|
||||||
|
* of the error belongs here
|
||||||
|
*
|
||||||
|
* @param eventType the type of event we want the qual for
|
||||||
|
* @param offset the offset into this read for the qual
|
||||||
|
* @return a fractional weight for an error at this offset
|
||||||
|
*/
|
||||||
|
double getErrorFraction(const EventTypeValue& event, const int offset) {
|
||||||
|
if (event == EventType::BASE_SUBSTITUTION) {
|
||||||
|
return snp_errs[offset];
|
||||||
|
} else if (event == EventType::BASE_INSERTION) {
|
||||||
|
return ins_errs[offset];
|
||||||
|
} else {
|
||||||
|
return del_errs[offset];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
@ -0,0 +1,3 @@
|
||||||
|
#include "recal_datum.h"
|
||||||
|
|
||||||
|
double RecalDatum::logPriorCache[MAX_GATK_USABLE_Q_SCORE + 1];
|
||||||
|
|
@ -0,0 +1,246 @@
|
||||||
|
/*
|
||||||
|
Description: bqsr计算过程中需要记录的一些数据
|
||||||
|
|
||||||
|
Copyright : All right reserved by ICT
|
||||||
|
|
||||||
|
Author : Zhang Zhonghai
|
||||||
|
Date : 2025/12/24
|
||||||
|
*/
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <stdint.h>
|
||||||
|
#include <float.h>
|
||||||
|
#include "qual_utils.h"
|
||||||
|
#include "util/math/normal_dist.h"
|
||||||
|
#include "util/math/math_utils.h"
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The container for the 4-tuple
|
||||||
|
*
|
||||||
|
* ( reported quality, empirical quality, num observations, num mismatches/errors )
|
||||||
|
*
|
||||||
|
* for a given set of covariates.
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
|
||||||
|
struct RecalDatum {
|
||||||
|
static constexpr uint8_t MAX_RECALIBRATED_Q_SCORE = 93; // SAMUtils.MAX_PHRED_SCORE;
|
||||||
|
static constexpr int UNINITIALIZED_EMPIRICAL_QUALITY = -1;
|
||||||
|
static constexpr double MULTIPLIER = 100000.0; // See discussion in numMismatches about what the multiplier is.
|
||||||
|
/**
|
||||||
|
* used when calculating empirical qualities to avoid division by zero
|
||||||
|
*/
|
||||||
|
static constexpr int SMOOTHING_CONSTANT = 1;
|
||||||
|
|
||||||
|
static constexpr uint64_t MAX_NUMBER_OF_OBSERVATIONS = INT_MAX - 1;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Quals above this value should be capped down to this value (because they are too high)
|
||||||
|
* in the base quality score recalibrator
|
||||||
|
*/
|
||||||
|
static constexpr uint8_t MAX_GATK_USABLE_Q_SCORE = 40;
|
||||||
|
|
||||||
|
static double logPriorCache[MAX_GATK_USABLE_Q_SCORE + 1];
|
||||||
|
|
||||||
|
|
||||||
|
static void StaticInit() {
|
||||||
|
// normal distribution describing P(empiricalQuality - reportedQuality). Its mean is zero because a priori we expect
|
||||||
|
// no systematic bias in the reported quality score
|
||||||
|
const double mean = 0.0;
|
||||||
|
const double sigma = 0.5; // with these parameters, deltas can shift at most ~20 Q points
|
||||||
|
const NormalDistribution gaussian(mean, sigma);
|
||||||
|
for (int i = 0; i <= MAX_GATK_USABLE_Q_SCORE; i++) {
|
||||||
|
logPriorCache[i] = gaussian.logDensity(i);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Estimated reported quality score based on combined data's individual q-reporteds and number of observations.
|
||||||
|
* The estimating occurs when collapsing counts across different reported qualities.
|
||||||
|
*/
|
||||||
|
// 测序仪给出的原始质量分数
|
||||||
|
double reportedQuality = 0.0;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The empirical quality for datums that have been collapsed together (by read group and reported quality, for example).
|
||||||
|
*
|
||||||
|
* This variable was historically a double, but {@link #bayesianEstimateOfEmpiricalQuality} has always returned an integer qual score.
|
||||||
|
* Thus the type has been changed to integer in February 2025 to highlight this implementation detail. It does not change the output.
|
||||||
|
*/
|
||||||
|
// 计算出来的真实质量分数
|
||||||
|
int empiricalQuality = 0;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Number of bases seen in total
|
||||||
|
*/
|
||||||
|
// 这个字段也用来判断当前datum的有效性,只有到numObservations > 0时,这个datum才有效,因为如果为0,说明这个datum都没有出现过
|
||||||
|
uint64_t numObservations = 0;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Number of bases seen that didn't match the reference
|
||||||
|
* (actually sum of the error weights - so not necessarily a whole number)
|
||||||
|
* Stored with an internal multiplier to keep it closer to the floating-point sweet spot and avoid numerical error
|
||||||
|
* (see https://github.com/broadinstitute/gatk/wiki/Numerical-errors ).
|
||||||
|
* However, the value of the multiplier influences the results.
|
||||||
|
* For example, you get different results for 1000.0 and 10000.0
|
||||||
|
* See MathUtilsUnitTest.testAddDoubles for a demonstration.
|
||||||
|
* The value of the MULTIPLIER that we found to give consistent results insensitive to sorting is 10000.0;
|
||||||
|
*/
|
||||||
|
double numMismatches = 0.0;
|
||||||
|
|
||||||
|
RecalDatum() {}
|
||||||
|
|
||||||
|
RecalDatum(const uint64_t _numObservations, const double _numMismatches, const uint8_t _reportedQuality) {
|
||||||
|
numObservations = _numObservations;
|
||||||
|
numMismatches = _numMismatches * MULTIPLIER;
|
||||||
|
reportedQuality = _reportedQuality;
|
||||||
|
empiricalQuality = UNINITIALIZED_EMPIRICAL_QUALITY;
|
||||||
|
}
|
||||||
|
|
||||||
|
void increment(const uint64_t incObservations, const double incMismatches) {
|
||||||
|
numObservations += incObservations;
|
||||||
|
numMismatches += (incMismatches * MULTIPLIER); // the multiplier used to avoid underflow, or something like that.
|
||||||
|
empiricalQuality = UNINITIALIZED_EMPIRICAL_QUALITY;
|
||||||
|
}
|
||||||
|
|
||||||
|
void increment(const uint64_t incObservations, const double incMismatches, int baseQuality) {
|
||||||
|
numObservations += incObservations;
|
||||||
|
numMismatches += (incMismatches * MULTIPLIER); // the multiplier used to avoid underflow, or something like that.
|
||||||
|
reportedQuality = baseQuality;
|
||||||
|
empiricalQuality = UNINITIALIZED_EMPIRICAL_QUALITY;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Add in all of the data from other into this object, updating the reported quality from the expected
|
||||||
|
* error rate implied by the two reported qualities.
|
||||||
|
*
|
||||||
|
* For example (the only example?), this method is called when collapsing the counts across reported quality scores within
|
||||||
|
* the same read group.
|
||||||
|
*
|
||||||
|
* @param other RecalDatum to combine
|
||||||
|
*/
|
||||||
|
void combine(const RecalDatum& other) {
|
||||||
|
// this is the *expected* (or theoretical) number of errors given the reported qualities and the number of observations.
|
||||||
|
double expectedNumErrors = this->calcExpectedErrors() + other.calcExpectedErrors();
|
||||||
|
|
||||||
|
// increment the counts
|
||||||
|
increment(other.getNumObservations(), other.getNumMismatches());
|
||||||
|
|
||||||
|
// we use the theoretical count above to compute the "estimated" reported quality
|
||||||
|
// after combining two datums with different reported qualities.
|
||||||
|
reportedQuality = -10 * log10(expectedNumErrors / getNumObservations());
|
||||||
|
empiricalQuality = UNINITIALIZED_EMPIRICAL_QUALITY;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* calculate the expected number of errors given the estimated Q reported and the number of observations
|
||||||
|
* in this datum.
|
||||||
|
*
|
||||||
|
* @return a positive (potentially fractional) estimate of the number of errors
|
||||||
|
*/
|
||||||
|
inline double calcExpectedErrors() const { return numObservations * QualityUtils::qualToErrorProb(reportedQuality); }
|
||||||
|
inline double getNumMismatches() const { return numMismatches / MULTIPLIER; }
|
||||||
|
inline uint64_t getNumObservations() const { return numObservations; }
|
||||||
|
inline double getReportedQuality() const { return reportedQuality; }
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Computes the empirical quality of the datum, using the reported quality as the prior.
|
||||||
|
* @see #getEmpiricalQuality(double) below.
|
||||||
|
*/
|
||||||
|
double getEmpiricalQuality() { return getEmpiricalQuality(getReportedQuality()); }
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Computes the empirical base quality (roughly (num errors)/(num observations)) from the counts stored in this datum.
|
||||||
|
*/
|
||||||
|
double getEmpiricalQuality(const double priorQualityScore) {
|
||||||
|
if (empiricalQuality == UNINITIALIZED_EMPIRICAL_QUALITY) {
|
||||||
|
calcEmpiricalQuality(priorQualityScore);
|
||||||
|
}
|
||||||
|
return empiricalQuality;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Calculate and cache the empirical quality score from mismatches and observations (expensive operation)
|
||||||
|
*/
|
||||||
|
void calcEmpiricalQuality(const double priorQualityScore) {
|
||||||
|
// smoothing is one error and one non-error observation
|
||||||
|
const uint64_t mismatches = (uint64_t)(getNumMismatches() + 0.5) + SMOOTHING_CONSTANT; // TODO: why add 0.5?
|
||||||
|
const uint64_t observations = getNumObservations() + SMOOTHING_CONSTANT + SMOOTHING_CONSTANT;
|
||||||
|
|
||||||
|
const int empiricalQual = bayesianEstimateOfEmpiricalQuality(observations, mismatches, priorQualityScore);
|
||||||
|
|
||||||
|
empiricalQuality = std::min(empiricalQual, (int)MAX_RECALIBRATED_Q_SCORE);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Compute the maximum a posteriori (MAP) estimate of the probability of sequencing error under the following model.
|
||||||
|
*
|
||||||
|
* Let
|
||||||
|
* X = number of sequencing errors,
|
||||||
|
* n = number of observations,
|
||||||
|
* theta = probability of sequencing error as a quality score,
|
||||||
|
* theta_rep = probability of sequencing error reported by the sequencing machine as a quality score.
|
||||||
|
*
|
||||||
|
* The prior and the likelihood are:
|
||||||
|
*
|
||||||
|
* P(theta|theta_rep) ~ Gaussian(theta - theta_rep| 0, 0.5) (Note this is done in log space)
|
||||||
|
* P(X|n, theta) ~ Binom(X|n,theta)
|
||||||
|
*
|
||||||
|
* Note the prior is equivalent to
|
||||||
|
*
|
||||||
|
* P(theta|theta_rep) ~ Gaussian(theta | theta_rep, 0.5)
|
||||||
|
*
|
||||||
|
* TODO: use beta prior to do away with the search.
|
||||||
|
*
|
||||||
|
* @param nObservations n in the model above.
|
||||||
|
* @param nErrors the observed number of sequencing errors.
|
||||||
|
* @param priorMeanQualityScore the prior quality score, often the reported quality score.
|
||||||
|
*
|
||||||
|
* @return phredScale quality score that maximizes the posterior probability.
|
||||||
|
*/
|
||||||
|
static int bayesianEstimateOfEmpiricalQuality(const uint64_t nObservations, const uint64_t nErrors, const double priorMeanQualityScore) {
|
||||||
|
const int numQualityScoreBins = (QualityUtils::MAX_REASONABLE_Q_SCORE + 1);
|
||||||
|
|
||||||
|
double logPosteriors[numQualityScoreBins];
|
||||||
|
for (int i = 0; i < numQualityScoreBins; ++i) {
|
||||||
|
logPosteriors[i] = getLogPrior(i, priorMeanQualityScore) + getLogBinomialLikelihood(i, nObservations, nErrors);
|
||||||
|
}
|
||||||
|
return MathUtils::maxElementIndex(logPosteriors, 0, numQualityScoreBins);
|
||||||
|
}
|
||||||
|
|
||||||
|
static double getLogPrior(const double qualityScore, const double priorQualityScore) {
|
||||||
|
const int difference = std::min(std::abs((int)(qualityScore - priorQualityScore)), (int)MAX_GATK_USABLE_Q_SCORE);
|
||||||
|
return logPriorCache[difference];
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Given:
|
||||||
|
* - n, the number of observations,
|
||||||
|
* - k, the number of sequencing errors,
|
||||||
|
* - p, the probability of error, encoded as the quality score.
|
||||||
|
*
|
||||||
|
* Return the binomial probability Bin(k|n,p).
|
||||||
|
*
|
||||||
|
* The method handles the case when the counts of type long are higher than the maximum allowed integer value,
|
||||||
|
* Integer.MAX_VALUE = (2^31)-1 ~= 2*10^9, since the library we use for binomial probability expects integer input.
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
static double getLogBinomialLikelihood(const double qualityScore, uint64_t nObservations, uint64_t nErrors) {
|
||||||
|
if (nObservations == 0)
|
||||||
|
return 0.0;
|
||||||
|
|
||||||
|
// the binomial code requires ints as input (because it does caching). This should theoretically be fine because
|
||||||
|
// there is plenty of precision in 2^31 observations, but we need to make sure that we don't have overflow
|
||||||
|
// before casting down to an int.
|
||||||
|
if (nObservations > MAX_NUMBER_OF_OBSERVATIONS) {
|
||||||
|
// we need to decrease nErrors by the same fraction that we are decreasing nObservations
|
||||||
|
const double fraction = (double)MAX_NUMBER_OF_OBSERVATIONS / (double)nObservations;
|
||||||
|
nErrors = std::round((double)nErrors * fraction);
|
||||||
|
nObservations = MAX_NUMBER_OF_OBSERVATIONS;
|
||||||
|
}
|
||||||
|
|
||||||
|
// this is just a straight binomial PDF
|
||||||
|
const double logLikelihood = MathUtils::logBinomialProbability((int)nObservations, (int)nErrors, QualityUtils::qualToErrorProb(qualityScore));
|
||||||
|
return (std::isinf(logLikelihood) || std::isnan(logLikelihood)) ? -DBL_MAX : logLikelihood;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
@ -0,0 +1,44 @@
|
||||||
|
/*
|
||||||
|
Description: 保存bqsr计算的各种数据,根据这些数据生成bqsr table
|
||||||
|
|
||||||
|
Copyright : All right reserved by ICT
|
||||||
|
|
||||||
|
Author : Zhang Zhonghai
|
||||||
|
Date : 2025/12/24
|
||||||
|
*/
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "covariate.h"
|
||||||
|
#include "nested_array.h"
|
||||||
|
#include "recal_datum.h"
|
||||||
|
|
||||||
|
struct RecalTables {
|
||||||
|
int qualDimension = 94; // MAX_SAM_QUAL_SCORE(93) + 1
|
||||||
|
int eventDimension = EventType::EVENT_SIZE;
|
||||||
|
int numReadGroups;
|
||||||
|
|
||||||
|
// These two tables are special
|
||||||
|
Array2D<RecalDatum> readGroupTable;
|
||||||
|
Array3D<RecalDatum> qualityScoreTable;
|
||||||
|
|
||||||
|
// additional tables
|
||||||
|
Array4D<RecalDatum> contextTable;
|
||||||
|
Array4D<RecalDatum> cycleTable;
|
||||||
|
|
||||||
|
// NestedArray<int> testArr;
|
||||||
|
|
||||||
|
RecalTables() {}
|
||||||
|
|
||||||
|
RecalTables(int _numReadGroups) { init(_numReadGroups); }
|
||||||
|
|
||||||
|
void init(int _numReadGroups) {
|
||||||
|
numReadGroups = _numReadGroups;
|
||||||
|
// 初始化readgroup和quality两个table
|
||||||
|
readGroupTable.init(numReadGroups, eventDimension);
|
||||||
|
qualityScoreTable.init(numReadGroups, qualDimension, eventDimension);
|
||||||
|
|
||||||
|
// 初始化context和cycle两个table
|
||||||
|
contextTable.init(numReadGroups, qualDimension, ContextCovariate::MaximumKeyValue() + 1, eventDimension);
|
||||||
|
cycleTable.init(numReadGroups, qualDimension, CycleCovariate::MaximumKeyValue() + 1, eventDimension);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
@ -0,0 +1,56 @@
|
||||||
|
/*
|
||||||
|
Description: bqsr计算过程的辅助类
|
||||||
|
|
||||||
|
Copyright : All right reserved by ICT
|
||||||
|
|
||||||
|
Author : Zhang Zhonghai
|
||||||
|
Date : 2025/12/24
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "recal_utils.h"
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Increments the RecalDatum at the specified position in the specified table, or put a new item there
|
||||||
|
* if there isn't already one.
|
||||||
|
*
|
||||||
|
* Note: we intentionally do not use varargs here to avoid the performance cost of allocating an array on every call. It showed on the profiler.
|
||||||
|
*
|
||||||
|
* @param table the table that holds/will hold our item
|
||||||
|
* @param qual qual for this event
|
||||||
|
* @param isError error value for this event
|
||||||
|
* @param key0, key1, key2 location in table of our item
|
||||||
|
*/
|
||||||
|
void RecalUtils::IncrementDatum3keys(Array3D<RecalDatum> table, uint8_t qual, double isError, int key0, int key1, int key2) {
|
||||||
|
RecalDatum &existingDatum = table.get(key0, key1, key2);
|
||||||
|
|
||||||
|
if (existingDatum.numObservations == 0) {
|
||||||
|
// No existing item, put a new one
|
||||||
|
table.put(RecalDatum(1, qual, isError), key0, key1, key2);
|
||||||
|
} else {
|
||||||
|
// Easy case: already an item here, so increment it
|
||||||
|
existingDatum.increment(1, isError);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Increments the RecalDatum at the specified position in the specified table, or put a new item there
|
||||||
|
* if there isn't already one.
|
||||||
|
*
|
||||||
|
* Note: we intentionally do not use varargs here to avoid the performance cost of allocating an array on every call. It showed on the profiler.
|
||||||
|
*
|
||||||
|
* @param table the table that holds/will hold our item
|
||||||
|
* @param qual qual for this event
|
||||||
|
* @param isError error value for this event
|
||||||
|
* @param key0, key1, key2, key3 location in table of our item
|
||||||
|
*/
|
||||||
|
void RecalUtils::IncrementDatum4keys(Array4D<RecalDatum> table, uint8_t qual, double isError, int key0, int key1, int key2, int key3) {
|
||||||
|
RecalDatum& existingDatum = table.get(key0, key1, key2, key3);
|
||||||
|
|
||||||
|
if (existingDatum.numObservations == 0) {
|
||||||
|
// No existing item, put a new one
|
||||||
|
table.put(RecalDatum(1, qual, isError), key0, key1, key2, key3);
|
||||||
|
} else {
|
||||||
|
// Easy case: already an item here, so increment it
|
||||||
|
existingDatum.increment(1, isError);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,207 @@
|
||||||
|
/*
|
||||||
|
Description: bqsr计算过程的辅助类
|
||||||
|
|
||||||
|
Copyright : All right reserved by ICT
|
||||||
|
|
||||||
|
Author : Zhang Zhonghai
|
||||||
|
Date : 2025/12/24
|
||||||
|
*/
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <spdlog/spdlog.h>
|
||||||
|
|
||||||
|
#include <cstdio>
|
||||||
|
|
||||||
|
#include "bqsr_args.h"
|
||||||
|
#include "nested_array.h"
|
||||||
|
#include "quant_info.h"
|
||||||
|
#include "recal_datum.h"
|
||||||
|
#include "recal_tables.h"
|
||||||
|
#include "util/report_table.h"
|
||||||
|
#include "util/utils.h"
|
||||||
|
#include "covariate.h"
|
||||||
|
|
||||||
|
struct RecalUtils {
|
||||||
|
|
||||||
|
static constexpr int EMPIRICAL_QUAL_DECIMAL_PLACES = 4;
|
||||||
|
static constexpr int REPORTED_QUALITY_DECIMAL_PLACES = 4;
|
||||||
|
static constexpr int NUMBER_ERRORS_DECIMAL_PLACES = 2;
|
||||||
|
|
||||||
|
// 根据每个read的key,在recalTable中添加对应的数据
|
||||||
|
static void IncrementDatum3keys(Array3D<RecalDatum> table, uint8_t qual, double isError, int key0, int key1, int key2);
|
||||||
|
static void IncrementDatum4keys(Array4D<RecalDatum> table, uint8_t qual, double isError, int key0, int key1, int key2, int key3);
|
||||||
|
|
||||||
|
// 输出bqsr报告
|
||||||
|
static void outputRecalibrationReport(const BQSRArg& RAC, const QuantizationInfo& quantInfo, const RecalTables& recalTables) {
|
||||||
|
spdlog::info("output tables");
|
||||||
|
// 输出文件
|
||||||
|
FILE* fpout = fopen(RAC.OUTPUT_FILE.c_str(), "w");
|
||||||
|
if (!fpout) {
|
||||||
|
fprintf(stderr, "Failed to open file %s.\n", RAC.OUTPUT_FILE.c_str());
|
||||||
|
exit(1);
|
||||||
|
}
|
||||||
|
// 输出version信息
|
||||||
|
fprintf(fpout, "%s\n", REPORT_HEADER_VERSION);
|
||||||
|
// 输出参数信息
|
||||||
|
outputArgsTable(RAC, fpout);
|
||||||
|
// 输出量化质量分数信息
|
||||||
|
outputQuantTable(quantInfo, fpout);
|
||||||
|
// 输出协变量信息
|
||||||
|
outputCovariateTable(recalTables, fpout);
|
||||||
|
|
||||||
|
fclose(fpout);
|
||||||
|
}
|
||||||
|
|
||||||
|
// 输出运行时使用的参数信息
|
||||||
|
static void outputArgsTable(const BQSRArg & p, FILE * fpout) {
|
||||||
|
ReportTable table("Arguments", "Recalibration argument collection values used in this run");
|
||||||
|
table.addColumn({"Argument", "%s"});
|
||||||
|
table.addColumn({"Value", ""});
|
||||||
|
// 添加行数据
|
||||||
|
table.addRowData({"covariate", "ReadGroupCovariate,QualityScoreCovariate,ContextCovariate,CycleCovariate"});
|
||||||
|
table.addRowData({"no_standard_covs", ReportUtil::ToString(p.DO_NOT_USE_STANDARD_COVARIATES)});
|
||||||
|
table.addRowData({"run_without_dbsnp", ReportUtil::ToString(p.RUN_WITHOUT_DBSNP)});
|
||||||
|
table.addRowData({"solid_recal_mode", ReportUtil::ToString(p.SOLID_RECAL_MODE)});
|
||||||
|
table.addRowData({"solid_nocall_strategy", ReportUtil::ToString(p.SOLID_NOCALL_STRATEGY)});
|
||||||
|
table.addRowData({"mismatches_context_size", ReportUtil::ToString(p.MISMATCHES_CONTEXT_SIZE)});
|
||||||
|
table.addRowData({"indels_context_size", ReportUtil::ToString(p.INDELS_CONTEXT_SIZE)});
|
||||||
|
table.addRowData({"mismatches_default_quality", ReportUtil::ToString(p.MISMATCHES_DEFAULT_QUALITY)});
|
||||||
|
table.addRowData({"deletions_default_quality", ReportUtil::ToString(p.DELETIONS_DEFAULT_QUALITY)});
|
||||||
|
table.addRowData({"insertions_default_quality", ReportUtil::ToString(p.INSERTIONS_DEFAULT_QUALITY)});
|
||||||
|
table.addRowData({"maximum_cycle_value", ReportUtil::ToString(p.MAXIMUM_CYCLE_VALUE)});
|
||||||
|
table.addRowData({"low_quality_tail", ReportUtil::ToString(p.LOW_QUAL_TAIL)});
|
||||||
|
table.addRowData({"default_platform", ReportUtil::ToString(p.DEFAULT_PLATFORM)});
|
||||||
|
table.addRowData({"force_platform", ReportUtil::ToString(p.FORCE_PLATFORM)});
|
||||||
|
table.addRowData({"quantizing_levels", ReportUtil::ToString(p.QUANTIZING_LEVELS)});
|
||||||
|
table.addRowData({"recalibration_report", ReportUtil::ToString(p.existingRecalibrationReport)});
|
||||||
|
table.addRowData({"binary_tag_name", ReportUtil::ToString(p.BINARY_TAG_NAME)});
|
||||||
|
// 输出到文件
|
||||||
|
table.write(fpout);
|
||||||
|
}
|
||||||
|
|
||||||
|
// 输出量化质量分数table
|
||||||
|
static void outputQuantTable(const QuantizationInfo& q, FILE* fpout) {
|
||||||
|
ReportTable table("Quantized", "Quality quantization map");
|
||||||
|
table.addColumn({"QualityScore", "%d"});
|
||||||
|
table.addColumn({"Count", "%d"});
|
||||||
|
table.addColumn({"QuantizedScore", "%d"});
|
||||||
|
|
||||||
|
for (int qual = 0; qual <= QualityUtils::MAX_SAM_QUAL_SCORE; ++qual) {
|
||||||
|
table.addRowData({ReportUtil::ToString(qual), ReportUtil::ToString(q.empiricalQualCounts[qual]), ReportUtil::ToString(q.quantizedQuals[qual])});
|
||||||
|
}
|
||||||
|
|
||||||
|
table.write(fpout);
|
||||||
|
}
|
||||||
|
|
||||||
|
// 输出协变量table
|
||||||
|
static void outputCovariateTable(const RecalTables &r, FILE* fpout) {
|
||||||
|
// 1. read group covariates
|
||||||
|
outputReadGroupTable(r, fpout);
|
||||||
|
// 2. quality score covariates
|
||||||
|
outputQualityScoreTable(r, fpout);
|
||||||
|
// 3. context and cycle covariates
|
||||||
|
outputContextCycleTable(r, fpout);
|
||||||
|
}
|
||||||
|
|
||||||
|
// read group table
|
||||||
|
static void outputReadGroupTable(const RecalTables& r, FILE* fpout) {
|
||||||
|
ReportTable table("RecalTable0");
|
||||||
|
table.addColumn({"ReadGroup", "%s"});
|
||||||
|
table.addColumn({"EventType", "%s"});
|
||||||
|
table.addColumn({"EmpiricalQuality", "%.4f"});
|
||||||
|
table.addColumn({"EstimatedQReported", "%.4f"});
|
||||||
|
table.addColumn({"Observations", "%d"});
|
||||||
|
table.addColumn({"Errors", "%.2f"});
|
||||||
|
|
||||||
|
spdlog::info("rg0: {}, {}, {}, {}", r.readGroupTable[0][0].numObservations, r.readGroupTable[0][0].numMismatches,
|
||||||
|
r.readGroupTable[0][0].reportedQuality, r.readGroupTable[0][0].empiricalQuality);
|
||||||
|
|
||||||
|
_Foreach2DK(r.readGroupTable, datum, {
|
||||||
|
RecalDatum &dat = const_cast<RecalDatum&>(datum);
|
||||||
|
spdlog::info("errors: {}, {}", datum.numMismatches, dat.getNumMismatches());
|
||||||
|
spdlog::info("obs: {}, {}", datum.numObservations, dat.getNumObservations());
|
||||||
|
if (dat.getNumObservations() > 0) {
|
||||||
|
table.addRowData({
|
||||||
|
ReadGroupCovariate::IdToRg[k1],
|
||||||
|
ReportUtil::ToString(EventType::EVENTS[k2].representation),
|
||||||
|
ReportUtil::ToString(dat.getEmpiricalQuality(), 4),
|
||||||
|
ReportUtil::ToString(dat.getReportedQuality(), 4),
|
||||||
|
ReportUtil::ToString(dat.getNumObservations()),
|
||||||
|
ReportUtil::ToString(dat.getNumMismatches(), 2)
|
||||||
|
});
|
||||||
|
}
|
||||||
|
});
|
||||||
|
table.write(fpout);
|
||||||
|
}
|
||||||
|
|
||||||
|
// quality table
|
||||||
|
static void outputQualityScoreTable(const RecalTables& r, FILE* fpout) {
|
||||||
|
ReportTable table("RecalTable1");
|
||||||
|
table.addColumn({"ReadGroup", "%s"});
|
||||||
|
table.addColumn({"QualityScore", "%d"});
|
||||||
|
table.addColumn({"EventType", "%s"});
|
||||||
|
table.addColumn({"EmpiricalQuality", "%.4f"});
|
||||||
|
table.addColumn({"Observations", "%d"});
|
||||||
|
table.addColumn({"Errors", "%.2f"});
|
||||||
|
|
||||||
|
_Foreach3DK(r.qualityScoreTable, datum, {
|
||||||
|
RecalDatum &dat = const_cast<RecalDatum&>(datum);
|
||||||
|
if (dat.getNumObservations() > 0) {
|
||||||
|
table.addRowData({
|
||||||
|
ReadGroupCovariate::IdToRg[k1],
|
||||||
|
ReportUtil::ToString(k2),
|
||||||
|
ReportUtil::ToString(EventType::EVENTS[k3].representation),
|
||||||
|
ReportUtil::ToString(dat.getEmpiricalQuality(), 4),
|
||||||
|
ReportUtil::ToString(dat.getNumObservations()),
|
||||||
|
ReportUtil::ToString(dat.getNumMismatches(), 2)
|
||||||
|
});
|
||||||
|
}
|
||||||
|
});
|
||||||
|
table.write(fpout);
|
||||||
|
}
|
||||||
|
|
||||||
|
static void outputContextCycleTable(const RecalTables& r, FILE* fpout) {
|
||||||
|
ReportTable table("RecalTable2");
|
||||||
|
table.addColumn({"ReadGroup", "%s"});
|
||||||
|
table.addColumn({"QualityScore", "%d"});
|
||||||
|
table.addColumn({"CovariateValue", "%s"});
|
||||||
|
table.addColumn({"CovariateName", "%s"});
|
||||||
|
table.addColumn({"EventType", "%s"});
|
||||||
|
table.addColumn({"EmpiricalQuality", "%.4f"});
|
||||||
|
table.addColumn({"Observations", "%d"});
|
||||||
|
table.addColumn({"Errors", "%.2f"});
|
||||||
|
|
||||||
|
_Foreach4DK(r.contextTable, datum, {
|
||||||
|
RecalDatum &dat = const_cast<RecalDatum&>(datum);
|
||||||
|
if (dat.getNumObservations() > 0) {
|
||||||
|
table.addRowData({
|
||||||
|
ReadGroupCovariate::IdToRg[k1],
|
||||||
|
ReportUtil::ToString(k2),
|
||||||
|
ReportUtil::ToString(ContextCovariate::ContextFromKey(k3)),
|
||||||
|
"Context",
|
||||||
|
ReportUtil::ToString(EventType::EVENTS[k4].representation),
|
||||||
|
ReportUtil::ToString(dat.getEmpiricalQuality(), 4),
|
||||||
|
ReportUtil::ToString(dat.getNumObservations()),
|
||||||
|
ReportUtil::ToString(dat.getNumMismatches(), 2)
|
||||||
|
});
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
_Foreach4DK(r.cycleTable, datum, {
|
||||||
|
RecalDatum &dat = const_cast<RecalDatum&>(datum);
|
||||||
|
if (dat.getNumObservations() > 0) {
|
||||||
|
table.addRowData({
|
||||||
|
ReadGroupCovariate::IdToRg[k1],
|
||||||
|
ReportUtil::ToString(k2),
|
||||||
|
ReportUtil::ToString(CycleCovariate::CycleFromKey(k3)),
|
||||||
|
"Cycle",
|
||||||
|
ReportUtil::ToString(EventType::EVENTS[k4].representation),
|
||||||
|
ReportUtil::ToString(dat.getEmpiricalQuality(), 4),
|
||||||
|
ReportUtil::ToString(dat.getNumObservations()),
|
||||||
|
ReportUtil::ToString(dat.getNumMismatches(), 2)
|
||||||
|
});
|
||||||
|
}
|
||||||
|
});
|
||||||
|
table.write(fpout);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
@ -42,6 +42,12 @@ struct FastArray {
|
||||||
size_t idx;
|
size_t idx;
|
||||||
void clear() { idx = 0; }
|
void clear() { idx = 0; }
|
||||||
size_t size() { return idx; }
|
size_t size() { return idx; }
|
||||||
|
bool empty() { return idx == 0; }
|
||||||
|
void reserve(size_t _size) { arr.reserve(_size); }
|
||||||
|
void resize(size_t _size) {
|
||||||
|
arr.resize(_size);
|
||||||
|
idx = _size;
|
||||||
|
}
|
||||||
void push_back(const T& val) {
|
void push_back(const T& val) {
|
||||||
if (idx < arr.size()) {
|
if (idx < arr.size()) {
|
||||||
arr[idx++] = val;
|
arr[idx++] = val;
|
||||||
|
|
@ -50,6 +56,7 @@ struct FastArray {
|
||||||
idx++;
|
idx++;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
inline T& operator[](size_t pos) { return arr[pos]; }
|
||||||
struct iterator {
|
struct iterator {
|
||||||
typename std::vector<T>::iterator it;
|
typename std::vector<T>::iterator it;
|
||||||
iterator(typename std::vector<T>::iterator _it) : it(_it) {}
|
iterator(typename std::vector<T>::iterator _it) : it(_it) {}
|
||||||
|
|
@ -65,6 +72,7 @@ struct FastArray {
|
||||||
};
|
};
|
||||||
|
|
||||||
// 对原始bam数据的补充,比如对两端进行hardclip等
|
// 对原始bam数据的补充,比如对两端进行hardclip等
|
||||||
|
class BamWrap;
|
||||||
struct SamData {
|
struct SamData {
|
||||||
int read_len = 0; // read长度,各种clip之后的长度
|
int read_len = 0; // read长度,各种clip之后的长度
|
||||||
int cigar_start = 0; // cigar起始位置,闭区间
|
int cigar_start = 0; // cigar起始位置,闭区间
|
||||||
|
|
@ -77,38 +85,43 @@ struct SamData {
|
||||||
|
|
||||||
// 记录一下bqsr运算过程中用到的数据,回头提前计算一下,修正现在的复杂逻辑
|
// 记录一下bqsr运算过程中用到的数据,回头提前计算一下,修正现在的复杂逻辑
|
||||||
static constexpr int READ_INDEX_NOT_FOUND = -1;
|
static constexpr int READ_INDEX_NOT_FOUND = -1;
|
||||||
string bases; // 处理之后的read的碱基
|
|
||||||
vector<uint8_t> quals; // 对应的质量分数
|
|
||||||
int64_t start_pos; // 因为soft clip都被切掉了,这里的softstart应该就是切掉之后的匹配位点,闭区间
|
|
||||||
int64_t end_pos; // 同上,闭区间
|
|
||||||
FastArray<Cigar> cigars;
|
|
||||||
int64_t& softStart() { return start_pos; }
|
|
||||||
int64_t& softEnd() { return end_pos; }
|
|
||||||
|
|
||||||
// functions
|
BamWrap* bw;
|
||||||
ReadIdxCigar getReadIndexForReferenceCoordinate(int64_t refPos) {
|
int64_t start_pos; // 因为soft clip都被切掉了,这里的softstart应该就是切掉之后的匹配位点,闭区间
|
||||||
ReadIdxCigar rc;
|
int64_t end_pos; // 同上,闭区间
|
||||||
if (refPos < start_pos)
|
string bases; // 处理之后的read的碱基
|
||||||
return rc;
|
FastArray<uint8_t> base_quals; // 对应的质量分数
|
||||||
int firstReadPosOfElement = 0; // inclusive
|
FastArray<uint8_t> ins_quals; // insert质量分数, BI (大部分应该都没有)
|
||||||
int firstRefPosOfElement = start_pos; // inclusive
|
FastArray<uint8_t> del_quals; // delete质量分数, BD (大部分应该都没有)
|
||||||
int lastReadPosOfElement = 0; // exclusive
|
|
||||||
int lastRefPosOfElement = start_pos; // exclusive
|
FastArray<Cigar> cigars;
|
||||||
// advance forward through all the cigar elements until we bracket the reference coordinate
|
int64_t& softStart() { return start_pos; }
|
||||||
for (auto& cigar : cigars) {
|
int64_t& softEnd() { return end_pos; }
|
||||||
firstReadPosOfElement = lastReadPosOfElement;
|
|
||||||
firstRefPosOfElement = lastRefPosOfElement;
|
// functions
|
||||||
lastReadPosOfElement += Cigar::ConsumeReadBases(cigar.op) ? cigar.len : 0;
|
ReadIdxCigar getReadIndexForReferenceCoordinate(int64_t refPos) {
|
||||||
lastRefPosOfElement += Cigar::ConsumeRefBases(cigar.op) || cigar.op == 'S' ? cigar.len : 0;
|
ReadIdxCigar rc;
|
||||||
if (firstRefPosOfElement <= refPos && refPos < lastRefPosOfElement) { // refCoord falls within this cigar element
|
if (refPos < start_pos)
|
||||||
int readPosAtRefCoord = firstReadPosOfElement + (Cigar::ConsumeReadBases(cigar.op) ? (refPos - firstRefPosOfElement) : 0);
|
|
||||||
rc.cigarOp = cigar.op;
|
|
||||||
rc.readIdx = readPosAtRefCoord;
|
|
||||||
return rc;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return rc;
|
return rc;
|
||||||
|
int firstReadPosOfElement = 0; // inclusive
|
||||||
|
int firstRefPosOfElement = start_pos; // inclusive
|
||||||
|
int lastReadPosOfElement = 0; // exclusive
|
||||||
|
int lastRefPosOfElement = start_pos; // exclusive
|
||||||
|
// advance forward through all the cigar elements until we bracket the reference coordinate
|
||||||
|
for (auto& cigar : cigars) {
|
||||||
|
firstReadPosOfElement = lastReadPosOfElement;
|
||||||
|
firstRefPosOfElement = lastRefPosOfElement;
|
||||||
|
lastReadPosOfElement += Cigar::ConsumeReadBases(cigar.op) ? cigar.len : 0;
|
||||||
|
lastRefPosOfElement += Cigar::ConsumeRefBases(cigar.op) || cigar.op == 'S' ? cigar.len : 0;
|
||||||
|
if (firstRefPosOfElement <= refPos && refPos < lastRefPosOfElement) { // refCoord falls within this cigar element
|
||||||
|
int readPosAtRefCoord = firstReadPosOfElement + (Cigar::ConsumeReadBases(cigar.op) ? (refPos - firstRefPosOfElement) : 0);
|
||||||
|
rc.cigarOp = cigar.op;
|
||||||
|
rc.readIdx = readPosAtRefCoord;
|
||||||
|
return rc;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
return rc;
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
/*
|
/*
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,39 @@
|
||||||
|
/*
|
||||||
|
Description: 二项式分布相关函数
|
||||||
|
|
||||||
|
Copyright : All right reserved by ICT
|
||||||
|
|
||||||
|
Author : Zhang Zhonghai
|
||||||
|
Date : 2025/12/26
|
||||||
|
*/
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <math.h>
|
||||||
|
#include <float.h>
|
||||||
|
#include <limits>
|
||||||
|
#include "saddle_pe.h"
|
||||||
|
|
||||||
|
struct BinomialDistribution {
|
||||||
|
/** The number of trials. */
|
||||||
|
int numberOfTrials;
|
||||||
|
/** The probability of success. */
|
||||||
|
double probabilityOfSuccess;
|
||||||
|
|
||||||
|
BinomialDistribution(int trials, double p) {
|
||||||
|
probabilityOfSuccess = p;
|
||||||
|
numberOfTrials = trials;
|
||||||
|
}
|
||||||
|
|
||||||
|
double logProbability(int x) {
|
||||||
|
if (numberOfTrials == 0) {
|
||||||
|
return (x == 0) ? 0. : -std::numeric_limits<double>::infinity();
|
||||||
|
}
|
||||||
|
double ret;
|
||||||
|
if (x < 0 || x > numberOfTrials) {
|
||||||
|
ret = -std::numeric_limits<double>::infinity();
|
||||||
|
} else {
|
||||||
|
ret = SaddlePointExpansion::logBinomialProbability(x, numberOfTrials, probabilityOfSuccess, 1.0 - probabilityOfSuccess);
|
||||||
|
}
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
@ -0,0 +1,162 @@
|
||||||
|
/*
|
||||||
|
Description: ContinuedFraction
|
||||||
|
|
||||||
|
Copyright : All right reserved by ICT
|
||||||
|
|
||||||
|
Author : Zhang Zhonghai
|
||||||
|
Date : 2025/12/25
|
||||||
|
*/
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <math.h>
|
||||||
|
|
||||||
|
#include <functional>
|
||||||
|
#include <iostream>
|
||||||
|
|
||||||
|
#include "math_const.h"
|
||||||
|
#include "precision.h"
|
||||||
|
|
||||||
|
//using ContinuedFractionGetFunc = std::function<double(int, double)>();
|
||||||
|
// using ContinuedFractionGetFunc = double (*)(int n, double x);
|
||||||
|
|
||||||
|
struct ContinuedFraction {
|
||||||
|
/** Maximum allowed numerical error. */
|
||||||
|
static constexpr double DEFAULT_EPSILON = 10e-9;
|
||||||
|
|
||||||
|
// ContinuedFraction(ContinuedFractionGetFunc _getA, ContinuedFractionGetFunc _getB) {
|
||||||
|
// getA = _getA;
|
||||||
|
// getB = _getB;
|
||||||
|
// }
|
||||||
|
/**
|
||||||
|
* Access the n-th a coefficient of the continued fraction. Since a can be
|
||||||
|
* a function of the evaluation point, x, that is passed in as well.
|
||||||
|
* @param n the coefficient index to retrieve.
|
||||||
|
* @param x the evaluation point.
|
||||||
|
* @return the n-th a coefficient.
|
||||||
|
*/
|
||||||
|
//ContinuedFractionGetFunc getA;
|
||||||
|
virtual double getA(int n, double x) = 0;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Access the n-th b coefficient of the continued fraction. Since b can be
|
||||||
|
* a function of the evaluation point, x, that is passed in as well.
|
||||||
|
* @param n the coefficient index to retrieve.
|
||||||
|
* @param x the evaluation point.
|
||||||
|
* @return the n-th b coefficient.
|
||||||
|
*/
|
||||||
|
//ContinuedFractionGetFunc getB;
|
||||||
|
virtual double getB(int n, double x) = 0;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Evaluates the continued fraction at the value x.
|
||||||
|
* @param x the evaluation point.
|
||||||
|
* @return the value of the continued fraction evaluated at x.
|
||||||
|
* @throws ConvergenceException if the algorithm fails to converge.
|
||||||
|
*/
|
||||||
|
double evaluate(double x) { return evaluate(x, DEFAULT_EPSILON, MathConst::INT_MAX_VALUE); }
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Evaluates the continued fraction at the value x.
|
||||||
|
* @param x the evaluation point.
|
||||||
|
* @param epsilon maximum error allowed.
|
||||||
|
* @return the value of the continued fraction evaluated at x.
|
||||||
|
* @throws ConvergenceException if the algorithm fails to converge.
|
||||||
|
*/
|
||||||
|
double evaluate(double x, double epsilon) { return evaluate(x, epsilon, MathConst::INT_MAX_VALUE); }
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Evaluates the continued fraction at the value x.
|
||||||
|
* @param x the evaluation point.
|
||||||
|
* @param maxIterations maximum number of convergents
|
||||||
|
* @return the value of the continued fraction evaluated at x.
|
||||||
|
* @throws ConvergenceException if the algorithm fails to converge.
|
||||||
|
* @throws MaxCountExceededException if maximal number of iterations is reached
|
||||||
|
*/
|
||||||
|
double evaluate(double x, int maxIterations) {
|
||||||
|
return evaluate(x, DEFAULT_EPSILON, maxIterations);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Evaluates the continued fraction at the value x.
|
||||||
|
* <p>
|
||||||
|
* The implementation of this method is based on the modified Lentz algorithm as described
|
||||||
|
* on page 18 ff. in:
|
||||||
|
* <ul>
|
||||||
|
* <li>
|
||||||
|
* I. J. Thompson, A. R. Barnett. "Coulomb and Bessel Functions of Complex Arguments and Order."
|
||||||
|
* <a target="_blank" href="http://www.fresco.org.uk/papers/Thompson-JCP64p490.pdf">
|
||||||
|
* http://www.fresco.org.uk/papers/Thompson-JCP64p490.pdf</a>
|
||||||
|
* </li>
|
||||||
|
* </ul>
|
||||||
|
* <b>Note:</b> the implementation uses the terms a<sub>i</sub> and b<sub>i</sub> as defined in
|
||||||
|
* <a href="http://mathworld.wolfram.com/ContinuedFraction.html">Continued Fraction @ MathWorld</a>.
|
||||||
|
* </p>
|
||||||
|
*
|
||||||
|
* @param x the evaluation point.
|
||||||
|
* @param epsilon maximum error allowed.
|
||||||
|
* @param maxIterations maximum number of convergents
|
||||||
|
* @return the value of the continued fraction evaluated at x.
|
||||||
|
* @throws ConvergenceException if the algorithm fails to converge.
|
||||||
|
* @throws MaxCountExceededException if maximal number of iterations is reached
|
||||||
|
*/
|
||||||
|
double evaluate(double x, double epsilon, int maxIterations) {
|
||||||
|
const double small = 1e-50;
|
||||||
|
double hPrev = getA(0, x);
|
||||||
|
|
||||||
|
// use the value of small as epsilon criteria for zero checks
|
||||||
|
if (Precision::equals(hPrev, 0.0, small)) {
|
||||||
|
hPrev = small;
|
||||||
|
}
|
||||||
|
|
||||||
|
int n = 1;
|
||||||
|
double dPrev = 0.0;
|
||||||
|
double cPrev = hPrev;
|
||||||
|
double hN = hPrev;
|
||||||
|
|
||||||
|
while (n < maxIterations) {
|
||||||
|
const double a = getA(n, x);
|
||||||
|
const double b = getB(n, x);
|
||||||
|
|
||||||
|
double dN = a + b * dPrev;
|
||||||
|
if (Precision::equals(dN, 0.0, small)) {
|
||||||
|
dN = small;
|
||||||
|
}
|
||||||
|
double cN = a + b / cPrev;
|
||||||
|
if (Precision::equals(cN, 0.0, small)) {
|
||||||
|
cN = small;
|
||||||
|
}
|
||||||
|
|
||||||
|
dN = 1 / dN;
|
||||||
|
const double deltaN = cN * dN;
|
||||||
|
hN = hPrev * deltaN;
|
||||||
|
|
||||||
|
if (std::isinf(hN)) {
|
||||||
|
std::cerr << "ConvergenceException INFINITY" << std::endl;
|
||||||
|
exit(1);
|
||||||
|
// throw new ConvergenceException(LocalizedFormats.CONTINUED_FRACTION_INFINITY_DIVERGENCE, x);
|
||||||
|
}
|
||||||
|
if (std::isnan(hN)) {
|
||||||
|
std::cerr << "ConvergenceException NAN" << std::endl;
|
||||||
|
exit(1);
|
||||||
|
// throw new ConvergenceException(LocalizedFormats.CONTINUED_FRACTION_NAN_DIVERGENCE, x);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (std::abs(deltaN - 1.0) < epsilon) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
dPrev = dN;
|
||||||
|
cPrev = cN;
|
||||||
|
hPrev = hN;
|
||||||
|
n++;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (n >= maxIterations) {
|
||||||
|
std::cerr << "MaxCountExceededException max iterations" << std::endl;
|
||||||
|
exit(1);
|
||||||
|
// throw new MaxCountExceededException(LocalizedFormats.NON_CONVERGENT_CONTINUED_FRACTION, maxIterations, x);
|
||||||
|
}
|
||||||
|
|
||||||
|
return hN;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
@ -0,0 +1,11 @@
|
||||||
|
#include "gamma.h"
|
||||||
|
|
||||||
|
const double Gamma::LANCZOS[] = {
|
||||||
|
0.99999999999999709182, 57.156235665862923517, -59.597960355475491248, 14.136097974741747174, -0.49191381609762019978,
|
||||||
|
.33994649984811888699e-4, .46523628927048575665e-4, -.98374475304879564677e-4, .15808870322491248884e-3, -.21026444172410488319e-3,
|
||||||
|
.21743961811521264320e-3, -.16431810653676389022e-3, .84418223983852743293e-4, -.26190838401581408670e-4, .36899182659531622704e-5,
|
||||||
|
};
|
||||||
|
|
||||||
|
const int Gamma::LANCZOS_LEN = sizeof(LANCZOS) / sizeof(LANCZOS[0]);
|
||||||
|
|
||||||
|
const double Gamma::HALF_LOG_2_PI = 0.5 * std::log(TWO_PI);
|
||||||
|
|
@ -0,0 +1,644 @@
|
||||||
|
/*
|
||||||
|
Description: Gamma相关的工具函数
|
||||||
|
|
||||||
|
Copyright : All right reserved by ICT
|
||||||
|
|
||||||
|
Author : Zhang Zhonghai
|
||||||
|
Date : 2025/12/25
|
||||||
|
*/
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <assert.h>
|
||||||
|
#include <math.h>
|
||||||
|
#include <limits>
|
||||||
|
#include <stdint.h>
|
||||||
|
#include <limits.h>
|
||||||
|
#include <iostream>
|
||||||
|
#include "math_const.h"
|
||||||
|
#include "continued_frac.h"
|
||||||
|
|
||||||
|
struct Gamma {
|
||||||
|
/**
|
||||||
|
* <a href="http://en.wikipedia.org/wiki/Euler-Mascheroni_constant">Euler-Mascheroni constant</a>
|
||||||
|
* @since 2.0
|
||||||
|
*/
|
||||||
|
static constexpr double GAMMA = 0.577215664901532860606512090082;
|
||||||
|
|
||||||
|
static constexpr double TWO_PI = 2 * (105414357.0 / 33554432.0 + 1.984187159361080883e-9);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The value of the {@code g} constant in the Lanczos approximation, see
|
||||||
|
* {@link #lanczos(double)}.
|
||||||
|
* @since 3.1
|
||||||
|
*/
|
||||||
|
static constexpr double LANCZOS_G = 607.0 / 128.0;
|
||||||
|
|
||||||
|
/** Maximum allowed numerical error. */
|
||||||
|
static constexpr double DEFAULT_EPSILON = 10e-15;
|
||||||
|
|
||||||
|
/** Lanczos coefficients */
|
||||||
|
static const double LANCZOS[];
|
||||||
|
static const int LANCZOS_LEN;
|
||||||
|
|
||||||
|
/** Avoid repeated computation of log of 2 PI in logGamma */
|
||||||
|
|
||||||
|
static const double HALF_LOG_2_PI;
|
||||||
|
|
||||||
|
/** The constant value of √(2π). */
|
||||||
|
|
||||||
|
static constexpr double SQRT_TWO_PI = 2.506628274631000502;
|
||||||
|
|
||||||
|
// limits for switching algorithm in digamma
|
||||||
|
/** C limit. */
|
||||||
|
|
||||||
|
static constexpr double C_LIMIT = 49;
|
||||||
|
|
||||||
|
/** S limit. */
|
||||||
|
|
||||||
|
static constexpr double S_LIMIT = 1e-5;
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Constants for the computation of double invGamma1pm1(double).
|
||||||
|
* Copied from DGAM1 in the NSWC library.
|
||||||
|
*/
|
||||||
|
|
||||||
|
/** The constant {@code A0} defined in {@code DGAM1}. */
|
||||||
|
static constexpr double INV_GAMMA1P_M1_A0 = .611609510448141581788E-08;
|
||||||
|
|
||||||
|
/** The constant {@code A1} defined in {@code DGAM1}. */
|
||||||
|
static constexpr double INV_GAMMA1P_M1_A1 = .624730830116465516210E-08;
|
||||||
|
|
||||||
|
/** The constant {@code B1} defined in {@code DGAM1}. */
|
||||||
|
static constexpr double INV_GAMMA1P_M1_B1 = .203610414066806987300E+00;
|
||||||
|
|
||||||
|
/** The constant {@code B2} defined in {@code DGAM1}. */
|
||||||
|
static constexpr double INV_GAMMA1P_M1_B2 = .266205348428949217746E-01;
|
||||||
|
|
||||||
|
/** The constant {@code B3} defined in {@code DGAM1}. */
|
||||||
|
static constexpr double INV_GAMMA1P_M1_B3 = .493944979382446875238E-03;
|
||||||
|
|
||||||
|
/** The constant {@code B4} defined in {@code DGAM1}. */
|
||||||
|
static constexpr double INV_GAMMA1P_M1_B4 = -.851419432440314906588E-05;
|
||||||
|
|
||||||
|
/** The constant {@code B5} defined in {@code DGAM1}. */
|
||||||
|
static constexpr double INV_GAMMA1P_M1_B5 = -.643045481779353022248E-05;
|
||||||
|
|
||||||
|
/** The constant {@code B6} defined in {@code DGAM1}. */
|
||||||
|
static constexpr double INV_GAMMA1P_M1_B6 = .992641840672773722196E-06;
|
||||||
|
|
||||||
|
/** The constant {@code B7} defined in {@code DGAM1}. */
|
||||||
|
static constexpr double INV_GAMMA1P_M1_B7 = -.607761895722825260739E-07;
|
||||||
|
|
||||||
|
/** The constant {@code B8} defined in {@code DGAM1}. */
|
||||||
|
static constexpr double INV_GAMMA1P_M1_B8 = .195755836614639731882E-09;
|
||||||
|
|
||||||
|
/** The constant {@code P0} defined in {@code DGAM1}. */
|
||||||
|
static constexpr double INV_GAMMA1P_M1_P0 = .6116095104481415817861E-08;
|
||||||
|
|
||||||
|
/** The constant {@code P1} defined in {@code DGAM1}. */
|
||||||
|
static constexpr double INV_GAMMA1P_M1_P1 = .6871674113067198736152E-08;
|
||||||
|
|
||||||
|
/** The constant {@code P2} defined in {@code DGAM1}. */
|
||||||
|
static constexpr double INV_GAMMA1P_M1_P2 = .6820161668496170657918E-09;
|
||||||
|
|
||||||
|
/** The constant {@code P3} defined in {@code DGAM1}. */
|
||||||
|
static constexpr double INV_GAMMA1P_M1_P3 = .4686843322948848031080E-10;
|
||||||
|
|
||||||
|
/** The constant {@code P4} defined in {@code DGAM1}. */
|
||||||
|
static constexpr double INV_GAMMA1P_M1_P4 = .1572833027710446286995E-11;
|
||||||
|
|
||||||
|
/** The constant {@code P5} defined in {@code DGAM1}. */
|
||||||
|
static constexpr double INV_GAMMA1P_M1_P5 = -.1249441572276366213222E-12;
|
||||||
|
|
||||||
|
/** The constant {@code P6} defined in {@code DGAM1}. */
|
||||||
|
static constexpr double INV_GAMMA1P_M1_P6 = .4343529937408594255178E-14;
|
||||||
|
|
||||||
|
/** The constant {@code Q1} defined in {@code DGAM1}. */
|
||||||
|
static constexpr double INV_GAMMA1P_M1_Q1 = .3056961078365221025009E+00;
|
||||||
|
|
||||||
|
/** The constant {@code Q2} defined in {@code DGAM1}. */
|
||||||
|
static constexpr double INV_GAMMA1P_M1_Q2 = .5464213086042296536016E-01;
|
||||||
|
|
||||||
|
/** The constant {@code Q3} defined in {@code DGAM1}. */
|
||||||
|
static constexpr double INV_GAMMA1P_M1_Q3 = .4956830093825887312020E-02;
|
||||||
|
|
||||||
|
/** The constant {@code Q4} defined in {@code DGAM1}. */
|
||||||
|
static constexpr double INV_GAMMA1P_M1_Q4 = .2692369466186361192876E-03;
|
||||||
|
|
||||||
|
/** The constant {@code C} defined in {@code DGAM1}. */
|
||||||
|
static constexpr double INV_GAMMA1P_M1_C = -.422784335098467139393487909917598E+00;
|
||||||
|
|
||||||
|
/** The constant {@code C0} defined in {@code DGAM1}. */
|
||||||
|
static constexpr double INV_GAMMA1P_M1_C0 = .577215664901532860606512090082402E+00;
|
||||||
|
|
||||||
|
/** The constant {@code C1} defined in {@code DGAM1}. */
|
||||||
|
static constexpr double INV_GAMMA1P_M1_C1 = -.655878071520253881077019515145390E+00;
|
||||||
|
|
||||||
|
/** The constant {@code C2} defined in {@code DGAM1}. */
|
||||||
|
static constexpr double INV_GAMMA1P_M1_C2 = -.420026350340952355290039348754298E-01;
|
||||||
|
|
||||||
|
/** The constant {@code C3} defined in {@code DGAM1}. */
|
||||||
|
static constexpr double INV_GAMMA1P_M1_C3 = .166538611382291489501700795102105E+00;
|
||||||
|
|
||||||
|
/** The constant {@code C4} defined in {@code DGAM1}. */
|
||||||
|
static constexpr double INV_GAMMA1P_M1_C4 = -.421977345555443367482083012891874E-01;
|
||||||
|
|
||||||
|
/** The constant {@code C5} defined in {@code DGAM1}. */
|
||||||
|
static constexpr double INV_GAMMA1P_M1_C5 = -.962197152787697356211492167234820E-02;
|
||||||
|
|
||||||
|
/** The constant {@code C6} defined in {@code DGAM1}. */
|
||||||
|
static constexpr double INV_GAMMA1P_M1_C6 = .721894324666309954239501034044657E-02;
|
||||||
|
|
||||||
|
/** The constant {@code C7} defined in {@code DGAM1}. */
|
||||||
|
static constexpr double INV_GAMMA1P_M1_C7 = -.116516759185906511211397108401839E-02;
|
||||||
|
|
||||||
|
/** The constant {@code C8} defined in {@code DGAM1}. */
|
||||||
|
static constexpr double INV_GAMMA1P_M1_C8 = -.215241674114950972815729963053648E-03;
|
||||||
|
|
||||||
|
/** The constant {@code C9} defined in {@code DGAM1}. */
|
||||||
|
static constexpr double INV_GAMMA1P_M1_C9 = .128050282388116186153198626328164E-03;
|
||||||
|
|
||||||
|
/** The constant {@code C10} defined in {@code DGAM1}. */
|
||||||
|
static constexpr double INV_GAMMA1P_M1_C10 = -.201348547807882386556893914210218E-04;
|
||||||
|
|
||||||
|
/** The constant {@code C11} defined in {@code DGAM1}. */
|
||||||
|
static constexpr double INV_GAMMA1P_M1_C11 = -.125049348214267065734535947383309E-05;
|
||||||
|
|
||||||
|
/** The constant {@code C12} defined in {@code DGAM1}. */
|
||||||
|
static constexpr double INV_GAMMA1P_M1_C12 = .113302723198169588237412962033074E-05;
|
||||||
|
|
||||||
|
/** The constant {@code C13} defined in {@code DGAM1}. */
|
||||||
|
static constexpr double INV_GAMMA1P_M1_C13 = -.205633841697760710345015413002057E-06;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Default constructor. Prohibit instantiation.
|
||||||
|
*/
|
||||||
|
private:
|
||||||
|
Gamma() {}
|
||||||
|
|
||||||
|
public:
|
||||||
|
|
||||||
|
/**
|
||||||
|
* <p>
|
||||||
|
* Returns the value of log Γ(x) for x > 0.
|
||||||
|
* </p>
|
||||||
|
* <p>
|
||||||
|
* For x ≤ 8, the implementation is based on the double precision
|
||||||
|
* implementation in the <em>NSWC Library of Mathematics Subroutines</em>,
|
||||||
|
* {@code DGAMLN}. For x > 8, the implementation is based on
|
||||||
|
* </p>
|
||||||
|
* <ul>
|
||||||
|
* <li><a href="http://mathworld.wolfram.com/GammaFunction.html">Gamma
|
||||||
|
* Function</a>, equation (28).</li>
|
||||||
|
* <li><a href="http://mathworld.wolfram.com/LanczosApproximation.html">
|
||||||
|
* Lanczos Approximation</a>, equations (1) through (5).</li>
|
||||||
|
* <li><a href="http://my.fit.edu/~gabdo/gamma.txt">Paul Godfrey, A note on
|
||||||
|
* the computation of the convergent Lanczos complex Gamma
|
||||||
|
* approximation</a></li>
|
||||||
|
* </ul>
|
||||||
|
*
|
||||||
|
* @param x Argument.
|
||||||
|
* @return the value of {@code log(Gamma(x))}, {@code Double.NaN} if
|
||||||
|
* {@code x <= 0.0}.
|
||||||
|
*/
|
||||||
|
static double logGamma(double x) {
|
||||||
|
double ret;
|
||||||
|
|
||||||
|
if (std::isnan(x) || (x <= 0.0)) {
|
||||||
|
ret = MathConst::DOUBLE_NAN;
|
||||||
|
} else if (x < 0.5) {
|
||||||
|
return logGamma1p(x) - std::log(x);
|
||||||
|
} else if (x <= 2.5) {
|
||||||
|
return logGamma1p((x - 0.5) - 0.5);
|
||||||
|
} else if (x <= 8.0) {
|
||||||
|
const int n = (int)std::floor(x - 1.5);
|
||||||
|
double prod = 1.0;
|
||||||
|
for (int i = 1; i <= n; i++) {
|
||||||
|
prod *= x - i;
|
||||||
|
}
|
||||||
|
return logGamma1p(x - (n + 1)) + std::log(prod);
|
||||||
|
} else {
|
||||||
|
double sum = lanczos(x);
|
||||||
|
double tmp = x + LANCZOS_G + .5;
|
||||||
|
ret = ((x + .5) * std::log(tmp)) - tmp + HALF_LOG_2_PI + std::log(sum / x);
|
||||||
|
}
|
||||||
|
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns the regularized gamma function P(a, x).
|
||||||
|
*
|
||||||
|
* @param a Parameter.
|
||||||
|
* @param x Value.
|
||||||
|
* @return the regularized gamma function P(a, x).
|
||||||
|
* @throws MaxCountExceededException if the algorithm fails to converge.
|
||||||
|
*/
|
||||||
|
static double regularizedGammaP(double a, double x) { return regularizedGammaP(a, x, DEFAULT_EPSILON, INT_MAX); }
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns the regularized gamma function P(a, x).
|
||||||
|
*
|
||||||
|
* The implementation of this method is based on:
|
||||||
|
* <ul>
|
||||||
|
* <li>
|
||||||
|
* <a href="http://mathworld.wolfram.com/RegularizedGammaFunction.html">
|
||||||
|
* Regularized Gamma Function</a>, equation (1)
|
||||||
|
* </li>
|
||||||
|
* <li>
|
||||||
|
* <a href="http://mathworld.wolfram.com/IncompleteGammaFunction.html">
|
||||||
|
* Incomplete Gamma Function</a>, equation (4).
|
||||||
|
* </li>
|
||||||
|
* <li>
|
||||||
|
* <a href="http://mathworld.wolfram.com/ConfluentHypergeometricFunctionoftheFirstKind.html">
|
||||||
|
* Confluent Hypergeometric Function of the First Kind</a>, equation (1).
|
||||||
|
* </li>
|
||||||
|
* </ul>
|
||||||
|
*
|
||||||
|
* @param a the a parameter.
|
||||||
|
* @param x the value.
|
||||||
|
* @param epsilon When the absolute value of the nth item in the
|
||||||
|
* series is less than epsilon the approximation ceases to calculate
|
||||||
|
* further elements in the series.
|
||||||
|
* @param maxIterations Maximum number of "iterations" to complete.
|
||||||
|
* @return the regularized gamma function P(a, x)
|
||||||
|
* @throws MaxCountExceededException if the algorithm fails to converge.
|
||||||
|
*/
|
||||||
|
static double regularizedGammaP(double a, double x, double epsilon, int maxIterations) {
|
||||||
|
double ret;
|
||||||
|
|
||||||
|
if (std::isnan(a) || std::isnan(x) || (a <= 0.0) || (x < 0.0)) {
|
||||||
|
ret = std::numeric_limits<double>::quiet_NaN();
|
||||||
|
} else if (x == 0.0) {
|
||||||
|
ret = 0.0;
|
||||||
|
} else if (x >= a + 1) {
|
||||||
|
// use regularizedGammaQ because it should converge faster in this
|
||||||
|
// case.
|
||||||
|
ret = 1.0 - regularizedGammaQ(a, x, epsilon, maxIterations);
|
||||||
|
} else {
|
||||||
|
// calculate series
|
||||||
|
double n = 0.0; // current element index
|
||||||
|
double an = 1.0 / a; // n-th element in the series
|
||||||
|
double sum = an; // partial sum
|
||||||
|
while (std::abs(an / sum) > epsilon && n < maxIterations && sum < std::numeric_limits<double>::infinity()) {
|
||||||
|
// compute next element in the series
|
||||||
|
n += 1.0;
|
||||||
|
an *= x / (a + n);
|
||||||
|
|
||||||
|
// update partial sum
|
||||||
|
sum += an;
|
||||||
|
}
|
||||||
|
if (n >= maxIterations) {
|
||||||
|
std::cerr << "Reach max iterations!" << std::endl;
|
||||||
|
exit(1);
|
||||||
|
} else if (std::isinf(sum)) {
|
||||||
|
ret = 1.0;
|
||||||
|
} else {
|
||||||
|
ret = std::exp(-x + (a * std::log(x)) - logGamma(a)) * sum;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns the regularized gamma function Q(a, x) = 1 - P(a, x).
|
||||||
|
*
|
||||||
|
* @param a the a parameter.
|
||||||
|
* @param x the value.
|
||||||
|
* @return the regularized gamma function Q(a, x)
|
||||||
|
* @throws MaxCountExceededException if the algorithm fails to converge.
|
||||||
|
*/
|
||||||
|
static double regularizedGammaQ(double a, double x) { return regularizedGammaQ(a, x, DEFAULT_EPSILON, INT_MAX); }
|
||||||
|
|
||||||
|
struct ContinuedFractionInGamma : ContinuedFraction {
|
||||||
|
double a;
|
||||||
|
ContinuedFractionInGamma(double _a) : a(_a) {}
|
||||||
|
double getA(int n, double x) { return ((2.0 * n) + 1.0) - a + x; };
|
||||||
|
double getB(int n, double x) { return n * (a - n); };
|
||||||
|
};
|
||||||
|
/**
|
||||||
|
* Returns the regularized gamma function Q(a, x) = 1 - P(a, x).
|
||||||
|
*
|
||||||
|
* The implementation of this method is based on:
|
||||||
|
* <ul>
|
||||||
|
* <li>
|
||||||
|
* <a href="http://mathworld.wolfram.com/RegularizedGammaFunction.html">
|
||||||
|
* Regularized Gamma Function</a>, equation (1).
|
||||||
|
* </li>
|
||||||
|
* <li>
|
||||||
|
* <a href="http://functions.wolfram.com/GammaBetaErf/GammaRegularized/10/0003/">
|
||||||
|
* Regularized incomplete gamma function: Continued fraction representations
|
||||||
|
* (formula 06.08.10.0003)</a>
|
||||||
|
* </li>
|
||||||
|
* </ul>
|
||||||
|
*
|
||||||
|
* @param a the a parameter.
|
||||||
|
* @param x the value.
|
||||||
|
* @param epsilon When the absolute value of the nth item in the
|
||||||
|
* series is less than epsilon the approximation ceases to calculate
|
||||||
|
* further elements in the series.
|
||||||
|
* @param maxIterations Maximum number of "iterations" to complete.
|
||||||
|
* @return the regularized gamma function P(a, x)
|
||||||
|
* @throws MaxCountExceededException if the algorithm fails to converge.
|
||||||
|
*/
|
||||||
|
static double regularizedGammaQ(const double a, double x, double epsilon, int maxIterations) {
|
||||||
|
double ret;
|
||||||
|
|
||||||
|
if (std::isnan(a) || std::isnan(x) || (a <= 0.0) || (x < 0.0)) {
|
||||||
|
ret = MathConst::DOUBLE_NAN;
|
||||||
|
} else if (x == 0.0) {
|
||||||
|
ret = 1.0;
|
||||||
|
} else if (x < a + 1.0) {
|
||||||
|
// use regularizedGammaP because it should converge faster in this
|
||||||
|
// case.
|
||||||
|
ret = 1.0 - regularizedGammaP(a, x, epsilon, maxIterations);
|
||||||
|
} else {
|
||||||
|
// create continued fraction
|
||||||
|
ContinuedFractionInGamma cf(a);
|
||||||
|
ret = 1.0 / cf.evaluate(x, epsilon, maxIterations);
|
||||||
|
ret = std::exp(-x + (a * std::log(x)) - logGamma(a)) * ret;
|
||||||
|
}
|
||||||
|
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* <p>Computes the digamma function of x.</p>
|
||||||
|
*
|
||||||
|
* <p>This is an independently written implementation of the algorithm described in
|
||||||
|
* Jose Bernardo, Algorithm AS 103: Psi (Digamma) Function, Applied Statistics, 1976.</p>
|
||||||
|
*
|
||||||
|
* <p>Some of the constants have been changed to increase accuracy at the moderate expense
|
||||||
|
* of run-time. The result should be accurate to within 10^-8 absolute tolerance for
|
||||||
|
* x >= 10^-5 and within 10^-8 relative tolerance for x > 0.</p>
|
||||||
|
*
|
||||||
|
* <p>Performance for large negative values of x will be quite expensive (proportional to
|
||||||
|
* |x|). Accuracy for negative values of x should be about 10^-8 absolute for results
|
||||||
|
* less than 10^5 and 10^-8 relative for results larger than that.</p>
|
||||||
|
*
|
||||||
|
* @param x Argument.
|
||||||
|
* @return digamma(x) to within 10-8 relative or absolute error whichever is smaller.
|
||||||
|
* @see <a href="http://en.wikipedia.org/wiki/Digamma_function">Digamma</a>
|
||||||
|
* @see <a href="http://www.uv.es/~bernardo/1976AppStatist.pdf">Bernardo's original article </a>
|
||||||
|
* @since 2.0
|
||||||
|
*/
|
||||||
|
static double digamma(double x) {
|
||||||
|
if (x > 0 && x <= S_LIMIT) {
|
||||||
|
// use method 5 from Bernardo AS103
|
||||||
|
// accurate to O(x)
|
||||||
|
return -GAMMA - 1 / x;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (x >= C_LIMIT) {
|
||||||
|
// use method 4 (accurate to O(1/x^8)
|
||||||
|
double inv = 1 / (x * x);
|
||||||
|
// 1 1 1 1
|
||||||
|
// log(x) - --- - ------ + ------- - -------
|
||||||
|
// 2 x 12 x^2 120 x^4 252 x^6
|
||||||
|
return std::log(x) - 0.5 / x - inv * ((1.0 / 12) + inv * (1.0 / 120 - inv / 252));
|
||||||
|
}
|
||||||
|
|
||||||
|
return digamma(x + 1) - 1 / x;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Computes the trigamma function of x.
|
||||||
|
* This function is derived by taking the derivative of the implementation
|
||||||
|
* of digamma.
|
||||||
|
*
|
||||||
|
* @param x Argument.
|
||||||
|
* @return trigamma(x) to within 10-8 relative or absolute error whichever is smaller
|
||||||
|
* @see <a href="http://en.wikipedia.org/wiki/Trigamma_function">Trigamma</a>
|
||||||
|
* @see Gamma#digamma(double)
|
||||||
|
* @since 2.0
|
||||||
|
*/
|
||||||
|
static double trigamma(double x) {
|
||||||
|
if (x > 0 && x <= S_LIMIT) {
|
||||||
|
return 1 / (x * x);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (x >= C_LIMIT) {
|
||||||
|
double inv = 1 / (x * x);
|
||||||
|
// 1 1 1 1 1
|
||||||
|
// - + ---- + ---- - ----- + -----
|
||||||
|
// x 2 3 5 7
|
||||||
|
// 2 x 6 x 30 x 42 x
|
||||||
|
return 1 / x + inv / 2 + inv / x * (1.0 / 6 - inv * (1.0 / 30 + inv / 42));
|
||||||
|
}
|
||||||
|
|
||||||
|
return trigamma(x + 1) + 1 / (x * x);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* <p>
|
||||||
|
* Returns the Lanczos approximation used to compute the gamma function.
|
||||||
|
* The Lanczos approximation is related to the Gamma function by the
|
||||||
|
* following equation
|
||||||
|
* <center>
|
||||||
|
* {@code gamma(x) = sqrt(2 * pi) / x * (x + g + 0.5) ^ (x + 0.5)
|
||||||
|
* * exp(-x - g - 0.5) * lanczos(x)},
|
||||||
|
* </center>
|
||||||
|
* where {@code g} is the Lanczos constant.
|
||||||
|
* </p>
|
||||||
|
*
|
||||||
|
* @param x Argument.
|
||||||
|
* @return The Lanczos approximation.
|
||||||
|
* @see <a href="http://mathworld.wolfram.com/LanczosApproximation.html">Lanczos Approximation</a>
|
||||||
|
* equations (1) through (5), and Paul Godfrey's
|
||||||
|
* <a href="http://my.fit.edu/~gabdo/gamma.txt">Note on the computation
|
||||||
|
* of the convergent Lanczos complex Gamma approximation</a>
|
||||||
|
* @since 3.1
|
||||||
|
*/
|
||||||
|
static double lanczos(const double x) {
|
||||||
|
double sum = 0.0;
|
||||||
|
for (int i = LANCZOS_LEN - 1; i > 0; --i) {
|
||||||
|
sum += LANCZOS[i] / (x + i);
|
||||||
|
}
|
||||||
|
return sum + LANCZOS[0];
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns the value of 1 / Γ(1 + x) - 1 for -0.5 ≤ x ≤
|
||||||
|
* 1.5. This implementation is based on the double precision
|
||||||
|
* implementation in the <em>NSWC Library of Mathematics Subroutines</em>,
|
||||||
|
* {@code DGAM1}.
|
||||||
|
*
|
||||||
|
* @param x Argument.
|
||||||
|
* @return The value of {@code 1.0 / Gamma(1.0 + x) - 1.0}.
|
||||||
|
* @throws NumberIsTooSmallException if {@code x < -0.5}
|
||||||
|
* @throws NumberIsTooLargeException if {@code x > 1.5}
|
||||||
|
* @since 3.1
|
||||||
|
*/
|
||||||
|
static double invGamma1pm1(const double x) {
|
||||||
|
// if (x < -0.5) {
|
||||||
|
// throw new NumberIsTooSmallException(x, -0.5, true);
|
||||||
|
// }
|
||||||
|
// if (x > 1.5) {
|
||||||
|
// throw new NumberIsTooLargeException(x, 1.5, true);
|
||||||
|
// }
|
||||||
|
assert(x >= -0.5 && x <= 1.5);
|
||||||
|
|
||||||
|
double ret;
|
||||||
|
double t = x <= 0.5 ? x : (x - 0.5) - 0.5;
|
||||||
|
if (t < 0.0) {
|
||||||
|
double a = INV_GAMMA1P_M1_A0 + t * INV_GAMMA1P_M1_A1;
|
||||||
|
double b = INV_GAMMA1P_M1_B8;
|
||||||
|
b = INV_GAMMA1P_M1_B7 + t * b;
|
||||||
|
b = INV_GAMMA1P_M1_B6 + t * b;
|
||||||
|
b = INV_GAMMA1P_M1_B5 + t * b;
|
||||||
|
b = INV_GAMMA1P_M1_B4 + t * b;
|
||||||
|
b = INV_GAMMA1P_M1_B3 + t * b;
|
||||||
|
b = INV_GAMMA1P_M1_B2 + t * b;
|
||||||
|
b = INV_GAMMA1P_M1_B1 + t * b;
|
||||||
|
b = 1.0 + t * b;
|
||||||
|
|
||||||
|
double c = INV_GAMMA1P_M1_C13 + t * (a / b);
|
||||||
|
c = INV_GAMMA1P_M1_C12 + t * c;
|
||||||
|
c = INV_GAMMA1P_M1_C11 + t * c;
|
||||||
|
c = INV_GAMMA1P_M1_C10 + t * c;
|
||||||
|
c = INV_GAMMA1P_M1_C9 + t * c;
|
||||||
|
c = INV_GAMMA1P_M1_C8 + t * c;
|
||||||
|
c = INV_GAMMA1P_M1_C7 + t * c;
|
||||||
|
c = INV_GAMMA1P_M1_C6 + t * c;
|
||||||
|
c = INV_GAMMA1P_M1_C5 + t * c;
|
||||||
|
c = INV_GAMMA1P_M1_C4 + t * c;
|
||||||
|
c = INV_GAMMA1P_M1_C3 + t * c;
|
||||||
|
c = INV_GAMMA1P_M1_C2 + t * c;
|
||||||
|
c = INV_GAMMA1P_M1_C1 + t * c;
|
||||||
|
c = INV_GAMMA1P_M1_C + t * c;
|
||||||
|
if (x > 0.5) {
|
||||||
|
ret = t * c / x;
|
||||||
|
} else {
|
||||||
|
ret = x * ((c + 0.5) + 0.5);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
double p = INV_GAMMA1P_M1_P6;
|
||||||
|
p = INV_GAMMA1P_M1_P5 + t * p;
|
||||||
|
p = INV_GAMMA1P_M1_P4 + t * p;
|
||||||
|
p = INV_GAMMA1P_M1_P3 + t * p;
|
||||||
|
p = INV_GAMMA1P_M1_P2 + t * p;
|
||||||
|
p = INV_GAMMA1P_M1_P1 + t * p;
|
||||||
|
p = INV_GAMMA1P_M1_P0 + t * p;
|
||||||
|
|
||||||
|
double q = INV_GAMMA1P_M1_Q4;
|
||||||
|
q = INV_GAMMA1P_M1_Q3 + t * q;
|
||||||
|
q = INV_GAMMA1P_M1_Q2 + t * q;
|
||||||
|
q = INV_GAMMA1P_M1_Q1 + t * q;
|
||||||
|
q = 1.0 + t * q;
|
||||||
|
|
||||||
|
double c = INV_GAMMA1P_M1_C13 + (p / q) * t;
|
||||||
|
c = INV_GAMMA1P_M1_C12 + t * c;
|
||||||
|
c = INV_GAMMA1P_M1_C11 + t * c;
|
||||||
|
c = INV_GAMMA1P_M1_C10 + t * c;
|
||||||
|
c = INV_GAMMA1P_M1_C9 + t * c;
|
||||||
|
c = INV_GAMMA1P_M1_C8 + t * c;
|
||||||
|
c = INV_GAMMA1P_M1_C7 + t * c;
|
||||||
|
c = INV_GAMMA1P_M1_C6 + t * c;
|
||||||
|
c = INV_GAMMA1P_M1_C5 + t * c;
|
||||||
|
c = INV_GAMMA1P_M1_C4 + t * c;
|
||||||
|
c = INV_GAMMA1P_M1_C3 + t * c;
|
||||||
|
c = INV_GAMMA1P_M1_C2 + t * c;
|
||||||
|
c = INV_GAMMA1P_M1_C1 + t * c;
|
||||||
|
c = INV_GAMMA1P_M1_C0 + t * c;
|
||||||
|
|
||||||
|
if (x > 0.5) {
|
||||||
|
ret = (t / x) * ((c - 0.5) - 0.5);
|
||||||
|
} else {
|
||||||
|
ret = x * c;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns the value of log Γ(1 + x) for -0.5 ≤ x ≤ 1.5.
|
||||||
|
* This implementation is based on the double precision implementation in
|
||||||
|
* the <em>NSWC Library of Mathematics Subroutines</em>, {@code DGMLN1}.
|
||||||
|
*
|
||||||
|
* @param x Argument.
|
||||||
|
* @return The value of {@code log(Gamma(1 + x))}.
|
||||||
|
* @throws NumberIsTooSmallException if {@code x < -0.5}.
|
||||||
|
* @throws NumberIsTooLargeException if {@code x > 1.5}.
|
||||||
|
* @since 3.1
|
||||||
|
*/
|
||||||
|
static double logGamma1p(const double x) {
|
||||||
|
// if (x < -0.5) {
|
||||||
|
// throw new NumberIsTooSmallException(x, -0.5, true);
|
||||||
|
// }
|
||||||
|
// if (x > 1.5) {
|
||||||
|
// throw new NumberIsTooLargeException(x, 1.5, true);
|
||||||
|
// }
|
||||||
|
assert(x >= -0.5 && x <= 1.5);
|
||||||
|
return -std::log1p(invGamma1pm1(x));
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns the value of Γ(x). Based on the <em>NSWC Library of
|
||||||
|
* Mathematics Subroutines</em> double precision implementation,
|
||||||
|
* {@code DGAMMA}.
|
||||||
|
*
|
||||||
|
* @param x Argument.
|
||||||
|
* @return the value of {@code Gamma(x)}.
|
||||||
|
* @since 3.1
|
||||||
|
*/
|
||||||
|
static double gamma(const double x) {
|
||||||
|
if ((x == std::rint(x)) && (x <= 0.0)) {
|
||||||
|
return std::numeric_limits<double>::quiet_NaN();
|
||||||
|
}
|
||||||
|
|
||||||
|
double ret;
|
||||||
|
const double absX = std::abs(x);
|
||||||
|
if (absX <= 20.0) {
|
||||||
|
if (x >= 1.0) {
|
||||||
|
/*
|
||||||
|
* From the recurrence relation
|
||||||
|
* Gamma(x) = (x - 1) * ... * (x - n) * Gamma(x - n),
|
||||||
|
* then
|
||||||
|
* Gamma(t) = 1 / [1 + invGamma1pm1(t - 1)],
|
||||||
|
* where t = x - n. This means that t must satisfy
|
||||||
|
* -0.5 <= t - 1 <= 1.5.
|
||||||
|
*/
|
||||||
|
double prod = 1.0;
|
||||||
|
double t = x;
|
||||||
|
while (t > 2.5) {
|
||||||
|
t -= 1.0;
|
||||||
|
prod *= t;
|
||||||
|
}
|
||||||
|
ret = prod / (1.0 + invGamma1pm1(t - 1.0));
|
||||||
|
} else {
|
||||||
|
/*
|
||||||
|
* From the recurrence relation
|
||||||
|
* Gamma(x) = Gamma(x + n + 1) / [x * (x + 1) * ... * (x + n)]
|
||||||
|
* then
|
||||||
|
* Gamma(x + n + 1) = 1 / [1 + invGamma1pm1(x + n)],
|
||||||
|
* which requires -0.5 <= x + n <= 1.5.
|
||||||
|
*/
|
||||||
|
double prod = x;
|
||||||
|
double t = x;
|
||||||
|
while (t < -0.5) {
|
||||||
|
t += 1.0;
|
||||||
|
prod *= t;
|
||||||
|
}
|
||||||
|
ret = 1.0 / (prod * (1.0 + invGamma1pm1(t)));
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
const double y = absX + LANCZOS_G + 0.5;
|
||||||
|
const double gammaAbs = SQRT_TWO_PI / x * std::pow(y, absX + 0.5) * std::exp(-y) * lanczos(absX);
|
||||||
|
if (x > 0.0) {
|
||||||
|
ret = gammaAbs;
|
||||||
|
} else {
|
||||||
|
/*
|
||||||
|
* From the reflection formula
|
||||||
|
* Gamma(x) * Gamma(1 - x) * sin(pi * x) = pi,
|
||||||
|
* and the recurrence relation
|
||||||
|
* Gamma(1 - x) = -x * Gamma(-x),
|
||||||
|
* it is found
|
||||||
|
* Gamma(x) = -pi / [x * sin(pi * x) * Gamma(-x)].
|
||||||
|
*/
|
||||||
|
ret = -MathConst::PI / (x * std::sin(MathConst::PI * x) * gammaAbs);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
@ -0,0 +1,26 @@
|
||||||
|
/*
|
||||||
|
Description: 数学常用的一些常数,如pi,e等
|
||||||
|
|
||||||
|
Copyright : All right reserved by ICT
|
||||||
|
|
||||||
|
Author : Zhang Zhonghai
|
||||||
|
Date : 2025/12/26
|
||||||
|
*/
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <limits>
|
||||||
|
#include <limits.h>
|
||||||
|
|
||||||
|
struct MathConst {
|
||||||
|
static constexpr double PI = 105414357.0 / 33554432.0 + 1.984187159361080883e-9;
|
||||||
|
|
||||||
|
static constexpr double TWO_PI = 2 * (105414357.0 / 33554432.0 + 1.984187159361080883e-9);
|
||||||
|
|
||||||
|
static constexpr double DOUBLE_NAN = std::numeric_limits<double>::quiet_NaN();
|
||||||
|
|
||||||
|
static constexpr double DOUBLE_NEGATIVE_INFINITY = -std::numeric_limits<double>::infinity();
|
||||||
|
|
||||||
|
static constexpr double DOUBLE_POSITIVE_INFINITY = std::numeric_limits<double>::infinity();
|
||||||
|
|
||||||
|
static constexpr int INT_MAX_VALUE = INT_MAX;
|
||||||
|
};
|
||||||
|
|
@ -0,0 +1,41 @@
|
||||||
|
/*
|
||||||
|
Description: 数学常用的工具函数
|
||||||
|
|
||||||
|
Copyright : All right reserved by ICT
|
||||||
|
|
||||||
|
Author : Zhang Zhonghai
|
||||||
|
Date : 2025/12/26
|
||||||
|
*/
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <limits>
|
||||||
|
#include <cstdint>
|
||||||
|
#include <cstdlib>
|
||||||
|
#include <cstring>
|
||||||
|
|
||||||
|
struct MathFunc {
|
||||||
|
// 1.1 使用memcpy(最安全,无未定义行为)
|
||||||
|
static int64_t doubleToRawLongBits(double value) {
|
||||||
|
int64_t result;
|
||||||
|
std::memcpy(&result, &value, sizeof(result));
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
static double longBitsToDouble(int64_t bits) {
|
||||||
|
double result;
|
||||||
|
std::memcpy(&result, &bits, sizeof(result));
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
static int floatToRawIntBits(float value) {
|
||||||
|
int result;
|
||||||
|
std::memcpy(&result, &value, sizeof(result));
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
static float intBitsToFloat(int bits) {
|
||||||
|
float result;
|
||||||
|
std::memcpy(&result, &bits, sizeof(result));
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
@ -0,0 +1,195 @@
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <bit>
|
||||||
|
#include <cfloat>
|
||||||
|
#include <cmath>
|
||||||
|
#include <cstdint>
|
||||||
|
#include <cstring>
|
||||||
|
#include <iomanip>
|
||||||
|
#include <iostream>
|
||||||
|
#include <limits>
|
||||||
|
#include <type_traits>
|
||||||
|
|
||||||
|
// 2.1 基础模板实现
|
||||||
|
template <typename T>
|
||||||
|
class MathUlp {
|
||||||
|
static_assert(std::is_floating_point_v<T>, "MathUlp only works with floating-point types");
|
||||||
|
|
||||||
|
private:
|
||||||
|
using UIntType = std::conditional_t<std::is_same_v<T, float>, uint32_t, std::conditional_t<std::is_same_v<T, double>, uint64_t, void> >;
|
||||||
|
|
||||||
|
static constexpr int BITS = sizeof(T) * 8;
|
||||||
|
static constexpr int MANTISSA_BITS = std::numeric_limits<T>::digits - 1;
|
||||||
|
static constexpr int EXPONENT_BITS = BITS - MANTISSA_BITS - 1;
|
||||||
|
static constexpr UIntType SIGN_BIT = UIntType(1) << (BITS - 1);
|
||||||
|
static constexpr UIntType EXPONENT_MASK = ((UIntType(1) << EXPONENT_BITS) - 1) << MANTISSA_BITS;
|
||||||
|
static constexpr UIntType MANTISSA_MASK = (UIntType(1) << MANTISSA_BITS) - 1;
|
||||||
|
static constexpr int EXPONENT_BIAS = (1 << (EXPONENT_BITS - 1)) - 1;
|
||||||
|
static constexpr int MAX_EXPONENT = (1 << EXPONENT_BITS) - 1;
|
||||||
|
|
||||||
|
// 按位访问浮点数
|
||||||
|
static UIntType get_bits(T value) {
|
||||||
|
UIntType bits;
|
||||||
|
std::memcpy(&bits, &value, sizeof(T));
|
||||||
|
return bits;
|
||||||
|
}
|
||||||
|
|
||||||
|
static T from_bits(UIntType bits) {
|
||||||
|
T value;
|
||||||
|
std::memcpy(&value, &bits, sizeof(T));
|
||||||
|
return value;
|
||||||
|
}
|
||||||
|
|
||||||
|
// 获取指数部分
|
||||||
|
static int get_exponent(T value) {
|
||||||
|
UIntType bits = get_bits(value);
|
||||||
|
UIntType exp_bits = (bits >> MANTISSA_BITS) & ((UIntType(1) << EXPONENT_BITS) - 1);
|
||||||
|
return static_cast<int>(exp_bits) - EXPONENT_BIAS;
|
||||||
|
}
|
||||||
|
|
||||||
|
// 获取尾数部分
|
||||||
|
static UIntType get_mantissa(T value) {
|
||||||
|
UIntType bits = get_bits(value);
|
||||||
|
return bits & MANTISSA_MASK;
|
||||||
|
}
|
||||||
|
|
||||||
|
// 判断是否为规格化数
|
||||||
|
static bool is_normal(T value) {
|
||||||
|
if (value == 0 || std::isinf(value) || std::isnan(value))
|
||||||
|
return false;
|
||||||
|
int exp = get_exponent(value);
|
||||||
|
return exp > -EXPONENT_BIAS && exp <= EXPONENT_BIAS;
|
||||||
|
}
|
||||||
|
|
||||||
|
// 判断是否为非规格化数(次正规数)
|
||||||
|
static bool is_subnormal(T value) {
|
||||||
|
if (value == 0 || std::isinf(value) || std::isnan(value))
|
||||||
|
return false;
|
||||||
|
int exp = get_exponent(value);
|
||||||
|
return exp == -EXPONENT_BIAS && get_mantissa(value) != 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
public:
|
||||||
|
// 主函数:计算ulp
|
||||||
|
static T ulp(T x) {
|
||||||
|
// 处理特殊值
|
||||||
|
if (std::isnan(x)) {
|
||||||
|
return std::numeric_limits<T>::quiet_NaN();
|
||||||
|
}
|
||||||
|
|
||||||
|
if (std::isinf(x)) {
|
||||||
|
return std::numeric_limits<T>::infinity();
|
||||||
|
}
|
||||||
|
|
||||||
|
// 处理0
|
||||||
|
if (x == 0.0 || x == -0.0) {
|
||||||
|
return std::numeric_limits<T>::min(); // 最小正规格化数
|
||||||
|
}
|
||||||
|
|
||||||
|
// 获取x的绝对值
|
||||||
|
T abs_x = std::fabs(x);
|
||||||
|
|
||||||
|
// 对于非规格化数,ulp是固定的
|
||||||
|
if (is_subnormal(abs_x)) {
|
||||||
|
// 非规格化数的ulp是最小正非规格化数
|
||||||
|
return std::numeric_limits<T>::denorm_min();
|
||||||
|
}
|
||||||
|
|
||||||
|
// 对于规格化数
|
||||||
|
int exp = get_exponent(abs_x);
|
||||||
|
|
||||||
|
// ulp = 2^(exp - mantissa_bits)
|
||||||
|
// 对于float: ulp = 2^(exp - 23)
|
||||||
|
// 对于double: ulp = 2^(exp - 52)
|
||||||
|
if constexpr (std::is_same_v<T, float>) {
|
||||||
|
// 使用ldexp计算2的幂次
|
||||||
|
return std::ldexp(1.0f, exp - 23);
|
||||||
|
} else if constexpr (std::is_same_v<T, double>) {
|
||||||
|
return std::ldexp(1.0, exp - 52);
|
||||||
|
} else if constexpr (std::is_same_v<T, long double>) {
|
||||||
|
// long double的实现可能因平台而异
|
||||||
|
constexpr int LD_MANTISSA_BITS = std::numeric_limits<long double>::digits - 1;
|
||||||
|
return std::ldexp(1.0L, exp - LD_MANTISSA_BITS);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 使用nextafter的实现(更简单但可能有性能开销)
|
||||||
|
static T ulp_simple(T x) {
|
||||||
|
// 处理特殊值
|
||||||
|
if (std::isnan(x)) {
|
||||||
|
return std::numeric_limits<T>::quiet_NaN();
|
||||||
|
}
|
||||||
|
|
||||||
|
if (std::isinf(x)) {
|
||||||
|
return std::numeric_limits<T>::infinity();
|
||||||
|
}
|
||||||
|
|
||||||
|
// 处理0
|
||||||
|
if (x == 0.0 || x == -0.0) {
|
||||||
|
return std::numeric_limits<T>::min();
|
||||||
|
}
|
||||||
|
|
||||||
|
// 使用nextafter计算ulp
|
||||||
|
T next = std::nextafter(std::fabs(x), std::numeric_limits<T>::infinity());
|
||||||
|
T prev = std::nextafter(std::fabs(x), -std::numeric_limits<T>::infinity());
|
||||||
|
|
||||||
|
// 取最小的正差值
|
||||||
|
T diff1 = next - std::fabs(x);
|
||||||
|
T diff2 = std::fabs(x) - prev;
|
||||||
|
|
||||||
|
// 返回较小的正差值
|
||||||
|
if (diff1 > 0 && diff2 > 0) {
|
||||||
|
return std::min(diff1, diff2);
|
||||||
|
} else if (diff1 > 0) {
|
||||||
|
return diff1;
|
||||||
|
} else {
|
||||||
|
return diff2;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 获取浮点数的内部表示信息
|
||||||
|
static void print_float_info(T value) {
|
||||||
|
std::cout << "Value: " << std::setprecision(15) << value << "\n";
|
||||||
|
std::cout << "Hex bits: ";
|
||||||
|
|
||||||
|
if constexpr (std::is_same_v<T, float>) {
|
||||||
|
uint32_t bits;
|
||||||
|
std::memcpy(&bits, &value, sizeof(float));
|
||||||
|
std::cout << std::hex << bits << std::dec;
|
||||||
|
std::cout << "\nSign: " << ((bits >> 31) & 1);
|
||||||
|
std::cout << "\nExponent: " << ((bits >> 23) & 0xFF);
|
||||||
|
std::cout << " (bias-adjusted: " << ((bits >> 23) & 0xFF) - 127 << ")";
|
||||||
|
std::cout << "\nMantissa: " << (bits & 0x7FFFFF);
|
||||||
|
} else {
|
||||||
|
uint64_t bits;
|
||||||
|
std::memcpy(&bits, &value, sizeof(double));
|
||||||
|
std::cout << std::hex << bits << std::dec;
|
||||||
|
std::cout << "\nSign: " << ((bits >> 63) & 1);
|
||||||
|
std::cout << "\nExponent: " << ((bits >> 52) & 0x7FF);
|
||||||
|
std::cout << " (bias-adjusted: " << ((bits >> 52) & 0x7FF) - 1023 << ")";
|
||||||
|
std::cout << "\nMantissa: " << (bits & 0xFFFFFFFFFFFFF);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::cout << "\nIs normal: " << std::boolalpha << is_normal(value);
|
||||||
|
std::cout << "\nIs subnormal: " << is_subnormal(value);
|
||||||
|
std::cout << "\nULP (our): " << ulp(value);
|
||||||
|
std::cout << "\nULP (simple): " << ulp_simple(value);
|
||||||
|
std::cout << "\n---\n";
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// 2.2 用户友好接口
|
||||||
|
template <typename T>
|
||||||
|
inline T math_ulp(T x) {
|
||||||
|
return MathUlp<T>::ulp(x);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
inline T math_ulp_simple(T x) {
|
||||||
|
return MathUlp<T>::ulp_simple(x);
|
||||||
|
}
|
||||||
|
|
||||||
|
// 2.3 类型别名
|
||||||
|
using Ulp32 = MathUlp<float>;
|
||||||
|
using Ulp64 = MathUlp<double>;
|
||||||
|
using Ulp128 = MathUlp<long double>;
|
||||||
|
|
@ -0,0 +1,138 @@
|
||||||
|
/*
|
||||||
|
Description: 数学计算相关的工具函数
|
||||||
|
|
||||||
|
Copyright : All right reserved by ICT
|
||||||
|
|
||||||
|
Author : Zhang Zhonghai
|
||||||
|
Date : 2025/12/25
|
||||||
|
*/
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <cassert>
|
||||||
|
#include <climits>
|
||||||
|
#include <cmath>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "binomial_dist.h"
|
||||||
|
#include "math_const.h"
|
||||||
|
#include "math_ulp.h"
|
||||||
|
|
||||||
|
using std::vector;
|
||||||
|
|
||||||
|
struct MathUtils {
|
||||||
|
|
||||||
|
// 初始化math库里所有需要静态初始化的类静态变量
|
||||||
|
static void StaticInit() {
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Rounds the double to the given number of decimal places.
|
||||||
|
* For example, rounding 3.1415926 to 3 places would give 3.142.
|
||||||
|
* The requirement is that it works exactly as writing a number down with string.format and reading back in.
|
||||||
|
*/
|
||||||
|
static double RoundToNDecimalPlaces(const double in, const int n) {
|
||||||
|
assert(n > 0);
|
||||||
|
const double mult = std::pow(10, n);
|
||||||
|
return std::round((in + Ulp64::ulp(in)) * mult) / mult;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* binomial Probability(int, int, double) with log applied to result
|
||||||
|
*/
|
||||||
|
static double logBinomialProbability(const int n, const int k, const double p) {
|
||||||
|
BinomialDistribution binomial(n, p);
|
||||||
|
return binomial.logProbability(k);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename ArrayType>
|
||||||
|
static int maxElementIndex(const ArrayType& array, const int start, const int endIndex) {
|
||||||
|
int maxI = start;
|
||||||
|
for (int i = (start + 1); i < endIndex; i++) {
|
||||||
|
if (array[i] > array[maxI])
|
||||||
|
maxI = i;
|
||||||
|
}
|
||||||
|
return maxI;
|
||||||
|
}
|
||||||
|
|
||||||
|
static int fastRound(double d) { return (d > 0.0) ? (int)(d + 0.5) : (int)(d - 0.5); }
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Compute in a numerically correct way the quantity log10(1-x)
|
||||||
|
*
|
||||||
|
* Uses the approximation log10(1-x) = log10(1/x - 1) + log10(x) to avoid very quick underflow
|
||||||
|
* in 1-x when x is very small
|
||||||
|
*
|
||||||
|
* log10(1-x) = log10( x * (1-x) / x )
|
||||||
|
* = log10(x) + log10( (1-x) / x)
|
||||||
|
* = log10(x) + log10(1/x - 1)
|
||||||
|
*
|
||||||
|
* @param x a positive double value between 0.0 and 1.0
|
||||||
|
* @return an estimate of log10(1-x)
|
||||||
|
*/
|
||||||
|
static double log10OneMinusX(const double x) {
|
||||||
|
if (x == 1.0)
|
||||||
|
return MathConst::DOUBLE_NEGATIVE_INFINITY;
|
||||||
|
else if (x == 0.0)
|
||||||
|
return 0.0;
|
||||||
|
else {
|
||||||
|
const double d = std::log10(1 / x - 1) + std::log10(x);
|
||||||
|
return std::isinf(d) || d > 0.0 ? 0.0 : d;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static double log10SumLog10(const vector<double>& log10Values, const int start) {
|
||||||
|
return log10SumLog10(log10Values, start, log10Values.size());
|
||||||
|
}
|
||||||
|
|
||||||
|
static double log10SumLog10(const vector<double>& log10Values) { return log10SumLog10(log10Values, 0); }
|
||||||
|
|
||||||
|
static double log10SumLog10(const vector<double>& log10Values, const int start, const int finish) {
|
||||||
|
// Utils.nonNull(log10Values);
|
||||||
|
if (start >= finish) {
|
||||||
|
return MathConst::DOUBLE_NEGATIVE_INFINITY;
|
||||||
|
}
|
||||||
|
const int maxElementIndex = MathUtils::maxElementIndex(log10Values, start, finish);
|
||||||
|
|
||||||
|
|
||||||
|
const double maxValue = log10Values[maxElementIndex];
|
||||||
|
if (maxValue == MathConst::DOUBLE_NEGATIVE_INFINITY) {
|
||||||
|
return maxValue;
|
||||||
|
}
|
||||||
|
double sum = 1.0;
|
||||||
|
for (int i = start; i < finish; i++) {
|
||||||
|
const double curVal = log10Values[i];
|
||||||
|
if (i == maxElementIndex || curVal == MathConst::DOUBLE_NEGATIVE_INFINITY) {
|
||||||
|
continue;
|
||||||
|
} else {
|
||||||
|
const double scaled_val = curVal - maxValue;
|
||||||
|
sum += std::pow(10.0, scaled_val);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (std::isinf(sum) || sum == MathConst::DOUBLE_POSITIVE_INFINITY) {
|
||||||
|
// throw new IllegalArgumentException("log10 p: Values must be non-infinite and non-NAN");
|
||||||
|
}
|
||||||
|
return maxValue + (sum != 1.0 ? std::log10(sum) : 0.0);
|
||||||
|
}
|
||||||
|
|
||||||
|
static double log10SumLog10(const double a, const double b) {
|
||||||
|
return a > b ? a + std::log10(1 + std::pow(10.0, b - a)) : b + std::log10(1 + std::pow(10.0, a - b));
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Do the log-sum trick for three double values.
|
||||||
|
* @param a
|
||||||
|
* @param b
|
||||||
|
* @param c
|
||||||
|
* @return the sum... perhaps NaN or infinity if it applies.
|
||||||
|
*/
|
||||||
|
static double log10SumLog10(const double a, const double b, const double c) {
|
||||||
|
if (a >= b && a >= c) {
|
||||||
|
return a + std::log10(1 + std::pow(10.0, b - a) + std::pow(10.0, c - a));
|
||||||
|
} else if (b >= c) {
|
||||||
|
return b + std::log10(1 + std::pow(10.0, a - b) + std::pow(10.0, c - b));
|
||||||
|
} else {
|
||||||
|
return c + std::log10(1 + std::pow(10.0, a - c) + std::pow(10.0, b - c));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
@ -0,0 +1,48 @@
|
||||||
|
/*
|
||||||
|
Description: 正态分布相关函数
|
||||||
|
|
||||||
|
Copyright : All right reserved by ICT
|
||||||
|
|
||||||
|
Author : Zhang Zhonghai
|
||||||
|
Date : 2025/12/26
|
||||||
|
*/
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <math.h>
|
||||||
|
|
||||||
|
struct NormalDistribution {
|
||||||
|
/**
|
||||||
|
* Default inverse cumulative probability accuracy.
|
||||||
|
* @since 2.1
|
||||||
|
*/
|
||||||
|
static constexpr double DEFAULT_INVERSE_ABSOLUTE_ACCURACY = 1e-9;
|
||||||
|
/** √(2) */
|
||||||
|
static constexpr double SQRT2 = 1.414213562373095;
|
||||||
|
|
||||||
|
static constexpr double PI = 105414357.0 / 33554432.0 + 1.984187159361080883e-9;
|
||||||
|
|
||||||
|
/** Mean of this distribution. */
|
||||||
|
double mean;
|
||||||
|
/** Standard deviation of this distribution. */
|
||||||
|
|
||||||
|
double standardDeviation;
|
||||||
|
/** The value of {@code log(sd) + 0.5*log(2*pi)} stored for faster computation. */
|
||||||
|
|
||||||
|
double logStandardDeviationPlusHalfLog2Pi;
|
||||||
|
/** Inverse cumulative probability accuracy. */
|
||||||
|
|
||||||
|
double solverAbsoluteAccuracy;
|
||||||
|
|
||||||
|
NormalDistribution(double mean, double sd) {
|
||||||
|
this->mean = mean;
|
||||||
|
standardDeviation = sd;
|
||||||
|
logStandardDeviationPlusHalfLog2Pi = std::log(sd) + 0.5 * std::log(2 * PI);
|
||||||
|
solverAbsoluteAccuracy = DEFAULT_INVERSE_ABSOLUTE_ACCURACY;
|
||||||
|
}
|
||||||
|
|
||||||
|
double logDensity(double x) const {
|
||||||
|
const double x0 = x - mean;
|
||||||
|
const double x1 = x0 / standardDeviation;
|
||||||
|
return -0.5 * x1 * x1 - logStandardDeviationPlusHalfLog2Pi;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
@ -0,0 +1,13 @@
|
||||||
|
#include "precision.h"
|
||||||
|
|
||||||
|
const double Precision::EPSILON = MathFunc::longBitsToDouble((EXPONENT_OFFSET - 53l) << 52);
|
||||||
|
|
||||||
|
const double Precision::SAFE_MIN = MathFunc::longBitsToDouble((EXPONENT_OFFSET - 1022l) << 52);
|
||||||
|
|
||||||
|
const int64_t Precision::POSITIVE_ZERO_DOUBLE_BITS = MathFunc::doubleToRawLongBits(+0.0);
|
||||||
|
|
||||||
|
const int64_t Precision::NEGATIVE_ZERO_DOUBLE_BITS = MathFunc::doubleToRawLongBits(-0.0);
|
||||||
|
|
||||||
|
const int Precision::POSITIVE_ZERO_FLOAT_BITS = MathFunc::floatToRawIntBits(+0.0f);
|
||||||
|
|
||||||
|
const int Precision::NEGATIVE_ZERO_FLOAT_BITS = MathFunc::floatToRawIntBits(-0.0f);
|
||||||
|
|
@ -0,0 +1,122 @@
|
||||||
|
/*
|
||||||
|
Description: 高精度相关的函数,如比较两个高精度浮点是否相等
|
||||||
|
|
||||||
|
Copyright : All right reserved by ICT
|
||||||
|
|
||||||
|
Author : Zhang Zhonghai
|
||||||
|
Date : 2025/12/26
|
||||||
|
*/
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <cmath>
|
||||||
|
#include <cstdint>
|
||||||
|
|
||||||
|
|
||||||
|
#include "math_func.h"
|
||||||
|
|
||||||
|
struct Precision {
|
||||||
|
/**
|
||||||
|
* <p>
|
||||||
|
* Largest double-precision floating-point number such that
|
||||||
|
* {@code 1 + EPSILON} is numerically equal to 1. This value is an upper
|
||||||
|
* bound on the relative error due to rounding real numbers to double
|
||||||
|
* precision floating-point numbers.
|
||||||
|
* </p>
|
||||||
|
* <p>
|
||||||
|
* In IEEE 754 arithmetic, this is 2<sup>-53</sup>.
|
||||||
|
* </p>
|
||||||
|
*
|
||||||
|
* @see <a href="http://en.wikipedia.org/wiki/Machine_epsilon">Machine epsilon</a>
|
||||||
|
*/
|
||||||
|
static const double EPSILON;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Safe minimum, such that {@code 1 / SAFE_MIN} does not overflow.
|
||||||
|
* <br/>
|
||||||
|
* In IEEE 754 arithmetic, this is also the smallest normalized
|
||||||
|
* number 2<sup>-1022</sup>.
|
||||||
|
*/
|
||||||
|
static const double SAFE_MIN;
|
||||||
|
|
||||||
|
/** Exponent offset in IEEE754 representation. */
|
||||||
|
static constexpr int64_t EXPONENT_OFFSET = 1023l;
|
||||||
|
|
||||||
|
/** Offset to order signed double numbers lexicographically. */
|
||||||
|
static constexpr int64_t SGN_MASK = 0x8000000000000000L;
|
||||||
|
/** Offset to order signed double numbers lexicographically. */
|
||||||
|
static constexpr int SGN_MASK_FLOAT = 0x80000000;
|
||||||
|
/** Positive zero. */
|
||||||
|
static constexpr double POSITIVE_ZERO = 0.0;
|
||||||
|
/** Positive zero bits. */
|
||||||
|
static const int64_t POSITIVE_ZERO_DOUBLE_BITS;
|
||||||
|
/** Negative zero bits. */
|
||||||
|
static const int64_t NEGATIVE_ZERO_DOUBLE_BITS;
|
||||||
|
/** Positive zero bits. */
|
||||||
|
static const int POSITIVE_ZERO_FLOAT_BITS;
|
||||||
|
/** Negative zero bits. */
|
||||||
|
static const int NEGATIVE_ZERO_FLOAT_BITS;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns {@code true} if there is no double value strictly between the
|
||||||
|
* arguments or the difference between them is within the range of allowed
|
||||||
|
* error (inclusive).
|
||||||
|
*
|
||||||
|
* @param x First value.
|
||||||
|
* @param y Second value.
|
||||||
|
* @param eps Amount of allowed absolute error.
|
||||||
|
* @return {@code true} if the values are two adjacent floating point
|
||||||
|
* numbers or they are within range of each other.
|
||||||
|
*/
|
||||||
|
static bool equals(double x, double y, double eps) { return equals(x, y, 1) || std::abs(y - x) <= eps; }
|
||||||
|
/**
|
||||||
|
* Returns true if both arguments are equal or within the range of allowed
|
||||||
|
* error (inclusive).
|
||||||
|
* <p>
|
||||||
|
* Two float numbers are considered equal if there are {@code (maxUlps - 1)}
|
||||||
|
* (or fewer) floating point numbers between them, i.e. two adjacent
|
||||||
|
* floating point numbers are considered equal.
|
||||||
|
* </p>
|
||||||
|
* <p>
|
||||||
|
* Adapted from <a
|
||||||
|
* href="http://randomascii.wordpress.com/2012/02/25/comparing-floating-point-numbers-2012-edition/">
|
||||||
|
* Bruce Dawson</a>
|
||||||
|
* </p>
|
||||||
|
*
|
||||||
|
* @param x first value
|
||||||
|
* @param y second value
|
||||||
|
* @param maxUlps {@code (maxUlps - 1)} is the number of floating point
|
||||||
|
* values between {@code x} and {@code y}.
|
||||||
|
* @return {@code true} if there are fewer than {@code maxUlps} floating
|
||||||
|
* point values between {@code x} and {@code y}.
|
||||||
|
*/
|
||||||
|
static bool equals(const double x, const double y, const int maxUlps) {
|
||||||
|
const int64_t xInt = MathFunc::doubleToRawLongBits(x);
|
||||||
|
const int64_t yInt = MathFunc::doubleToRawLongBits(y);
|
||||||
|
|
||||||
|
bool isEqual;
|
||||||
|
|
||||||
|
if (((xInt ^ yInt) & SGN_MASK) == 0l) {
|
||||||
|
// number have same sign, there is no risk of overflow
|
||||||
|
isEqual = std::abs(xInt - yInt) <= maxUlps;
|
||||||
|
} else {
|
||||||
|
// number have opposite signs, take care of overflow
|
||||||
|
int64_t deltaPlus;
|
||||||
|
int64_t deltaMinus;
|
||||||
|
if (xInt < yInt) {
|
||||||
|
deltaPlus = yInt - POSITIVE_ZERO_DOUBLE_BITS;
|
||||||
|
deltaMinus = xInt - NEGATIVE_ZERO_DOUBLE_BITS;
|
||||||
|
} else {
|
||||||
|
deltaPlus = xInt - POSITIVE_ZERO_DOUBLE_BITS;
|
||||||
|
deltaMinus = yInt - NEGATIVE_ZERO_DOUBLE_BITS;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (deltaPlus > maxUlps) {
|
||||||
|
isEqual = false;
|
||||||
|
} else {
|
||||||
|
isEqual = deltaMinus <= (maxUlps - deltaPlus);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return isEqual && !std::isnan(x) && !std::isnan(y);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
@ -0,0 +1,37 @@
|
||||||
|
#include "saddle_pe.h"
|
||||||
|
|
||||||
|
const double SaddlePointExpansion::HALF_LOG_2_PI = 0.5 * std::log(TWO_PI);
|
||||||
|
|
||||||
|
const double SaddlePointExpansion::EXACT_STIRLING_ERRORS[31] = {
|
||||||
|
0.0, /* 0.0 */
|
||||||
|
0.1534264097200273452913848, /* 0.5 */
|
||||||
|
0.0810614667953272582196702, /* 1.0 */
|
||||||
|
0.0548141210519176538961390, /* 1.5 */
|
||||||
|
0.0413406959554092940938221, /* 2.0 */
|
||||||
|
0.03316287351993628748511048, /* 2.5 */
|
||||||
|
0.02767792568499833914878929, /* 3.0 */
|
||||||
|
0.02374616365629749597132920, /* 3.5 */
|
||||||
|
0.02079067210376509311152277, /* 4.0 */
|
||||||
|
0.01848845053267318523077934, /* 4.5 */
|
||||||
|
0.01664469118982119216319487, /* 5.0 */
|
||||||
|
0.01513497322191737887351255, /* 5.5 */
|
||||||
|
0.01387612882307074799874573, /* 6.0 */
|
||||||
|
0.01281046524292022692424986, /* 6.5 */
|
||||||
|
0.01189670994589177009505572, /* 7.0 */
|
||||||
|
0.01110455975820691732662991, /* 7.5 */
|
||||||
|
0.010411265261972096497478567, /* 8.0 */
|
||||||
|
0.009799416126158803298389475, /* 8.5 */
|
||||||
|
0.009255462182712732917728637, /* 9.0 */
|
||||||
|
0.008768700134139385462952823, /* 9.5 */
|
||||||
|
0.008330563433362871256469318, /* 10.0 */
|
||||||
|
0.007934114564314020547248100, /* 10.5 */
|
||||||
|
0.007573675487951840794972024, /* 11.0 */
|
||||||
|
0.007244554301320383179543912, /* 11.5 */
|
||||||
|
0.006942840107209529865664152, /* 12.0 */
|
||||||
|
0.006665247032707682442354394, /* 12.5 */
|
||||||
|
0.006408994188004207068439631, /* 13.0 */
|
||||||
|
0.006171712263039457647532867, /* 13.5 */
|
||||||
|
0.005951370112758847735624416, /* 14.0 */
|
||||||
|
0.005746216513010115682023589, /* 14.5 */
|
||||||
|
0.005554733551962801371038690 /* 15.0 */
|
||||||
|
};
|
||||||
|
|
@ -0,0 +1,128 @@
|
||||||
|
/*
|
||||||
|
Description: saddle point expansion 鞍点展开
|
||||||
|
|
||||||
|
Copyright : All right reserved by ICT
|
||||||
|
|
||||||
|
Author : Zhang Zhonghai
|
||||||
|
Date : 2025/12/26
|
||||||
|
*/
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <math.h>
|
||||||
|
#include <limits>
|
||||||
|
#include "gamma.h"
|
||||||
|
|
||||||
|
struct SaddlePointExpansion {
|
||||||
|
static constexpr double TWO_PI = 2 * (105414357.0 / 33554432.0 + 1.984187159361080883e-9);
|
||||||
|
|
||||||
|
/** 1/2 * log(2 π). */
|
||||||
|
static const double HALF_LOG_2_PI;
|
||||||
|
|
||||||
|
/** exact Stirling expansion error for certain values. */
|
||||||
|
static const double EXACT_STIRLING_ERRORS[31];
|
||||||
|
/**
|
||||||
|
* A part of the deviance portion of the saddle point approximation.
|
||||||
|
* <p>
|
||||||
|
* References:
|
||||||
|
* <ol>
|
||||||
|
* <li>Catherine Loader (2000). "Fast and Accurate Computation of Binomial
|
||||||
|
* Probabilities.". <a target="_blank"
|
||||||
|
* href="http://www.herine.net/stat/papers/dbinom.pdf">
|
||||||
|
* http://www.herine.net/stat/papers/dbinom.pdf</a></li>
|
||||||
|
* </ol>
|
||||||
|
* </p>
|
||||||
|
*
|
||||||
|
* @param x the x value.
|
||||||
|
* @param mu the average.
|
||||||
|
* @return a part of the deviance.
|
||||||
|
*/
|
||||||
|
static double getDeviancePart(double x, double mu) {
|
||||||
|
double ret;
|
||||||
|
if (std::abs(x - mu) < 0.1 * (x + mu)) {
|
||||||
|
double d = x - mu;
|
||||||
|
double v = d / (x + mu);
|
||||||
|
double s1 = v * d;
|
||||||
|
double s = std::numeric_limits<double>::quiet_NaN();
|
||||||
|
double ej = 2.0 * x * v;
|
||||||
|
v *= v;
|
||||||
|
int j = 1;
|
||||||
|
while (s1 != s) {
|
||||||
|
s = s1;
|
||||||
|
ej *= v;
|
||||||
|
s1 = s + ej / ((j * 2) + 1);
|
||||||
|
++j;
|
||||||
|
}
|
||||||
|
ret = s1;
|
||||||
|
} else {
|
||||||
|
ret = x * std::log(x / mu) + mu - x;
|
||||||
|
}
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Compute the error of Stirling's series at the given value.
|
||||||
|
* <p>
|
||||||
|
* References:
|
||||||
|
* <ol>
|
||||||
|
* <li>Eric W. Weisstein. "Stirling's Series." From MathWorld--A Wolfram Web
|
||||||
|
* Resource. <a target="_blank"
|
||||||
|
* href="http://mathworld.wolfram.com/StirlingsSeries.html">
|
||||||
|
* http://mathworld.wolfram.com/StirlingsSeries.html</a></li>
|
||||||
|
* </ol>
|
||||||
|
* </p>
|
||||||
|
*
|
||||||
|
* @param z the value.
|
||||||
|
* @return the Striling's series error.
|
||||||
|
*/
|
||||||
|
static double getStirlingError(double z) {
|
||||||
|
double ret;
|
||||||
|
if (z < 15.0) {
|
||||||
|
double z2 = 2.0 * z;
|
||||||
|
if (std::floor(z2) == z2) {
|
||||||
|
ret = EXACT_STIRLING_ERRORS[(int)z2];
|
||||||
|
} else {
|
||||||
|
ret = Gamma::logGamma(z + 1.0) - (z + 0.5) * std::log(z) + z - HALF_LOG_2_PI;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
double z2 = z * z;
|
||||||
|
ret = (0.083333333333333333333 -
|
||||||
|
(0.00277777777777777777778 -
|
||||||
|
(0.00079365079365079365079365 - (0.000595238095238095238095238 - 0.0008417508417508417508417508 / z2) / z2) / z2) /
|
||||||
|
z2) /
|
||||||
|
z;
|
||||||
|
}
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Compute the logarithm of the PMF for a binomial distribution
|
||||||
|
* using the saddle point expansion.
|
||||||
|
*
|
||||||
|
* @param x the value at which the probability is evaluated.
|
||||||
|
* @param n the number of trials.
|
||||||
|
* @param p the probability of success.
|
||||||
|
* @param q the probability of failure (1 - p).
|
||||||
|
* @return log(p(x)).
|
||||||
|
*/
|
||||||
|
static double logBinomialProbability(int x, int n, double p, double q) {
|
||||||
|
double ret;
|
||||||
|
if (x == 0) {
|
||||||
|
if (p < 0.1) {
|
||||||
|
ret = -getDeviancePart(n, n * q) - n * p;
|
||||||
|
} else {
|
||||||
|
ret = n * std::log(q);
|
||||||
|
}
|
||||||
|
} else if (x == n) {
|
||||||
|
if (q < 0.1) {
|
||||||
|
ret = -getDeviancePart(n, n * p) - n * q;
|
||||||
|
} else {
|
||||||
|
ret = n * std::log(p);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
ret = getStirlingError(n) - getStirlingError(x) - getStirlingError(n - x) - getDeviancePart(x, n * p) - getDeviancePart(n - x, n * q);
|
||||||
|
double f = (TWO_PI * x * (n - x)) / n;
|
||||||
|
ret = -0.5 * std::log(f) + ret;
|
||||||
|
}
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
@ -57,6 +57,7 @@ int DisplayProfiling(int nthread) {
|
||||||
PRINT_GP(read_ref);
|
PRINT_GP(read_ref);
|
||||||
PRINT_GP(read_vcf);
|
PRINT_GP(read_vcf);
|
||||||
PRINT_GP(covariate);
|
PRINT_GP(covariate);
|
||||||
|
PRINT_GP(update_info);
|
||||||
//PRINT_GP(markdup);
|
//PRINT_GP(markdup);
|
||||||
//PRINT_GP(intersect);
|
//PRINT_GP(intersect);
|
||||||
// PRINT_GP(merge_result);
|
// PRINT_GP(merge_result);
|
||||||
|
|
|
||||||
|
|
@ -43,6 +43,7 @@ enum {
|
||||||
GP_covariate,
|
GP_covariate,
|
||||||
GP_read_ref,
|
GP_read_ref,
|
||||||
GP_read_vcf,
|
GP_read_vcf,
|
||||||
|
GP_update_info,
|
||||||
GP_gen_wait,
|
GP_gen_wait,
|
||||||
GP_sort_wait,
|
GP_sort_wait,
|
||||||
GP_markdup_wait,
|
GP_markdup_wait,
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,8 @@
|
||||||
|
#include "report_table.h"
|
||||||
|
|
||||||
|
const string ReportTable::COULD_NOT_READ_HEADER = "Could not read the header of this file -- ";
|
||||||
|
const string ReportTable::COULD_NOT_READ_COLUMN_NAMES = "Could not read the column names of this file -- ";
|
||||||
|
const string ReportTable::COULD_NOT_READ_DATA_LINE = "Could not read a data line of this table -- ";
|
||||||
|
const string ReportTable::COULD_NOT_READ_EMPTY_LINE = "Could not read the last empty line of this table -- ";
|
||||||
|
const string ReportTable::OLD_GATK_TABLE_VERSION = "We no longer support older versions of the GATK Tables";
|
||||||
|
const char* ReportTable::TABLE_PREFIX = "GATKTable";
|
||||||
|
|
@ -0,0 +1,253 @@
|
||||||
|
/*
|
||||||
|
Description: 生成数据报表
|
||||||
|
|
||||||
|
Copyright : All right reserved by ICT
|
||||||
|
|
||||||
|
Author : Zhang Zhonghai
|
||||||
|
Date : 2025/12/27
|
||||||
|
*/
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
#include <cstdio>
|
||||||
|
#include <cassert>
|
||||||
|
#include <algorithm>
|
||||||
|
#include <sstream>
|
||||||
|
#include <iomanip>
|
||||||
|
|
||||||
|
#include <spdlog/spdlog.h>
|
||||||
|
|
||||||
|
using std::string;
|
||||||
|
using std::stringstream;
|
||||||
|
using std::vector;
|
||||||
|
|
||||||
|
#define REPORT_HEADER_VERSION "#:GATKReport.v1.1:5"
|
||||||
|
|
||||||
|
struct ReportUtil {
|
||||||
|
static string ToString(const bool val) { return val ? "true" : "false"; }
|
||||||
|
static string ToString(const char val) {
|
||||||
|
string s = "";
|
||||||
|
s += val;
|
||||||
|
// spdlog::info("char: {}, str: {}", val, s);
|
||||||
|
return s;
|
||||||
|
}
|
||||||
|
static string ToString(const string& val) { return val == "" ? "null" : val; }
|
||||||
|
static string ToString(const double val, int precise) {
|
||||||
|
stringstream ss;
|
||||||
|
ss << std::fixed << std::setprecision(precise) << val;
|
||||||
|
return ss.str();
|
||||||
|
}
|
||||||
|
template<typename T>
|
||||||
|
static string ToString(const T val) { return std::to_string(val); }
|
||||||
|
};
|
||||||
|
|
||||||
|
struct ReportTable {
|
||||||
|
// 静态数据
|
||||||
|
|
||||||
|
static const string COULD_NOT_READ_HEADER;
|
||||||
|
static const string COULD_NOT_READ_COLUMN_NAMES;
|
||||||
|
static const string COULD_NOT_READ_DATA_LINE;
|
||||||
|
static const string COULD_NOT_READ_EMPTY_LINE;
|
||||||
|
static const string OLD_GATK_TABLE_VERSION;
|
||||||
|
static const char* TABLE_PREFIX;
|
||||||
|
|
||||||
|
// 数据表排序方式
|
||||||
|
// 好像目前都是按照SORT_BY_COLUMN排序的,这个排序是以此按照每一个cloumn内的所有行数据进行排序,从左到右,目的还是排列行的数据
|
||||||
|
// SORT_BY_ROW应该就是按照row id排序的
|
||||||
|
enum struct Sorting { SORT_BY_ROW, SORT_BY_COLUMN, DO_NOT_SORT };
|
||||||
|
// 数据类型
|
||||||
|
enum struct DataType {
|
||||||
|
/**
|
||||||
|
* The null type should not be used.
|
||||||
|
*/
|
||||||
|
NullType,
|
||||||
|
/**
|
||||||
|
* The default value when a format string is not present
|
||||||
|
*/
|
||||||
|
Unknown,
|
||||||
|
/**
|
||||||
|
* Used for boolean values. Will display as true or false in the table.
|
||||||
|
*/
|
||||||
|
Boolean,
|
||||||
|
/**
|
||||||
|
* Used for char values. Will display as a char so use printable values!
|
||||||
|
*/
|
||||||
|
Character,
|
||||||
|
/**
|
||||||
|
* Used for float and double values. Will output a decimal with format %.8f unless otherwise specified.
|
||||||
|
*/
|
||||||
|
Decimal,
|
||||||
|
/**
|
||||||
|
* Used for int, byte, short, and long values. Will display the full number by default.
|
||||||
|
*/
|
||||||
|
Integer,
|
||||||
|
/**
|
||||||
|
* Used for string values. Displays the string itself.
|
||||||
|
*/
|
||||||
|
String
|
||||||
|
};
|
||||||
|
// 列类型
|
||||||
|
struct Column {
|
||||||
|
// 只有column内容全是数字时(或者nan,inf,null这些),可以右对齐,剩下的字符或者字符串只能左对齐
|
||||||
|
enum struct Alignment {LEFT, RIGHT};
|
||||||
|
string columnName;
|
||||||
|
string format;
|
||||||
|
DataType dataType;
|
||||||
|
Alignment alignment = Alignment::RIGHT;
|
||||||
|
int maxWidth = 0;
|
||||||
|
|
||||||
|
Column() {}
|
||||||
|
Column(const string& name) : Column(name, "%s") {}
|
||||||
|
Column(const string& name, const string& _format) : Column(name, _format, DataType::String) {}
|
||||||
|
Column(const string& name, const string& _format, const DataType _dataType) : Column(name, _format, _dataType, Alignment::RIGHT) {}
|
||||||
|
Column(const string& name, const string& _format, const DataType _dataType, const Alignment align) { init(name, _format, _dataType, align); }
|
||||||
|
|
||||||
|
void init(const string& name, const string &_format, const DataType _dataType, const Alignment align) {
|
||||||
|
columnName = name;
|
||||||
|
format = _format;
|
||||||
|
dataType = _dataType;
|
||||||
|
alignment = align;
|
||||||
|
maxWidth = name.size();
|
||||||
|
updateFormat(*this); // format和datatype要一致
|
||||||
|
}
|
||||||
|
|
||||||
|
void updateMaxWidth(const string &value) {
|
||||||
|
if (!value.empty()) {
|
||||||
|
maxWidth = std::max(maxWidth, (int)value.size());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 根据col的format更新一下对齐等信息
|
||||||
|
static void updateFormat(Column &col) {
|
||||||
|
if (col.format == "") {
|
||||||
|
col.dataType = DataType::Unknown;
|
||||||
|
col.format = "%s";
|
||||||
|
} else if (col.format.find('s') != string::npos) {
|
||||||
|
col.dataType = DataType::String;
|
||||||
|
} else if (col.format.find('d') != string::npos) {
|
||||||
|
col.dataType = DataType::Integer;
|
||||||
|
} else if (col.format.find('f') != string::npos) {
|
||||||
|
col.dataType = DataType::Decimal;
|
||||||
|
}
|
||||||
|
// spdlog::info("type: {}, align {}", col.dataType == DataType::Integer, col.alignment == Alignment::LEFT);
|
||||||
|
// 只有数字可以右对齐
|
||||||
|
if (col.dataType != DataType::Decimal && col.dataType != DataType::Integer)
|
||||||
|
col.alignment = Column::Alignment::LEFT;
|
||||||
|
}
|
||||||
|
|
||||||
|
string getNameFormat() { return "%-" + std::to_string(maxWidth) + "s";}
|
||||||
|
|
||||||
|
string getValueFormat() {
|
||||||
|
string align = alignment == Alignment::LEFT ? "-" : "";
|
||||||
|
return "%" + align + std::to_string(maxWidth) + "s"; // 因为所有值都转换成string类型了
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
using RowData = vector<string>; // 每一行的数据
|
||||||
|
|
||||||
|
string tableName; // 数据表名称
|
||||||
|
string tableDescription; // 数据表描述
|
||||||
|
Sorting sortingWay; // 排序方式
|
||||||
|
vector<Column> columnInfo; // 包含数据表的所有列信息
|
||||||
|
vector<RowData> underlyingData; // 保存数据表所有数据
|
||||||
|
vector<size_t> idxArr; // 保存data的索引,用来排序
|
||||||
|
|
||||||
|
ReportTable() {}
|
||||||
|
ReportTable(const string& _tableName) : ReportTable(_tableName, "") {}
|
||||||
|
ReportTable(const string& _tableName, const string& _tableDescription) : ReportTable(_tableName, _tableDescription, Sorting::SORT_BY_COLUMN) {}
|
||||||
|
ReportTable(const string& _tableName, const string& _tableDescription, const Sorting sorting) { init(_tableName, _tableDescription, sorting); }
|
||||||
|
|
||||||
|
// 初始化
|
||||||
|
void init(const string& _tableName, const string& _tableDescription, const Sorting sorting) {
|
||||||
|
tableName = _tableName;
|
||||||
|
tableDescription = _tableDescription;
|
||||||
|
sortingWay = sorting;
|
||||||
|
}
|
||||||
|
|
||||||
|
// 添加column
|
||||||
|
void addColumn(Column col) {
|
||||||
|
// spdlog::info("col align: {}", col.alignment == Column::Alignment::RIGHT);
|
||||||
|
Column::updateFormat(col);
|
||||||
|
columnInfo.push_back(col);
|
||||||
|
}
|
||||||
|
|
||||||
|
void addRowData(const RowData &dat) {
|
||||||
|
assert(dat.size() == columnInfo.size());
|
||||||
|
idxArr.push_back(underlyingData.size());
|
||||||
|
underlyingData.push_back(dat);
|
||||||
|
for (int i = 0; i < dat.size(); ++i) {
|
||||||
|
columnInfo[i].updateMaxWidth(dat[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void sort() {
|
||||||
|
std::sort(idxArr.begin(), idxArr.end(), [&](size_t a, size_t b) {
|
||||||
|
auto& r1 = underlyingData[a];
|
||||||
|
auto& r2 = underlyingData[b];
|
||||||
|
for (int i = 0; i < r1.size(); ++i) {
|
||||||
|
if (r1[i] != r2[i]) {
|
||||||
|
if (columnInfo[i].format.find('d') != string::npos) // 或者用datatype来判断
|
||||||
|
return std::stoll(r1[i]) < std::stoll(r2[i]);
|
||||||
|
else if (columnInfo[i].format.find('f') != string::npos) {
|
||||||
|
return std::stod(r1[i]) < std::stod(r2[i]);
|
||||||
|
} else {
|
||||||
|
return r1[i] < r2[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return a < b;
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
// 输出
|
||||||
|
void write(FILE *fpout) {
|
||||||
|
// 表头,列数:行数:列格式
|
||||||
|
fprintf(fpout, "#:%s:%d:%d:", TABLE_PREFIX, (int)columnInfo.size(), (int)underlyingData.size());
|
||||||
|
for (auto& col : columnInfo) fprintf(fpout, "%s:", col.format.c_str());
|
||||||
|
fprintf(fpout, ";\n");
|
||||||
|
// 表的说明
|
||||||
|
fprintf(fpout, "#:%s:%s:%s\n", TABLE_PREFIX, tableName.c_str(), tableDescription.c_str());
|
||||||
|
// 列名称
|
||||||
|
bool needPadding = false;
|
||||||
|
for (auto& col : columnInfo) {
|
||||||
|
if (needPadding)
|
||||||
|
fprintf(fpout, " ");
|
||||||
|
needPadding = true;
|
||||||
|
fprintf(fpout, col.getNameFormat().c_str(), col.columnName.c_str());
|
||||||
|
}
|
||||||
|
fprintf(fpout, "\n");
|
||||||
|
// 表的所有行数据
|
||||||
|
// 先排个序
|
||||||
|
sort();
|
||||||
|
for (auto idx : idxArr) {
|
||||||
|
auto &row = underlyingData[idx];
|
||||||
|
bool needPadding = false;
|
||||||
|
for (int i = 0; i < row.size(); ++i) {
|
||||||
|
if (needPadding)
|
||||||
|
fprintf(fpout, " ");
|
||||||
|
needPadding = true;
|
||||||
|
fprintf(fpout, columnInfo[i].getValueFormat().c_str(), row[i].c_str());
|
||||||
|
}
|
||||||
|
fprintf(fpout, "\n");
|
||||||
|
}
|
||||||
|
fprintf(fpout, "\n");
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
/*
|
||||||
|
// 输出参数相关的数据
|
||||||
|
struct ArgReportTable : ReportTable {
|
||||||
|
using ReportTable::ReportTable;
|
||||||
|
};
|
||||||
|
|
||||||
|
// 输出quantized quality相关的数据
|
||||||
|
struct QuantReportTable : ReportTable {
|
||||||
|
|
||||||
|
};
|
||||||
|
|
||||||
|
// 输出协变量相关的数据
|
||||||
|
struct CovariateReportTable : ReportTable {
|
||||||
|
|
||||||
|
};
|
||||||
|
*/
|
||||||
|
|
@ -1,4 +1,11 @@
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
using std::string;
|
||||||
|
|
||||||
// Report an error and exit -1
|
// Report an error and exit -1
|
||||||
void error(const char* format, ...);
|
void error(const char* format, ...);
|
||||||
|
|
||||||
|
struct Utils {
|
||||||
|
};
|
||||||
1090
test/stats.table
1090
test/stats.table
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue