From 75ceddf9e5edff5d281aae31ed484a6194a0f02a Mon Sep 17 00:00:00 2001 From: Eric Banks Date: Thu, 31 Jan 2013 09:46:38 -0500 Subject: [PATCH] Adding new unit tests for RR. These tests took a frustratingly long time to get to pass, but now we have a framework for testing the adding of reads into the SlidingWindow plus consensus creation. Will flesh these out more after I take care of some other items on my plate. --- .../reducereads/CompressionStash.java | 8 +- .../reducereads/SlidingWindow.java | 7 +- .../reducereads/SlidingWindowUnitTest.java | 251 +++++++++++------- 3 files changed, 170 insertions(+), 96 deletions(-) diff --git a/protected/java/src/org/broadinstitute/sting/gatk/walkers/compression/reducereads/CompressionStash.java b/protected/java/src/org/broadinstitute/sting/gatk/walkers/compression/reducereads/CompressionStash.java index e0e49cba3..bd7bdfe89 100644 --- a/protected/java/src/org/broadinstitute/sting/gatk/walkers/compression/reducereads/CompressionStash.java +++ b/protected/java/src/org/broadinstitute/sting/gatk/walkers/compression/reducereads/CompressionStash.java @@ -74,7 +74,7 @@ public class CompressionStash extends TreeSet { * @return true if the loc, or it's merged version, wasn't present in the list before. */ @Override - public boolean add(FinishedGenomeLoc insertLoc) { + public boolean add(final FinishedGenomeLoc insertLoc) { TreeSet removedLocs = new TreeSet(); for (FinishedGenomeLoc existingLoc : this) { if (existingLoc.isPast(insertLoc)) { @@ -87,10 +87,10 @@ public class CompressionStash extends TreeSet { removedLocs.add(existingLoc); // list the original loc for merging } } - for (GenomeLoc loc : removedLocs) { - this.remove(loc); // remove all locs that will be merged - } + + this.removeAll(removedLocs); // remove all locs that will be merged removedLocs.add(insertLoc); // add the new loc to the list of locs that will be merged + return super.add(new FinishedGenomeLoc(GenomeLoc.merge(removedLocs), insertLoc.isFinished())); } diff --git a/protected/java/src/org/broadinstitute/sting/gatk/walkers/compression/reducereads/SlidingWindow.java b/protected/java/src/org/broadinstitute/sting/gatk/walkers/compression/reducereads/SlidingWindow.java index e2f8b6682..985fbba57 100644 --- a/protected/java/src/org/broadinstitute/sting/gatk/walkers/compression/reducereads/SlidingWindow.java +++ b/protected/java/src/org/broadinstitute/sting/gatk/walkers/compression/reducereads/SlidingWindow.java @@ -141,7 +141,10 @@ public class SlidingWindow { protected SlidingWindow(final String contig, final int contigIndex, final int startLocation) { this.contig = contig; this.contigIndex = contigIndex; + + contextSize = 10; nContigs = 1; + this.windowHeader = new LinkedList(); windowHeader.addFirst(new HeaderElement(startLocation)); this.readsInWindow = new TreeSet(); @@ -293,7 +296,7 @@ public class SlidingWindow { } - private final class MarkedSites { + protected final class MarkedSites { private boolean[] siteIsVariant = new boolean[0]; private int startLocation = 0; @@ -302,6 +305,8 @@ public class SlidingWindow { public boolean[] getVariantSiteBitSet() { return siteIsVariant; } + protected int getStartLocation() { return startLocation; } + /** * Updates the variant site bitset given the new startlocation and size of the region to mark. * diff --git a/protected/java/test/org/broadinstitute/sting/gatk/walkers/compression/reducereads/SlidingWindowUnitTest.java b/protected/java/test/org/broadinstitute/sting/gatk/walkers/compression/reducereads/SlidingWindowUnitTest.java index d9b55963d..91dfb94a9 100644 --- a/protected/java/test/org/broadinstitute/sting/gatk/walkers/compression/reducereads/SlidingWindowUnitTest.java +++ b/protected/java/test/org/broadinstitute/sting/gatk/walkers/compression/reducereads/SlidingWindowUnitTest.java @@ -46,23 +46,48 @@ package org.broadinstitute.sting.gatk.walkers.compression.reducereads; +import net.sf.picard.reference.IndexedFastaSequenceFile; +import net.sf.samtools.SAMFileHeader; import org.broadinstitute.sting.BaseTest; +import org.broadinstitute.sting.utils.GenomeLoc; +import org.broadinstitute.sting.utils.UnvalidatingGenomeLoc; +import org.broadinstitute.sting.utils.Utils; +import org.broadinstitute.sting.utils.collections.Pair; +import org.broadinstitute.sting.utils.fasta.CachingIndexedFastaSequenceFile; +import org.broadinstitute.sting.utils.sam.ArtificialSAMUtils; +import org.broadinstitute.sting.utils.sam.GATKSAMReadGroupRecord; +import org.broadinstitute.sting.utils.sam.GATKSAMRecord; import org.testng.Assert; +import org.testng.annotations.BeforeClass; import org.testng.annotations.DataProvider; import org.testng.annotations.Test; +import java.io.File; +import java.io.FileNotFoundException; import java.util.ArrayList; import java.util.Arrays; import java.util.List; +import java.util.Set; public class SlidingWindowUnitTest extends BaseTest { + private static final int variantRegionLength = 1000; + private static final int globalStartPosition = 1000000; + + private static boolean[] createBitset(final List locs) { + final boolean[] variantRegionBitset = new boolean[variantRegionLength]; + for ( FinishedGenomeLoc loc : locs ) { + final int stop = loc.getStop() - globalStartPosition; + for ( int i = loc.getStart() - globalStartPosition; i <= stop; i++ ) + variantRegionBitset[i] = true; + } + return variantRegionBitset; + } + ////////////////////////////////////////////////////////////////////////////////////// //// This section tests the findVariantRegions() method and related functionality //// ////////////////////////////////////////////////////////////////////////////////////// - private static final int variantRegionLength = 1000; - private static final int globalStartPosition = 1000000; private static final FinishedGenomeLoc loc90to95 = new FinishedGenomeLoc("1", 0, 1000090, 1000095, false); private static final FinishedGenomeLoc loc96to99 = new FinishedGenomeLoc("1", 0, 1000096, 1000099, false); private static final FinishedGenomeLoc loc100to110 = new FinishedGenomeLoc("1", 0, 1000100, 1000110, false); @@ -85,16 +110,6 @@ public class SlidingWindowUnitTest extends BaseTest { } } - private static boolean[] createBitset(final List locs) { - boolean[] variantRegionBitset = new boolean[variantRegionLength]; - for ( FinishedGenomeLoc loc : locs ) { - final int stop = loc.getStop() - globalStartPosition; - for ( int i = loc.getStart() - globalStartPosition; i <= stop; i++ ) - variantRegionBitset[i] = true; - } - return variantRegionBitset; - } - @DataProvider(name = "findVariantRegions") public Object[][] createFindVariantRegionsData() { List tests = new ArrayList(); @@ -127,115 +142,169 @@ public class SlidingWindowUnitTest extends BaseTest { } + ///////////////////////////////////////////////////////////////////////////// + //// This section tests the markSites() method and related functionality //// + ///////////////////////////////////////////////////////////////////////////// + @Test(enabled = true) + public void testMarkedSitesClass() { + final SlidingWindow slidingWindow = new SlidingWindow("1", 0, globalStartPosition); + final SlidingWindow.MarkedSites markedSites = slidingWindow.new MarkedSites(); + markedSites.updateRegion(100, 100); + Assert.assertEquals(markedSites.getStartLocation(), 100); + Assert.assertEquals(markedSites.getVariantSiteBitSet().length, 100); + markedSites.updateRegion(300, 100); + Assert.assertEquals(markedSites.getStartLocation(), 300); + Assert.assertEquals(markedSites.getVariantSiteBitSet().length, 100); + markedSites.getVariantSiteBitSet()[10] = true; + markedSites.updateRegion(290, 100); + Assert.assertEquals(markedSites.getStartLocation(), 290); + Assert.assertEquals(markedSites.getVariantSiteBitSet().length, 100); + Assert.assertFalse(markedSites.getVariantSiteBitSet()[10]); + markedSites.getVariantSiteBitSet()[20] = true; + markedSites.updateRegion(290, 100); + Assert.assertEquals(markedSites.getStartLocation(), 290); + Assert.assertEquals(markedSites.getVariantSiteBitSet().length, 100); + Assert.assertTrue(markedSites.getVariantSiteBitSet()[20]); + markedSites.updateRegion(300, 100); + Assert.assertEquals(markedSites.getStartLocation(), 300); + Assert.assertEquals(markedSites.getVariantSiteBitSet().length, 100); + markedSites.getVariantSiteBitSet()[95] = true; + markedSites.updateRegion(390, 20); + Assert.assertEquals(markedSites.getStartLocation(), 390); + Assert.assertEquals(markedSites.getVariantSiteBitSet().length, 20); + Assert.assertTrue(markedSites.getVariantSiteBitSet()[5]); + markedSites.updateRegion(340, 60); + Assert.assertEquals(markedSites.getStartLocation(), 340); + Assert.assertEquals(markedSites.getVariantSiteBitSet().length, 60); + markedSites.getVariantSiteBitSet()[20] = true; + markedSites.updateRegion(350, 60); + Assert.assertEquals(markedSites.getStartLocation(), 350); + Assert.assertEquals(markedSites.getVariantSiteBitSet().length, 60); + Assert.assertTrue(markedSites.getVariantSiteBitSet()[10]); + } + @Test(enabled = true) + public void testMarkVariantRegion() { + final SlidingWindow slidingWindow = new SlidingWindow("1", 0, globalStartPosition); + SlidingWindow.MarkedSites markedSites = slidingWindow.new MarkedSites(); + markedSites.updateRegion(100, 100); - /* + slidingWindow.markVariantRegion(markedSites, 40); + Assert.assertEquals(countTrueBits(markedSites.getVariantSiteBitSet()), 21); - private static class DownsamplingReadsIteratorTest extends TestDataProvider { - private DownsamplingReadsIterator downsamplingIter; - private int targetCoverage; - private ArtificialSingleSampleReadStream stream; - private ArtificialSingleSampleReadStreamAnalyzer streamAnalyzer; + slidingWindow.markVariantRegion(markedSites, 5); + Assert.assertEquals(countTrueBits(markedSites.getVariantSiteBitSet()), 37); - public DownsamplingReadsIteratorTest( ArtificialSingleSampleReadStream stream, int targetCoverage ) { - super(DownsamplingReadsIteratorTest.class); + slidingWindow.markVariantRegion(markedSites, 95); + Assert.assertEquals(countTrueBits(markedSites.getVariantSiteBitSet()), 52); + } - this.stream = stream; - this.targetCoverage = targetCoverage; - - setName(String.format("%s: targetCoverage=%d numContigs=%d stacksPerContig=%d readsPerStack=%d-%d distanceBetweenStacks=%d-%d readLength=%d-%d unmappedReads=%d", - getClass().getSimpleName(), - targetCoverage, - stream.getNumContigs(), - stream.getNumStacksPerContig(), - stream.getMinReadsPerStack(), - stream.getMaxReadsPerStack(), - stream.getMinDistanceBetweenStacks(), - stream.getMaxDistanceBetweenStacks(), - stream.getMinReadLength(), - stream.getMaxReadLength(), - stream.getNumUnmappedReads())); + private static int countTrueBits(final boolean[] bitset) { + int count = 0; + for ( final boolean bit : bitset ) { + if ( bit ) + count++; } + return count; + } - public void run() { - streamAnalyzer = new PositionallyDownsampledArtificialSingleSampleReadStreamAnalyzer(stream, targetCoverage); - downsamplingIter = new DownsamplingReadsIterator(stream.getStingSAMIterator(), new SimplePositionalDownsampler(targetCoverage)); + ///////////////////////////////////////////////////////////////// + //// This section tests the consensus creation functionality //// + ///////////////////////////////////////////////////////////////// - streamAnalyzer.analyze(downsamplingIter); + private static final int readLength = 100; + private final List basicReads = new ArrayList(20); + private IndexedFastaSequenceFile seq; + private SAMFileHeader header; - // Check whether the observed properties of the downsampled stream are what they should be - streamAnalyzer.validate(); + @BeforeClass + public void setup() throws FileNotFoundException { + seq = new CachingIndexedFastaSequenceFile(new File(b37KGReference)); + header = ArtificialSAMUtils.createArtificialSamHeader(seq.getSequenceDictionary()); - // Allow memory used by this test to be reclaimed - stream = null; - streamAnalyzer = null; - downsamplingIter = null; + final int testRegionSize = 1000; + final int readFrequency = 20; + + basicReads.clear(); + for ( int i = 0; i < testRegionSize; i += readFrequency ) { + final GATKSAMRecord read = ArtificialSAMUtils.createArtificialRead(header, "basicRead" + i, 0, globalStartPosition + i, readLength); + read.setReadBases(Utils.dupBytes((byte) 'A', readLength)); + read.setBaseQualities(Utils.dupBytes((byte)30, readLength)); + read.setMappingQuality((byte)30); + basicReads.add(read); } } - @DataProvider(name = "DownsamplingReadsIteratorTestDataProvider") - public Object[][] createDownsamplingReadsIteratorTests() { - SAMFileHeader header = ArtificialSAMUtils.createArtificialSamHeader(5, 1, 10000); - String readGroupID = "testReadGroup"; - SAMReadGroupRecord readGroup = new SAMReadGroupRecord(readGroupID); - readGroup.setSample("testSample"); - header.addReadGroup(readGroup); + private class ConsensusCreationTest { + public final int expectedNumberOfReads; + public final List myReads = new ArrayList(20); - // Values that don't vary across tests - int targetCoverage = 10; - int minReadLength = 50; - int maxReadLength = 100; - int minDistanceBetweenStacks = 1; - int maxDistanceBetweenStacks = maxReadLength + 1; + private ConsensusCreationTest(final List locs, final boolean readsShouldBeLowQuality, final int expectedNumberOfReads) { + this.expectedNumberOfReads = expectedNumberOfReads; - GenomeAnalysisEngine.resetRandomGenerator(); + // first, add the basic reads to the collection + myReads.addAll(basicReads); - // brute force testing! - for ( int numContigs : Arrays.asList(1, 2, 5) ) { - for ( int stacksPerContig : Arrays.asList(1, 2, 10) ) { - for ( int minReadsPerStack : Arrays.asList(1, targetCoverage / 2, targetCoverage, targetCoverage - 1, targetCoverage + 1, targetCoverage * 2) ) { - for ( int maxReadsPerStack : Arrays.asList(1, targetCoverage / 2, targetCoverage, targetCoverage - 1, targetCoverage + 1, targetCoverage * 2) ) { - for ( int numUnmappedReads : Arrays.asList(0, 1, targetCoverage, targetCoverage * 2) ) { - // Only interested in sane read stream configurations here - if ( minReadsPerStack <= maxReadsPerStack ) { - new DownsamplingReadsIteratorTest(new ArtificialSingleSampleReadStream(header, - readGroupID, - numContigs, - stacksPerContig, - minReadsPerStack, - maxReadsPerStack, - minDistanceBetweenStacks, - maxDistanceBetweenStacks, - minReadLength, - maxReadLength, - numUnmappedReads), - targetCoverage); - } - } - } - } - } + // then add the permuted reads + for ( final GenomeLoc loc : locs ) + myReads.add(createVariantRead(loc, readsShouldBeLowQuality)); } - return DownsamplingReadsIteratorTest.getTests(DownsamplingReadsIteratorTest.class); + private GATKSAMRecord createVariantRead(final GenomeLoc loc, final boolean baseShouldBeLowQuality) { + + final int startPos = loc.getStart() - 50; + + final GATKSAMRecord read = ArtificialSAMUtils.createArtificialRead(header, "myRead" + startPos, 0, startPos, readLength); + final byte[] bases = Utils.dupBytes((byte) 'A', readLength); + // create a mismatch + bases[50] = 'C'; + read.setReadBases(bases); + final byte qual = baseShouldBeLowQuality ? (byte)10 : (byte)30; + read.setBaseQualities(Utils.dupBytes(qual, readLength)); + read.setMappingQuality((byte)30); + return read; + } } - @Test(dataProvider = "DownsamplingReadsIteratorTestDataProvider") - public void runDownsamplingReadsIteratorTest( DownsamplingReadsIteratorTest test ) { - logger.warn("Running test: " + test); + private static final GenomeLoc loc290 = new UnvalidatingGenomeLoc("1", 0, 1000290, 1000290); + private static final GenomeLoc loc295 = new UnvalidatingGenomeLoc("1", 0, 1000295, 1000295); + private static final GenomeLoc loc309 = new UnvalidatingGenomeLoc("1", 0, 1000309, 1000309); + private static final GenomeLoc loc310 = new UnvalidatingGenomeLoc("1", 0, 1000310, 1000310); - GenomeAnalysisEngine.resetRandomGenerator(); - test.run(); + @DataProvider(name = "ConsensusCreation") + public Object[][] createConsensusCreationTestData() { + List tests = new ArrayList(); + + tests.add(new Object[]{new ConsensusCreationTest(Arrays.asList(), false, 1)}); + tests.add(new Object[]{new ConsensusCreationTest(Arrays.asList(loc290), false, 9)}); + tests.add(new Object[]{new ConsensusCreationTest(Arrays.asList(loc290, loc295), false, 10)}); + tests.add(new Object[]{new ConsensusCreationTest(Arrays.asList(loc290, loc309), false, 10)}); + tests.add(new Object[]{new ConsensusCreationTest(Arrays.asList(loc290, loc310), false, 11)}); + + return tests.toArray(new Object[][]{}); } - */ + @Test(dataProvider = "ConsensusCreation", enabled = true) + public void testConsensusCreationTest(ConsensusCreationTest test) { + final SlidingWindow slidingWindow = new SlidingWindow("1", 0, 10, header, new GATKSAMReadGroupRecord("test"), 0, 0.05, 0.05, 20, 20, 100, ReduceReads.DownsampleStrategy.Normal, false, 1, false); + for ( final GATKSAMRecord read : test.myReads ) + slidingWindow.addRead(read); + final Pair, CompressionStash> result = slidingWindow.close(); + + Assert.assertEquals(result.getFirst().size(), test.expectedNumberOfReads); + } + + + + + }