Moving HaplotypeCaller integration and unit tests over to protected as well.

This commit is contained in:
Ryan Poplin 2012-07-17 14:51:26 -04:00
parent c55934043e
commit bf2d5efe4d
4 changed files with 737 additions and 0 deletions

View File

@ -0,0 +1,271 @@
package org.broadinstitute.sting.gatk.walkers.haplotypecaller;
/**
* Created by IntelliJ IDEA.
* User: rpoplin
* Date: 3/15/12
*/
import net.sf.picard.reference.ReferenceSequenceFile;
import org.broadinstitute.sting.BaseTest;
import org.broadinstitute.sting.utils.*;
import org.broadinstitute.sting.utils.fasta.CachingIndexedFastaSequenceFile;
import org.broadinstitute.sting.utils.variantcontext.Allele;
import org.broadinstitute.sting.utils.variantcontext.VariantContext;
import org.testng.Assert;
import org.testng.annotations.BeforeClass;
import org.testng.annotations.DataProvider;
import org.testng.annotations.Test;
import java.io.File;
import java.io.FileNotFoundException;
import java.util.*;
/**
* Unit tests for GenotypingEngine
*/
public class GenotypingEngineUnitTest extends BaseTest {
private static ReferenceSequenceFile seq;
private GenomeLocParser genomeLocParser;
@BeforeClass
public void init() throws FileNotFoundException {
// sequence
seq = new CachingIndexedFastaSequenceFile(new File(b37KGReference));
genomeLocParser = new GenomeLocParser(seq);
}
@Test
public void testFindHomVarEventAllelesInSample() {
final List<Allele> eventAlleles = new ArrayList<Allele>();
eventAlleles.add( Allele.create("A", true) );
eventAlleles.add( Allele.create("C", false) );
final List<Allele> haplotypeAlleles = new ArrayList<Allele>();
haplotypeAlleles.add( Allele.create("AATA", true) );
haplotypeAlleles.add( Allele.create("AACA", false) );
haplotypeAlleles.add( Allele.create("CATA", false) );
haplotypeAlleles.add( Allele.create("CACA", false) );
final ArrayList<Haplotype> haplotypes = new ArrayList<Haplotype>();
haplotypes.add(new Haplotype("AATA".getBytes()));
haplotypes.add(new Haplotype("AACA".getBytes()));
haplotypes.add(new Haplotype("CATA".getBytes()));
haplotypes.add(new Haplotype("CACA".getBytes()));
final List<Allele> haplotypeAllelesForSample = new ArrayList<Allele>();
haplotypeAllelesForSample.add( Allele.create("CATA", false) );
haplotypeAllelesForSample.add( Allele.create("CACA", false) );
final ArrayList<ArrayList<Haplotype>> alleleMapper = new ArrayList<ArrayList<Haplotype>>();
ArrayList<Haplotype> Aallele = new ArrayList<Haplotype>();
Aallele.add(haplotypes.get(0));
Aallele.add(haplotypes.get(1));
ArrayList<Haplotype> Callele = new ArrayList<Haplotype>();
Callele.add(haplotypes.get(2));
Callele.add(haplotypes.get(3));
alleleMapper.add(Aallele);
alleleMapper.add(Callele);
final List<Allele> eventAllelesForSample = new ArrayList<Allele>();
eventAllelesForSample.add( Allele.create("C", false) );
eventAllelesForSample.add( Allele.create("C", false) );
if(!compareAlleleLists(eventAllelesForSample, GenotypingEngine.findEventAllelesInSample(eventAlleles, haplotypeAlleles, haplotypeAllelesForSample, alleleMapper, haplotypes))) {
logger.warn("calc alleles = " + GenotypingEngine.findEventAllelesInSample(eventAlleles, haplotypeAlleles, haplotypeAllelesForSample, alleleMapper, haplotypes));
logger.warn("expected alleles = " + eventAllelesForSample);
}
Assert.assertTrue(compareAlleleLists(eventAllelesForSample, GenotypingEngine.findEventAllelesInSample(eventAlleles, haplotypeAlleles, haplotypeAllelesForSample, alleleMapper, haplotypes)));
}
@Test
public void testFindHetEventAllelesInSample() {
final List<Allele> eventAlleles = new ArrayList<Allele>();
eventAlleles.add( Allele.create("A", true) );
eventAlleles.add( Allele.create("C", false) );
eventAlleles.add( Allele.create("T", false) );
final List<Allele> haplotypeAlleles = new ArrayList<Allele>();
haplotypeAlleles.add( Allele.create("AATA", true) );
haplotypeAlleles.add( Allele.create("AACA", false) );
haplotypeAlleles.add( Allele.create("CATA", false) );
haplotypeAlleles.add( Allele.create("CACA", false) );
haplotypeAlleles.add( Allele.create("TACA", false) );
haplotypeAlleles.add( Allele.create("TTCA", false) );
haplotypeAlleles.add( Allele.create("TTTA", false) );
final ArrayList<Haplotype> haplotypes = new ArrayList<Haplotype>();
haplotypes.add(new Haplotype("AATA".getBytes()));
haplotypes.add(new Haplotype("AACA".getBytes()));
haplotypes.add(new Haplotype("CATA".getBytes()));
haplotypes.add(new Haplotype("CACA".getBytes()));
haplotypes.add(new Haplotype("TACA".getBytes()));
haplotypes.add(new Haplotype("TTCA".getBytes()));
haplotypes.add(new Haplotype("TTTA".getBytes()));
final List<Allele> haplotypeAllelesForSample = new ArrayList<Allele>();
haplotypeAllelesForSample.add( Allele.create("TTTA", false) );
haplotypeAllelesForSample.add( Allele.create("AATA", true) );
final ArrayList<ArrayList<Haplotype>> alleleMapper = new ArrayList<ArrayList<Haplotype>>();
ArrayList<Haplotype> Aallele = new ArrayList<Haplotype>();
Aallele.add(haplotypes.get(0));
Aallele.add(haplotypes.get(1));
ArrayList<Haplotype> Callele = new ArrayList<Haplotype>();
Callele.add(haplotypes.get(2));
Callele.add(haplotypes.get(3));
ArrayList<Haplotype> Tallele = new ArrayList<Haplotype>();
Tallele.add(haplotypes.get(4));
Tallele.add(haplotypes.get(5));
Tallele.add(haplotypes.get(6));
alleleMapper.add(Aallele);
alleleMapper.add(Callele);
alleleMapper.add(Tallele);
final List<Allele> eventAllelesForSample = new ArrayList<Allele>();
eventAllelesForSample.add( Allele.create("A", true) );
eventAllelesForSample.add( Allele.create("T", false) );
if(!compareAlleleLists(eventAllelesForSample, GenotypingEngine.findEventAllelesInSample(eventAlleles, haplotypeAlleles, haplotypeAllelesForSample, alleleMapper, haplotypes))) {
logger.warn("calc alleles = " + GenotypingEngine.findEventAllelesInSample(eventAlleles, haplotypeAlleles, haplotypeAllelesForSample, alleleMapper, haplotypes));
logger.warn("expected alleles = " + eventAllelesForSample);
}
Assert.assertTrue(compareAlleleLists(eventAllelesForSample, GenotypingEngine.findEventAllelesInSample(eventAlleles, haplotypeAlleles, haplotypeAllelesForSample, alleleMapper, haplotypes)));
}
private boolean compareAlleleLists(List<Allele> l1, List<Allele> l2) {
if( l1.size() != l2.size() ) {
return false; // sanity check
}
for( int i=0; i < l1.size(); i++ ){
if ( !l2.contains(l1.get(i)) )
return false;
}
return true;
}
private class BasicGenotypingTestProvider extends TestDataProvider {
byte[] ref;
byte[] hap;
HashMap<Integer,Byte> expected;
GenotypingEngine ge = new GenotypingEngine(false, 0, false);
public BasicGenotypingTestProvider(String refString, String hapString, HashMap<Integer, Byte> expected) {
super(BasicGenotypingTestProvider.class, String.format("Haplotype to VCF test: ref = %s, alignment = %s", refString,hapString));
ref = refString.getBytes();
hap = hapString.getBytes();
this.expected = expected;
}
public HashMap<Integer,VariantContext> calcAlignment() {
final SWPairwiseAlignment alignment = new SWPairwiseAlignment(ref, hap);
return ge.generateVCsFromAlignment( alignment.getAlignmentStart2wrt1(), alignment.getCigar(), ref, hap, genomeLocParser.createGenomeLoc("4",1,1+ref.length), "name", 0);
}
}
@DataProvider(name = "BasicGenotypingTestProvider")
public Object[][] makeBasicGenotypingTests() {
for( int contextSize : new int[]{0,1,5,9,24,36} ) {
HashMap<Integer, Byte> map = new HashMap<Integer, Byte>();
map.put(1 + contextSize, (byte)'M');
final String context = Utils.dupString('G', contextSize);
new BasicGenotypingTestProvider(context + "AGCTCGCATCGCGAGCATCGACTAGCCGATAG" + context, "CGCTCGCATCGCGAGCATCGACTAGCCGATAG", map);
}
for( int contextSize : new int[]{0,1,5,9,24,36} ) {
HashMap<Integer, Byte> map = new HashMap<Integer, Byte>();
map.put(2 + contextSize, (byte)'M');
map.put(21 + contextSize, (byte)'M');
final String context = Utils.dupString('G', contextSize);
new BasicGenotypingTestProvider(context + "AGCTCGCATCGCGAGCATCGACTAGCCGATAG", "ATCTCGCATCGCGAGCATCGCCTAGCCGATAG", map);
}
for( int contextSize : new int[]{0,1,5,9,24,36} ) {
HashMap<Integer, Byte> map = new HashMap<Integer, Byte>();
map.put(1 + contextSize, (byte)'M');
map.put(20 + contextSize, (byte)'I');
final String context = Utils.dupString('G', contextSize);
new BasicGenotypingTestProvider(context + "AGCTCGCATCGCGAGCATCGACTAGCCGATAG" + context, "CGCTCGCATCGCGAGCATCGACACTAGCCGATAG", map);
}
for( int contextSize : new int[]{0,1,5,9,24,36} ) {
HashMap<Integer, Byte> map = new HashMap<Integer, Byte>();
map.put(1 + contextSize, (byte)'M');
map.put(20 + contextSize, (byte)'D');
final String context = Utils.dupString('G', contextSize);
new BasicGenotypingTestProvider(context + "AGCTCGCATCGCGAGCATCGACTAGCCGATAG" + context, "CGCTCGCATCGCGAGCATCGCTAGCCGATAG", map);
}
for( int contextSize : new int[]{1,5,9,24,36} ) {
HashMap<Integer, Byte> map = new HashMap<Integer, Byte>();
map.put(1, (byte)'M');
map.put(20, (byte)'D');
final String context = Utils.dupString('G', contextSize);
new BasicGenotypingTestProvider("AGCTCGCATCGCGAGCATCGACTAGCCGATAG" + context, "CGCTCGCATCGCGAGCATCGCTAGCCGATAG", map);
}
for( int contextSize : new int[]{0,1,5,9,24,36} ) {
HashMap<Integer, Byte> map = new HashMap<Integer, Byte>();
map.put(2 + contextSize, (byte)'M');
map.put(20 + contextSize, (byte)'I');
map.put(30 + contextSize, (byte)'D');
final String context = Utils.dupString('G', contextSize);
new BasicGenotypingTestProvider(context + "AGCTCGCATCGCGAGCATCGACTAGCCGATAG" + context, "ACCTCGCATCGCGAGCATCGTTACTAGCCGATG", map);
}
for( int contextSize : new int[]{0,1,5,9,24,36} ) {
HashMap<Integer, Byte> map = new HashMap<Integer, Byte>();
map.put(1 + contextSize, (byte)'M');
map.put(20 + contextSize, (byte)'D');
map.put(28 + contextSize, (byte)'M');
final String context = Utils.dupString('G', contextSize);
new BasicGenotypingTestProvider(context + "AGCTCGCATCGCGAGCATCGACTAGCCGATAG" + context, "CGCTCGCATCGCGAGCATCGCTAGCCCATAG", map);
}
return BasicGenotypingTestProvider.getTests(BasicGenotypingTestProvider.class);
}
@Test(dataProvider = "BasicGenotypingTestProvider", enabled = true)
public void testHaplotypeToVCF(BasicGenotypingTestProvider cfg) {
HashMap<Integer,VariantContext> calculatedMap = cfg.calcAlignment();
HashMap<Integer,Byte> expectedMap = cfg.expected;
logger.warn(String.format("Test: %s", cfg.toString()));
if(!compareVCMaps(calculatedMap, expectedMap)) {
logger.warn("calc map = " + calculatedMap);
logger.warn("expected map = " + expectedMap);
}
Assert.assertTrue(compareVCMaps(calculatedMap, expectedMap));
}
/**
* Tests that we get the right values from the binomial distribution
*/
@Test
public void testCalculateR2LD() {
logger.warn("Executing testCalculateR2LD");
Assert.assertEquals(GenotypingEngine.calculateR2LD(1,1,1,1), 0.0, 0.00001);
Assert.assertEquals(GenotypingEngine.calculateR2LD(100,100,100,100), 0.0, 0.00001);
Assert.assertEquals(GenotypingEngine.calculateR2LD(1,0,0,1), 1.0, 0.00001);
Assert.assertEquals(GenotypingEngine.calculateR2LD(100,0,0,100), 1.0, 0.00001);
Assert.assertEquals(GenotypingEngine.calculateR2LD(1,2,3,4), (0.1 - 0.12) * (0.1 - 0.12) / (0.3 * 0.7 * 0.4 * 0.6), 0.00001);
}
/**
* Private function to compare HashMap of VCs, it only checks the types and start locations of the VariantContext
*/
private boolean compareVCMaps(HashMap<Integer, VariantContext> calc, HashMap<Integer, Byte> expected) {
if( !calc.keySet().equals(expected.keySet()) ) { return false; } // sanity check
for( Integer loc : expected.keySet() ) {
Byte type = expected.get(loc);
switch( type ) {
case 'I':
if( !calc.get(loc).isSimpleInsertion() ) { return false; }
break;
case 'D':
if( !calc.get(loc).isSimpleDeletion() ) { return false; }
break;
case 'M':
if( !(calc.get(loc).isMNP() || calc.get(loc).isSNP()) ) { return false; }
break;
default:
return false;
}
}
return true;
}
}

