From bf2d5efe4d51d42b1ed53b13adf5595fdb76f444 Mon Sep 17 00:00:00 2001 From: Ryan Poplin Date: Tue, 17 Jul 2012 14:51:26 -0400 Subject: [PATCH] Moving HaplotypeCaller integration and unit tests over to protected as well. --- .../GenotypingEngineUnitTest.java | 271 ++++++++++++++++++ .../HaplotypeCallerIntegrationTest.java | 36 +++ .../LikelihoodCalculationEngineUnitTest.java | 173 +++++++++++ .../SimpleDeBruijnAssemblerUnitTest.java | 257 +++++++++++++++++ 4 files changed, 737 insertions(+) create mode 100644 protected/java/test/org/broadinstitute/sting/gatk/walkers/haplotypecaller/GenotypingEngineUnitTest.java create mode 100644 protected/java/test/org/broadinstitute/sting/gatk/walkers/haplotypecaller/HaplotypeCallerIntegrationTest.java create mode 100644 protected/java/test/org/broadinstitute/sting/gatk/walkers/haplotypecaller/LikelihoodCalculationEngineUnitTest.java create mode 100644 protected/java/test/org/broadinstitute/sting/gatk/walkers/haplotypecaller/SimpleDeBruijnAssemblerUnitTest.java diff --git a/protected/java/test/org/broadinstitute/sting/gatk/walkers/haplotypecaller/GenotypingEngineUnitTest.java b/protected/java/test/org/broadinstitute/sting/gatk/walkers/haplotypecaller/GenotypingEngineUnitTest.java new file mode 100644 index 000000000..4826bfb16 --- /dev/null +++ b/protected/java/test/org/broadinstitute/sting/gatk/walkers/haplotypecaller/GenotypingEngineUnitTest.java @@ -0,0 +1,271 @@ +package org.broadinstitute.sting.gatk.walkers.haplotypecaller; + +/** + * Created by IntelliJ IDEA. + * User: rpoplin + * Date: 3/15/12 + */ + +import net.sf.picard.reference.ReferenceSequenceFile; +import org.broadinstitute.sting.BaseTest; +import org.broadinstitute.sting.utils.*; +import org.broadinstitute.sting.utils.fasta.CachingIndexedFastaSequenceFile; +import org.broadinstitute.sting.utils.variantcontext.Allele; +import org.broadinstitute.sting.utils.variantcontext.VariantContext; +import org.testng.Assert; +import org.testng.annotations.BeforeClass; +import org.testng.annotations.DataProvider; +import org.testng.annotations.Test; + +import java.io.File; +import java.io.FileNotFoundException; +import java.util.*; + +/** + * Unit tests for GenotypingEngine + */ +public class GenotypingEngineUnitTest extends BaseTest { + + private static ReferenceSequenceFile seq; + private GenomeLocParser genomeLocParser; + + @BeforeClass + public void init() throws FileNotFoundException { + // sequence + seq = new CachingIndexedFastaSequenceFile(new File(b37KGReference)); + genomeLocParser = new GenomeLocParser(seq); + } + + @Test + public void testFindHomVarEventAllelesInSample() { + final List eventAlleles = new ArrayList(); + eventAlleles.add( Allele.create("A", true) ); + eventAlleles.add( Allele.create("C", false) ); + final List haplotypeAlleles = new ArrayList(); + haplotypeAlleles.add( Allele.create("AATA", true) ); + haplotypeAlleles.add( Allele.create("AACA", false) ); + haplotypeAlleles.add( Allele.create("CATA", false) ); + haplotypeAlleles.add( Allele.create("CACA", false) ); + final ArrayList haplotypes = new ArrayList(); + haplotypes.add(new Haplotype("AATA".getBytes())); + haplotypes.add(new Haplotype("AACA".getBytes())); + haplotypes.add(new Haplotype("CATA".getBytes())); + haplotypes.add(new Haplotype("CACA".getBytes())); + final List haplotypeAllelesForSample = new ArrayList(); + haplotypeAllelesForSample.add( Allele.create("CATA", false) ); + haplotypeAllelesForSample.add( Allele.create("CACA", false) ); + final ArrayList> alleleMapper = new ArrayList>(); + ArrayList Aallele = new ArrayList(); + Aallele.add(haplotypes.get(0)); + Aallele.add(haplotypes.get(1)); + ArrayList Callele = new ArrayList(); + Callele.add(haplotypes.get(2)); + Callele.add(haplotypes.get(3)); + alleleMapper.add(Aallele); + alleleMapper.add(Callele); + final List eventAllelesForSample = new ArrayList(); + eventAllelesForSample.add( Allele.create("C", false) ); + eventAllelesForSample.add( Allele.create("C", false) ); + + if(!compareAlleleLists(eventAllelesForSample, GenotypingEngine.findEventAllelesInSample(eventAlleles, haplotypeAlleles, haplotypeAllelesForSample, alleleMapper, haplotypes))) { + logger.warn("calc alleles = " + GenotypingEngine.findEventAllelesInSample(eventAlleles, haplotypeAlleles, haplotypeAllelesForSample, alleleMapper, haplotypes)); + logger.warn("expected alleles = " + eventAllelesForSample); + } + Assert.assertTrue(compareAlleleLists(eventAllelesForSample, GenotypingEngine.findEventAllelesInSample(eventAlleles, haplotypeAlleles, haplotypeAllelesForSample, alleleMapper, haplotypes))); + } + + @Test + public void testFindHetEventAllelesInSample() { + final List eventAlleles = new ArrayList(); + eventAlleles.add( Allele.create("A", true) ); + eventAlleles.add( Allele.create("C", false) ); + eventAlleles.add( Allele.create("T", false) ); + final List haplotypeAlleles = new ArrayList(); + haplotypeAlleles.add( Allele.create("AATA", true) ); + haplotypeAlleles.add( Allele.create("AACA", false) ); + haplotypeAlleles.add( Allele.create("CATA", false) ); + haplotypeAlleles.add( Allele.create("CACA", false) ); + haplotypeAlleles.add( Allele.create("TACA", false) ); + haplotypeAlleles.add( Allele.create("TTCA", false) ); + haplotypeAlleles.add( Allele.create("TTTA", false) ); + final ArrayList haplotypes = new ArrayList(); + haplotypes.add(new Haplotype("AATA".getBytes())); + haplotypes.add(new Haplotype("AACA".getBytes())); + haplotypes.add(new Haplotype("CATA".getBytes())); + haplotypes.add(new Haplotype("CACA".getBytes())); + haplotypes.add(new Haplotype("TACA".getBytes())); + haplotypes.add(new Haplotype("TTCA".getBytes())); + haplotypes.add(new Haplotype("TTTA".getBytes())); + final List haplotypeAllelesForSample = new ArrayList(); + haplotypeAllelesForSample.add( Allele.create("TTTA", false) ); + haplotypeAllelesForSample.add( Allele.create("AATA", true) ); + final ArrayList> alleleMapper = new ArrayList>(); + ArrayList Aallele = new ArrayList(); + Aallele.add(haplotypes.get(0)); + Aallele.add(haplotypes.get(1)); + ArrayList Callele = new ArrayList(); + Callele.add(haplotypes.get(2)); + Callele.add(haplotypes.get(3)); + ArrayList Tallele = new ArrayList(); + Tallele.add(haplotypes.get(4)); + Tallele.add(haplotypes.get(5)); + Tallele.add(haplotypes.get(6)); + alleleMapper.add(Aallele); + alleleMapper.add(Callele); + alleleMapper.add(Tallele); + final List eventAllelesForSample = new ArrayList(); + eventAllelesForSample.add( Allele.create("A", true) ); + eventAllelesForSample.add( Allele.create("T", false) ); + + if(!compareAlleleLists(eventAllelesForSample, GenotypingEngine.findEventAllelesInSample(eventAlleles, haplotypeAlleles, haplotypeAllelesForSample, alleleMapper, haplotypes))) { + logger.warn("calc alleles = " + GenotypingEngine.findEventAllelesInSample(eventAlleles, haplotypeAlleles, haplotypeAllelesForSample, alleleMapper, haplotypes)); + logger.warn("expected alleles = " + eventAllelesForSample); + } + Assert.assertTrue(compareAlleleLists(eventAllelesForSample, GenotypingEngine.findEventAllelesInSample(eventAlleles, haplotypeAlleles, haplotypeAllelesForSample, alleleMapper, haplotypes))); + } + + private boolean compareAlleleLists(List l1, List l2) { + if( l1.size() != l2.size() ) { + return false; // sanity check + } + + for( int i=0; i < l1.size(); i++ ){ + if ( !l2.contains(l1.get(i)) ) + return false; + } + return true; + } + + + private class BasicGenotypingTestProvider extends TestDataProvider { + byte[] ref; + byte[] hap; + HashMap expected; + GenotypingEngine ge = new GenotypingEngine(false, 0, false); + + public BasicGenotypingTestProvider(String refString, String hapString, HashMap expected) { + super(BasicGenotypingTestProvider.class, String.format("Haplotype to VCF test: ref = %s, alignment = %s", refString,hapString)); + ref = refString.getBytes(); + hap = hapString.getBytes(); + this.expected = expected; + } + + public HashMap calcAlignment() { + final SWPairwiseAlignment alignment = new SWPairwiseAlignment(ref, hap); + return ge.generateVCsFromAlignment( alignment.getAlignmentStart2wrt1(), alignment.getCigar(), ref, hap, genomeLocParser.createGenomeLoc("4",1,1+ref.length), "name", 0); + } + } + + @DataProvider(name = "BasicGenotypingTestProvider") + public Object[][] makeBasicGenotypingTests() { + + for( int contextSize : new int[]{0,1,5,9,24,36} ) { + HashMap map = new HashMap(); + map.put(1 + contextSize, (byte)'M'); + final String context = Utils.dupString('G', contextSize); + new BasicGenotypingTestProvider(context + "AGCTCGCATCGCGAGCATCGACTAGCCGATAG" + context, "CGCTCGCATCGCGAGCATCGACTAGCCGATAG", map); + } + + for( int contextSize : new int[]{0,1,5,9,24,36} ) { + HashMap map = new HashMap(); + map.put(2 + contextSize, (byte)'M'); + map.put(21 + contextSize, (byte)'M'); + final String context = Utils.dupString('G', contextSize); + new BasicGenotypingTestProvider(context + "AGCTCGCATCGCGAGCATCGACTAGCCGATAG", "ATCTCGCATCGCGAGCATCGCCTAGCCGATAG", map); + } + + for( int contextSize : new int[]{0,1,5,9,24,36} ) { + HashMap map = new HashMap(); + map.put(1 + contextSize, (byte)'M'); + map.put(20 + contextSize, (byte)'I'); + final String context = Utils.dupString('G', contextSize); + new BasicGenotypingTestProvider(context + "AGCTCGCATCGCGAGCATCGACTAGCCGATAG" + context, "CGCTCGCATCGCGAGCATCGACACTAGCCGATAG", map); + } + + for( int contextSize : new int[]{0,1,5,9,24,36} ) { + HashMap map = new HashMap(); + map.put(1 + contextSize, (byte)'M'); + map.put(20 + contextSize, (byte)'D'); + final String context = Utils.dupString('G', contextSize); + new BasicGenotypingTestProvider(context + "AGCTCGCATCGCGAGCATCGACTAGCCGATAG" + context, "CGCTCGCATCGCGAGCATCGCTAGCCGATAG", map); + } + + for( int contextSize : new int[]{1,5,9,24,36} ) { + HashMap map = new HashMap(); + map.put(1, (byte)'M'); + map.put(20, (byte)'D'); + final String context = Utils.dupString('G', contextSize); + new BasicGenotypingTestProvider("AGCTCGCATCGCGAGCATCGACTAGCCGATAG" + context, "CGCTCGCATCGCGAGCATCGCTAGCCGATAG", map); + } + + for( int contextSize : new int[]{0,1,5,9,24,36} ) { + HashMap map = new HashMap(); + map.put(2 + contextSize, (byte)'M'); + map.put(20 + contextSize, (byte)'I'); + map.put(30 + contextSize, (byte)'D'); + final String context = Utils.dupString('G', contextSize); + new BasicGenotypingTestProvider(context + "AGCTCGCATCGCGAGCATCGACTAGCCGATAG" + context, "ACCTCGCATCGCGAGCATCGTTACTAGCCGATG", map); + } + + for( int contextSize : new int[]{0,1,5,9,24,36} ) { + HashMap map = new HashMap(); + map.put(1 + contextSize, (byte)'M'); + map.put(20 + contextSize, (byte)'D'); + map.put(28 + contextSize, (byte)'M'); + final String context = Utils.dupString('G', contextSize); + new BasicGenotypingTestProvider(context + "AGCTCGCATCGCGAGCATCGACTAGCCGATAG" + context, "CGCTCGCATCGCGAGCATCGCTAGCCCATAG", map); + } + + return BasicGenotypingTestProvider.getTests(BasicGenotypingTestProvider.class); + } + + @Test(dataProvider = "BasicGenotypingTestProvider", enabled = true) + public void testHaplotypeToVCF(BasicGenotypingTestProvider cfg) { + HashMap calculatedMap = cfg.calcAlignment(); + HashMap expectedMap = cfg.expected; + logger.warn(String.format("Test: %s", cfg.toString())); + if(!compareVCMaps(calculatedMap, expectedMap)) { + logger.warn("calc map = " + calculatedMap); + logger.warn("expected map = " + expectedMap); + } + Assert.assertTrue(compareVCMaps(calculatedMap, expectedMap)); + } + + /** + * Tests that we get the right values from the binomial distribution + */ + @Test + public void testCalculateR2LD() { + logger.warn("Executing testCalculateR2LD"); + + Assert.assertEquals(GenotypingEngine.calculateR2LD(1,1,1,1), 0.0, 0.00001); + Assert.assertEquals(GenotypingEngine.calculateR2LD(100,100,100,100), 0.0, 0.00001); + Assert.assertEquals(GenotypingEngine.calculateR2LD(1,0,0,1), 1.0, 0.00001); + Assert.assertEquals(GenotypingEngine.calculateR2LD(100,0,0,100), 1.0, 0.00001); + Assert.assertEquals(GenotypingEngine.calculateR2LD(1,2,3,4), (0.1 - 0.12) * (0.1 - 0.12) / (0.3 * 0.7 * 0.4 * 0.6), 0.00001); + } + + /** + * Private function to compare HashMap of VCs, it only checks the types and start locations of the VariantContext + */ + private boolean compareVCMaps(HashMap calc, HashMap expected) { + if( !calc.keySet().equals(expected.keySet()) ) { return false; } // sanity check + for( Integer loc : expected.keySet() ) { + Byte type = expected.get(loc); + switch( type ) { + case 'I': + if( !calc.get(loc).isSimpleInsertion() ) { return false; } + break; + case 'D': + if( !calc.get(loc).isSimpleDeletion() ) { return false; } + break; + case 'M': + if( !(calc.get(loc).isMNP() || calc.get(loc).isSNP()) ) { return false; } + break; + default: + return false; + } + } + return true; + } +} \ No newline at end of file diff --git a/protected/java/test/org/broadinstitute/sting/gatk/walkers/haplotypecaller/HaplotypeCallerIntegrationTest.java b/protected/java/test/org/broadinstitute/sting/gatk/walkers/haplotypecaller/HaplotypeCallerIntegrationTest.java new file mode 100644 index 000000000..70b558054 --- /dev/null +++ b/protected/java/test/org/broadinstitute/sting/gatk/walkers/haplotypecaller/HaplotypeCallerIntegrationTest.java @@ -0,0 +1,36 @@ +package org.broadinstitute.sting.gatk.walkers.haplotypecaller; + +import org.broadinstitute.sting.WalkerTest; +import org.testng.annotations.Test; + +import java.util.Arrays; + +public class HaplotypeCallerIntegrationTest extends WalkerTest { + final static String REF = b37KGReference; + final String NA12878_BAM = validationDataLocation + "NA12878.HiSeq.b37.chr20.10_11mb.bam"; + final String CEUTRIO_BAM = validationDataLocation + "CEUTrio.HiSeq.b37.chr20.10_11mb.bam"; + final String INTERVALS_FILE = validationDataLocation + "NA12878.HiSeq.b37.chr20.10_11mb.test.intervals"; + //final String RECAL_FILE = validationDataLocation + "NA12878.kmer.8.subset.recal_data.bqsr"; + + private void HCTest(String bam, String args, String md5) { + final String base = String.format("-T HaplotypeCaller -R %s -I %s -L %s", REF, bam, INTERVALS_FILE) + " --no_cmdline_in_header -o %s -minPruning 3"; + final WalkerTestSpec spec = new WalkerTestSpec(base + " " + args, Arrays.asList(md5)); + executeTest("testHaplotypeCaller: args=" + args, spec); + } + + @Test + public void testHaplotypeCallerMultiSample() { + HCTest(CEUTRIO_BAM, "", "7b4e76934e0c911220b4e7da8776ab2b"); + } + + @Test + public void testHaplotypeCallerSingleSample() { + HCTest(NA12878_BAM, "", "fcf0cea98a571d5e2d1dfa8b5edc599d"); + } + + @Test + public void testHaplotypeCallerMultiSampleGGA() { + HCTest(CEUTRIO_BAM, "-gt_mode GENOTYPE_GIVEN_ALLELES -alleles " + validationDataLocation + "combined.phase1.chr20.raw.indels.sites.vcf", "ff370c42c8b09a29f1aeff5ac57c7ea6"); + } +} + diff --git a/protected/java/test/org/broadinstitute/sting/gatk/walkers/haplotypecaller/LikelihoodCalculationEngineUnitTest.java b/protected/java/test/org/broadinstitute/sting/gatk/walkers/haplotypecaller/LikelihoodCalculationEngineUnitTest.java new file mode 100644 index 000000000..185641140 --- /dev/null +++ b/protected/java/test/org/broadinstitute/sting/gatk/walkers/haplotypecaller/LikelihoodCalculationEngineUnitTest.java @@ -0,0 +1,173 @@ +package org.broadinstitute.sting.gatk.walkers.haplotypecaller; + +/** + * Created by IntelliJ IDEA. + * User: rpoplin + * Date: 3/14/12 + */ + +import org.broadinstitute.sting.BaseTest; +import org.broadinstitute.sting.utils.Haplotype; +import org.broadinstitute.sting.utils.MathUtils; +import org.testng.Assert; +import org.testng.annotations.BeforeClass; +import org.testng.annotations.DataProvider; +import org.testng.annotations.Test; + +import java.util.*; + +/** + * Unit tests for LikelihoodCalculationEngine + */ +public class LikelihoodCalculationEngineUnitTest extends BaseTest { + + @Test + public void testNormalizeDiploidLikelihoodMatrixFromLog10() { + double[][] likelihoodMatrix = { + {-90.2, 0, 0}, + {-190.1, -2.1, 0}, + {-7.0, -17.5, -35.9} + }; + double[][] normalizedMatrix = { + {-88.1, 0, 0}, + {-188.0, 0.0, 0}, + {-4.9, -15.4, -33.8} + }; + + + Assert.assertTrue(compareDoubleArrays(LikelihoodCalculationEngine.normalizeDiploidLikelihoodMatrixFromLog10(likelihoodMatrix), normalizedMatrix)); + + double[][] likelihoodMatrix2 = { + {-90.2, 0, 0, 0}, + {-190.1, -2.1, 0, 0}, + {-7.0, -17.5, -35.9, 0}, + {-7.0, -17.5, -35.9, -1000.0}, + }; + double[][] normalizedMatrix2 = { + {-88.1, 0, 0, 0}, + {-188.0, 0.0, 0, 0}, + {-4.9, -15.4, -33.8, 0}, + {-4.9, -15.4, -33.8, -997.9}, + }; + Assert.assertTrue(compareDoubleArrays(LikelihoodCalculationEngine.normalizeDiploidLikelihoodMatrixFromLog10(likelihoodMatrix2), normalizedMatrix2)); + } + + private class BasicLikelihoodTestProvider extends TestDataProvider { + public Double readLikelihoodForHaplotype1; + public Double readLikelihoodForHaplotype2; + public Double readLikelihoodForHaplotype3; + + public BasicLikelihoodTestProvider(double a, double b) { + super(BasicLikelihoodTestProvider.class, String.format("Diploid haplotype likelihoods for reads %f / %f",a,b)); + readLikelihoodForHaplotype1 = a; + readLikelihoodForHaplotype2 = b; + readLikelihoodForHaplotype3 = null; + } + + public BasicLikelihoodTestProvider(double a, double b, double c) { + super(BasicLikelihoodTestProvider.class, String.format("Diploid haplotype likelihoods for reads %f / %f / %f",a,b,c)); + readLikelihoodForHaplotype1 = a; + readLikelihoodForHaplotype2 = b; + readLikelihoodForHaplotype3 = c; + } + + public double[][] expectedDiploidHaplotypeMatrix() { + if( readLikelihoodForHaplotype3 == null ) { + double maxValue = Math.max(readLikelihoodForHaplotype1,readLikelihoodForHaplotype2); + double[][] normalizedMatrix = { + {readLikelihoodForHaplotype1 - maxValue, Double.NEGATIVE_INFINITY}, + {Math.log10(0.5*Math.pow(10,readLikelihoodForHaplotype1) + 0.5*Math.pow(10,readLikelihoodForHaplotype2)) - maxValue, readLikelihoodForHaplotype2 - maxValue} + }; + return normalizedMatrix; + } else { + double maxValue = MathUtils.max(readLikelihoodForHaplotype1,readLikelihoodForHaplotype2,readLikelihoodForHaplotype3); + double[][] normalizedMatrix = { + {readLikelihoodForHaplotype1 - maxValue, Double.NEGATIVE_INFINITY, Double.NEGATIVE_INFINITY}, + {Math.log10(0.5*Math.pow(10,readLikelihoodForHaplotype1) + 0.5*Math.pow(10,readLikelihoodForHaplotype2)) - maxValue, readLikelihoodForHaplotype2 - maxValue, Double.NEGATIVE_INFINITY}, + {Math.log10(0.5*Math.pow(10,readLikelihoodForHaplotype1) + 0.5*Math.pow(10,readLikelihoodForHaplotype3)) - maxValue, + Math.log10(0.5*Math.pow(10,readLikelihoodForHaplotype2) + 0.5*Math.pow(10,readLikelihoodForHaplotype3)) - maxValue, readLikelihoodForHaplotype3 - maxValue} + }; + return normalizedMatrix; + } + } + + public double[][] calcDiploidHaplotypeMatrix() { + ArrayList haplotypes = new ArrayList(); + for( int iii = 1; iii <= 3; iii++) { + Double readLikelihood = ( iii == 1 ? readLikelihoodForHaplotype1 : ( iii == 2 ? readLikelihoodForHaplotype2 : readLikelihoodForHaplotype3) ); + if( readLikelihood != null ) { + Haplotype haplotype = new Haplotype( (iii == 1 ? "AAAA" : (iii == 2 ? "CCCC" : "TTTT")).getBytes() ); + haplotype.addReadLikelihoods("myTestSample", new double[]{readLikelihood}); + haplotypes.add(haplotype); + } + } + return LikelihoodCalculationEngine.computeDiploidHaplotypeLikelihoods(haplotypes, "myTestSample"); + } + } + + @DataProvider(name = "BasicLikelihoodTestProvider") + public Object[][] makeBasicLikelihoodTests() { + new BasicLikelihoodTestProvider(-1.1, -2.2); + new BasicLikelihoodTestProvider(-2.2, -1.1); + new BasicLikelihoodTestProvider(-1.1, -1.1); + new BasicLikelihoodTestProvider(-9.7, -15.0); + new BasicLikelihoodTestProvider(-1.1, -2000.2); + new BasicLikelihoodTestProvider(-1000.1, -2.2); + new BasicLikelihoodTestProvider(0, 0); + new BasicLikelihoodTestProvider(-1.1, 0); + new BasicLikelihoodTestProvider(0, -2.2); + new BasicLikelihoodTestProvider(-100.1, -200.2); + + new BasicLikelihoodTestProvider(-1.1, -2.2, 0); + new BasicLikelihoodTestProvider(-2.2, -1.1, 0); + new BasicLikelihoodTestProvider(-1.1, -1.1, 0); + new BasicLikelihoodTestProvider(-9.7, -15.0, 0); + new BasicLikelihoodTestProvider(-1.1, -2000.2, 0); + new BasicLikelihoodTestProvider(-1000.1, -2.2, 0); + new BasicLikelihoodTestProvider(0, 0, 0); + new BasicLikelihoodTestProvider(-1.1, 0, 0); + new BasicLikelihoodTestProvider(0, -2.2, 0); + new BasicLikelihoodTestProvider(-100.1, -200.2, 0); + + new BasicLikelihoodTestProvider(-1.1, -2.2, -12.121); + new BasicLikelihoodTestProvider(-2.2, -1.1, -12.121); + new BasicLikelihoodTestProvider(-1.1, -1.1, -12.121); + new BasicLikelihoodTestProvider(-9.7, -15.0, -12.121); + new BasicLikelihoodTestProvider(-1.1, -2000.2, -12.121); + new BasicLikelihoodTestProvider(-1000.1, -2.2, -12.121); + new BasicLikelihoodTestProvider(0, 0, -12.121); + new BasicLikelihoodTestProvider(-1.1, 0, -12.121); + new BasicLikelihoodTestProvider(0, -2.2, -12.121); + new BasicLikelihoodTestProvider(-100.1, -200.2, -12.121); + + return BasicLikelihoodTestProvider.getTests(BasicLikelihoodTestProvider.class); + } + + @Test(dataProvider = "BasicLikelihoodTestProvider", enabled = true) + public void testOneReadWithTwoOrThreeHaplotypes(BasicLikelihoodTestProvider cfg) { + double[][] calculatedMatrix = cfg.calcDiploidHaplotypeMatrix(); + double[][] expectedMatrix = cfg.expectedDiploidHaplotypeMatrix(); + logger.warn(String.format("Test: %s", cfg.toString())); + Assert.assertTrue(compareDoubleArrays(calculatedMatrix, expectedMatrix)); + } + + /** + * Private function to compare 2d arrays + */ + private boolean compareDoubleArrays(double[][] b1, double[][] b2) { + if( b1.length != b2.length ) { + return false; // sanity check + } + + for( int i=0; i < b1.length; i++ ){ + if( b1[i].length != b2[i].length) { + return false; // sanity check + } + for( int j=0; j < b1.length; j++ ){ + if ( MathUtils.compareDoubles(b1[i][j], b2[i][j]) != 0 && !Double.isInfinite(b1[i][j]) && !Double.isInfinite(b2[i][j])) + return false; + } + } + return true; + } +} diff --git a/protected/java/test/org/broadinstitute/sting/gatk/walkers/haplotypecaller/SimpleDeBruijnAssemblerUnitTest.java b/protected/java/test/org/broadinstitute/sting/gatk/walkers/haplotypecaller/SimpleDeBruijnAssemblerUnitTest.java new file mode 100644 index 000000000..4f42d5bc8 --- /dev/null +++ b/protected/java/test/org/broadinstitute/sting/gatk/walkers/haplotypecaller/SimpleDeBruijnAssemblerUnitTest.java @@ -0,0 +1,257 @@ +package org.broadinstitute.sting.gatk.walkers.haplotypecaller; + +/** + * Created by IntelliJ IDEA. + * User: rpoplin + * Date: 3/27/12 + */ + +import org.broadinstitute.sting.BaseTest; +import org.broadinstitute.sting.utils.Haplotype; +import org.broadinstitute.sting.utils.MathUtils; +import org.broadinstitute.sting.utils.variantcontext.Allele; +import org.jgrapht.graph.DefaultDirectedGraph; +import org.testng.Assert; +import org.testng.annotations.BeforeClass; +import org.testng.annotations.DataProvider; +import org.testng.annotations.Test; + +import java.io.File; +import java.io.FileNotFoundException; +import java.util.*; + +public class SimpleDeBruijnAssemblerUnitTest extends BaseTest { + + + private class MergeNodesWithNoVariationTestProvider extends TestDataProvider { + public byte[] sequence; + public int KMER_LENGTH; + + public MergeNodesWithNoVariationTestProvider(String seq, int kmer) { + super(MergeNodesWithNoVariationTestProvider.class, String.format("Merge nodes with no variation test. kmer = %d, seq = %s", kmer, seq)); + sequence = seq.getBytes(); + KMER_LENGTH = kmer; + } + + public DefaultDirectedGraph expectedGraph() { + DeBruijnVertex v = new DeBruijnVertex(sequence, 0); + DefaultDirectedGraph graph = new DefaultDirectedGraph(DeBruijnEdge.class); + graph.addVertex(v); + return graph; + } + + public DefaultDirectedGraph calcGraph() { + + DefaultDirectedGraph graph = new DefaultDirectedGraph(DeBruijnEdge.class); + final int kmersInSequence = sequence.length - KMER_LENGTH + 1; + for (int i = 0; i < kmersInSequence - 1; i++) { + // get the kmers + final byte[] kmer1 = new byte[KMER_LENGTH]; + System.arraycopy(sequence, i, kmer1, 0, KMER_LENGTH); + final byte[] kmer2 = new byte[KMER_LENGTH]; + System.arraycopy(sequence, i+1, kmer2, 0, KMER_LENGTH); + + SimpleDeBruijnAssembler.addKmersToGraph(graph, kmer1, kmer2, false); + } + SimpleDeBruijnAssembler.mergeNodes(graph); + return graph; + } + } + + @DataProvider(name = "MergeNodesWithNoVariationTestProvider") + public Object[][] makeMergeNodesWithNoVariationTests() { + new MergeNodesWithNoVariationTestProvider("GGTTAACC", 3); + new MergeNodesWithNoVariationTestProvider("GGTTAACC", 4); + new MergeNodesWithNoVariationTestProvider("GGTTAACC", 5); + new MergeNodesWithNoVariationTestProvider("GGTTAACC", 6); + new MergeNodesWithNoVariationTestProvider("GGTTAACC", 7); + new MergeNodesWithNoVariationTestProvider("GGTTAACCATGCAGACGGGAGGCTGAGCGAGAGTTTT", 6); + new MergeNodesWithNoVariationTestProvider("AATACCATTGGAGTTTTTTTCCAGGTTAAGATGGTGCATTGAATCCACCCATCTACTTTTGCTCCTCCCAAAACTCACTAAAACTATTATAAAGGGATTTTGTTTAAAGACACAAACTCATGAGGACAGAGAGAACAGAGTAGACAATAGTGGGGGAAAAATAAGTTGGAAGATAGAAAACAGATGGGTGAGTGGTAATCGACTCAGCAGCCCCAAGAAAGCTGAAACCCAGGGAAAGTTAAGAGTAGCCCTATTTTCATGGCAAAATCCAAGGGGGGGTGGGGAAAGAAAGAAAAACAGAAAAAAAAATGGGAATTGGCAGTCCTAGATATCTCTGGTACTGGGCAAGCCAAAGAATCAGGATAACTGGGTGAAAGGTGATTGGGAAGCAGTTAAAATCTTAGTTCCCCTCTTCCACTCTCCGAGCAGCAGGTTTCTCTCTCTCATCAGGCAGAGGGCTGGAGAT", 66); + new MergeNodesWithNoVariationTestProvider("AATACCATTGGAGTTTTTTTCCAGGTTAAGATGGTGCATTGAATCCACCCATCTACTTTTGCTCCTCCCAAAACTCACTAAAACTATTATAAAGGGATTTTGTTTAAAGACACAAACTCATGAGGACAGAGAGAACAGAGTAGACAATAGTGGGGGAAAAATAAGTTGGAAGATAGAAAACAGATGGGTGAGTGGTAATCGACTCAGCAGCCCCAAGAAAGCTGAAACCCAGGGAAAGTTAAGAGTAGCCCTATTTTCATGGCAAAATCCAAGGGGGGGTGGGGAAAGAAAGAAAAACAGAAAAAAAAATGGGAATTGGCAGTCCTAGATATCTCTGGTACTGGGCAAGCCAAAGAATCAGGATAACTGGGTGAAAGGTGATTGGGAAGCAGTTAAAATCTTAGTTCCCCTCTTCCACTCTCCGAGCAGCAGGTTTCTCTCTCTCATCAGGCAGAGGGCTGGAGAT", 76); + + return MergeNodesWithNoVariationTestProvider.getTests(MergeNodesWithNoVariationTestProvider.class); + } + + @Test(dataProvider = "MergeNodesWithNoVariationTestProvider", enabled = true) + public void testMergeNodesWithNoVariation(MergeNodesWithNoVariationTestProvider cfg) { + logger.warn(String.format("Test: %s", cfg.toString())); + Assert.assertTrue(graphEquals(cfg.calcGraph(), cfg.expectedGraph())); + } + + @Test(enabled = true) + public void testPruneGraph() { + DefaultDirectedGraph graph = new DefaultDirectedGraph(DeBruijnEdge.class); + DefaultDirectedGraph expectedGraph = new DefaultDirectedGraph(DeBruijnEdge.class); + + DeBruijnVertex v = new DeBruijnVertex("ATGG".getBytes(), 0); + DeBruijnVertex v2 = new DeBruijnVertex("ATGGA".getBytes(), 0); + DeBruijnVertex v3 = new DeBruijnVertex("ATGGT".getBytes(), 0); + DeBruijnVertex v4 = new DeBruijnVertex("ATGGG".getBytes(), 0); + DeBruijnVertex v5 = new DeBruijnVertex("ATGGC".getBytes(), 0); + DeBruijnVertex v6 = new DeBruijnVertex("ATGGCCCCCC".getBytes(), 0); + + graph.addVertex(v); + graph.addVertex(v2); + graph.addVertex(v3); + graph.addVertex(v4); + graph.addVertex(v5); + graph.addVertex(v6); + graph.addEdge(v, v2, new DeBruijnEdge(false, 1)); + graph.addEdge(v2, v3, new DeBruijnEdge(false, 3)); + graph.addEdge(v3, v4, new DeBruijnEdge(false, 5)); + graph.addEdge(v4, v5, new DeBruijnEdge(false, 3)); + graph.addEdge(v5, v6, new DeBruijnEdge(false, 2)); + + expectedGraph.addVertex(v2); + expectedGraph.addVertex(v3); + expectedGraph.addVertex(v4); + expectedGraph.addVertex(v5); + expectedGraph.addEdge(v2, v3, new DeBruijnEdge(false, 3)); + expectedGraph.addEdge(v3, v4, new DeBruijnEdge(false, 5)); + expectedGraph.addEdge(v4, v5, new DeBruijnEdge(false, 3)); + + SimpleDeBruijnAssembler.pruneGraph(graph, 2); + + Assert.assertTrue(graphEquals(graph, expectedGraph)); + + graph = new DefaultDirectedGraph(DeBruijnEdge.class); + expectedGraph = new DefaultDirectedGraph(DeBruijnEdge.class); + + graph.addVertex(v); + graph.addVertex(v2); + graph.addVertex(v3); + graph.addVertex(v4); + graph.addVertex(v5); + graph.addVertex(v6); + graph.addEdge(v, v2, new DeBruijnEdge(true, 1)); + graph.addEdge(v2, v3, new DeBruijnEdge(false, 3)); + graph.addEdge(v3, v4, new DeBruijnEdge(false, 5)); + graph.addEdge(v4, v5, new DeBruijnEdge(false, 3)); + + expectedGraph.addVertex(v); + expectedGraph.addVertex(v2); + expectedGraph.addVertex(v3); + expectedGraph.addVertex(v4); + expectedGraph.addVertex(v5); + expectedGraph.addEdge(v, v2, new DeBruijnEdge(true, 1)); + expectedGraph.addEdge(v2, v3, new DeBruijnEdge(false, 3)); + expectedGraph.addEdge(v3, v4, new DeBruijnEdge(false, 5)); + expectedGraph.addEdge(v4, v5, new DeBruijnEdge(false, 3)); + + SimpleDeBruijnAssembler.pruneGraph(graph, 2); + + Assert.assertTrue(graphEquals(graph, expectedGraph)); + } + + @Test(enabled = true) + public void testEliminateNonRefPaths() { + DefaultDirectedGraph graph = new DefaultDirectedGraph(DeBruijnEdge.class); + DefaultDirectedGraph expectedGraph = new DefaultDirectedGraph(DeBruijnEdge.class); + + DeBruijnVertex v = new DeBruijnVertex("ATGG".getBytes(), 0); + DeBruijnVertex v2 = new DeBruijnVertex("ATGGA".getBytes(), 0); + DeBruijnVertex v3 = new DeBruijnVertex("ATGGT".getBytes(), 0); + DeBruijnVertex v4 = new DeBruijnVertex("ATGGG".getBytes(), 0); + DeBruijnVertex v5 = new DeBruijnVertex("ATGGC".getBytes(), 0); + DeBruijnVertex v6 = new DeBruijnVertex("ATGGCCCCCC".getBytes(), 0); + + graph.addVertex(v); + graph.addVertex(v2); + graph.addVertex(v3); + graph.addVertex(v4); + graph.addVertex(v5); + graph.addVertex(v6); + graph.addEdge(v, v2, new DeBruijnEdge(false)); + graph.addEdge(v2, v3, new DeBruijnEdge(true)); + graph.addEdge(v3, v4, new DeBruijnEdge(true)); + graph.addEdge(v4, v5, new DeBruijnEdge(true)); + graph.addEdge(v5, v6, new DeBruijnEdge(false)); + + expectedGraph.addVertex(v2); + expectedGraph.addVertex(v3); + expectedGraph.addVertex(v4); + expectedGraph.addVertex(v5); + expectedGraph.addEdge(v2, v3, new DeBruijnEdge()); + expectedGraph.addEdge(v3, v4, new DeBruijnEdge()); + expectedGraph.addEdge(v4, v5, new DeBruijnEdge()); + + SimpleDeBruijnAssembler.eliminateNonRefPaths(graph); + + Assert.assertTrue(graphEquals(graph, expectedGraph)); + + + + + graph = new DefaultDirectedGraph(DeBruijnEdge.class); + expectedGraph = new DefaultDirectedGraph(DeBruijnEdge.class); + + graph.addVertex(v); + graph.addVertex(v2); + graph.addVertex(v3); + graph.addVertex(v4); + graph.addVertex(v5); + graph.addVertex(v6); + graph.addEdge(v, v2, new DeBruijnEdge(true)); + graph.addEdge(v2, v3, new DeBruijnEdge(true)); + graph.addEdge(v4, v5, new DeBruijnEdge(false)); + graph.addEdge(v5, v6, new DeBruijnEdge(false)); + + expectedGraph.addVertex(v); + expectedGraph.addVertex(v2); + expectedGraph.addVertex(v3); + expectedGraph.addEdge(v, v2, new DeBruijnEdge()); + expectedGraph.addEdge(v2, v3, new DeBruijnEdge()); + + SimpleDeBruijnAssembler.eliminateNonRefPaths(graph); + + Assert.assertTrue(graphEquals(graph, expectedGraph)); + + + + graph = new DefaultDirectedGraph(DeBruijnEdge.class); + expectedGraph = new DefaultDirectedGraph(DeBruijnEdge.class); + + graph.addVertex(v); + graph.addVertex(v2); + graph.addVertex(v3); + graph.addVertex(v4); + graph.addVertex(v5); + graph.addVertex(v6); + graph.addEdge(v, v2, new DeBruijnEdge(true)); + graph.addEdge(v2, v3, new DeBruijnEdge(true)); + graph.addEdge(v4, v5, new DeBruijnEdge(false)); + graph.addEdge(v5, v6, new DeBruijnEdge(false)); + graph.addEdge(v4, v2, new DeBruijnEdge(false)); + + expectedGraph.addVertex(v); + expectedGraph.addVertex(v2); + expectedGraph.addVertex(v3); + expectedGraph.addEdge(v, v2, new DeBruijnEdge()); + expectedGraph.addEdge(v2, v3, new DeBruijnEdge()); + + SimpleDeBruijnAssembler.eliminateNonRefPaths(graph); + + Assert.assertTrue(graphEquals(graph, expectedGraph)); + } + + private boolean graphEquals(DefaultDirectedGraph g1, DefaultDirectedGraph g2) { + if( !(g1.vertexSet().containsAll(g2.vertexSet()) && g2.vertexSet().containsAll(g1.vertexSet())) ) { + return false; + } + for( DeBruijnEdge e1 : g1.edgeSet() ) { + boolean found = false; + for( DeBruijnEdge e2 : g2.edgeSet() ) { + if( e1.equals(g1, e2, g2) ) { found = true; break; } + } + if( !found ) { return false; } + } + for( DeBruijnEdge e2 : g2.edgeSet() ) { + boolean found = false; + for( DeBruijnEdge e1 : g1.edgeSet() ) { + if( e2.equals(g2, e1, g1) ) { found = true; break; } + } + if( !found ) { return false; } + } + return true; + } +}