From 7fa84ea15721ed82187700ec410e684520ba42b9 Mon Sep 17 00:00:00 2001 From: depristo Date: Tue, 9 Jun 2009 15:39:40 +0000 Subject: [PATCH] 10x speedup of recalibration walker git-svn-id: file:///humgen/gsa-scr1/gsa-engineering/svn_contents/trunk@954 348d0f76-0448-11de-a6fe-93d51630548a --- .../gatk/refdata/ReferenceOrderedData.java | 2 +- .../sting/gatk/refdata/rodDbSNP.java | 3 + .../walkers/LogisticRecalibrationWalker.java | 202 +++++++++++++++--- .../varianteval/PairwiseDistanceAnalysis.java | 2 +- python/Gelis2PopSNPs.py | 2 +- python/MergeBAMBatch.py | 7 +- python/MergeBAMsUtils.py | 4 +- python/picard_utils.py | 23 +- 8 files changed, 205 insertions(+), 40 deletions(-) diff --git a/java/src/org/broadinstitute/sting/gatk/refdata/ReferenceOrderedData.java b/java/src/org/broadinstitute/sting/gatk/refdata/ReferenceOrderedData.java index 952e1e303..543148d95 100644 --- a/java/src/org/broadinstitute/sting/gatk/refdata/ReferenceOrderedData.java +++ b/java/src/org/broadinstitute/sting/gatk/refdata/ReferenceOrderedData.java @@ -271,7 +271,7 @@ public class ReferenceOrderedData implements do { final String line = parser.next(); - //System.out.printf("Line is %s%n", line); + //System.out.printf("Line is '%s'%n", line); String parts[] = line.split(fieldDelimiter); try { diff --git a/java/src/org/broadinstitute/sting/gatk/refdata/rodDbSNP.java b/java/src/org/broadinstitute/sting/gatk/refdata/rodDbSNP.java index 1f8a30e57..f238596dd 100644 --- a/java/src/org/broadinstitute/sting/gatk/refdata/rodDbSNP.java +++ b/java/src/org/broadinstitute/sting/gatk/refdata/rodDbSNP.java @@ -181,6 +181,9 @@ public class rodDbSNP extends BasicReferenceOrderedDatum implements AllelicVaria } catch( MalformedGenomeLocException ex ) { // Just rethrow malformed genome locs; the ROD system itself will deal with these. throw ex; + } catch( ArrayIndexOutOfBoundsException ex ) { + // Just rethrow malformed genome locs; the ROD system itself will deal with these. + throw new RuntimeException("Badly formed dbSNP line: " + ex); } catch ( RuntimeException e ) { System.out.printf(" Exception caught during parsing DBSNP line %s%n", Utils.join(" <=> ", parts)); throw e; diff --git a/java/src/org/broadinstitute/sting/playground/gatk/walkers/LogisticRecalibrationWalker.java b/java/src/org/broadinstitute/sting/playground/gatk/walkers/LogisticRecalibrationWalker.java index 42cc6e0ea..ebac1f1ad 100755 --- a/java/src/org/broadinstitute/sting/playground/gatk/walkers/LogisticRecalibrationWalker.java +++ b/java/src/org/broadinstitute/sting/playground/gatk/walkers/LogisticRecalibrationWalker.java @@ -19,9 +19,21 @@ public class LogisticRecalibrationWalker extends ReadWalker, LogisticRegressor> regressors = new HashMap, LogisticRegressor>(); private static Logger logger = Logger.getLogger(LogisticRecalibrationWalker.class); + // maps from [readGroup] -> [prevBase x base -> [cycle, qual, new qual]] + HashMap> cache = new HashMap>(); + + private static byte MAX_Q_SCORE = 64; + + + @Argument(shortName="maxReadLen", doc="Maximum allowed read length to allow during recalibration, needed for recalibration table allocation", required=false) + public static int maxReadLen = 125; + public void initialize() { try { List lines = new xReadLines(new File(logisticParamsFile)).readLines(); @@ -46,12 +58,19 @@ public class LogisticRecalibrationWalker extends ReadWalker, LogisticRegressor> e : regressors.entrySet() ) { String readGroup = e.getKey().first; String dinuc = e.getKey().second; LogisticRegressor regressor = e.getValue(); logger.debug(String.format("Regressor: %s,%s => %s", readGroup, dinuc, regressor)); + + if ( useLogisticCache ) { + addToLogisticCache(readGroup, dinuc, regressor); + } } + if ( useLogisticCache ) System.out.printf("done%n"); //System.exit(1); } catch ( FileNotFoundException e ) { @@ -89,8 +108,51 @@ public class LogisticRecalibrationWalker extends ReadWalker lookup1 = cache.containsKey(readGroup) ? cache.get(readGroup) : new HashMap(); + lookup1.put(dinuc, dataTable); + cache.put(readGroup, lookup1); + } public SAMRecord map(char[] ref, SAMRecord read) { + if ( useLogisticCache ) + return mapCached(ref, read); + else + return mapOriginal(ref, read); + } + + private byte cache2newQual(final String readGroup, HashMap RGcache, byte prevBase, byte base, LogisticRegressor regressor, int cycle, byte qual) { + //System.out.printf("Lookup %s %c %c %d %d%n", readGroup, prevBase, base, cycle, qual); + //String dinuc = String.format("%c%c", (char)prevBase, (char)base); + byte[] bp = {prevBase, base}; + String dinuc = new String(bp); + + //byte newQualCalc = regressor2newQual(regressor, cycle, qual); + byte[][] dataTable = RGcache.get(dinuc); + + byte newQualCached = dataTable != null ? dataTable[cycle][qual] : qual; + //if ( newQualCached != newQualCalc ) { + // throw new RuntimeException(String.format("Inconsistent quals between the cache and calculation for RG=%s: %s %d %d : %d <> %d", + // readGroup, dinuc, cycle, qual, newQualCalc, newQualCached)); + //} + + return newQualCached; + } + + public SAMRecord mapCached(char[] ref, SAMRecord read) { + if ( read.getReadLength() > maxReadLen ) { + throw new RuntimeException("Expectedly long read, please increase maxium read len with maxReadLen parameter: " + read.format()); + } + SAMRecord recalRead = read; byte[] bases = read.getReadBases(); byte[] quals = read.getBaseQualities(); @@ -102,34 +164,20 @@ public class LogisticRecalibrationWalker extends ReadWalker RGcache = cache.get(readGroup); + int numBases = read.getReadLength(); recalQuals[0] = quals[0]; // can't change the first -- no dinuc - //recalQuals[numBases-1] = quals[numBases-1]; // can't change last -- no dinuc + for ( int cycle = 1; cycle < numBases; cycle++ ) { // skip first and last base, qual already set because no dinuc - // Take into account that previous base is the next base in terms of machine chemistry if this is a negative strand - //int cycle = i; //read.getReadNegativeStrandFlag() ? numBases - i - 1 : i; - String dinuc = String.format("%c%c", bases[cycle - 1], bases[cycle]); + // Take into account that previous base is the next base in terms of machine chemistry if + // this is a negative strand byte qual = quals[cycle]; - LogisticRegressor regressor = regressors.get(new Pair(readGroup,dinuc)); - byte newQual; - - if ( regressor != null ) { // no N or some other unexpected bp in the stream - double gamma = regressor.regress((double)cycle+1, (double)qual); - double expGamma = Math.exp(gamma); - double finalP = expGamma / (1+expGamma); - newQual = QualityUtils.probToQual(1-finalP); - //newQual = -10 * Math.round(logPOver1minusP) - /*double POver1minusP = Math.pow(10, logPOver1minusP); - P = POver1minusP / (1 + POver1minusP);*/ - //newQual = QualityUtils.probToQual(P); - - //newQual = (byte)Math.min(Math.round(-10*logPOver1minusP),63); - //System.out.printf("Recal %s %d %d %d%n", dinuc, cycle, qual, newQual); - }else{ - newQual = qual; - } - + //LogisticRegressor regressor = getLogisticRegressor(readGroup, bases[cycle - 1], bases[cycle]); + LogisticRegressor regressor = null; + byte newQual = cache2newQual(readGroup, RGcache, bases[cycle - 1], bases[cycle], regressor, cycle, qual); recalQuals[cycle] = newQual; } @@ -141,6 +189,112 @@ public class LogisticRecalibrationWalker extends ReadWalker(readGroup,dinuc)); + } + + private byte regressor2newQual(LogisticRegressor regressor, int cycle, byte qual) { + byte newQual = qual; + if ( regressor != null ) { // no N or some other unexpected bp in the stream + double gamma = regressor.regress((double)cycle+1, (double)qual); + double expGamma = Math.exp(gamma); + double finalP = expGamma / (1+expGamma); + newQual = QualityUtils.probToQual(1-finalP); + } + return newQual; + } + + public SAMRecord mapOriginal(char[] ref, SAMRecord read) { + SAMRecord recalRead = read; + byte[] bases = read.getReadBases(); + byte[] quals = read.getBaseQualities(); + byte[] recalQuals = new byte[quals.length]; + + // Since we want machine direction reads not corrected positive strand reads, rev comp any negative strand reads + if (read.getReadNegativeStrandFlag()) { + bases = BaseUtils.simpleReverseComplement(bases); + quals = BaseUtils.reverse(quals); + } + + String readGroup = read.getAttribute("RG").toString(); + int numBases = read.getReadLength(); + recalQuals[0] = quals[0]; // can't change the first -- no dinuc + + for ( int cycle = 1; cycle < numBases; cycle++ ) { // skip first and last base, qual already set because no dinuc + // Take into account that previous base is the next base in terms of machine chemistry if + // this is a negative strand + byte qual = quals[cycle]; + LogisticRegressor regressor = getLogisticRegressor(readGroup, bases[cycle - 1], bases[cycle]); + byte newQual = regressor2newQual(regressor, cycle, qual); + recalQuals[cycle] = newQual; + } + + if (read.getReadNegativeStrandFlag()) + recalQuals = BaseUtils.reverse(quals); + //System.out.printf("OLD: %s%n", read.format()); + read.setBaseQualities(recalQuals); + //System.out.printf("NEW: %s%n", read.format()); + return recalRead; + } + +// public SAMRecord mapOriginalUnmodified(char[] ref, SAMRecord read) { +// SAMRecord recalRead = read; +// byte[] bases = read.getReadBases(); +// byte[] quals = read.getBaseQualities(); +// byte[] recalQuals = new byte[quals.length]; +// +// // Since we want machine direction reads not corrected positive strand reads, rev comp any negative strand reads +// if (read.getReadNegativeStrandFlag()) { +// bases = BaseUtils.simpleReverseComplement(bases); +// quals = BaseUtils.reverse(quals); +// } +// +// String readGroup = read.getAttribute("RG").toString(); +// int numBases = read.getReadLength(); +// recalQuals[0] = quals[0]; // can't change the first -- no dinuc +// //recalQuals[numBases-1] = quals[numBases-1]; // can't change last -- no dinuc +// for ( int cycle = 1; cycle < numBases; cycle++ ) { // skip first and last base, qual already set because no dinuc +// // Take into account that previous base is the next base in terms of machine chemistry if this is a negative strand +// //int cycle = i; //read.getReadNegativeStrandFlag() ? numBases - i - 1 : i; +// String dinuc = String.format("%c%c", bases[cycle - 1], bases[cycle]); +// byte qual = quals[cycle]; +// LogisticRegressor regressor = regressors.get(new Pair(readGroup,dinuc)); +// byte newQual; +// +// if ( regressor != null ) { // no N or some other unexpected bp in the stream +// double gamma = regressor.regress((double)cycle+1, (double)qual); +// double expGamma = Math.exp(gamma); +// double finalP = expGamma / (1+expGamma); +// newQual = QualityUtils.probToQual(1-finalP); +// //newQual = -10 * Math.round(logPOver1minusP) +// /*double POver1minusP = Math.pow(10, logPOver1minusP); +// P = POver1minusP / (1 + POver1minusP);*/ +// //newQual = QualityUtils.probToQual(P); +// +// //newQual = (byte)Math.min(Math.round(-10*logPOver1minusP),63); +// //System.out.printf("Recal %s %d %d %d%n", dinuc, cycle, qual, newQual); +// }else{ +// newQual = qual; +// } +// +// recalQuals[cycle] = newQual; +// } +// +// if (read.getReadNegativeStrandFlag()) +// recalQuals = BaseUtils.reverse(quals); +// //System.out.printf("OLD: %s%n", read.format()); +// read.setBaseQualities(recalQuals); +// //System.out.printf("NEW: %s%n", read.format()); +// return recalRead; +// } + + public void onTraversalDone(SAMFileWriter output) { if ( output != null ) { output.close(); diff --git a/java/src/org/broadinstitute/sting/playground/gatk/walkers/varianteval/PairwiseDistanceAnalysis.java b/java/src/org/broadinstitute/sting/playground/gatk/walkers/varianteval/PairwiseDistanceAnalysis.java index 97dc37375..c5344775f 100755 --- a/java/src/org/broadinstitute/sting/playground/gatk/walkers/varianteval/PairwiseDistanceAnalysis.java +++ b/java/src/org/broadinstitute/sting/playground/gatk/walkers/varianteval/PairwiseDistanceAnalysis.java @@ -47,7 +47,7 @@ public class PairwiseDistanceAnalysis extends BasicVariantAnalysis { //out.printf("# Excluding %d %s %s vs. %s %s%n", d, eL, interval, lvL, lastVariantInterval); } else { pairWiseDistances.add(d); - r = String.format("Pairwise-distance %d %s %s%n", d, eL, lvL); + r = String.format("Pairwise-distance %d %s %s", d, eL, lvL); } } } diff --git a/python/Gelis2PopSNPs.py b/python/Gelis2PopSNPs.py index ab3234430..61e479969 100755 --- a/python/Gelis2PopSNPs.py +++ b/python/Gelis2PopSNPs.py @@ -50,7 +50,7 @@ def main(): (OPTIONS, args) = parser.parse_args() if len(args) != 2: - parser.error("incorrect number of arguments") + parser.error("incorrect number of arguments: " + str(args)) lines = [line.split() for line in open(args[0])] nIndividuals = int(args[1]) diff --git a/python/MergeBAMBatch.py b/python/MergeBAMBatch.py index 7e7889f79..8ef1f4fbb 100755 --- a/python/MergeBAMBatch.py +++ b/python/MergeBAMBatch.py @@ -28,6 +28,9 @@ if __name__ == "__main__": parser.add_option("-i", "--ignoreExistingFiles", dest="ignoreExistingFiles", action='store_true', default=False, help="Ignores already written files, if present") + parser.add_option("-s", "--useSamtools", dest="useSamtools", + action='store_true', default=False, + help="If present, uses samtools to perform the merge") parser.add_option("-m", "--mergeBin", dest="mergeBin", type="string", default=None, help="Path to merge binary") @@ -61,8 +64,8 @@ if __name__ == "__main__": jobid = None if OPTIONS.ignoreExistingFiles or not os.path.exists(spec.getMergedBAM()): output = spec.getMergedBase() + '.stdout' - cmd = spec.mergeCmd(OPTIONS.mergeBin) - print cmd + cmd = spec.mergeCmd(OPTIONS.mergeBin, useSamtools = OPTIONS.useSamtools) + #print cmd jobid = farm_commands.cmd(cmd, OPTIONS.farmQueue, output, just_print_commands = OPTIONS.dry) if OPTIONS.ignoreExistingFiles or not os.path.exists(spec.getMergedBAMIndex()): diff --git a/python/MergeBAMsUtils.py b/python/MergeBAMsUtils.py index f520521b0..6aa5066d7 100755 --- a/python/MergeBAMsUtils.py +++ b/python/MergeBAMsUtils.py @@ -90,11 +90,11 @@ class MergeFilesSpec: sizes = map(greek, sizes) return sizes - def mergeCmd(self, mergeBin = None, MSD = False): + def mergeCmd(self, mergeBin = None, MSD = True, useSamtools = False): if mergeBin == None: mergeBin = MERGE_BIN - return picard_utils.mergeBAMCmd(self.getMergedBAM(), self.sources(), mergeBin, MSD = MSD) + return picard_utils.mergeBAMCmd(self.getMergedBAM(), self.sources(), mergeBin, MSD = MSD, useSamtools = useSamtools) def getIndexCmd(self): return "samtools index " + self.getMergedBAM() diff --git a/python/picard_utils.py b/python/picard_utils.py index cbdaa406d..d8df8f4b9 100755 --- a/python/picard_utils.py +++ b/python/picard_utils.py @@ -19,6 +19,7 @@ ref = "/seq/references/Homo_sapiens_assembly18/v0/Homo_sapiens_assembly18.fasta" analysis = "CombineDuplicates" MERGE_BIN = '/seq/software/picard/current/bin/MergeSamFiles.jar' +SAMTOOLS_MERGE_BIN = '/seq/dirseq/samtools/current/samtools merge' CALL_GENOTYPES_BIN = '/seq/software/picard/current/bin/CallGenotypes.jar' def CollectDbSnpMatchesCmd(inputGeli, outputFile, lod): @@ -152,15 +153,19 @@ def aggregateGeliCalls( sortedGeliCalls ): #return [[loc, list(sharedCallsGroup)] for (loc, sharedCallsGroup) in itertools.groupby(sortedGeliCalls, call2loc)] return [[loc, list(sharedCallsGroup)] for (loc, sharedCallsGroup) in itertools.groupby(sortedGeliCalls, call2loc)] -def mergeBAMCmd( output_filename, inputFiles, mergeBin = MERGE_BIN, MSD = True ): - if type(inputFiles) <> list: - inputFiles = list(inputFiles) - - MSDStr = '' - if MSD: MSDStr = 'MSD=true' - - return 'java -Xmx4096m -jar ' + mergeBin + ' ' + MSDStr + ' AS=true SO=coordinate O=' + output_filename + ' VALIDATION_STRINGENCY=SILENT ' + ' I=' + (' I='.join(inputFiles)) - #return 'java -Xmx4096m -jar ' + mergeBin + ' AS=true SO=coordinate O=' + output_filename + ' VALIDATION_STRINGENCY=SILENT ' + ' I=' + (' I='.join(inputFiles)) +def mergeBAMCmd( output_filename, inputFiles, mergeBin = MERGE_BIN, MSD = True, useSamtools = False ): + if useSamtools: + return SAMTOOLS_MERGE_BIN + ' ' + output_filename + ' ' + ' '.join(inputFiles) + else: + # use picard + if type(inputFiles) <> list: + inputFiles = list(inputFiles) + + MSDStr = '' + if MSD: MSDStr = 'MSD=true' + + return 'java -Xmx4096m -jar ' + mergeBin + ' ' + MSDStr + ' AS=true SO=coordinate O=' + output_filename + ' VALIDATION_STRINGENCY=SILENT ' + ' I=' + (' I='.join(inputFiles)) + #return 'java -Xmx4096m -jar ' + mergeBin + ' AS=true SO=coordinate O=' + output_filename + ' VALIDATION_STRINGENCY=SILENT ' + ' I=' + (' I='.join(inputFiles)) def getPicardPath(lane, picardRoot = '/seq/picard/'): flowcell, laneNo = lane.split('.')