View File

@ -0,0 +1,36 @@
package org.broadinstitute.sting.gatk.walkers.haplotypecaller;
import org.broadinstitute.sting.WalkerTest;
import org.testng.annotations.Test;
import java.util.Arrays;
public class HaplotypeCallerIntegrationTest extends WalkerTest {
final static String REF = b37KGReference;
final String NA12878_BAM = validationDataLocation + "NA12878.HiSeq.b37.chr20.10_11mb.bam";
final String CEUTRIO_BAM = validationDataLocation + "CEUTrio.HiSeq.b37.chr20.10_11mb.bam";
final String INTERVALS_FILE = validationDataLocation + "NA12878.HiSeq.b37.chr20.10_11mb.test.intervals";
//final String RECAL_FILE = validationDataLocation + "NA12878.kmer.8.subset.recal_data.bqsr";
private void HCTest(String bam, String args, String md5) {
final String base = String.format("-T HaplotypeCaller -R %s -I %s -L %s", REF, bam, INTERVALS_FILE) + " --no_cmdline_in_header -o %s -minPruning 3";
final WalkerTestSpec spec = new WalkerTestSpec(base + " " + args, Arrays.asList(md5));
executeTest("testHaplotypeCaller: args=" + args, spec);
}
@Test
public void testHaplotypeCallerMultiSample() {
HCTest(CEUTRIO_BAM, "", "7b4e76934e0c911220b4e7da8776ab2b");
}
@Test
public void testHaplotypeCallerSingleSample() {
HCTest(NA12878_BAM, "", "fcf0cea98a571d5e2d1dfa8b5edc599d");
}
@Test
public void testHaplotypeCallerMultiSampleGGA() {
HCTest(CEUTRIO_BAM, "-gt_mode GENOTYPE_GIVEN_ALLELES -alleles " + validationDataLocation + "combined.phase1.chr20.raw.indels.sites.vcf", "ff370c42c8b09a29f1aeff5ac57c7ea6");
}
}

