diff --git a/protected/java/src/org/broadinstitute/sting/gatk/walkers/bqsr/AdvancedRecalibrationEngine.java b/protected/java/src/org/broadinstitute/sting/gatk/walkers/bqsr/AdvancedRecalibrationEngine.java index d0bcd0eb3..ac280b70e 100644 --- a/protected/java/src/org/broadinstitute/sting/gatk/walkers/bqsr/AdvancedRecalibrationEngine.java +++ b/protected/java/src/org/broadinstitute/sting/gatk/walkers/bqsr/AdvancedRecalibrationEngine.java @@ -25,13 +25,11 @@ package org.broadinstitute.sting.gatk.walkers.bqsr; * OTHER DEALINGS IN THE SOFTWARE. */ -import org.broadinstitute.sting.utils.recalibration.covariates.Covariate; -import org.broadinstitute.sting.utils.BaseUtils; import org.broadinstitute.sting.utils.classloader.ProtectedPackageSource; -import org.broadinstitute.sting.utils.pileup.PileupElement; import org.broadinstitute.sting.utils.recalibration.EventType; import org.broadinstitute.sting.utils.recalibration.ReadCovariates; import org.broadinstitute.sting.utils.recalibration.RecalibrationTables; +import org.broadinstitute.sting.utils.recalibration.covariates.Covariate; import org.broadinstitute.sting.utils.sam.GATKSAMRecord; import org.broadinstitute.sting.utils.threading.ThreadLocalArray; @@ -47,13 +45,12 @@ public class AdvancedRecalibrationEngine extends StandardRecalibrationEngine imp @Override public void updateDataForRead(final GATKSAMRecord read, final boolean[] skip, final double[] snpErrors, final double[] insertionErrors, final double[] deletionErrors ) { + final ReadCovariates readCovariates = covariateKeySetFrom(read); + byte[] tempQualArray = threadLocalTempQualArray.get(); + double[] tempFractionalErrorArray = threadLocalTempFractionalErrorArray.get(); + for( int offset = 0; offset < read.getReadBases().length; offset++ ) { if( !skip[offset] ) { - final ReadCovariates readCovariates = covariateKeySetFrom(read); - - byte[] tempQualArray = threadLocalTempQualArray.get(); - double[] tempFractionalErrorArray = threadLocalTempFractionalErrorArray.get(); - tempQualArray[EventType.BASE_SUBSTITUTION.index] = read.getBaseQualities()[offset]; tempFractionalErrorArray[EventType.BASE_SUBSTITUTION.index] = snpErrors[offset]; tempQualArray[EventType.BASE_INSERTION.index] = read.getBaseInsertionQualities()[offset]; @@ -67,8 +64,6 @@ public class AdvancedRecalibrationEngine extends StandardRecalibrationEngine imp final byte qual = tempQualArray[eventIndex]; final double isError = tempFractionalErrorArray[eventIndex]; - combineDatumOrPutIfNecessary(recalibrationTables.getReadGroupTable(), qual, isError, keys[0], eventIndex); - incrementDatumOrPutIfNecessary(recalibrationTables.getQualityScoreTable(), qual, isError, keys[0], keys[1], eventIndex); for (int i = 2; i < covariates.length; i++) { diff --git a/public/java/src/org/broadinstitute/sting/gatk/datasources/reads/GATKBAMIndex.java b/public/java/src/org/broadinstitute/sting/gatk/datasources/reads/GATKBAMIndex.java index e3a1b61bd..abaa0b226 100644 --- a/public/java/src/org/broadinstitute/sting/gatk/datasources/reads/GATKBAMIndex.java +++ b/public/java/src/org/broadinstitute/sting/gatk/datasources/reads/GATKBAMIndex.java @@ -23,17 +23,17 @@ */ package org.broadinstitute.sting.gatk.datasources.reads; +import org.broad.tribble.util.SeekableBufferedStream; +import org.broad.tribble.util.SeekableFileStream; + import net.sf.samtools.*; -import org.broadinstitute.sting.gatk.CommandLineGATK; + import org.broadinstitute.sting.utils.exceptions.ReviewedStingException; import org.broadinstitute.sting.utils.exceptions.UserException; -import java.io.File; -import java.io.FileInputStream; -import java.io.IOException; +import java.io.*; import java.nio.ByteBuffer; import java.nio.ByteOrder; -import java.nio.channels.FileChannel; import java.util.ArrayList; import java.util.Arrays; import java.util.List; @@ -68,6 +68,9 @@ public class GATKBAMIndex { private final File mFile; + //TODO: figure out a good value for this buffer size + private final int BUFFERED_STREAM_BUFFER_SIZE=8192; + /** * Number of sequences stored in this index. */ @@ -78,8 +81,8 @@ public class GATKBAMIndex { */ private final long[] sequenceStartCache; - private FileInputStream fileStream; - private FileChannel fileChannel; + private SeekableFileStream fileStream; + private SeekableBufferedStream bufferedStream; public GATKBAMIndex(final File file) { mFile = file; @@ -277,12 +280,11 @@ public class GATKBAMIndex { for (int i = sequenceIndex; i < referenceSequence; i++) { sequenceStartCache[i] = position(); - // System.out.println("# Sequence TID: " + i); final int nBins = readInteger(); // System.out.println("# nBins: " + nBins); for (int j = 0; j < nBins; j++) { - skipInteger(); + final int bin = readInteger(); final int nChunks = readInteger(); // System.out.println("# bin[" + j + "] = " + bin + ", nChunks = " + nChunks); skipBytes(16 * nChunks); @@ -290,15 +292,18 @@ public class GATKBAMIndex { final int nLinearBins = readInteger(); // System.out.println("# nLinearBins: " + nLinearBins); skipBytes(8 * nLinearBins); + } sequenceStartCache[referenceSequence] = position(); } + + private void openIndexFile() { try { - fileStream = new FileInputStream(mFile); - fileChannel = fileStream.getChannel(); + fileStream = new SeekableFileStream(mFile); + bufferedStream = new SeekableBufferedStream(fileStream,BUFFERED_STREAM_BUFFER_SIZE); } catch (IOException exc) { throw new ReviewedStingException("Unable to open index file (" + exc.getMessage() +")" + mFile, exc); @@ -307,7 +312,7 @@ public class GATKBAMIndex { private void closeIndexFile() { try { - fileChannel.close(); + bufferedStream.close(); fileStream.close(); } catch (IOException exc) { @@ -334,10 +339,6 @@ public class GATKBAMIndex { return buffer.getInt(); } - private void skipInteger() { - skipBytes(INT_SIZE_IN_BYTES); - } - /** * Reads an array of longs from the file channel, returning the results as an array. * @param count Number of longs to read. @@ -356,7 +357,9 @@ public class GATKBAMIndex { private void read(final ByteBuffer buffer) { try { int bytesExpected = buffer.limit(); - int bytesRead = fileChannel.read(buffer); + //BufferedInputStream cannot read directly into a byte buffer, so we read into an array + //and put the result into the bytebuffer after the if statement. + int bytesRead = bufferedStream.read(byteArray,0,bytesExpected); // We have a rigid expectation here to read in exactly the number of bytes we've limited // our buffer to -- if we read in fewer bytes than this, or encounter EOF (-1), the index @@ -367,6 +370,7 @@ public class GATKBAMIndex { "Please try re-indexing the corresponding BAM file.", mFile)); } + buffer.put(byteArray,0,bytesRead); } catch(IOException ex) { throw new ReviewedStingException("Index: unable to read bytes from index file " + mFile); @@ -380,10 +384,13 @@ public class GATKBAMIndex { */ private ByteBuffer buffer = null; + //BufferedStream don't read into ByteBuffers, so we need this temporary array + private byte[] byteArray=null; private ByteBuffer getBuffer(final int size) { if(buffer == null || buffer.capacity() < size) { // Allocate a new byte buffer. For now, make it indirect to make sure it winds up on the heap for easier debugging. buffer = ByteBuffer.allocate(size); + byteArray = new byte[size]; buffer.order(ByteOrder.LITTLE_ENDIAN); } buffer.clear(); @@ -393,7 +400,13 @@ public class GATKBAMIndex { private void skipBytes(final int count) { try { - fileChannel.position(fileChannel.position() + count); + + //try to skip forward the requested amount. + long skipped = bufferedStream.skip(count); + + if( skipped != count ) { //if not managed to skip the requested amount + throw new ReviewedStingException("Index: unable to reposition file channel of index file " + mFile); + } } catch(IOException ex) { throw new ReviewedStingException("Index: unable to reposition file channel of index file " + mFile); @@ -402,7 +415,8 @@ public class GATKBAMIndex { private void seek(final long position) { try { - fileChannel.position(position); + //to seek a new position, move the fileChannel, and reposition the bufferedStream + bufferedStream.seek(position); } catch(IOException ex) { throw new ReviewedStingException("Index: unable to reposition of file channel of index file " + mFile); @@ -415,7 +429,7 @@ public class GATKBAMIndex { */ private long position() { try { - return fileChannel.position(); + return bufferedStream.position(); } catch (IOException exc) { throw new ReviewedStingException("Unable to read position from index file " + mFile, exc); diff --git a/public/java/src/org/broadinstitute/sting/gatk/traversals/TraverseReadsNano.java b/public/java/src/org/broadinstitute/sting/gatk/traversals/TraverseReadsNano.java index 735f62ca3..cd0198a29 100755 --- a/public/java/src/org/broadinstitute/sting/gatk/traversals/TraverseReadsNano.java +++ b/public/java/src/org/broadinstitute/sting/gatk/traversals/TraverseReadsNano.java @@ -33,14 +33,13 @@ import org.broadinstitute.sting.gatk.datasources.providers.ReadShardDataProvider import org.broadinstitute.sting.gatk.datasources.providers.ReadView; import org.broadinstitute.sting.gatk.refdata.RefMetaDataTracker; import org.broadinstitute.sting.gatk.walkers.ReadWalker; -import org.broadinstitute.sting.utils.GenomeLoc; import org.broadinstitute.sting.utils.nanoScheduler.NSMapFunction; +import org.broadinstitute.sting.utils.nanoScheduler.NSProgressFunction; import org.broadinstitute.sting.utils.nanoScheduler.NSReduceFunction; import org.broadinstitute.sting.utils.nanoScheduler.NanoScheduler; import org.broadinstitute.sting.utils.sam.GATKSAMRecord; -import java.util.LinkedList; -import java.util.List; +import java.util.Iterator; /** * A nano-scheduling version of TraverseReads. @@ -60,6 +59,13 @@ public class TraverseReadsNano extends TraversalEngine, public TraverseReadsNano(int nThreads) { nanoScheduler = new NanoScheduler(nThreads); + nanoScheduler.setProgressFunction(new NSProgressFunction() { + @Override + public void progress(MapData lastProcessedMap) { + if ( lastProcessedMap.refContext != null ) + printProgress(lastProcessedMap.refContext.getLocus()); + } + }); } @Override @@ -78,7 +84,8 @@ public class TraverseReadsNano extends TraversalEngine, public T traverse(ReadWalker walker, ReadShardDataProvider dataProvider, T sum) { - logger.debug(String.format("TraverseReadsNano.traverse Covered dataset is %s", dataProvider)); + if ( logger.isDebugEnabled() ) + logger.debug(String.format("TraverseReadsNano.traverse Covered dataset is %s", dataProvider)); if( !dataProvider.hasReads() ) throw new IllegalArgumentException("Unable to traverse reads; no read data is available."); @@ -87,14 +94,10 @@ public class TraverseReadsNano extends TraversalEngine, final TraverseReadsMap myMap = new TraverseReadsMap(walker); final TraverseReadsReduce myReduce = new TraverseReadsReduce(walker); - final List aggregatedInputs = aggregateMapData(dataProvider); - final T result = nanoScheduler.execute(aggregatedInputs.iterator(), myMap, sum, myReduce); - - final GATKSAMRecord lastRead = aggregatedInputs.get(aggregatedInputs.size() - 1).read; - final GenomeLoc locus = engine.getGenomeLocParser().createGenomeLoc(lastRead); + final Iterator aggregatedInputs = aggregateMapData(dataProvider); + final T result = nanoScheduler.execute(aggregatedInputs, myMap, sum, myReduce); updateCumulativeMetrics(dataProvider.getShard()); - printProgress(locus); return result; } @@ -107,29 +110,37 @@ public class TraverseReadsNano extends TraversalEngine, * @return a linked list of MapData objects holding the read, ref, and ROD info for every map/reduce * should execute */ - private List aggregateMapData(final ReadShardDataProvider dataProvider) { - final ReadView reads = new ReadView(dataProvider); - final ReadReferenceView reference = new ReadReferenceView(dataProvider); - final ReadBasedReferenceOrderedView rodView = new ReadBasedReferenceOrderedView(dataProvider); + private Iterator aggregateMapData(final ReadShardDataProvider dataProvider) { + return new Iterator() { + final ReadView reads = new ReadView(dataProvider); + final ReadReferenceView reference = new ReadReferenceView(dataProvider); + final ReadBasedReferenceOrderedView rodView = new ReadBasedReferenceOrderedView(dataProvider); + final Iterator readIterator = reads.iterator(); - final List mapData = new LinkedList(); - for ( final SAMRecord read : reads ) { - final ReferenceContext refContext = ! read.getReadUnmappedFlag() - ? reference.getReferenceContext(read) - : null; + @Override public boolean hasNext() { return readIterator.hasNext(); } - // if the read is mapped, create a metadata tracker - final RefMetaDataTracker tracker = read.getReferenceIndex() >= 0 - ? rodView.getReferenceOrderedDataForRead(read) - : null; + @Override + public MapData next() { + final SAMRecord read = readIterator.next(); + final ReferenceContext refContext = ! read.getReadUnmappedFlag() + ? reference.getReferenceContext(read) + : null; - // update the number of reads we've seen - dataProvider.getShard().getReadMetrics().incrementNumIterations(); + // if the read is mapped, create a metadata tracker + final RefMetaDataTracker tracker = read.getReferenceIndex() >= 0 + ? rodView.getReferenceOrderedDataForRead(read) + : null; - mapData.add(new MapData((GATKSAMRecord)read, refContext, tracker)); - } + // update the number of reads we've seen + dataProvider.getShard().getReadMetrics().incrementNumIterations(); - return mapData; + return new MapData((GATKSAMRecord)read, refContext, tracker); + } + + @Override public void remove() { + throw new UnsupportedOperationException("Remove not supported"); + } + }; } @Override diff --git a/public/java/src/org/broadinstitute/sting/gatk/walkers/annotator/IndelType.java b/public/java/src/org/broadinstitute/sting/gatk/walkers/annotator/VariantType.java similarity index 52% rename from public/java/src/org/broadinstitute/sting/gatk/walkers/annotator/IndelType.java rename to public/java/src/org/broadinstitute/sting/gatk/walkers/annotator/VariantType.java index c67d829c2..a5c2b32f0 100755 --- a/public/java/src/org/broadinstitute/sting/gatk/walkers/annotator/IndelType.java +++ b/public/java/src/org/broadinstitute/sting/gatk/walkers/annotator/VariantType.java @@ -15,9 +15,9 @@ import org.broadinstitute.sting.utils.variantcontext.VariantContext; import java.util.*; /** - * Rough category of indel type (insertion, deletion, multi-allelic, other) + * Assigns a roughly correct category of the variant type (SNP, MNP, insertion, deletion, etc.) */ -public class IndelType extends InfoFieldAnnotation implements ExperimentalAnnotation { +public class VariantType extends InfoFieldAnnotation implements ExperimentalAnnotation { public Map annotate(final RefMetaDataTracker tracker, final AnnotatorCompatible walker, @@ -26,41 +26,35 @@ public class IndelType extends InfoFieldAnnotation implements ExperimentalAnnota final VariantContext vc, final Map stratifiedPerReadAlleleLikelihoodMap) { - int run; - if (vc.isMixed()) { - Map map = new HashMap(); - map.put(getKeyNames().get(0), String.format("%s", "MIXED")); - return map; - - } - else if ( vc.isIndel() ) { - String type=""; - if (!vc.isBiallelic()) - type = "MULTIALLELIC_INDEL"; - else { - if (vc.isSimpleInsertion()) - type = "INS."; - else if (vc.isSimpleDeletion()) - type = "DEL."; - else - type = "OTHER."; - ArrayList inds = IndelUtils.findEventClassificationIndex(vc, ref); - for (int k : inds) { - type = type+ IndelUtils.getIndelClassificationName(k)+"."; - } - } - Map map = new HashMap(); - map.put(getKeyNames().get(0), String.format("%s", type)); - return map; + StringBuffer type = new StringBuffer(""); + if ( vc.isVariant() && !vc.isBiallelic() ) + type.append("MULTIALLELIC_"); + if ( !vc.isIndel() ) { + type.append(vc.getType().toString()); } else { - return null; + if (vc.isSimpleInsertion()) + type.append("INSERTION."); + else if (vc.isSimpleDeletion()) + type.append("DELETION."); + else + type.append("COMPLEX."); + ArrayList inds = IndelUtils.findEventClassificationIndex(vc, ref); + type.append(IndelUtils.getIndelClassificationName(inds.get(0))); + + for (int i = 1; i < inds.size(); i++ ) { + type.append("."); + type.append(IndelUtils.getIndelClassificationName(inds.get(i))); + } } + Map map = new HashMap(); + map.put(getKeyNames().get(0), String.format("%s", type)); + return map; } - public List getKeyNames() { return Arrays.asList("IndelType"); } + public List getKeyNames() { return Arrays.asList("VariantType"); } - public List getDescriptions() { return Arrays.asList(new VCFInfoHeaderLine("IndelType", 1, VCFHeaderLineType.String, "Indel type description")); } + public List getDescriptions() { return Arrays.asList(new VCFInfoHeaderLine("VariantType", 1, VCFHeaderLineType.String, "Variant type description")); } } diff --git a/public/java/src/org/broadinstitute/sting/gatk/walkers/bqsr/BaseRecalibrator.java b/public/java/src/org/broadinstitute/sting/gatk/walkers/bqsr/BaseRecalibrator.java index 7ce98cf1d..4d7dbc912 100755 --- a/public/java/src/org/broadinstitute/sting/gatk/walkers/bqsr/BaseRecalibrator.java +++ b/public/java/src/org/broadinstitute/sting/gatk/walkers/bqsr/BaseRecalibrator.java @@ -45,7 +45,6 @@ import org.broadinstitute.sting.utils.clipping.ReadClipper; import org.broadinstitute.sting.utils.collections.Pair; import org.broadinstitute.sting.utils.exceptions.ReviewedStingException; import org.broadinstitute.sting.utils.exceptions.UserException; -import org.broadinstitute.sting.utils.fasta.CachingIndexedFastaSequenceFile; import org.broadinstitute.sting.utils.help.DocumentedGATKFeature; import org.broadinstitute.sting.utils.recalibration.*; import org.broadinstitute.sting.utils.recalibration.covariates.Covariate; @@ -53,7 +52,6 @@ import org.broadinstitute.sting.utils.sam.GATKSAMRecord; import org.broadinstitute.sting.utils.sam.ReadUtils; import java.io.File; -import java.io.FileNotFoundException; import java.io.IOException; import java.io.PrintStream; import java.lang.reflect.Constructor; @@ -194,14 +192,7 @@ public class BaseRecalibrator extends ReadWalker implements NanoSche recalibrationEngine.initialize(requestedCovariates, recalibrationTables); minimumQToUse = getToolkit().getArguments().PRESERVE_QSCORES_LESS_THAN; - - try { - // fasta reference reader for use with BAQ calculation - referenceReader = new CachingIndexedFastaSequenceFile(getToolkit().getArguments().referenceFile); - } catch( FileNotFoundException e ) { - throw new UserException.CouldNotReadInputFile(getToolkit().getArguments().referenceFile, e); - } - + referenceReader = getToolkit().getReferenceDataSource().getReference(); } private RecalibrationEngine initializeRecalibrationEngine() { @@ -425,6 +416,7 @@ public class BaseRecalibrator extends ReadWalker implements NanoSche } private byte[] calculateBAQArray( final GATKSAMRecord read ) { + // todo -- it would be good to directly use the BAQ qualities rather than encoding and decoding the result and using the special @ value baq.baqRead(read, referenceReader, BAQ.CalculationMode.RECALCULATE, BAQ.QualityMode.ADD_TAG); return BAQ.getBAQTag(read); } @@ -452,6 +444,8 @@ public class BaseRecalibrator extends ReadWalker implements NanoSche @Override public void onTraversalDone(Long result) { + recalibrationEngine.finalizeData(); + logger.info("Calculating quantized quality scores..."); quantizeQualityScores(); diff --git a/public/java/src/org/broadinstitute/sting/gatk/walkers/bqsr/RecalibrationEngine.java b/public/java/src/org/broadinstitute/sting/gatk/walkers/bqsr/RecalibrationEngine.java index 962d62d5e..35375eb1d 100644 --- a/public/java/src/org/broadinstitute/sting/gatk/walkers/bqsr/RecalibrationEngine.java +++ b/public/java/src/org/broadinstitute/sting/gatk/walkers/bqsr/RecalibrationEngine.java @@ -1,8 +1,7 @@ package org.broadinstitute.sting.gatk.walkers.bqsr; -import org.broadinstitute.sting.utils.recalibration.covariates.Covariate; -import org.broadinstitute.sting.utils.pileup.PileupElement; import org.broadinstitute.sting.utils.recalibration.RecalibrationTables; +import org.broadinstitute.sting.utils.recalibration.covariates.Covariate; import org.broadinstitute.sting.utils.sam.GATKSAMRecord; /* @@ -34,4 +33,6 @@ public interface RecalibrationEngine { public void initialize(final Covariate[] covariates, final RecalibrationTables recalibrationTables); public void updateDataForRead(final GATKSAMRecord read, final boolean[] skip, final double[] snpErrors, final double[] insertionErrors, final double[] deletionErrors); + + public void finalizeData(); } diff --git a/public/java/src/org/broadinstitute/sting/gatk/walkers/bqsr/StandardRecalibrationEngine.java b/public/java/src/org/broadinstitute/sting/gatk/walkers/bqsr/StandardRecalibrationEngine.java index 6031aa955..1e166dfd0 100644 --- a/public/java/src/org/broadinstitute/sting/gatk/walkers/bqsr/StandardRecalibrationEngine.java +++ b/public/java/src/org/broadinstitute/sting/gatk/walkers/bqsr/StandardRecalibrationEngine.java @@ -25,15 +25,13 @@ package org.broadinstitute.sting.gatk.walkers.bqsr; * OTHER DEALINGS IN THE SOFTWARE. */ -import org.broadinstitute.sting.utils.recalibration.covariates.Covariate; -import org.broadinstitute.sting.utils.BaseUtils; import org.broadinstitute.sting.utils.classloader.PublicPackageSource; import org.broadinstitute.sting.utils.collections.NestedIntegerArray; -import org.broadinstitute.sting.utils.pileup.PileupElement; import org.broadinstitute.sting.utils.recalibration.EventType; import org.broadinstitute.sting.utils.recalibration.ReadCovariates; import org.broadinstitute.sting.utils.recalibration.RecalDatum; import org.broadinstitute.sting.utils.recalibration.RecalibrationTables; +import org.broadinstitute.sting.utils.recalibration.covariates.Covariate; import org.broadinstitute.sting.utils.sam.GATKSAMRecord; public class StandardRecalibrationEngine implements RecalibrationEngine, PublicPackageSource { @@ -58,8 +56,6 @@ public class StandardRecalibrationEngine implements RecalibrationEngine, PublicP final int[] keys = readCovariates.getKeySet(offset, EventType.BASE_SUBSTITUTION); final int eventIndex = EventType.BASE_SUBSTITUTION.index; - combineDatumOrPutIfNecessary(recalibrationTables.getReadGroupTable(), qual, isError, keys[0], eventIndex); - incrementDatumOrPutIfNecessary(recalibrationTables.getQualityScoreTable(), qual, isError, keys[0], keys[1], eventIndex); for (int i = 2; i < covariates.length; i++) { @@ -93,6 +89,34 @@ public class StandardRecalibrationEngine implements RecalibrationEngine, PublicP return (ReadCovariates) read.getTemporaryAttribute(BaseRecalibrator.COVARS_ATTRIBUTE); } + /** + * Create derived recalibration data tables + * + * Assumes that all of the principal tables (by quality score) have been completely updated, + * and walks over this data to create summary data tables like by read group table. + */ + @Override + public void finalizeData() { + final NestedIntegerArray byReadGroupTable = recalibrationTables.getReadGroupTable(); + final NestedIntegerArray byQualTable = recalibrationTables.getQualityScoreTable(); + + // iterate over all values in the qual table + for ( NestedIntegerArray.Leaf leaf : byQualTable.getAllLeaves() ) { + final int rgKey = leaf.keys[0]; + final int eventIndex = leaf.keys[2]; + final RecalDatum rgDatum = byReadGroupTable.get(rgKey, eventIndex); + final RecalDatum qualDatum = leaf.value; + + if ( rgDatum == null ) { + // create a copy of qualDatum, and initialize byReadGroup table with it + byReadGroupTable.put(new RecalDatum(qualDatum), rgKey, eventIndex); + } else { + // combine the qual datum with the existing datum in the byReadGroup table + rgDatum.combine(qualDatum); + } + } + } + /** * Increments the RecalDatum at the specified position in the specified table, or put a new item there * if there isn't already one. @@ -121,34 +145,4 @@ public class StandardRecalibrationEngine implements RecalibrationEngine, PublicP existingDatum.increment(1.0, isError); } } - - /** - * Combines the RecalDatum at the specified position in the specified table with a new RecalDatum, or put a - * new item there if there isn't already one. - * - * Does this in a thread-safe way WITHOUT being synchronized: relies on the behavior of NestedIntegerArray.put() - * to return false if another thread inserts a new item at our position in the middle of our put operation. - * - * @param table the table that holds/will hold our item - * @param qual qual for this event - * @param isError error value for this event - * @param keys location in table of our item - */ - protected void combineDatumOrPutIfNecessary( final NestedIntegerArray table, final byte qual, final double isError, final int... keys ) { - final RecalDatum existingDatum = table.get(keys); - final RecalDatum newDatum = createDatumObject(qual, isError); - - if ( existingDatum == null ) { - // No existing item, try to put a new one - if ( ! table.put(newDatum, keys) ) { - // Failed to put a new item because another thread came along and put an item here first. - // Get the newly-put item and combine it with our item (item is guaranteed to exist at this point) - table.get(keys).combine(newDatum); - } - } - else { - // Easy case: already an item here, so combine it with our item - existingDatum.combine(newDatum); - } - } } diff --git a/public/java/src/org/broadinstitute/sting/utils/IndelUtils.java b/public/java/src/org/broadinstitute/sting/utils/IndelUtils.java index c6ca39f4b..9b1cc9733 100755 --- a/public/java/src/org/broadinstitute/sting/utils/IndelUtils.java +++ b/public/java/src/org/broadinstitute/sting/utils/IndelUtils.java @@ -122,9 +122,9 @@ public class IndelUtils { ArrayList inds = new ArrayList(); if ( vc.isSimpleInsertion() ) { - indelAlleleString = vc.getAlternateAllele(0).getDisplayString(); + indelAlleleString = vc.getAlternateAllele(0).getDisplayString().substring(1); } else if ( vc.isSimpleDeletion() ) { - indelAlleleString = vc.getReference().getDisplayString(); + indelAlleleString = vc.getReference().getDisplayString().substring(1); } else { inds.add(IND_FOR_OTHER_EVENT); diff --git a/public/java/src/org/broadinstitute/sting/utils/LRUCache.java b/public/java/src/org/broadinstitute/sting/utils/LRUCache.java new file mode 100644 index 000000000..a3514c95f --- /dev/null +++ b/public/java/src/org/broadinstitute/sting/utils/LRUCache.java @@ -0,0 +1,20 @@ +package org.broadinstitute.sting.utils; + +import java.util.LinkedHashMap; +import java.util.Map; + +/** + * An LRU cache implemented as an extension to LinkedHashMap + */ +public class LRUCache extends LinkedHashMap { + private int capacity; // Maximum number of items in the cache. + + public LRUCache(int capacity) { + super(capacity+1, 1.0f, true); // Pass 'true' for accessOrder. + this.capacity = capacity; + } + + protected boolean removeEldestEntry(final Map.Entry entry) { + return (size() > this.capacity); + } +} diff --git a/public/java/src/org/broadinstitute/sting/utils/baq/BAQ.java b/public/java/src/org/broadinstitute/sting/utils/baq/BAQ.java index 3966434c0..51753ecef 100644 --- a/public/java/src/org/broadinstitute/sting/utils/baq/BAQ.java +++ b/public/java/src/org/broadinstitute/sting/utils/baq/BAQ.java @@ -6,6 +6,7 @@ import net.sf.samtools.CigarElement; import net.sf.samtools.CigarOperator; import net.sf.samtools.SAMRecord; import net.sf.samtools.SAMUtils; +import org.apache.log4j.Logger; import org.broadinstitute.sting.utils.collections.Pair; import org.broadinstitute.sting.utils.exceptions.ReviewedStingException; import org.broadinstitute.sting.utils.exceptions.UserException; @@ -37,6 +38,7 @@ import org.broadinstitute.sting.utils.sam.ReadUtils; state[i] being wrong. */ public class BAQ { + private final static Logger logger = Logger.getLogger(BAQ.class); private final static boolean DEBUG = false; public enum CalculationMode { @@ -179,8 +181,7 @@ public class BAQ { /*** initialization ***/ // change coordinates - int l_ref = ref.length; - + final int l_ref = ref.length; // set band width int bw2, bw = l_ref > l_query? l_ref : l_query; @@ -266,26 +267,6 @@ public class BAQ { s[l_query+1] = sum; // the last scaling factor } - //gdbebug+ -/* - double cac=0.; - // undo scaling of forward probabilities to obtain plain probability of observation given model - double[] su = new double[f[l_query].length]; - { - double sum = 0.; - double[] logs = new double[s.length]; - for (k=0; k < logs.length; k++) { - logs[k] = Math.log10(s[k]); - sum += logs[k]; - } - for (k=0; k < f[l_query].length; k++) - su[k]= Math.log10(f[l_query][k])+ sum; - - cac = MathUtils.softMax(su); - } - System.out.format("s:%f\n",cac); - // gdebug- - */ /*** backward ***/ // b[l_query] (b[l_query+1][0]=1 and thus \tilde{b}[][]=1/s[l_query+1]; this is where s[l_query+1] comes from) for (k = 1; k <= l_ref; ++k) { @@ -305,8 +286,8 @@ public class BAQ { for (k = end; k >= beg; --k) { int u, v11, v01, v10; u = set_u(bw, i, k); v11 = set_u(bw, i+1, k+1); v10 = set_u(bw, i+1, k); v01 = set_u(bw, i, k+1); - double e = (k >= l_ref? 0 : calcEpsilon(ref[k], qyi1, _iqual[qstart+i])) * bi1[v11]; - bi[u+0] = e * m[0] + EI * m[1] * bi1[v10+1] + m[2] * bi[v01+2]; // bi1[v11] has been foled into e. + final double e = (k >= l_ref? 0 : calcEpsilon(ref[k], qyi1, _iqual[qstart+i])) * bi1[v11]; + bi[u+0] = e * m[0] + EI * m[1] * bi1[v10+1] + m[2] * bi[v01+2]; // bi1[v11] has been folded into e. bi[u+1] = e * m[3] + EI * m[4] * bi1[v10+1]; bi[u+2] = (e * m[6] + m[8] * bi[v01+2]) * y; } @@ -332,12 +313,12 @@ public class BAQ { /*** MAP ***/ for (i = 1; i <= l_query; ++i) { double sum = 0., max = 0.; - double[] fi = f[i], bi = b[i]; + final double[] fi = f[i], bi = b[i]; int beg = 1, end = l_ref, x, max_k = -1; x = i - bw; beg = beg > x? beg : x; x = i + bw; end = end < x? end : x; for (k = beg; k <= end; ++k) { - int u = set_u(bw, i, k); + final int u = set_u(bw, i, k); double z; sum += (z = fi[u+0] * bi[u+0]); if (z > max) { max = z; max_k = (k-1)<<2 | 0; } sum += (z = fi[u+1] * bi[u+1]); if (z > max) { max = z; max_k = (k-1)<<2 | 1; } @@ -531,7 +512,11 @@ public class BAQ { } } +// final SimpleTimer total = new SimpleTimer(); +// final SimpleTimer local = new SimpleTimer(); +// int n = 0; public BAQCalculationResult calcBAQFromHMM(byte[] ref, byte[] query, byte[] quals, int queryStart, int queryEnd ) { +// total.restart(); if ( queryStart < 0 ) throw new ReviewedStingException("BUG: queryStart < 0: " + queryStart); if ( queryEnd < 0 ) throw new ReviewedStingException("BUG: queryEnd < 0: " + queryEnd); if ( queryEnd < queryStart ) throw new ReviewedStingException("BUG: queryStart < queryEnd : " + queryStart + " end =" + queryEnd); @@ -539,7 +524,12 @@ public class BAQ { // note -- assumes ref is offset from the *CLIPPED* start BAQCalculationResult baqResult = new BAQCalculationResult(query, quals, ref); int queryLen = queryEnd - queryStart; +// local.restart(); hmm_glocal(baqResult.refBases, baqResult.readBases, queryStart, queryLen, baqResult.rawQuals, baqResult.state, baqResult.bq); +// local.stop(); +// total.stop(); +// if ( n++ % 100000 == 0 ) +// logger.info("n = " + n + ": Total " + total.getElapsedTimeNano() + " local " + local.getElapsedTimeNano()); return baqResult; } diff --git a/public/java/src/org/broadinstitute/sting/utils/collections/NestedIntegerArray.java b/public/java/src/org/broadinstitute/sting/utils/collections/NestedIntegerArray.java index 050ed52ac..890a9b488 100755 --- a/public/java/src/org/broadinstitute/sting/utils/collections/NestedIntegerArray.java +++ b/public/java/src/org/broadinstitute/sting/utils/collections/NestedIntegerArray.java @@ -58,13 +58,20 @@ public class NestedIntegerArray { int dimensionsToPreallocate = Math.min(dimensions.length, NUM_DIMENSIONS_TO_PREALLOCATE); - logger.info(String.format("Creating NestedIntegerArray with dimensions %s", Arrays.toString(dimensions))); - logger.info(String.format("Pre-allocating first %d dimensions", dimensionsToPreallocate)); + if ( logger.isDebugEnabled() ) logger.debug(String.format("Creating NestedIntegerArray with dimensions %s", Arrays.toString(dimensions))); + if ( logger.isDebugEnabled() ) logger.debug(String.format("Pre-allocating first %d dimensions", dimensionsToPreallocate)); data = new Object[dimensions[0]]; preallocateArray(data, 0, dimensionsToPreallocate); - logger.info(String.format("Done pre-allocating first %d dimensions", dimensionsToPreallocate)); + if ( logger.isDebugEnabled() ) logger.debug(String.format("Done pre-allocating first %d dimensions", dimensionsToPreallocate)); + } + + /** + * @return the dimensions of this nested integer array. DO NOT MODIFY + */ + public int[] getDimensions() { + return dimensions; } /** @@ -174,23 +181,23 @@ public class NestedIntegerArray { } } - public static class Leaf { + public static class Leaf { public final int[] keys; - public final Object value; + public final T value; - public Leaf(final int[] keys, final Object value) { + public Leaf(final int[] keys, final T value) { this.keys = keys; this.value = value; } } - public List getAllLeaves() { - final List result = new ArrayList(); + public List> getAllLeaves() { + final List> result = new ArrayList>(); fillAllLeaves(data, new int[0], result); return result; } - private void fillAllLeaves(final Object[] array, final int[] path, final List result) { + private void fillAllLeaves(final Object[] array, final int[] path, final List> result) { for ( int key = 0; key < array.length; key++ ) { final Object value = array[key]; if ( value == null ) @@ -199,7 +206,7 @@ public class NestedIntegerArray { if ( value instanceof Object[] ) { fillAllLeaves((Object[]) value, newPath, result); } else { - result.add(new Leaf(newPath, value)); + result.add(new Leaf(newPath, (T)value)); } } } diff --git a/public/java/src/org/broadinstitute/sting/utils/progressmeter/ProgressMeter.java b/public/java/src/org/broadinstitute/sting/utils/progressmeter/ProgressMeter.java index b69283b9d..161335957 100755 --- a/public/java/src/org/broadinstitute/sting/utils/progressmeter/ProgressMeter.java +++ b/public/java/src/org/broadinstitute/sting/utils/progressmeter/ProgressMeter.java @@ -145,7 +145,7 @@ public class ProgressMeter { private final SimpleTimer timer = new SimpleTimer(); private GenomeLoc maxGenomeLoc = null; - private String positionMessage = "starting"; + private Position position = new Position(PositionStatus.STARTING); private long nTotalRecordsProcessed = 0; final ProgressMeterDaemon progressMeterDaemon; @@ -234,9 +234,65 @@ public class ProgressMeter { this.nTotalRecordsProcessed = Math.max(this.nTotalRecordsProcessed, nTotalRecordsProcessed); // a pretty name for our position - this.positionMessage = maxGenomeLoc == null - ? "unmapped reads" - : String.format("%s:%d", maxGenomeLoc.getContig(), maxGenomeLoc.getStart()); + this.position = maxGenomeLoc == null ? new Position(PositionStatus.IN_UNMAPPED_READS) : new Position(maxGenomeLoc); + } + + /** + * Describes the status of this position marker, such as starting up, done, in the unmapped reads, + * or somewhere on the genome + */ + private enum PositionStatus { + STARTING("Starting"), + DONE("done"), + IN_UNMAPPED_READS("unmapped reads"), + ON_GENOME(null); + + public final String message; + + private PositionStatus(String message) { + this.message = message; + } + } + + /** + * A pair of position status and the genome loc, if necessary. Used to get a + * status update message as needed, without the computational cost of formatting + * the genome loc string every time a progress notification happens (which is almost + * always not printed) + */ + private class Position { + final PositionStatus type; + final GenomeLoc maybeLoc; + + /** + * Create a position object of any type != ON_GENOME + * @param type + */ + @Requires({"type != null", "type != PositionStatus.ON_GENOME"}) + private Position(PositionStatus type) { + this.type = type; + this.maybeLoc = null; + } + + /** + * Create a position object of type ON_GENOME at genomeloc loc + * @param loc + */ + @Requires("loc != null") + private Position(GenomeLoc loc) { + this.type = PositionStatus.ON_GENOME; + this.maybeLoc = loc; + } + + /** + * @return a human-readable representation of this position + */ + private String getMessage() { + if ( type == PositionStatus.ON_GENOME ) + return maxGenomeLoc.getContig() + ":" + maxGenomeLoc.getStart(); + else + return type.message; + } } /** @@ -267,7 +323,7 @@ public class ProgressMeter { updateLoggerPrintFrequency(estTotalRuntime.getTimeInSeconds()); logger.info(String.format("%15s %5.2e %s %s %5.1f%% %s %s", - positionMessage, progressData.getUnitsProcessed()*1.0, elapsed, unitRate, + position.getMessage(), progressData.getUnitsProcessed()*1.0, elapsed, unitRate, 100*fractionGenomeTargetCompleted, estTotalRuntime, timeToCompletion)); } @@ -317,7 +373,7 @@ public class ProgressMeter { public void notifyDone(final long nTotalRecordsProcessed) { // print out the progress meter this.nTotalRecordsProcessed = nTotalRecordsProcessed; - this.positionMessage = "done"; + this.position = new Position(PositionStatus.DONE); printProgress(true); logger.info(String.format("Total runtime %.2f secs, %.2f min, %.2f hours", diff --git a/public/java/src/org/broadinstitute/sting/utils/recalibration/BaseRecalibration.java b/public/java/src/org/broadinstitute/sting/utils/recalibration/BaseRecalibration.java index 5d4020a07..567514f8c 100644 --- a/public/java/src/org/broadinstitute/sting/utils/recalibration/BaseRecalibration.java +++ b/public/java/src/org/broadinstitute/sting/utils/recalibration/BaseRecalibration.java @@ -27,6 +27,7 @@ package org.broadinstitute.sting.utils.recalibration; import net.sf.samtools.SAMTag; import net.sf.samtools.SAMUtils; +import org.apache.log4j.Logger; import org.broadinstitute.sting.utils.MathUtils; import org.broadinstitute.sting.utils.QualityUtils; import org.broadinstitute.sting.utils.collections.NestedIntegerArray; @@ -44,7 +45,8 @@ import java.io.File; */ public class BaseRecalibration { - private final static int MAXIMUM_RECALIBRATED_READ_LENGTH = 5000; + private static Logger logger = Logger.getLogger(BaseRecalibration.class); + private final static boolean TEST_CACHING = false; private final QuantizationInfo quantizationInfo; // histogram containing the map for qual quantization (calculated after recalibration is done) private final RecalibrationTables recalibrationTables; @@ -54,12 +56,8 @@ public class BaseRecalibration { private final int preserveQLessThan; private final boolean emitOriginalQuals; - // TODO -- was this supposed to be used somewhere? -// private static final NestedHashMap[] qualityScoreByFullCovariateKey = new NestedHashMap[EventType.values().length]; // Caches the result of performSequentialQualityCalculation(..) for all sets of covariate values. -// static { -// for (int i = 0; i < EventType.values().length; i++) -// qualityScoreByFullCovariateKey[i] = new NestedHashMap(); -// } + private final NestedIntegerArray globalDeltaQs; + private final NestedIntegerArray deltaQReporteds; /** @@ -84,6 +82,44 @@ public class BaseRecalibration { this.disableIndelQuals = disableIndelQuals; this.preserveQLessThan = preserveQLessThan; this.emitOriginalQuals = emitOriginalQuals; + + logger.info("Calculating cached tables..."); + + // + // Create a NestedIntegerArray that maps from rgKey x errorModel -> double, + // where the double is the result of this calculation. The entire calculation can + // be done upfront, on initialization of this BaseRecalibration structure + // + final NestedIntegerArray byReadGroupTable = recalibrationTables.getReadGroupTable(); + globalDeltaQs = new NestedIntegerArray( byReadGroupTable.getDimensions() ); + logger.info("Calculating global delta Q table..."); + for ( NestedIntegerArray.Leaf leaf : byReadGroupTable.getAllLeaves() ) { + final int rgKey = leaf.keys[0]; + final int eventIndex = leaf.keys[1]; + final double globalDeltaQ = calculateGlobalDeltaQ(rgKey, EventType.eventFrom(eventIndex)); + globalDeltaQs.put(globalDeltaQ, rgKey, eventIndex); + } + + + // The calculation of the deltaQ report is constant. key[0] and key[1] are the read group and qual, respectively + // and globalDeltaQ is a constant for the read group. So technically the delta Q reported is simply a lookup + // into a matrix indexed by rgGroup, qual, and event type. + // the code below actually creates this cache with a NestedIntegerArray calling into the actual + // calculateDeltaQReported code. + final NestedIntegerArray byQualTable = recalibrationTables.getQualityScoreTable(); + deltaQReporteds = new NestedIntegerArray( byQualTable.getDimensions() ); + logger.info("Calculating delta Q reported table..."); + for ( NestedIntegerArray.Leaf leaf : byQualTable.getAllLeaves() ) { + final int rgKey = leaf.keys[0]; + final int qual = leaf.keys[1]; + final int eventIndex = leaf.keys[2]; + final EventType event = EventType.eventFrom(eventIndex); + final double globalDeltaQ = getGlobalDeltaQ(rgKey, event); + final double deltaQReported = calculateDeltaQReported(rgKey, qual, event, globalDeltaQ, (byte)qual); + deltaQReporteds.put(deltaQReported, rgKey, qual, eventIndex); + } + + logger.info("done calculating cache"); } /** @@ -91,6 +127,18 @@ public class BaseRecalibration { * * It updates the base qualities of the read with the new recalibrated qualities (for all event types) * + * Implements a serial recalibration of the reads using the combinational table. + * First, we perform a positional recalibration, and then a subsequent dinuc correction. + * + * Given the full recalibration table, we perform the following preprocessing steps: + * + * - calculate the global quality score shift across all data [DeltaQ] + * - calculate for each of cycle and dinuc the shift of the quality scores relative to the global shift + * -- i.e., DeltaQ(dinuc) = Sum(pos) Sum(Qual) Qempirical(pos, qual, dinuc) - Qreported(pos, qual, dinuc) / Npos * Nqual + * - The final shift equation is: + * + * Qrecal = Qreported + DeltaQ + DeltaQ(pos) + DeltaQ(dinuc) + DeltaQ( ... any other covariate ... ) + * * @param read the read to recalibrate */ public void recalibrateRead(final GATKSAMRecord read) { @@ -103,6 +151,7 @@ public class BaseRecalibration { } final ReadCovariates readCovariates = RecalUtils.computeCovariates(read, requestedCovariates); + final int readLength = read.getReadLength(); for (final EventType errorModel : EventType.values()) { // recalibrate all three quality strings if (disableIndelQuals && errorModel != EventType.BASE_SUBSTITUTION) { @@ -111,58 +160,88 @@ public class BaseRecalibration { } final byte[] quals = read.getBaseQualities(errorModel); - final int[][] fullReadKeySet = readCovariates.getKeySet(errorModel); // get the keyset for this base using the error model - final int readLength = read.getReadLength(); + // get the keyset for this base using the error model + final int[][] fullReadKeySet = readCovariates.getKeySet(errorModel); + + // the rg key is constant over the whole read, the global deltaQ is too + final int rgKey = fullReadKeySet[0][0]; + + final double globalDeltaQ = getGlobalDeltaQ(rgKey, errorModel); + for (int offset = 0; offset < readLength; offset++) { // recalibrate all bases in the read + final byte origQual = quals[offset]; - final byte originalQualityScore = quals[offset]; + // only recalibrate usable qualities (the original quality will come from the instrument -- reported quality) + if ( origQual >= preserveQLessThan ) { + // get the keyset for this base using the error model + final int[] keySet = fullReadKeySet[offset]; + final double deltaQReported = getDeltaQReported(keySet[0], keySet[1], errorModel, globalDeltaQ); + final double deltaQCovariates = calculateDeltaQCovariates(recalibrationTables, keySet, errorModel, globalDeltaQ, deltaQReported, origQual); + + // calculate the recalibrated qual using the BQSR formula + double recalibratedQualDouble = origQual + globalDeltaQ + deltaQReported + deltaQCovariates; + + // recalibrated quality is bound between 1 and MAX_QUAL + final byte recalibratedQual = QualityUtils.boundQual(MathUtils.fastRound(recalibratedQualDouble), QualityUtils.MAX_RECALIBRATED_Q_SCORE); + + // return the quantized version of the recalibrated quality + final byte recalibratedQualityScore = quantizationInfo.getQuantizedQuals().get(recalibratedQual); - if (originalQualityScore >= preserveQLessThan) { // only recalibrate usable qualities (the original quality will come from the instrument -- reported quality) - final int[] keySet = fullReadKeySet[offset]; // get the keyset for this base using the error model - final byte recalibratedQualityScore = performSequentialQualityCalculation(keySet, errorModel); // recalibrate the base quals[offset] = recalibratedQualityScore; } } + + // finally update the base qualities in the read read.setBaseQualities(quals, errorModel); } } + private double getGlobalDeltaQ(final int rgKey, final EventType errorModel) { + final Double cached = globalDeltaQs.get(rgKey, errorModel.index); - /** - * Implements a serial recalibration of the reads using the combinational table. - * First, we perform a positional recalibration, and then a subsequent dinuc correction. - * - * Given the full recalibration table, we perform the following preprocessing steps: - * - * - calculate the global quality score shift across all data [DeltaQ] - * - calculate for each of cycle and dinuc the shift of the quality scores relative to the global shift - * -- i.e., DeltaQ(dinuc) = Sum(pos) Sum(Qual) Qempirical(pos, qual, dinuc) - Qreported(pos, qual, dinuc) / Npos * Nqual - * - The final shift equation is: - * - * Qrecal = Qreported + DeltaQ + DeltaQ(pos) + DeltaQ(dinuc) + DeltaQ( ... any other covariate ... ) - * - * @param key The list of Comparables that were calculated from the covariates - * @param errorModel the event type - * @return A recalibrated quality score as a byte - */ - private byte performSequentialQualityCalculation(final int[] key, final EventType errorModel) { + if ( TEST_CACHING ) { + final double calcd = calculateGlobalDeltaQ(rgKey, errorModel); + if ( calcd != cached ) + throw new IllegalStateException("calculated " + calcd + " and cached " + cached + " global delta q not equal at " + rgKey + " / " + errorModel); + } - final byte qualFromRead = (byte)(long)key[1]; - final double globalDeltaQ = calculateGlobalDeltaQ(recalibrationTables.getReadGroupTable(), key, errorModel); - final double deltaQReported = calculateDeltaQReported(recalibrationTables.getQualityScoreTable(), key, errorModel, globalDeltaQ, qualFromRead); - final double deltaQCovariates = calculateDeltaQCovariates(recalibrationTables, key, errorModel, globalDeltaQ, deltaQReported, qualFromRead); - - double recalibratedQual = qualFromRead + globalDeltaQ + deltaQReported + deltaQCovariates; // calculate the recalibrated qual using the BQSR formula - recalibratedQual = QualityUtils.boundQual(MathUtils.fastRound(recalibratedQual), QualityUtils.MAX_RECALIBRATED_Q_SCORE); // recalibrated quality is bound between 1 and MAX_QUAL - - return quantizationInfo.getQuantizedQuals().get((int) recalibratedQual); // return the quantized version of the recalibrated quality + return cachedWithDefault(cached); } - private double calculateGlobalDeltaQ(final NestedIntegerArray table, final int[] key, final EventType errorModel) { + private double getDeltaQReported(final int rgKey, final int qualKey, final EventType errorModel, final double globalDeltaQ) { + final Double cached = deltaQReporteds.get(rgKey, qualKey, errorModel.index); + + if ( TEST_CACHING ) { + final double calcd = calculateDeltaQReported(rgKey, qualKey, errorModel, globalDeltaQ, (byte)qualKey); + if ( calcd != cached ) + throw new IllegalStateException("calculated " + calcd + " and cached " + cached + " global delta q not equal at " + rgKey + " / " + qualKey + " / " + errorModel); + } + + return cachedWithDefault(cached); + } + + /** + * @param d a Double (that may be null) that is the result of a delta Q calculation + * @return a double == d if d != null, or 0.0 if it is + */ + private double cachedWithDefault(final Double d) { + return d == null ? 0.0 : d; + } + + /** + * Note that this calculation is a constant for each rgKey and errorModel. We need only + * compute this value once for all data. + * + * @param rgKey + * @param errorModel + * @return + */ + private double calculateGlobalDeltaQ(final int rgKey, final EventType errorModel) { double result = 0.0; - final RecalDatum empiricalQualRG = table.get(key[0], errorModel.index); + final RecalDatum empiricalQualRG = recalibrationTables.getReadGroupTable().get(rgKey, errorModel.index); + if (empiricalQualRG != null) { final double globalDeltaQEmpirical = empiricalQualRG.getEmpiricalQuality(); final double aggregrateQReported = empiricalQualRG.getEstimatedQReported(); @@ -172,10 +251,10 @@ public class BaseRecalibration { return result; } - private double calculateDeltaQReported(final NestedIntegerArray table, final int[] key, final EventType errorModel, final double globalDeltaQ, final byte qualFromRead) { + private double calculateDeltaQReported(final int rgKey, final int qualKey, final EventType errorModel, final double globalDeltaQ, final byte qualFromRead) { double result = 0.0; - final RecalDatum empiricalQualQS = table.get(key[0], key[1], errorModel.index); + final RecalDatum empiricalQualQS = recalibrationTables.getQualityScoreTable().get(rgKey, qualKey, errorModel.index); if (empiricalQualQS != null) { final double deltaQReportedEmpirical = empiricalQualQS.getEmpiricalQuality(); result = deltaQReportedEmpirical - qualFromRead - globalDeltaQ; @@ -192,12 +271,28 @@ public class BaseRecalibration { if (key[i] < 0) continue; - final RecalDatum empiricalQualCO = recalibrationTables.getTable(i).get(key[0], key[1], key[i], errorModel.index); - if (empiricalQualCO != null) { - final double deltaQCovariateEmpirical = empiricalQualCO.getEmpiricalQuality(); - result += (deltaQCovariateEmpirical - qualFromRead - (globalDeltaQ + deltaQReported)); - } + result += calculateDeltaQCovariate(recalibrationTables.getTable(i), + key[0], key[1], key[i], errorModel, + globalDeltaQ, deltaQReported, qualFromRead); } + return result; } + + private double calculateDeltaQCovariate(final NestedIntegerArray table, + final int rgKey, + final int qualKey, + final int tableKey, + final EventType errorModel, + final double globalDeltaQ, + final double deltaQReported, + final byte qualFromRead) { + final RecalDatum empiricalQualCO = table.get(rgKey, qualKey, tableKey, errorModel.index); + if (empiricalQualCO != null) { + final double deltaQCovariateEmpirical = empiricalQualCO.getEmpiricalQuality(); + return deltaQCovariateEmpirical - qualFromRead - (globalDeltaQ + deltaQReported); + } else { + return 0.0; + } + } } diff --git a/public/java/src/org/broadinstitute/sting/utils/recalibration/ReadCovariates.java b/public/java/src/org/broadinstitute/sting/utils/recalibration/ReadCovariates.java index 2b682f84b..4ddcb2b92 100644 --- a/public/java/src/org/broadinstitute/sting/utils/recalibration/ReadCovariates.java +++ b/public/java/src/org/broadinstitute/sting/utils/recalibration/ReadCovariates.java @@ -1,6 +1,7 @@ package org.broadinstitute.sting.utils.recalibration; -import java.util.Arrays; +import org.apache.log4j.Logger; +import org.broadinstitute.sting.utils.LRUCache; /** * The object temporarily held by a read that describes all of it's covariates. @@ -11,12 +12,47 @@ import java.util.Arrays; * @since 2/8/12 */ public class ReadCovariates { + private final static Logger logger = Logger.getLogger(ReadCovariates.class); + + /** + * How big should we let the LRU cache grow + */ + private static final int LRU_CACHE_SIZE = 500; + + /** + * Use an LRU cache to keep cache of keys (int[][][]) arrays for each read length we've seen. + * The cache allows us to avoid the expense of recreating these arrays for every read. The LRU + * keeps the total number of cached arrays to less than LRU_CACHE_SIZE. + * + * This is a thread local variable, so the total memory required may grow to N_THREADS x LRU_CACHE_SIZE + */ + private final static ThreadLocal> keysCache = new ThreadLocal>() { + @Override protected LRUCache initialValue() { + return new LRUCache(LRU_CACHE_SIZE); + } + }; + + /** + * Our keys, indexed by event type x read length x covariate + */ private final int[][][] keys; + /** + * The index of the current covariate, used by addCovariate + */ private int currentCovariateIndex = 0; public ReadCovariates(final int readLength, final int numberOfCovariates) { - keys = new int[EventType.values().length][readLength][numberOfCovariates]; + final LRUCache cache = keysCache.get(); + final int[][][] cachedKeys = cache.get(readLength); + if ( cachedKeys == null ) { + // There's no cached value for read length so we need to create a new int[][][] array + if ( logger.isDebugEnabled() ) logger.debug("Keys cache miss for length " + readLength + " cache size " + cache.size()); + keys = new int[EventType.values().length][readLength][numberOfCovariates]; + cache.put(readLength, keys); + } else { + keys = cachedKeys; + } } public void setCovariateIndex(final int index) { @@ -24,22 +60,26 @@ public class ReadCovariates { } /** - * Necessary due to bug in BaseRecalibration recalibrateRead function. It is clearly seeing space it's not supposed to - * @return + * Update the keys for mismatch, insertion, and deletion for the current covariate at read offset + * + * @param mismatch the mismatch key value + * @param insertion the insertion key value + * @param deletion the deletion key value + * @param readOffset the read offset, must be >= 0 and <= the read length used to create this ReadCovariates */ - public ReadCovariates clear() { - for ( int i = 0; i < keys.length; i++ ) - for ( int j = 0; j < keys[i].length; j++) - Arrays.fill(keys[i][j], 0); - return this; - } - public void addCovariate(final int mismatch, final int insertion, final int deletion, final int readOffset) { keys[EventType.BASE_SUBSTITUTION.index][readOffset][currentCovariateIndex] = mismatch; keys[EventType.BASE_INSERTION.index][readOffset][currentCovariateIndex] = insertion; keys[EventType.BASE_DELETION.index][readOffset][currentCovariateIndex] = deletion; } + /** + * Get the keys for all covariates at read position for error model + * + * @param readPosition + * @param errorModel + * @return + */ public int[] getKeySet(final int readPosition, final EventType errorModel) { return keys[errorModel.index][readPosition]; } @@ -48,21 +88,12 @@ public class ReadCovariates { return keys[errorModel.index]; } - public int[] getMismatchesKeySet(final int readPosition) { - return keys[EventType.BASE_SUBSTITUTION.index][readPosition]; - } + // ---------------------------------------------------------------------- + // + // routines for testing + // + // ---------------------------------------------------------------------- - public int[] getInsertionsKeySet(final int readPosition) { - return keys[EventType.BASE_INSERTION.index][readPosition]; - } - - public int[] getDeletionsKeySet(final int readPosition) { - return keys[EventType.BASE_DELETION.index][readPosition]; - } - - /** - * Testing routines - */ protected int[][] getMismatchesKeySet() { return keys[EventType.BASE_SUBSTITUTION.index]; } @@ -74,4 +105,16 @@ public class ReadCovariates { protected int[][] getDeletionsKeySet() { return keys[EventType.BASE_DELETION.index]; } + + protected int[] getMismatchesKeySet(final int readPosition) { + return getKeySet(readPosition, EventType.BASE_SUBSTITUTION); + } + + protected int[] getInsertionsKeySet(final int readPosition) { + return getKeySet(readPosition, EventType.BASE_INSERTION); + } + + protected int[] getDeletionsKeySet(final int readPosition) { + return getKeySet(readPosition, EventType.BASE_DELETION); + } } diff --git a/public/java/src/org/broadinstitute/sting/utils/recalibration/RecalDatum.java b/public/java/src/org/broadinstitute/sting/utils/recalibration/RecalDatum.java index 207988749..4cacc26c4 100755 --- a/public/java/src/org/broadinstitute/sting/utils/recalibration/RecalDatum.java +++ b/public/java/src/org/broadinstitute/sting/utils/recalibration/RecalDatum.java @@ -191,9 +191,9 @@ public class RecalDatum { return (byte)(Math.round(getEmpiricalQuality())); } - //--------------------------------------------------------------------------------------------------------------- + //--------------------------------------------------------------------------------------------------------------- // - // increment methods + // toString methods // //--------------------------------------------------------------------------------------------------------------- @@ -206,22 +206,6 @@ public class RecalDatum { return String.format("%s,%.2f,%.2f", toString(), getEstimatedQReported(), getEmpiricalQuality() - getEstimatedQReported()); } -// /** -// * We don't compare the estimated quality reported because it may be different when read from -// * report tables. -// * -// * @param o the other recal datum -// * @return true if the two recal datums have the same number of observations, errors and empirical quality. -// */ -// @Override -// public boolean equals(Object o) { -// if (!(o instanceof RecalDatum)) -// return false; -// RecalDatum other = (RecalDatum) o; -// return super.equals(o) && -// MathUtils.compareDoubles(this.empiricalQuality, other.empiricalQuality, 0.001) == 0; -// } - //--------------------------------------------------------------------------------------------------------------- // // increment methods @@ -264,15 +248,14 @@ public class RecalDatum { @Requires({"incObservations >= 0", "incMismatches >= 0"}) @Ensures({"numObservations == old(numObservations) + incObservations", "numMismatches == old(numMismatches) + incMismatches"}) public synchronized void increment(final double incObservations, final double incMismatches) { - incrementNumObservations(incObservations); - incrementNumMismatches(incMismatches); + numObservations += incObservations; + numMismatches += incMismatches; + empiricalQuality = UNINITIALIZED; } @Ensures({"numObservations == old(numObservations) + 1", "numMismatches >= old(numMismatches)"}) public synchronized void increment(final boolean isError) { - incrementNumObservations(1); - if ( isError ) - incrementNumMismatches(1); + increment(1, isError ? 1 : 0.0); } // ------------------------------------------------------------------------------------- @@ -286,7 +269,7 @@ public class RecalDatum { */ @Requires("empiricalQuality == UNINITIALIZED") @Ensures("empiricalQuality != UNINITIALIZED") - private synchronized final void calcEmpiricalQuality() { + private synchronized void calcEmpiricalQuality() { final double empiricalQual = -10 * Math.log10(getEmpiricalErrorRate()); empiricalQuality = Math.min(empiricalQual, (double) QualityUtils.MAX_RECALIBRATED_Q_SCORE); } diff --git a/public/java/src/org/broadinstitute/sting/utils/recalibration/covariates/ContextCovariate.java b/public/java/src/org/broadinstitute/sting/utils/recalibration/covariates/ContextCovariate.java index 5e470b35f..b586a1607 100644 --- a/public/java/src/org/broadinstitute/sting/utils/recalibration/covariates/ContextCovariate.java +++ b/public/java/src/org/broadinstitute/sting/utils/recalibration/covariates/ContextCovariate.java @@ -26,13 +26,13 @@ package org.broadinstitute.sting.utils.recalibration.covariates; import org.apache.log4j.Logger; -import org.broadinstitute.sting.utils.recalibration.ReadCovariates; import org.broadinstitute.sting.gatk.walkers.bqsr.RecalibrationArgumentCollection; import org.broadinstitute.sting.utils.BaseUtils; import org.broadinstitute.sting.utils.clipping.ClippingRepresentation; import org.broadinstitute.sting.utils.clipping.ReadClipper; import org.broadinstitute.sting.utils.exceptions.ReviewedStingException; import org.broadinstitute.sting.utils.exceptions.UserException; +import org.broadinstitute.sting.utils.recalibration.ReadCovariates; import org.broadinstitute.sting.utils.sam.GATKSAMRecord; import java.util.ArrayList; @@ -99,9 +99,21 @@ public class ContextCovariate implements StandardCovariate { final ArrayList indelKeys = contextWith(bases, indelsContextSize, indelsKeyMask); final int readLength = bases.length; + + // this is necessary to ensure that we don't keep historical data in the ReadCovariates values + // since the context covariate may not span the entire set of values in read covariates + // due to the clipping of the low quality bases + if ( readLength != originalBases.length ) { + // don't both zeroing out if we are going to overwrite the whole array + for ( int i = 0; i < originalBases.length; i++ ) + // this base has been clipped off, so zero out the covariate values here + values.addCovariate(0, 0, 0, i); + } + for (int i = 0; i < readLength; i++) { + final int readOffset = (negativeStrand ? readLength - i - 1 : i); final int indelKey = indelKeys.get(i); - values.addCovariate(mismatchKeys.get(i), indelKey, indelKey, (negativeStrand ? readLength - i - 1 : i)); + values.addCovariate(mismatchKeys.get(i), indelKey, indelKey, readOffset); } // put the original bases back in diff --git a/public/java/src/org/broadinstitute/sting/utils/recalibration/covariates/ReadGroupCovariate.java b/public/java/src/org/broadinstitute/sting/utils/recalibration/covariates/ReadGroupCovariate.java index 29c15adf7..47f11312a 100755 --- a/public/java/src/org/broadinstitute/sting/utils/recalibration/covariates/ReadGroupCovariate.java +++ b/public/java/src/org/broadinstitute/sting/utils/recalibration/covariates/ReadGroupCovariate.java @@ -1,11 +1,13 @@ package org.broadinstitute.sting.utils.recalibration.covariates; -import org.broadinstitute.sting.utils.recalibration.ReadCovariates; import org.broadinstitute.sting.gatk.walkers.bqsr.RecalibrationArgumentCollection; +import org.broadinstitute.sting.utils.recalibration.ReadCovariates; import org.broadinstitute.sting.utils.sam.GATKSAMReadGroupRecord; import org.broadinstitute.sting.utils.sam.GATKSAMRecord; import java.util.HashMap; +import java.util.Map; +import java.util.Set; /* * Copyright (c) 2009 The Broad Institute @@ -77,6 +79,14 @@ public class ReadGroupCovariate implements RequiredCovariate { return keyForReadGroup((String) value); } + /** + * Get the mapping from read group names to integer key values for all read groups in this covariate + * @return a set of mappings from read group names -> integer key values + */ + public Set> getKeyMap() { + return readGroupLookupTable.entrySet(); + } + private int keyForReadGroup(final String readGroupId) { // Rather than synchronize this entire method (which would be VERY expensive for walkers like the BQSR), // synchronize only the table updates. diff --git a/public/java/test/org/broadinstitute/sting/gatk/walkers/PileupWalkerIntegrationTest.java b/public/java/test/org/broadinstitute/sting/gatk/walkers/PileupWalkerIntegrationTest.java index e16ef3125..b457698e9 100644 --- a/public/java/test/org/broadinstitute/sting/gatk/walkers/PileupWalkerIntegrationTest.java +++ b/public/java/test/org/broadinstitute/sting/gatk/walkers/PileupWalkerIntegrationTest.java @@ -6,6 +6,9 @@ import org.testng.annotations.Test; import java.util.Arrays; public class PileupWalkerIntegrationTest extends WalkerTest { + String gatkSpeedupArgs="-T Pileup -I " + validationDataLocation + "NA12878.HiSeq.WGS.bwa.cleaned.recal.hg19.20.bam " + + "-R " + hg19Reference + " -o %s "; + @Test public void testGnarleyFHSPileup() { String gatk_args = "-T Pileup -I " + validationDataLocation + "FHS_Pileup_Test.bam " @@ -39,4 +42,31 @@ public class PileupWalkerIntegrationTest extends WalkerTest { WalkerTestSpec spec = new WalkerTestSpec(gatk_args, 1, Arrays.asList(SingleReadAligningOffChromosome1MD5)); executeTest("Testing single read spanning off chromosome 1 unindexed", spec); } + + /************************/ + + //testing speedup to GATKBAMIndex + + + @Test + public void testPileupOnLargeBamChr20(){ + WalkerTestSpec spec = new WalkerTestSpec(gatkSpeedupArgs + "-L 20:1-76,050", 1, Arrays.asList("8702701350de11a6d28204acefdc4775")); + executeTest("Testing single on big BAM at start of chromosome 20", spec); + } + @Test + public void testPileupOnLargeBamMid20(){ + WalkerTestSpec spec = new WalkerTestSpec(gatkSpeedupArgs + "-L 20:10,000,000-10,001,100", 1, Arrays.asList("818cf5a8229efe6f89fc1cd8145ccbe3")); + executeTest("Testing single on big BAM somewhere in chromosome 20", spec); + } + @Test + public void testPileupOnLargeBamEnd20(){ + WalkerTestSpec spec = new WalkerTestSpec(gatkSpeedupArgs + "-L 20:62,954,114-63,025,520", 1, Arrays.asList("22471ea4a12e5139aef62bf8ff2a5b63")); + executeTest("Testing single at end of chromosome 20", spec); + } + @Test + public void testPileupOnLargeBam20Many(){ + WalkerTestSpec spec = new WalkerTestSpec(gatkSpeedupArgs + "-L 20:1-76,050 -L 20:20,000,000-20,000,100 -L 20:40,000,000-40,000,100 -L 20:30,000,000-30,000,100 -L 20:50,000,000-50,000,100 -L 20:62,954,114-63,025,520 ", + 1, Arrays.asList("08d899ed7c5a76ef3947bf67338acda1")); + executeTest("Testing single on big BAM many places", spec); + } } diff --git a/public/scala/qscript/org/broadinstitute/sting/queue/qscripts/GATKResourcesBundle.scala b/public/scala/qscript/org/broadinstitute/sting/queue/qscripts/GATKResourcesBundle.scala index dc6cae197..3bd7514f2 100755 --- a/public/scala/qscript/org/broadinstitute/sting/queue/qscripts/GATKResourcesBundle.scala +++ b/public/scala/qscript/org/broadinstitute/sting/queue/qscripts/GATKResourcesBundle.scala @@ -44,7 +44,7 @@ class GATKResourcesBundle extends QScript { var exampleFASTA: Reference = _ var refs: List[Reference] = _ - class Resource(val file: File, val name: String, val ref: Reference, val useName: Boolean = true, val makeSites: Boolean = true ) { + class Resource(val file: File, val name: String, val ref: Reference, val useName: Boolean = true, val makeSites: Boolean = true, val makeCallsIfBam: Boolean = true ) { def destname(target: Reference): String = { if ( useName ) return name + "." + target.name + "." + getExtension(file) @@ -68,6 +68,7 @@ class GATKResourcesBundle extends QScript { def isVCF(file: File) = file.getName.endsWith(".vcf") def isBAM(file: File) = file.getName.endsWith(".bam") + def isOUT(file: File) = file.getName.endsWith(".out") def isFASTA(file: File) = file.getName.endsWith(".fasta") var RESOURCES: List[Resource] = Nil @@ -94,7 +95,7 @@ class GATKResourcesBundle extends QScript { addResource(new Resource(DATAROOT + "dbsnp_132_b37.vcf", "dbsnp_132", b37, true, false)) addResource(new Resource(exampleFASTA.file, "exampleFASTA", exampleFASTA, false)) - addResource(new Resource("public/testdata/exampleBAM.bam", "exampleBAM", exampleFASTA, false)) + addResource(new Resource("public/testdata/exampleBAM.bam", "exampleBAM", exampleFASTA, false, false, false)) } def initializeStandardDataFiles() = { @@ -172,7 +173,7 @@ class GATKResourcesBundle extends QScript { // exampleFASTA file // addResource(new Resource(exampleFASTA.file, "exampleFASTA", exampleFASTA, false)) - addResource(new Resource("public/testdata/exampleBAM.bam", "exampleBAM", exampleFASTA, false)) + addResource(new Resource("public/testdata/exampleBAM.bam", "exampleBAM", exampleFASTA, false, false, false)) } def createBundleDirectories(dir: File) = { @@ -184,6 +185,15 @@ class GATKResourcesBundle extends QScript { } } + def createCurrentLink(bundleDir: File) = { + + val currentLink = new File(BUNDLE_ROOT + "/current") + + if ( currentLink.exists ) currentLink.delete() + + add(new linkFile(bundleDir, currentLink)) + } + def script = { if ( TEST ) initializeTestDataFiles(); @@ -201,8 +211,10 @@ class GATKResourcesBundle extends QScript { } else if ( isBAM(resource.file) ) { val f = copyBundleFile(resource, resource.ref) add(new IndexBAM(f)) - @Output val outvcf: File = swapExt(f.getParent, f, ".bam", ".vcf") - add(new UG(resource.file, resource.ref.file, outvcf)) + if ( resource.makeCallsIfBam ) { + @Output val outvcf: File = swapExt(f.getParent, f, ".bam", ".vcf") + add(new UG(resource.file, resource.ref.file, outvcf)) + } } else if ( isVCF(resource.file) ) { for ( destRef <- refs ) { val out = destFile(BUNDLE_DIR, destRef, resource.destname(destRef)) @@ -240,6 +252,9 @@ class GATKResourcesBundle extends QScript { //throw new ReviewedStingException("Unknown file type: " + resource) } } + + createCurrentLink(BUNDLE_DIR) + } else { createBundleDirectories(DOWNLOAD_DIR) createDownloadsFromBundle(BUNDLE_DIR, DOWNLOAD_DIR) @@ -249,7 +264,6 @@ class GATKResourcesBundle extends QScript { def createDownloadsFromBundle(in: File, out: File) { Console.printf("Visiting %s%n", in) - // todo -- ignore some of the other files too (e.g. *.out); will test next time we make a bundle if (! in.getName.startsWith(".")) { if ( in.isDirectory ) { out.mkdirs @@ -261,7 +275,7 @@ class GATKResourcesBundle extends QScript { if ( isBAM(in) ) { add(new cpFile(in, out)) add(new md5sum(out)) - } else { + } else if ( !isOUT(in) ) { add(new GzipFile(in, out + ".gz")) add(new md5sum(out + ".gz")) } @@ -299,6 +313,10 @@ class GATKResourcesBundle extends QScript { def commandLine = "cp %s %s".format(in.getAbsolutePath, out.getAbsolutePath) } + class linkFile(@Input val in: File, @Output val out: File) extends CommandLineFunction { + def commandLine = "ln -s %s %s".format(in.getAbsolutePath, out.getAbsolutePath) + } + class md5sum(@Input val in: File) extends CommandLineFunction { @Output val o: File = new File(in.getAbsolutePath + ".md5") def commandLine = "md5sum %s > %s".format(in.getAbsolutePath, o)