10x speedup of recalibration walker

git-svn-id: file:///humgen/gsa-scr1/gsa-engineering/svn_contents/trunk@954 348d0f76-0448-11de-a6fe-93d51630548a
This commit is contained in:
depristo 2009-06-09 15:39:40 +00:00
parent a62bc6b05d
commit 7fa84ea157
8 changed files with 205 additions and 40 deletions

View File

@ -271,7 +271,7 @@ public class ReferenceOrderedData<ROD extends ReferenceOrderedDatum> implements
do {
final String line = parser.next();
//System.out.printf("Line is %s%n", line);
//System.out.printf("Line is '%s'%n", line);
String parts[] = line.split(fieldDelimiter);
try {

View File

@ -181,6 +181,9 @@ public class rodDbSNP extends BasicReferenceOrderedDatum implements AllelicVaria
} catch( MalformedGenomeLocException ex ) {
// Just rethrow malformed genome locs; the ROD system itself will deal with these.
throw ex;
} catch( ArrayIndexOutOfBoundsException ex ) {
// Just rethrow malformed genome locs; the ROD system itself will deal with these.
throw new RuntimeException("Badly formed dbSNP line: " + ex);
} catch ( RuntimeException e ) {
System.out.printf(" Exception caught during parsing DBSNP line %s%n", Utils.join(" <=> ", parts));
throw e;

View File

@ -19,9 +19,21 @@ public class LogisticRecalibrationWalker extends ReadWalker<SAMRecord, SAMFileWr
@Argument(shortName="outputBAM", doc="output BAM file", required=false)
public String outputBamFile = null;
@Argument(shortName="useCache", doc="If true, uses high-performance caching of logistic regress results. Experimental", required=false)
public boolean useLogisticCache = true;
Map<Pair<String,String>, LogisticRegressor> regressors = new HashMap<Pair<String,String>, LogisticRegressor>();
private static Logger logger = Logger.getLogger(LogisticRecalibrationWalker.class);
// maps from [readGroup] -> [prevBase x base -> [cycle, qual, new qual]]
HashMap<String, HashMap<String, byte[][]>> cache = new HashMap<String, HashMap<String, byte[][]>>();
private static byte MAX_Q_SCORE = 64;
@Argument(shortName="maxReadLen", doc="Maximum allowed read length to allow during recalibration, needed for recalibration table allocation", required=false)
public static int maxReadLen = 125;
public void initialize() {
try {
List<String> lines = new xReadLines(new File(logisticParamsFile)).readLines();
@ -46,12 +58,19 @@ public class LogisticRecalibrationWalker extends ReadWalker<SAMRecord, SAMFileWr
}
}
if ( useLogisticCache ) System.out.printf("Building recalibration cache%n");
for ( Map.Entry<Pair<String,String>, LogisticRegressor> e : regressors.entrySet() ) {
String readGroup = e.getKey().first;
String dinuc = e.getKey().second;
LogisticRegressor regressor = e.getValue();
logger.debug(String.format("Regressor: %s,%s => %s", readGroup, dinuc, regressor));
if ( useLogisticCache ) {
addToLogisticCache(readGroup, dinuc, regressor);
}
}
if ( useLogisticCache ) System.out.printf("done%n");
//System.exit(1);
} catch ( FileNotFoundException e ) {
@ -89,8 +108,51 @@ public class LogisticRecalibrationWalker extends ReadWalker<SAMRecord, SAMFileWr
return mapping;
}
private void addToLogisticCache(final String readGroup, final String dinuc, LogisticRegressor regressor) {
System.out.printf("%s x %s ", readGroup, dinuc);
byte[][] dataTable = new byte[maxReadLen][MAX_Q_SCORE];
for ( int cycle = 1; cycle < maxReadLen; cycle++ ) {
for ( byte qual = 0; qual < MAX_Q_SCORE; qual++ ) {
dataTable[cycle][qual] = regressor2newQual(regressor, cycle, qual);
}
}
HashMap<String, byte[][]> lookup1 = cache.containsKey(readGroup) ? cache.get(readGroup) : new HashMap<String, byte[][]>();
lookup1.put(dinuc, dataTable);
cache.put(readGroup, lookup1);
}
public SAMRecord map(char[] ref, SAMRecord read) {
if ( useLogisticCache )
return mapCached(ref, read);
else
return mapOriginal(ref, read);
}
private byte cache2newQual(final String readGroup, HashMap<String, byte[][]> RGcache, byte prevBase, byte base, LogisticRegressor regressor, int cycle, byte qual) {
//System.out.printf("Lookup %s %c %c %d %d%n", readGroup, prevBase, base, cycle, qual);
//String dinuc = String.format("%c%c", (char)prevBase, (char)base);
byte[] bp = {prevBase, base};
String dinuc = new String(bp);
//byte newQualCalc = regressor2newQual(regressor, cycle, qual);
byte[][] dataTable = RGcache.get(dinuc);
byte newQualCached = dataTable != null ? dataTable[cycle][qual] : qual;
//if ( newQualCached != newQualCalc ) {
// throw new RuntimeException(String.format("Inconsistent quals between the cache and calculation for RG=%s: %s %d %d : %d <> %d",
// readGroup, dinuc, cycle, qual, newQualCalc, newQualCached));
//}
return newQualCached;
}
public SAMRecord mapCached(char[] ref, SAMRecord read) {
if ( read.getReadLength() > maxReadLen ) {
throw new RuntimeException("Expectedly long read, please increase maxium read len with maxReadLen parameter: " + read.format());
}
SAMRecord recalRead = read;
byte[] bases = read.getReadBases();
byte[] quals = read.getBaseQualities();
@ -102,34 +164,20 @@ public class LogisticRecalibrationWalker extends ReadWalker<SAMRecord, SAMFileWr
quals = BaseUtils.reverse(quals);
}
String readGroup = read.getAttribute("RG").toString();
String readGroup = read.getAttribute("RG").toString();
HashMap<String, byte[][]> RGcache = cache.get(readGroup);
int numBases = read.getReadLength();
recalQuals[0] = quals[0]; // can't change the first -- no dinuc
//recalQuals[numBases-1] = quals[numBases-1]; // can't change last -- no dinuc
for ( int cycle = 1; cycle < numBases; cycle++ ) { // skip first and last base, qual already set because no dinuc
// Take into account that previous base is the next base in terms of machine chemistry if this is a negative strand
//int cycle = i; //read.getReadNegativeStrandFlag() ? numBases - i - 1 : i;
String dinuc = String.format("%c%c", bases[cycle - 1], bases[cycle]);
// Take into account that previous base is the next base in terms of machine chemistry if
// this is a negative strand
byte qual = quals[cycle];
LogisticRegressor regressor = regressors.get(new Pair<String,String>(readGroup,dinuc));
byte newQual;
if ( regressor != null ) { // no N or some other unexpected bp in the stream
double gamma = regressor.regress((double)cycle+1, (double)qual);
double expGamma = Math.exp(gamma);
double finalP = expGamma / (1+expGamma);
newQual = QualityUtils.probToQual(1-finalP);
//newQual = -10 * Math.round(logPOver1minusP)
/*double POver1minusP = Math.pow(10, logPOver1minusP);
P = POver1minusP / (1 + POver1minusP);*/
//newQual = QualityUtils.probToQual(P);
//newQual = (byte)Math.min(Math.round(-10*logPOver1minusP),63);
//System.out.printf("Recal %s %d %d %d%n", dinuc, cycle, qual, newQual);
}else{
newQual = qual;
}
//LogisticRegressor regressor = getLogisticRegressor(readGroup, bases[cycle - 1], bases[cycle]);
LogisticRegressor regressor = null;
byte newQual = cache2newQual(readGroup, RGcache, bases[cycle - 1], bases[cycle], regressor, cycle, qual);
recalQuals[cycle] = newQual;
}
@ -141,6 +189,112 @@ public class LogisticRecalibrationWalker extends ReadWalker<SAMRecord, SAMFileWr
return recalRead;
}
// ----------------------------------------------------------------------------------------------------
//
// Old-style, expensive recalibrator
//
// ----------------------------------------------------------------------------------------------------
private LogisticRegressor getLogisticRegressor(final String readGroup, byte prevBase, byte base) {
String dinuc = String.format("%c%c", (char)prevBase, (char)base);
return regressors.get(new Pair<String,String>(readGroup,dinuc));
}
private byte regressor2newQual(LogisticRegressor regressor, int cycle, byte qual) {
byte newQual = qual;
if ( regressor != null ) { // no N or some other unexpected bp in the stream
double gamma = regressor.regress((double)cycle+1, (double)qual);
double expGamma = Math.exp(gamma);
double finalP = expGamma / (1+expGamma);
newQual = QualityUtils.probToQual(1-finalP);
}
return newQual;
}
public SAMRecord mapOriginal(char[] ref, SAMRecord read) {
SAMRecord recalRead = read;
byte[] bases = read.getReadBases();
byte[] quals = read.getBaseQualities();
byte[] recalQuals = new byte[quals.length];
// Since we want machine direction reads not corrected positive strand reads, rev comp any negative strand reads
if (read.getReadNegativeStrandFlag()) {
bases = BaseUtils.simpleReverseComplement(bases);
quals = BaseUtils.reverse(quals);
}
String readGroup = read.getAttribute("RG").toString();
int numBases = read.getReadLength();
recalQuals[0] = quals[0]; // can't change the first -- no dinuc
for ( int cycle = 1; cycle < numBases; cycle++ ) { // skip first and last base, qual already set because no dinuc
// Take into account that previous base is the next base in terms of machine chemistry if
// this is a negative strand
byte qual = quals[cycle];
LogisticRegressor regressor = getLogisticRegressor(readGroup, bases[cycle - 1], bases[cycle]);
byte newQual = regressor2newQual(regressor, cycle, qual);
recalQuals[cycle] = newQual;
}
if (read.getReadNegativeStrandFlag())
recalQuals = BaseUtils.reverse(quals);
//System.out.printf("OLD: %s%n", read.format());
read.setBaseQualities(recalQuals);
//System.out.printf("NEW: %s%n", read.format());
return recalRead;
}
// public SAMRecord mapOriginalUnmodified(char[] ref, SAMRecord read) {
// SAMRecord recalRead = read;
// byte[] bases = read.getReadBases();
// byte[] quals = read.getBaseQualities();
// byte[] recalQuals = new byte[quals.length];
//
// // Since we want machine direction reads not corrected positive strand reads, rev comp any negative strand reads
// if (read.getReadNegativeStrandFlag()) {
// bases = BaseUtils.simpleReverseComplement(bases);
// quals = BaseUtils.reverse(quals);
// }
//
// String readGroup = read.getAttribute("RG").toString();
// int numBases = read.getReadLength();
// recalQuals[0] = quals[0]; // can't change the first -- no dinuc
// //recalQuals[numBases-1] = quals[numBases-1]; // can't change last -- no dinuc
// for ( int cycle = 1; cycle < numBases; cycle++ ) { // skip first and last base, qual already set because no dinuc
// // Take into account that previous base is the next base in terms of machine chemistry if this is a negative strand
// //int cycle = i; //read.getReadNegativeStrandFlag() ? numBases - i - 1 : i;
// String dinuc = String.format("%c%c", bases[cycle - 1], bases[cycle]);
// byte qual = quals[cycle];
// LogisticRegressor regressor = regressors.get(new Pair<String,String>(readGroup,dinuc));
// byte newQual;
//
// if ( regressor != null ) { // no N or some other unexpected bp in the stream
// double gamma = regressor.regress((double)cycle+1, (double)qual);
// double expGamma = Math.exp(gamma);
// double finalP = expGamma / (1+expGamma);
// newQual = QualityUtils.probToQual(1-finalP);
// //newQual = -10 * Math.round(logPOver1minusP)
// /*double POver1minusP = Math.pow(10, logPOver1minusP);
// P = POver1minusP / (1 + POver1minusP);*/
// //newQual = QualityUtils.probToQual(P);
//
// //newQual = (byte)Math.min(Math.round(-10*logPOver1minusP),63);
// //System.out.printf("Recal %s %d %d %d%n", dinuc, cycle, qual, newQual);
// }else{
// newQual = qual;
// }
//
// recalQuals[cycle] = newQual;
// }
//
// if (read.getReadNegativeStrandFlag())
// recalQuals = BaseUtils.reverse(quals);
// //System.out.printf("OLD: %s%n", read.format());
// read.setBaseQualities(recalQuals);
// //System.out.printf("NEW: %s%n", read.format());
// return recalRead;
// }
public void onTraversalDone(SAMFileWriter output) {
if ( output != null ) {
output.close();

View File

@ -47,7 +47,7 @@ public class PairwiseDistanceAnalysis extends BasicVariantAnalysis {
//out.printf("# Excluding %d %s %s vs. %s %s%n", d, eL, interval, lvL, lastVariantInterval);
} else {
pairWiseDistances.add(d);
r = String.format("Pairwise-distance %d %s %s%n", d, eL, lvL);
r = String.format("Pairwise-distance %d %s %s", d, eL, lvL);
}
}
}