View File

@ -0,0 +1,173 @@
package org.broadinstitute.sting.gatk.walkers.haplotypecaller;
/**
* Created by IntelliJ IDEA.
* User: rpoplin
* Date: 3/14/12
*/
import org.broadinstitute.sting.BaseTest;
import org.broadinstitute.sting.utils.Haplotype;
import org.broadinstitute.sting.utils.MathUtils;
import org.testng.Assert;
import org.testng.annotations.BeforeClass;
import org.testng.annotations.DataProvider;
import org.testng.annotations.Test;
import java.util.*;
/**
* Unit tests for LikelihoodCalculationEngine
*/
public class LikelihoodCalculationEngineUnitTest extends BaseTest {
@Test
public void testNormalizeDiploidLikelihoodMatrixFromLog10() {
double[][] likelihoodMatrix = {
{-90.2, 0, 0},
{-190.1, -2.1, 0},
{-7.0, -17.5, -35.9}
};
double[][] normalizedMatrix = {
{-88.1, 0, 0},
{-188.0, 0.0, 0},
{-4.9, -15.4, -33.8}
};
Assert.assertTrue(compareDoubleArrays(LikelihoodCalculationEngine.normalizeDiploidLikelihoodMatrixFromLog10(likelihoodMatrix), normalizedMatrix));
double[][] likelihoodMatrix2 = {
{-90.2, 0, 0, 0},
{-190.1, -2.1, 0, 0},
{-7.0, -17.5, -35.9, 0},
{-7.0, -17.5, -35.9, -1000.0},
};
double[][] normalizedMatrix2 = {
{-88.1, 0, 0, 0},
{-188.0, 0.0, 0, 0},
{-4.9, -15.4, -33.8, 0},
{-4.9, -15.4, -33.8, -997.9},
};
Assert.assertTrue(compareDoubleArrays(LikelihoodCalculationEngine.normalizeDiploidLikelihoodMatrixFromLog10(likelihoodMatrix2), normalizedMatrix2));
}
private class BasicLikelihoodTestProvider extends TestDataProvider {
public Double readLikelihoodForHaplotype1;
public Double readLikelihoodForHaplotype2;
public Double readLikelihoodForHaplotype3;
public BasicLikelihoodTestProvider(double a, double b) {
super(BasicLikelihoodTestProvider.class, String.format("Diploid haplotype likelihoods for reads %f / %f",a,b));
readLikelihoodForHaplotype1 = a;
readLikelihoodForHaplotype2 = b;
readLikelihoodForHaplotype3 = null;
}
public BasicLikelihoodTestProvider(double a, double b, double c) {
super(BasicLikelihoodTestProvider.class, String.format("Diploid haplotype likelihoods for reads %f / %f / %f",a,b,c));
readLikelihoodForHaplotype1 = a;
readLikelihoodForHaplotype2 = b;
readLikelihoodForHaplotype3 = c;
}
public double[][] expectedDiploidHaplotypeMatrix() {
if( readLikelihoodForHaplotype3 == null ) {
double maxValue = Math.max(readLikelihoodForHaplotype1,readLikelihoodForHaplotype2);
double[][] normalizedMatrix = {
{readLikelihoodForHaplotype1 - maxValue, Double.NEGATIVE_INFINITY},
{Math.log10(0.5*Math.pow(10,readLikelihoodForHaplotype1) + 0.5*Math.pow(10,readLikelihoodForHaplotype2)) - maxValue, readLikelihoodForHaplotype2 - maxValue}
};
return normalizedMatrix;
} else {
double maxValue = MathUtils.max(readLikelihoodForHaplotype1,readLikelihoodForHaplotype2,readLikelihoodForHaplotype3);
double[][] normalizedMatrix = {
{readLikelihoodForHaplotype1 - maxValue, Double.NEGATIVE_INFINITY, Double.NEGATIVE_INFINITY},
{Math.log10(0.5*Math.pow(10,readLikelihoodForHaplotype1) + 0.5*Math.pow(10,readLikelihoodForHaplotype2)) - maxValue, readLikelihoodForHaplotype2 - maxValue, Double.NEGATIVE_INFINITY},
{Math.log10(0.5*Math.pow(10,readLikelihoodForHaplotype1) + 0.5*Math.pow(10,readLikelihoodForHaplotype3)) - maxValue,
Math.log10(0.5*Math.pow(10,readLikelihoodForHaplotype2) + 0.5*Math.pow(10,readLikelihoodForHaplotype3)) - maxValue, readLikelihoodForHaplotype3 - maxValue}
};
return normalizedMatrix;
}
}
public double[][] calcDiploidHaplotypeMatrix() {
ArrayList<Haplotype> haplotypes = new ArrayList<Haplotype>();
for( int iii = 1; iii <= 3; iii++) {
Double readLikelihood = ( iii == 1 ? readLikelihoodForHaplotype1 : ( iii == 2 ? readLikelihoodForHaplotype2 : readLikelihoodForHaplotype3) );
if( readLikelihood != null ) {
Haplotype haplotype = new Haplotype( (iii == 1 ? "AAAA" : (iii == 2 ? "CCCC" : "TTTT")).getBytes() );
haplotype.addReadLikelihoods("myTestSample", new double[]{readLikelihood});
haplotypes.add(haplotype);
}
}
return LikelihoodCalculationEngine.computeDiploidHaplotypeLikelihoods(haplotypes, "myTestSample");
}
}
@DataProvider(name = "BasicLikelihoodTestProvider")
public Object[][] makeBasicLikelihoodTests() {
new BasicLikelihoodTestProvider(-1.1, -2.2);
new BasicLikelihoodTestProvider(-2.2, -1.1);
new BasicLikelihoodTestProvider(-1.1, -1.1);
new BasicLikelihoodTestProvider(-9.7, -15.0);
new BasicLikelihoodTestProvider(-1.1, -2000.2);
new BasicLikelihoodTestProvider(-1000.1, -2.2);
new BasicLikelihoodTestProvider(0, 0);
new BasicLikelihoodTestProvider(-1.1, 0);
new BasicLikelihoodTestProvider(0, -2.2);
new BasicLikelihoodTestProvider(-100.1, -200.2);
new BasicLikelihoodTestProvider(-1.1, -2.2, 0);
new BasicLikelihoodTestProvider(-2.2, -1.1, 0);
new BasicLikelihoodTestProvider(-1.1, -1.1, 0);
new BasicLikelihoodTestProvider(-9.7, -15.0, 0);
new BasicLikelihoodTestProvider(-1.1, -2000.2, 0);
new BasicLikelihoodTestProvider(-1000.1, -2.2, 0);
new BasicLikelihoodTestProvider(0, 0, 0);
new BasicLikelihoodTestProvider(-1.1, 0, 0);
new BasicLikelihoodTestProvider(0, -2.2, 0);
new BasicLikelihoodTestProvider(-100.1, -200.2, 0);
new BasicLikelihoodTestProvider(-1.1, -2.2, -12.121);
new BasicLikelihoodTestProvider(-2.2, -1.1, -12.121);
new BasicLikelihoodTestProvider(-1.1, -1.1, -12.121);
new BasicLikelihoodTestProvider(-9.7, -15.0, -12.121);
new BasicLikelihoodTestProvider(-1.1, -2000.2, -12.121);
new BasicLikelihoodTestProvider(-1000.1, -2.2, -12.121);
new BasicLikelihoodTestProvider(0, 0, -12.121);
new BasicLikelihoodTestProvider(-1.1, 0, -12.121);
new BasicLikelihoodTestProvider(0, -2.2, -12.121);
new BasicLikelihoodTestProvider(-100.1, -200.2, -12.121);
return BasicLikelihoodTestProvider.getTests(BasicLikelihoodTestProvider.class);
}
@Test(dataProvider = "BasicLikelihoodTestProvider", enabled = true)
public void testOneReadWithTwoOrThreeHaplotypes(BasicLikelihoodTestProvider cfg) {
double[][] calculatedMatrix = cfg.calcDiploidHaplotypeMatrix();
double[][] expectedMatrix = cfg.expectedDiploidHaplotypeMatrix();
logger.warn(String.format("Test: %s", cfg.toString()));
Assert.assertTrue(compareDoubleArrays(calculatedMatrix, expectedMatrix));
}
/**
* Private function to compare 2d arrays
*/
private boolean compareDoubleArrays(double[][] b1, double[][] b2) {
if( b1.length != b2.length ) {
return false; // sanity check
}
for( int i=0; i < b1.length; i++ ){
if( b1[i].length != b2[i].length) {
return false; // sanity check
}
for( int j=0; j < b1.length; j++ ){
if ( MathUtils.compareDoubles(b1[i][j], b2[i][j]) != 0 && !Double.isInfinite(b1[i][j]) && !Double.isInfinite(b2[i][j]))
return false;
}
}
return true;
}
}