View File

@ -50,7 +50,7 @@ def main():
(OPTIONS, args) = parser.parse_args()
if len(args) != 2:
parser.error("incorrect number of arguments")
parser.error("incorrect number of arguments: " + str(args))
lines = [line.split() for line in open(args[0])]
nIndividuals = int(args[1])

View File

@ -28,6 +28,9 @@ if __name__ == "__main__":
parser.add_option("-i", "--ignoreExistingFiles", dest="ignoreExistingFiles",
action='store_true', default=False,
help="Ignores already written files, if present")
parser.add_option("-s", "--useSamtools", dest="useSamtools",
action='store_true', default=False,
help="If present, uses samtools to perform the merge")
parser.add_option("-m", "--mergeBin", dest="mergeBin",
type="string", default=None,
help="Path to merge binary")
@ -61,8 +64,8 @@ if __name__ == "__main__":
jobid = None
if OPTIONS.ignoreExistingFiles or not os.path.exists(spec.getMergedBAM()):
output = spec.getMergedBase() + '.stdout'
cmd = spec.mergeCmd(OPTIONS.mergeBin)
print cmd
cmd = spec.mergeCmd(OPTIONS.mergeBin, useSamtools = OPTIONS.useSamtools)
#print cmd
jobid = farm_commands.cmd(cmd, OPTIONS.farmQueue, output, just_print_commands = OPTIONS.dry)
if OPTIONS.ignoreExistingFiles or not os.path.exists(spec.getMergedBAMIndex()):