View File

@ -0,0 +1,257 @@
package org.broadinstitute.sting.gatk.walkers.haplotypecaller;
/**
* Created by IntelliJ IDEA.
* User: rpoplin
* Date: 3/27/12
*/
import org.broadinstitute.sting.BaseTest;
import org.broadinstitute.sting.utils.Haplotype;
import org.broadinstitute.sting.utils.MathUtils;
import org.broadinstitute.sting.utils.variantcontext.Allele;
import org.jgrapht.graph.DefaultDirectedGraph;
import org.testng.Assert;
import org.testng.annotations.BeforeClass;
import org.testng.annotations.DataProvider;
import org.testng.annotations.Test;
import java.io.File;
import java.io.FileNotFoundException;
import java.util.*;
public class SimpleDeBruijnAssemblerUnitTest extends BaseTest {
private class MergeNodesWithNoVariationTestProvider extends TestDataProvider {
public byte[] sequence;
public int KMER_LENGTH;
public MergeNodesWithNoVariationTestProvider(String seq, int kmer) {
super(MergeNodesWithNoVariationTestProvider.class, String.format("Merge nodes with no variation test. kmer = %d, seq = %s", kmer, seq));
sequence = seq.getBytes();
KMER_LENGTH = kmer;
}
public DefaultDirectedGraph<DeBruijnVertex,DeBruijnEdge> expectedGraph() {
DeBruijnVertex v = new DeBruijnVertex(sequence, 0);
DefaultDirectedGraph<DeBruijnVertex,DeBruijnEdge> graph = new DefaultDirectedGraph<DeBruijnVertex, DeBruijnEdge>(DeBruijnEdge.class);
graph.addVertex(v);
return graph;
}
public DefaultDirectedGraph<DeBruijnVertex,DeBruijnEdge> calcGraph() {
DefaultDirectedGraph<DeBruijnVertex,DeBruijnEdge> graph = new DefaultDirectedGraph<DeBruijnVertex, DeBruijnEdge>(DeBruijnEdge.class);
final int kmersInSequence = sequence.length - KMER_LENGTH + 1;
for (int i = 0; i < kmersInSequence - 1; i++) {
// get the kmers
final byte[] kmer1 = new byte[KMER_LENGTH];
System.arraycopy(sequence, i, kmer1, 0, KMER_LENGTH);
final byte[] kmer2 = new byte[KMER_LENGTH];
System.arraycopy(sequence, i+1, kmer2, 0, KMER_LENGTH);
SimpleDeBruijnAssembler.addKmersToGraph(graph, kmer1, kmer2, false);
}
SimpleDeBruijnAssembler.mergeNodes(graph);
return graph;
}
}
@DataProvider(name = "MergeNodesWithNoVariationTestProvider")
public Object[][] makeMergeNodesWithNoVariationTests() {
new MergeNodesWithNoVariationTestProvider("GGTTAACC", 3);
new MergeNodesWithNoVariationTestProvider("GGTTAACC", 4);
new MergeNodesWithNoVariationTestProvider("GGTTAACC", 5);
new MergeNodesWithNoVariationTestProvider("GGTTAACC", 6);
new MergeNodesWithNoVariationTestProvider("GGTTAACC", 7);
new MergeNodesWithNoVariationTestProvider("GGTTAACCATGCAGACGGGAGGCTGAGCGAGAGTTTT", 6);
new MergeNodesWithNoVariationTestProvider("AATACCATTGGAGTTTTTTTCCAGGTTAAGATGGTGCATTGAATCCACCCATCTACTTTTGCTCCTCCCAAAACTCACTAAAACTATTATAAAGGGATTTTGTTTAAAGACACAAACTCATGAGGACAGAGAGAACAGAGTAGACAATAGTGGGGGAAAAATAAGTTGGAAGATAGAAAACAGATGGGTGAGTGGTAATCGACTCAGCAGCCCCAAGAAAGCTGAAACCCAGGGAAAGTTAAGAGTAGCCCTATTTTCATGGCAAAATCCAAGGGGGGGTGGGGAAAGAAAGAAAAACAGAAAAAAAAATGGGAATTGGCAGTCCTAGATATCTCTGGTACTGGGCAAGCCAAAGAATCAGGATAACTGGGTGAAAGGTGATTGGGAAGCAGTTAAAATCTTAGTTCCCCTCTTCCACTCTCCGAGCAGCAGGTTTCTCTCTCTCATCAGGCAGAGGGCTGGAGAT", 66);
new MergeNodesWithNoVariationTestProvider("AATACCATTGGAGTTTTTTTCCAGGTTAAGATGGTGCATTGAATCCACCCATCTACTTTTGCTCCTCCCAAAACTCACTAAAACTATTATAAAGGGATTTTGTTTAAAGACACAAACTCATGAGGACAGAGAGAACAGAGTAGACAATAGTGGGGGAAAAATAAGTTGGAAGATAGAAAACAGATGGGTGAGTGGTAATCGACTCAGCAGCCCCAAGAAAGCTGAAACCCAGGGAAAGTTAAGAGTAGCCCTATTTTCATGGCAAAATCCAAGGGGGGGTGGGGAAAGAAAGAAAAACAGAAAAAAAAATGGGAATTGGCAGTCCTAGATATCTCTGGTACTGGGCAAGCCAAAGAATCAGGATAACTGGGTGAAAGGTGATTGGGAAGCAGTTAAAATCTTAGTTCCCCTCTTCCACTCTCCGAGCAGCAGGTTTCTCTCTCTCATCAGGCAGAGGGCTGGAGAT", 76);
return MergeNodesWithNoVariationTestProvider.getTests(MergeNodesWithNoVariationTestProvider.class);
}
@Test(dataProvider = "MergeNodesWithNoVariationTestProvider", enabled = true)
public void testMergeNodesWithNoVariation(MergeNodesWithNoVariationTestProvider cfg) {
logger.warn(String.format("Test: %s", cfg.toString()));
Assert.assertTrue(graphEquals(cfg.calcGraph(), cfg.expectedGraph()));
}
@Test(enabled = true)
public void testPruneGraph() {
DefaultDirectedGraph<DeBruijnVertex,DeBruijnEdge> graph = new DefaultDirectedGraph<DeBruijnVertex, DeBruijnEdge>(DeBruijnEdge.class);
DefaultDirectedGraph<DeBruijnVertex,DeBruijnEdge> expectedGraph = new DefaultDirectedGraph<DeBruijnVertex, DeBruijnEdge>(DeBruijnEdge.class);
DeBruijnVertex v = new DeBruijnVertex("ATGG".getBytes(), 0);
DeBruijnVertex v2 = new DeBruijnVertex("ATGGA".getBytes(), 0);
DeBruijnVertex v3 = new DeBruijnVertex("ATGGT".getBytes(), 0);
DeBruijnVertex v4 = new DeBruijnVertex("ATGGG".getBytes(), 0);
DeBruijnVertex v5 = new DeBruijnVertex("ATGGC".getBytes(), 0);
DeBruijnVertex v6 = new DeBruijnVertex("ATGGCCCCCC".getBytes(), 0);
graph.addVertex(v);
graph.addVertex(v2);
graph.addVertex(v3);
graph.addVertex(v4);
graph.addVertex(v5);
graph.addVertex(v6);
graph.addEdge(v, v2, new DeBruijnEdge(false, 1));
graph.addEdge(v2, v3, new DeBruijnEdge(false, 3));
graph.addEdge(v3, v4, new DeBruijnEdge(false, 5));
graph.addEdge(v4, v5, new DeBruijnEdge(false, 3));
graph.addEdge(v5, v6, new DeBruijnEdge(false, 2));
expectedGraph.addVertex(v2);
expectedGraph.addVertex(v3);
expectedGraph.addVertex(v4);
expectedGraph.addVertex(v5);
expectedGraph.addEdge(v2, v3, new DeBruijnEdge(false, 3));
expectedGraph.addEdge(v3, v4, new DeBruijnEdge(false, 5));
expectedGraph.addEdge(v4, v5, new DeBruijnEdge(false, 3));
SimpleDeBruijnAssembler.pruneGraph(graph, 2);
Assert.assertTrue(graphEquals(graph, expectedGraph));
graph = new DefaultDirectedGraph<DeBruijnVertex, DeBruijnEdge>(DeBruijnEdge.class);
expectedGraph = new DefaultDirectedGraph<DeBruijnVertex, DeBruijnEdge>(DeBruijnEdge.class);
graph.addVertex(v);
graph.addVertex(v2);
graph.addVertex(v3);
graph.addVertex(v4);
graph.addVertex(v5);
graph.addVertex(v6);
graph.addEdge(v, v2, new DeBruijnEdge(true, 1));
graph.addEdge(v2, v3, new DeBruijnEdge(false, 3));
graph.addEdge(v3, v4, new DeBruijnEdge(false, 5));
graph.addEdge(v4, v5, new DeBruijnEdge(false, 3));
expectedGraph.addVertex(v);
expectedGraph.addVertex(v2);
expectedGraph.addVertex(v3);
expectedGraph.addVertex(v4);
expectedGraph.addVertex(v5);
expectedGraph.addEdge(v, v2, new DeBruijnEdge(true, 1));
expectedGraph.addEdge(v2, v3, new DeBruijnEdge(false, 3));
expectedGraph.addEdge(v3, v4, new DeBruijnEdge(false, 5));
expectedGraph.addEdge(v4, v5, new DeBruijnEdge(false, 3));
SimpleDeBruijnAssembler.pruneGraph(graph, 2);
Assert.assertTrue(graphEquals(graph, expectedGraph));
}
@Test(enabled = true)
public void testEliminateNonRefPaths() {
DefaultDirectedGraph<DeBruijnVertex,DeBruijnEdge> graph = new DefaultDirectedGraph<DeBruijnVertex, DeBruijnEdge>(DeBruijnEdge.class);
DefaultDirectedGraph<DeBruijnVertex,DeBruijnEdge> expectedGraph = new DefaultDirectedGraph<DeBruijnVertex, DeBruijnEdge>(DeBruijnEdge.class);
DeBruijnVertex v = new DeBruijnVertex("ATGG".getBytes(), 0);
DeBruijnVertex v2 = new DeBruijnVertex("ATGGA".getBytes(), 0);
DeBruijnVertex v3 = new DeBruijnVertex("ATGGT".getBytes(), 0);
DeBruijnVertex v4 = new DeBruijnVertex("ATGGG".getBytes(), 0);
DeBruijnVertex v5 = new DeBruijnVertex("ATGGC".getBytes(), 0);
DeBruijnVertex v6 = new DeBruijnVertex("ATGGCCCCCC".getBytes(), 0);
graph.addVertex(v);
graph.addVertex(v2);
graph.addVertex(v3);
graph.addVertex(v4);
graph.addVertex(v5);
graph.addVertex(v6);
graph.addEdge(v, v2, new DeBruijnEdge(false));
graph.addEdge(v2, v3, new DeBruijnEdge(true));
graph.addEdge(v3, v4, new DeBruijnEdge(true));
graph.addEdge(v4, v5, new DeBruijnEdge(true));
graph.addEdge(v5, v6, new DeBruijnEdge(false));
expectedGraph.addVertex(v2);
expectedGraph.addVertex(v3);
expectedGraph.addVertex(v4);
expectedGraph.addVertex(v5);
expectedGraph.addEdge(v2, v3, new DeBruijnEdge());
expectedGraph.addEdge(v3, v4, new DeBruijnEdge());
expectedGraph.addEdge(v4, v5, new DeBruijnEdge());
SimpleDeBruijnAssembler.eliminateNonRefPaths(graph);
Assert.assertTrue(graphEquals(graph, expectedGraph));
graph = new DefaultDirectedGraph<DeBruijnVertex, DeBruijnEdge>(DeBruijnEdge.class);
expectedGraph = new DefaultDirectedGraph<DeBruijnVertex, DeBruijnEdge>(DeBruijnEdge.class);
graph.addVertex(v);
graph.addVertex(v2);
graph.addVertex(v3);
graph.addVertex(v4);
graph.addVertex(v5);
graph.addVertex(v6);
graph.addEdge(v, v2, new DeBruijnEdge(true));
graph.addEdge(v2, v3, new DeBruijnEdge(true));
graph.addEdge(v4, v5, new DeBruijnEdge(false));
graph.addEdge(v5, v6, new DeBruijnEdge(false));
expectedGraph.addVertex(v);
expectedGraph.addVertex(v2);
expectedGraph.addVertex(v3);
expectedGraph.addEdge(v, v2, new DeBruijnEdge());
expectedGraph.addEdge(v2, v3, new DeBruijnEdge());
SimpleDeBruijnAssembler.eliminateNonRefPaths(graph);
Assert.assertTrue(graphEquals(graph, expectedGraph));
graph = new DefaultDirectedGraph<DeBruijnVertex, DeBruijnEdge>(DeBruijnEdge.class);
expectedGraph = new DefaultDirectedGraph<DeBruijnVertex, DeBruijnEdge>(DeBruijnEdge.class);
graph.addVertex(v);
graph.addVertex(v2);
graph.addVertex(v3);
graph.addVertex(v4);
graph.addVertex(v5);
graph.addVertex(v6);
graph.addEdge(v, v2, new DeBruijnEdge(true));
graph.addEdge(v2, v3, new DeBruijnEdge(true));
graph.addEdge(v4, v5, new DeBruijnEdge(false));
graph.addEdge(v5, v6, new DeBruijnEdge(false));
graph.addEdge(v4, v2, new DeBruijnEdge(false));
expectedGraph.addVertex(v);
expectedGraph.addVertex(v2);
expectedGraph.addVertex(v3);
expectedGraph.addEdge(v, v2, new DeBruijnEdge());
expectedGraph.addEdge(v2, v3, new DeBruijnEdge());
SimpleDeBruijnAssembler.eliminateNonRefPaths(graph);
Assert.assertTrue(graphEquals(graph, expectedGraph));
}
private boolean graphEquals(DefaultDirectedGraph<DeBruijnVertex,DeBruijnEdge> g1, DefaultDirectedGraph<DeBruijnVertex,DeBruijnEdge> g2) {
if( !(g1.vertexSet().containsAll(g2.vertexSet()) && g2.vertexSet().containsAll(g1.vertexSet())) ) {
return false;
}
for( DeBruijnEdge e1 : g1.edgeSet() ) {
boolean found = false;
for( DeBruijnEdge e2 : g2.edgeSet() ) {
if( e1.equals(g1, e2, g2) ) { found = true; break; }
}
if( !found ) { return false; }
}
for( DeBruijnEdge e2 : g2.edgeSet() ) {
boolean found = false;
for( DeBruijnEdge e1 : g1.edgeSet() ) {
if( e2.equals(g2, e1, g1) ) { found = true; break; }
}
if( !found ) { return false; }
}
return true;
}
}