View File

@ -90,11 +90,11 @@ class MergeFilesSpec:
sizes = map(greek, sizes)
return sizes
def mergeCmd(self, mergeBin = None, MSD = False):
def mergeCmd(self, mergeBin = None, MSD = True, useSamtools = False):
if mergeBin == None:
mergeBin = MERGE_BIN
return picard_utils.mergeBAMCmd(self.getMergedBAM(), self.sources(), mergeBin, MSD = MSD)
return picard_utils.mergeBAMCmd(self.getMergedBAM(), self.sources(), mergeBin, MSD = MSD, useSamtools = useSamtools)
def getIndexCmd(self):
return "samtools index " + self.getMergedBAM()

View File

@ -19,6 +19,7 @@ ref = "/seq/references/Homo_sapiens_assembly18/v0/Homo_sapiens_assembly18.fasta"
analysis = "CombineDuplicates"
MERGE_BIN = '/seq/software/picard/current/bin/MergeSamFiles.jar'
SAMTOOLS_MERGE_BIN = '/seq/dirseq/samtools/current/samtools merge'
CALL_GENOTYPES_BIN = '/seq/software/picard/current/bin/CallGenotypes.jar'
def CollectDbSnpMatchesCmd(inputGeli, outputFile, lod):
@ -152,15 +153,19 @@ def aggregateGeliCalls( sortedGeliCalls ):
#return [[loc, list(sharedCallsGroup)] for (loc, sharedCallsGroup) in itertools.groupby(sortedGeliCalls, call2loc)]
return [[loc, list(sharedCallsGroup)] for (loc, sharedCallsGroup) in itertools.groupby(sortedGeliCalls, call2loc)]
def mergeBAMCmd( output_filename, inputFiles, mergeBin = MERGE_BIN, MSD = True ):
if type(inputFiles) <> list:
inputFiles = list(inputFiles)
MSDStr = ''
if MSD: MSDStr = 'MSD=true'
return 'java -Xmx4096m -jar ' + mergeBin + ' ' + MSDStr + ' AS=true SO=coordinate O=' + output_filename + ' VALIDATION_STRINGENCY=SILENT ' + ' I=' + (' I='.join(inputFiles))
#return 'java -Xmx4096m -jar ' + mergeBin + ' AS=true SO=coordinate O=' + output_filename + ' VALIDATION_STRINGENCY=SILENT ' + ' I=' + (' I='.join(inputFiles))
def mergeBAMCmd( output_filename, inputFiles, mergeBin = MERGE_BIN, MSD = True, useSamtools = False ):
if useSamtools:
return SAMTOOLS_MERGE_BIN + ' ' + output_filename + ' ' + ' '.join(inputFiles)
else:
# use picard
if type(inputFiles) <> list:
inputFiles = list(inputFiles)
MSDStr = ''
if MSD: MSDStr = 'MSD=true'
return 'java -Xmx4096m -jar ' + mergeBin + ' ' + MSDStr + ' AS=true SO=coordinate O=' + output_filename + ' VALIDATION_STRINGENCY=SILENT ' + ' I=' + (' I='.join(inputFiles))
#return 'java -Xmx4096m -jar ' + mergeBin + ' AS=true SO=coordinate O=' + output_filename + ' VALIDATION_STRINGENCY=SILENT ' + ' I=' + (' I='.join(inputFiles))
def getPicardPath(lane, picardRoot = '/seq/picard/'):
flowcell, laneNo = lane.split('.')