diff --git a/protected/java/src/org/broadinstitute/sting/gatk/walkers/haplotypecaller/HaplotypeCaller.java b/protected/java/src/org/broadinstitute/sting/gatk/walkers/haplotypecaller/HaplotypeCaller.java index 1ea7354cf..df786bc20 100755 --- a/protected/java/src/org/broadinstitute/sting/gatk/walkers/haplotypecaller/HaplotypeCaller.java +++ b/protected/java/src/org/broadinstitute/sting/gatk/walkers/haplotypecaller/HaplotypeCaller.java @@ -357,9 +357,9 @@ public class HaplotypeCaller extends ActiveRegionWalker implem } } } - genotypeLikelihoods[AA] += QualityUtils.qualToProbLog10(qual); - genotypeLikelihoods[AB] += MathUtils.approximateLog10SumLog10( QualityUtils.qualToProbLog10(qual) + LOG_ONE_HALF, QualityUtils.qualToErrorProbLog10(qual) + LOG_ONE_THIRD + LOG_ONE_HALF ); - genotypeLikelihoods[BB] += QualityUtils.qualToErrorProbLog10(qual) + LOG_ONE_THIRD; + genotypeLikelihoods[AA] += p.getRepresentativeCount() * QualityUtils.qualToProbLog10(qual); + genotypeLikelihoods[AB] += p.getRepresentativeCount() * MathUtils.approximateLog10SumLog10( QualityUtils.qualToProbLog10(qual) + LOG_ONE_HALF, QualityUtils.qualToErrorProbLog10(qual) + LOG_ONE_THIRD + LOG_ONE_HALF ); + genotypeLikelihoods[BB] += p.getRepresentativeCount() * QualityUtils.qualToErrorProbLog10(qual) + LOG_ONE_THIRD; } } genotypes.add( new GenotypeBuilder(sample).alleles(noCall).PL(genotypeLikelihoods).make() ); diff --git a/protected/java/src/org/broadinstitute/sting/gatk/walkers/haplotypecaller/LikelihoodCalculationEngine.java b/protected/java/src/org/broadinstitute/sting/gatk/walkers/haplotypecaller/LikelihoodCalculationEngine.java index f1d0a8a12..1a1348487 100644 --- a/protected/java/src/org/broadinstitute/sting/gatk/walkers/haplotypecaller/LikelihoodCalculationEngine.java +++ b/protected/java/src/org/broadinstitute/sting/gatk/walkers/haplotypecaller/LikelihoodCalculationEngine.java @@ -50,7 +50,6 @@ public class LikelihoodCalculationEngine { } public void computeReadLikelihoods( final ArrayList haplotypes, final HashMap> perSampleReadList ) { - final int numHaplotypes = haplotypes.size(); int X_METRIC_LENGTH = 0; for( final String sample : perSampleReadList.keySet() ) { @@ -60,8 +59,8 @@ public class LikelihoodCalculationEngine { } } int Y_METRIC_LENGTH = 0; - for( int jjj = 0; jjj < numHaplotypes; jjj++ ) { - final int haplotypeLength = haplotypes.get(jjj).getBases().length; + for( final Haplotype h : haplotypes ) { + final int haplotypeLength = h.getBases().length; if( haplotypeLength > Y_METRIC_LENGTH ) { Y_METRIC_LENGTH = haplotypeLength; } } @@ -90,8 +89,10 @@ public class LikelihoodCalculationEngine { final int numHaplotypes = haplotypes.size(); final int numReads = reads.size(); final double[][] readLikelihoods = new double[numHaplotypes][numReads]; + final int[][] readCounts = new int[numHaplotypes][numReads]; for( int iii = 0; iii < numReads; iii++ ) { final GATKSAMRecord read = reads.get(iii); + final int readCount = getRepresentativeReadCount(read); final byte[] overallGCP = new byte[read.getReadLength()]; Arrays.fill( overallGCP, constantGCP ); // Is there a way to derive empirical estimates for this from the data? @@ -114,13 +115,23 @@ public class LikelihoodCalculationEngine { readLikelihoods[jjj][iii] = pairHMM.computeReadLikelihoodGivenHaplotype(haplotype.getBases(), read.getReadBases(), readQuals, readInsQuals, readDelQuals, overallGCP, haplotypeStart, matchMetricArray, XMetricArray, YMetricArray); + readCounts[jjj][iii] = readCount; } } for( int jjj = 0; jjj < numHaplotypes; jjj++ ) { - haplotypes.get(jjj).addReadLikelihoods( sample, readLikelihoods[jjj] ); + haplotypes.get(jjj).addReadLikelihoods( sample, readLikelihoods[jjj], readCounts[jjj] ); } } + private static int getRepresentativeReadCount(GATKSAMRecord read) { + if (!read.isReducedRead()) + return 1; + + // compute mean representative read counts + final byte[] counts = read.getReducedReadCounts(); + return MathUtils.sum(counts)/counts.length; + } + private static int computeFirstDifferingPosition( final byte[] b1, final byte[] b2 ) { for( int iii = 0; iii < b1.length && iii < b2.length; iii++ ){ if( b1[iii] != b2[iii] ) { @@ -142,10 +153,20 @@ public class LikelihoodCalculationEngine { } return computeDiploidHaplotypeLikelihoods( sample, haplotypeMapping ); } - + + // This function takes just a single sample and a haplotypeMapping @Requires({"haplotypeMapping.size() > 0"}) @Ensures({"result.length == result[0].length", "result.length == haplotypeMapping.size()"}) public static double[][] computeDiploidHaplotypeLikelihoods( final String sample, final ArrayList> haplotypeMapping ) { + final TreeSet sampleSet = new TreeSet(); + sampleSet.add(sample); + return computeDiploidHaplotypeLikelihoods(sampleSet, haplotypeMapping); + } + + // This function takes a set of samples to pool over and a haplotypeMapping + @Requires({"haplotypeMapping.size() > 0"}) + @Ensures({"result.length == result[0].length", "result.length == haplotypeMapping.size()"}) + public static double[][] computeDiploidHaplotypeLikelihoods( final Set samples, final ArrayList> haplotypeMapping ) { final int numHaplotypes = haplotypeMapping.size(); final double[][] haplotypeLikelihoodMatrix = new double[numHaplotypes][numHaplotypes]; @@ -154,17 +175,22 @@ public class LikelihoodCalculationEngine { } // compute the diploid haplotype likelihoods + // todo - needs to be generalized to arbitrary ploidy, cleaned and merged with PairHMMIndelErrorModel code for( int iii = 0; iii < numHaplotypes; iii++ ) { for( int jjj = 0; jjj <= iii; jjj++ ) { for( final Haplotype iii_mapped : haplotypeMapping.get(iii) ) { - final double[] readLikelihoods_iii = iii_mapped.getReadLikelihoods(sample); for( final Haplotype jjj_mapped : haplotypeMapping.get(jjj) ) { - final double[] readLikelihoods_jjj = jjj_mapped.getReadLikelihoods(sample); double haplotypeLikelihood = 0.0; - for( int kkk = 0; kkk < readLikelihoods_iii.length; kkk++ ) { - // Compute log10(10^x1/2 + 10^x2/2) = log10(10^x1+10^x2)-log10(2) - // First term is approximated by Jacobian log with table lookup. - haplotypeLikelihood += MathUtils.approximateLog10SumLog10(readLikelihoods_iii[kkk], readLikelihoods_jjj[kkk]) + LOG_ONE_HALF; + for( final String sample : samples ) { + final double[] readLikelihoods_iii = iii_mapped.getReadLikelihoods(sample); + final int[] readCounts_iii = iii_mapped.getReadCounts(sample); + final double[] readLikelihoods_jjj = jjj_mapped.getReadLikelihoods(sample); + for( int kkk = 0; kkk < readLikelihoods_iii.length; kkk++ ) { + // Compute log10(10^x1/2 + 10^x2/2) = log10(10^x1+10^x2)-log10(2) + // log10(10^(a*x1) + 10^(b*x2)) ??? + // First term is approximated by Jacobian log with table lookup. + haplotypeLikelihood += readCounts_iii[kkk] * ( MathUtils.approximateLog10SumLog10(readLikelihoods_iii[kkk], readLikelihoods_jjj[kkk]) + LOG_ONE_HALF ); + } } haplotypeLikelihoodMatrix[iii][jjj] = Math.max(haplotypeLikelihoodMatrix[iii][jjj], haplotypeLikelihood); // MathUtils.approximateLog10SumLog10(haplotypeLikelihoodMatrix[iii][jjj], haplotypeLikelihood); // BUGBUG: max or sum? } @@ -176,48 +202,6 @@ public class LikelihoodCalculationEngine { return normalizeDiploidLikelihoodMatrixFromLog10( haplotypeLikelihoodMatrix ); } - @Requires({"haplotypes.size() > 0"}) - @Ensures({"result.length == result[0].length", "result.length == haplotypes.size()"}) - public static double[][] computeDiploidHaplotypeLikelihoods( final ArrayList haplotypes, final Set samples ) { - // set up the default 1-to-1 haplotype mapping object, BUGBUG: target for future optimization? - final ArrayList> haplotypeMapping = new ArrayList>(); - for( final Haplotype h : haplotypes ) { - final ArrayList list = new ArrayList(); - list.add(h); - haplotypeMapping.add(list); - } - - final int numHaplotypes = haplotypeMapping.size(); - final double[][] haplotypeLikelihoodMatrix = new double[numHaplotypes][numHaplotypes]; - for( int iii = 0; iii < numHaplotypes; iii++ ) { - Arrays.fill(haplotypeLikelihoodMatrix[iii], Double.NEGATIVE_INFINITY); - } - - // compute the diploid haplotype likelihoods - for( int iii = 0; iii < numHaplotypes; iii++ ) { - for( int jjj = 0; jjj <= iii; jjj++ ) { - for( final Haplotype iii_mapped : haplotypeMapping.get(iii) ) { - for( final Haplotype jjj_mapped : haplotypeMapping.get(jjj) ) { - double haplotypeLikelihood = 0.0; - for( final String sample : samples ) { - final double[] readLikelihoods_iii = iii_mapped.getReadLikelihoods(sample); - final double[] readLikelihoods_jjj = jjj_mapped.getReadLikelihoods(sample); - for( int kkk = 0; kkk < readLikelihoods_iii.length; kkk++ ) { - // Compute log10(10^x1/2 + 10^x2/2) = log10(10^x1+10^x2)-log10(2) - // First term is approximated by Jacobian log with table lookup. - haplotypeLikelihood += MathUtils.approximateLog10SumLog10(readLikelihoods_iii[kkk], readLikelihoods_jjj[kkk]) + LOG_ONE_HALF; - } - } - haplotypeLikelihoodMatrix[iii][jjj] = Math.max(haplotypeLikelihoodMatrix[iii][jjj], haplotypeLikelihood); // MathUtils.approximateLog10SumLog10(haplotypeLikelihoodMatrix[iii][jjj], haplotypeLikelihood); // BUGBUG: max or sum? - } - } - } - } - - // normalize the diploid likelihoods matrix - return normalizeDiploidLikelihoodMatrixFromLog10( haplotypeLikelihoodMatrix ); - } - @Requires({"likelihoodMatrix.length == likelihoodMatrix[0].length"}) @Ensures({"result.length == result[0].length", "result.length == likelihoodMatrix.length"}) protected static double[][] normalizeDiploidLikelihoodMatrixFromLog10( final double[][] likelihoodMatrix ) { @@ -306,7 +290,14 @@ public class LikelihoodCalculationEngine { final Set sampleKeySet = haplotypes.get(0).getSampleKeySet(); // BUGBUG: assume all haplotypes saw the same samples final ArrayList bestHaplotypesIndexList = new ArrayList(); bestHaplotypesIndexList.add(0); // always start with the reference haplotype - final double[][] haplotypeLikelihoodMatrix = computeDiploidHaplotypeLikelihoods( haplotypes, sampleKeySet ); // all samples pooled together + // set up the default 1-to-1 haplotype mapping object + final ArrayList> haplotypeMapping = new ArrayList>(); + for( final Haplotype h : haplotypes ) { + final ArrayList list = new ArrayList(); + list.add(h); + haplotypeMapping.add(list); + } + final double[][] haplotypeLikelihoodMatrix = computeDiploidHaplotypeLikelihoods( sampleKeySet, haplotypeMapping ); // all samples pooled together int hap1 = 0; int hap2 = 0; diff --git a/protected/java/src/org/broadinstitute/sting/gatk/walkers/haplotypecaller/SimpleDeBruijnAssembler.java b/protected/java/src/org/broadinstitute/sting/gatk/walkers/haplotypecaller/SimpleDeBruijnAssembler.java index be6c4a51f..72c4e0c7e 100755 --- a/protected/java/src/org/broadinstitute/sting/gatk/walkers/haplotypecaller/SimpleDeBruijnAssembler.java +++ b/protected/java/src/org/broadinstitute/sting/gatk/walkers/haplotypecaller/SimpleDeBruijnAssembler.java @@ -4,6 +4,7 @@ import com.google.java.contract.Ensures; import org.apache.commons.lang.ArrayUtils; import org.broadinstitute.sting.utils.GenomeLoc; import org.broadinstitute.sting.utils.Haplotype; +import org.broadinstitute.sting.utils.MathUtils; import org.broadinstitute.sting.utils.SWPairwiseAlignment; import org.broadinstitute.sting.utils.activeregion.ActiveRegion; import org.broadinstitute.sting.utils.exceptions.ReviewedStingException; @@ -68,7 +69,7 @@ public class SimpleDeBruijnAssembler extends LocalAssemblyEngine { return findBestPaths( refHaplotype, fullReferenceWithPadding, refLoc, activeAllelesToGenotype, activeRegion.getExtendedLoc() ); } - private void createDeBruijnGraphs( final ArrayList reads, final Haplotype refHaplotype ) { + protected void createDeBruijnGraphs( final List reads, final Haplotype refHaplotype ) { graphs.clear(); // create the graph @@ -161,7 +162,7 @@ public class SimpleDeBruijnAssembler extends LocalAssemblyEngine { } } - private static boolean createGraphFromSequences( final DefaultDirectedGraph graph, final ArrayList reads, final int KMER_LENGTH, final Haplotype refHaplotype, final boolean DEBUG ) { + private static boolean createGraphFromSequences( final DefaultDirectedGraph graph, final Collection reads, final int KMER_LENGTH, final Haplotype refHaplotype, final boolean DEBUG ) { final byte[] refSequence = refHaplotype.getBases(); if( refSequence.length >= KMER_LENGTH + KMER_OVERLAP ) { final int kmersInSequence = refSequence.length - KMER_LENGTH + 1; @@ -183,6 +184,7 @@ public class SimpleDeBruijnAssembler extends LocalAssemblyEngine { for( final GATKSAMRecord read : reads ) { final byte[] sequence = read.getReadBases(); final byte[] qualities = read.getBaseQualities(); + final byte[] reducedReadCounts = read.getReducedReadCounts(); // will be null if read is not readuced if( sequence.length > KMER_LENGTH + KMER_OVERLAP ) { final int kmersInSequence = sequence.length - KMER_LENGTH + 1; for( int iii = 0; iii < kmersInSequence - 1; iii++ ) { @@ -194,6 +196,12 @@ public class SimpleDeBruijnAssembler extends LocalAssemblyEngine { break; } } + int countNumber = 1; + if (read.isReducedRead()) { + // compute min (?) number of reduced read counts in current kmer span + countNumber = MathUtils.arrayMin(Arrays.copyOfRange(reducedReadCounts,iii,iii+KMER_LENGTH+1)); + } + if( !badKmer ) { // get the kmers final byte[] kmer1 = new byte[KMER_LENGTH]; @@ -201,7 +209,8 @@ public class SimpleDeBruijnAssembler extends LocalAssemblyEngine { final byte[] kmer2 = new byte[KMER_LENGTH]; System.arraycopy(sequence, iii+1, kmer2, 0, KMER_LENGTH); - addKmersToGraph(graph, kmer1, kmer2, false); + for (int k=0; k < countNumber; k++) + addKmersToGraph(graph, kmer1, kmer2, false); } } } @@ -230,7 +239,7 @@ public class SimpleDeBruijnAssembler extends LocalAssemblyEngine { return true; } - private void printGraphs() { + protected void printGraphs() { int count = 0; for( final DefaultDirectedGraph graph : graphs ) { GRAPH_WRITER.println("digraph kmer" + count++ +" {"); diff --git a/protected/java/test/org/broadinstitute/sting/gatk/walkers/haplotypecaller/LikelihoodCalculationEngineUnitTest.java b/protected/java/test/org/broadinstitute/sting/gatk/walkers/haplotypecaller/LikelihoodCalculationEngineUnitTest.java index 185641140..e82946690 100644 --- a/protected/java/test/org/broadinstitute/sting/gatk/walkers/haplotypecaller/LikelihoodCalculationEngineUnitTest.java +++ b/protected/java/test/org/broadinstitute/sting/gatk/walkers/haplotypecaller/LikelihoodCalculationEngineUnitTest.java @@ -95,9 +95,10 @@ public class LikelihoodCalculationEngineUnitTest extends BaseTest { ArrayList haplotypes = new ArrayList(); for( int iii = 1; iii <= 3; iii++) { Double readLikelihood = ( iii == 1 ? readLikelihoodForHaplotype1 : ( iii == 2 ? readLikelihoodForHaplotype2 : readLikelihoodForHaplotype3) ); + int readCount = 1; if( readLikelihood != null ) { Haplotype haplotype = new Haplotype( (iii == 1 ? "AAAA" : (iii == 2 ? "CCCC" : "TTTT")).getBytes() ); - haplotype.addReadLikelihoods("myTestSample", new double[]{readLikelihood}); + haplotype.addReadLikelihoods("myTestSample", new double[]{readLikelihood}, new int[]{readCount}); haplotypes.add(haplotype); } } diff --git a/protected/java/test/org/broadinstitute/sting/gatk/walkers/haplotypecaller/SimpleDeBruijnAssemblerUnitTest.java b/protected/java/test/org/broadinstitute/sting/gatk/walkers/haplotypecaller/SimpleDeBruijnAssemblerUnitTest.java index 4f42d5bc8..5652b118d 100644 --- a/protected/java/test/org/broadinstitute/sting/gatk/walkers/haplotypecaller/SimpleDeBruijnAssemblerUnitTest.java +++ b/protected/java/test/org/broadinstitute/sting/gatk/walkers/haplotypecaller/SimpleDeBruijnAssemblerUnitTest.java @@ -7,6 +7,8 @@ package org.broadinstitute.sting.gatk.walkers.haplotypecaller; */ import org.broadinstitute.sting.BaseTest; +import org.broadinstitute.sting.gatk.contexts.AlignmentContext; +import org.broadinstitute.sting.gatk.walkers.genotyper.ArtificialReadPileupTestProvider; import org.broadinstitute.sting.utils.Haplotype; import org.broadinstitute.sting.utils.MathUtils; import org.broadinstitute.sting.utils.variantcontext.Allele; @@ -18,6 +20,7 @@ import org.testng.annotations.Test; import java.io.File; import java.io.FileNotFoundException; +import java.io.PrintStream; import java.util.*; public class SimpleDeBruijnAssemblerUnitTest extends BaseTest { @@ -143,6 +146,44 @@ public class SimpleDeBruijnAssemblerUnitTest extends BaseTest { Assert.assertTrue(graphEquals(graph, expectedGraph)); } + @Test(enabled=false) +// not ready yet + public void testBasicGraphCreation() { + final ArtificialReadPileupTestProvider refPileupTestProvider = new ArtificialReadPileupTestProvider(1,"ref"); + final byte refBase = refPileupTestProvider.getReferenceContext().getBase(); + final String altBase = (refBase==(byte)'A'?"C":"A"); + final int matches = 50; + final int mismatches = 50; + Map refContext = refPileupTestProvider.getAlignmentContextFromAlleles(0, altBase, new int[]{matches, mismatches}, false, 30); + PrintStream graphWriter = null; + + try{ + graphWriter = new PrintStream("du.txt"); + } catch (Exception e) {} + + + SimpleDeBruijnAssembler assembler = new SimpleDeBruijnAssembler(true,graphWriter); + final Haplotype refHaplotype = new Haplotype(refPileupTestProvider.getReferenceContext().getBases()); + refHaplotype.setIsReference(true); + assembler.createDeBruijnGraphs(refContext.get(refPileupTestProvider.getSampleNames().get(0)).getBasePileup().getReads(), refHaplotype); + +/* // clean up the graphs by pruning and merging + for( final DefaultDirectedGraph graph : graphs ) { + SimpleDeBruijnAssembler.pruneGraph( graph, PRUNE_FACTOR ); + //eliminateNonRefPaths( graph ); + SimpleDeBruijnAssembler.mergeNodes( graph ); + } + */ + if( graphWriter != null ) { + assembler.printGraphs(); + } + + int k=2; + + // find the best paths in the graphs + // return findBestPaths( refHaplotype, fullReferenceWithPadding, refLoc, activeAllelesToGenotype, activeRegion.getExtendedLoc() ); + + } @Test(enabled = true) public void testEliminateNonRefPaths() { DefaultDirectedGraph graph = new DefaultDirectedGraph(DeBruijnEdge.class); diff --git a/public/java/src/org/broadinstitute/sting/gatk/report/GATKReportTable.java b/public/java/src/org/broadinstitute/sting/gatk/report/GATKReportTable.java index 7a272e155..3b4bdd087 100755 --- a/public/java/src/org/broadinstitute/sting/gatk/report/GATKReportTable.java +++ b/public/java/src/org/broadinstitute/sting/gatk/report/GATKReportTable.java @@ -208,11 +208,23 @@ public class GATKReportTable { } /** - * Verifies that a table or column name has only alphanumeric characters - no spaces or special characters allowed - * - * @param name the name of the table or column - * @return true if the name is valid, false if otherwise + * Create a new GATKReportTable with the same structure + * @param tableToCopy */ + public GATKReportTable(final GATKReportTable tableToCopy, final boolean copyData) { + this(tableToCopy.getTableName(), tableToCopy.getTableDescription(), tableToCopy.getNumColumns(), tableToCopy.sortByRowID); + for ( final GATKReportColumn column : tableToCopy.getColumnInfo() ) + addColumn(column.getColumnName(), column.getFormat()); + if ( copyData ) + throw new IllegalArgumentException("sorry, copying data in GATKReportTable isn't supported"); + } + + /** + * Verifies that a table or column name has only alphanumeric characters - no spaces or special characters allowed + * + * @param name the name of the table or column + * @return true if the name is valid, false if otherwise + */ private boolean isValidName(String name) { Pattern p = Pattern.compile(INVALID_TABLE_NAME_REGEX); Matcher m = p.matcher(name); @@ -490,6 +502,17 @@ public class GATKReportTable { return get(rowIdToIndex.get(rowID), columnNameToIndex.get(columnName)); } + /** + * Get a value from the given position in the table + * + * @param rowIndex the row ID + * @param columnName the name of the column + * @return the value stored at the specified position in the table + */ + public Object get(final int rowIndex, final String columnName) { + return get(rowIndex, columnNameToIndex.get(columnName)); + } + /** * Get a value from the given position in the table * diff --git a/public/java/src/org/broadinstitute/sting/gatk/walkers/genotyper/UnifiedArgumentCollection.java b/public/java/src/org/broadinstitute/sting/gatk/walkers/genotyper/UnifiedArgumentCollection.java index 020f7904d..69f1176cc 100755 --- a/public/java/src/org/broadinstitute/sting/gatk/walkers/genotyper/UnifiedArgumentCollection.java +++ b/public/java/src/org/broadinstitute/sting/gatk/walkers/genotyper/UnifiedArgumentCollection.java @@ -114,11 +114,11 @@ public class UnifiedArgumentCollection { * that you not play around with this parameter. */ @Advanced - @Argument(fullName = "max_alternate_alleles", shortName = "maxAlleles", doc = "Maximum number of alternate alleles to genotype", required = false) + @Argument(fullName = "max_alternate_alleles", shortName = "maxAltAlleles", doc = "Maximum number of alternate alleles to genotype", required = false) public int MAX_ALTERNATE_ALLELES = 3; @Hidden - @Argument(fullName = "cap_max_alternate_alleles_for_indels", shortName = "capMaxAllelesForIndels", doc = "Cap the maximum number of alternate alleles to genotype for indel calls at 2; overrides the --max_alternate_alleles argument; GSA production use only", required = false) + @Argument(fullName = "cap_max_alternate_alleles_for_indels", shortName = "capMaxAltAllelesForIndels", doc = "Cap the maximum number of alternate alleles to genotype for indel calls at 2; overrides the --max_alternate_alleles argument; GSA production use only", required = false) public boolean CAP_MAX_ALTERNATE_ALLELES_FOR_INDELS = false; // indel-related arguments diff --git a/public/java/src/org/broadinstitute/sting/utils/Haplotype.java b/public/java/src/org/broadinstitute/sting/utils/Haplotype.java index 188d01098..fcde1f419 100755 --- a/public/java/src/org/broadinstitute/sting/utils/Haplotype.java +++ b/public/java/src/org/broadinstitute/sting/utils/Haplotype.java @@ -41,6 +41,7 @@ public class Haplotype { protected final double[] quals; private GenomeLoc genomeLocation = null; private HashMap readLikelihoodsPerSample = null; + private HashMap readCountsPerSample = null; private HashMap eventMap = null; private boolean isRef = false; private Cigar cigar; @@ -84,18 +85,27 @@ public class Haplotype { return Arrays.hashCode(bases); } - public void addReadLikelihoods( final String sample, final double[] readLikelihoods ) { + public void addReadLikelihoods( final String sample, final double[] readLikelihoods, final int[] readCounts ) { if( readLikelihoodsPerSample == null ) { readLikelihoodsPerSample = new HashMap(); } readLikelihoodsPerSample.put(sample, readLikelihoods); + if( readCountsPerSample == null ) { + readCountsPerSample = new HashMap(); + } + readCountsPerSample.put(sample, readCounts); } @Ensures({"result != null"}) public double[] getReadLikelihoods( final String sample ) { return readLikelihoodsPerSample.get(sample); } - + + @Ensures({"result != null"}) + public int[] getReadCounts( final String sample ) { + return readCountsPerSample.get(sample); + } + public Set getSampleKeySet() { return readLikelihoodsPerSample.keySet(); } diff --git a/public/java/src/org/broadinstitute/sting/utils/MathUtils.java b/public/java/src/org/broadinstitute/sting/utils/MathUtils.java index 1f5eaefee..96704f0b8 100644 --- a/public/java/src/org/broadinstitute/sting/utils/MathUtils.java +++ b/public/java/src/org/broadinstitute/sting/utils/MathUtils.java @@ -210,6 +210,13 @@ public class MathUtils { return total; } + public static int sum(byte[] x) { + int total = 0; + for (byte v : x) + total += (int)v; + return total; + } + /** * Calculates the log10 cumulative sum of an array with log10 probabilities * diff --git a/public/java/src/org/broadinstitute/sting/utils/codecs/vcf/AbstractVCFCodec.java b/public/java/src/org/broadinstitute/sting/utils/codecs/vcf/AbstractVCFCodec.java index 996cef8a4..043e5e185 100755 --- a/public/java/src/org/broadinstitute/sting/utils/codecs/vcf/AbstractVCFCodec.java +++ b/public/java/src/org/broadinstitute/sting/utils/codecs/vcf/AbstractVCFCodec.java @@ -237,7 +237,12 @@ public abstract class AbstractVCFCodec extends AsciiFeatureCodec // parse out the required fields final String chr = getCachedString(parts[0]); builder.chr(chr); - int pos = Integer.valueOf(parts[1]); + int pos = -1; + try { + pos = Integer.valueOf(parts[1]); + } catch (NumberFormatException e) { + generateException(parts[1] + " is not a valid start position in the VCF format"); + } builder.start(pos); if ( parts[2].length() == 0 ) diff --git a/public/java/src/org/broadinstitute/sting/utils/fragments/FragmentUtils.java b/public/java/src/org/broadinstitute/sting/utils/fragments/FragmentUtils.java index 851272673..2f31c154c 100644 --- a/public/java/src/org/broadinstitute/sting/utils/fragments/FragmentUtils.java +++ b/public/java/src/org/broadinstitute/sting/utils/fragments/FragmentUtils.java @@ -134,17 +134,36 @@ public class FragmentUtils { GATKSAMRecord firstRead = overlappingPair.get(0); GATKSAMRecord secondRead = overlappingPair.get(1); - if( !(secondRead.getUnclippedStart() <= firstRead.getUnclippedEnd() && secondRead.getUnclippedStart() >= firstRead.getUnclippedStart() && secondRead.getUnclippedEnd() >= firstRead.getUnclippedEnd()) ) { + /* + System.out.println("read 0 unclipped start:"+overlappingPair.get(0).getUnclippedStart()); + System.out.println("read 0 unclipped end:"+overlappingPair.get(0).getUnclippedEnd()); + System.out.println("read 1 unclipped start:"+overlappingPair.get(1).getUnclippedStart()); + System.out.println("read 1 unclipped end:"+overlappingPair.get(1).getUnclippedEnd()); + System.out.println("read 0 start:"+overlappingPair.get(0).getAlignmentStart()); + System.out.println("read 0 end:"+overlappingPair.get(0).getAlignmentEnd()); + System.out.println("read 1 start:"+overlappingPair.get(1).getAlignmentStart()); + System.out.println("read 1 end:"+overlappingPair.get(1).getAlignmentEnd()); + */ + if( !(secondRead.getSoftStart() <= firstRead.getSoftEnd() && secondRead.getSoftStart() >= firstRead.getSoftStart() && secondRead.getSoftEnd() >= firstRead.getSoftEnd()) ) { firstRead = overlappingPair.get(1); // swap them secondRead = overlappingPair.get(0); } - if( !(secondRead.getUnclippedStart() <= firstRead.getUnclippedEnd() && secondRead.getUnclippedStart() >= firstRead.getUnclippedStart() && secondRead.getUnclippedEnd() >= firstRead.getUnclippedEnd()) ) { + if( !(secondRead.getSoftStart() <= firstRead.getSoftEnd() && secondRead.getSoftStart() >= firstRead.getSoftStart() && secondRead.getSoftEnd() >= firstRead.getSoftEnd()) ) { return overlappingPair; // can't merge them, yet: AAAAAAAAAAA-BBBBBBBBBBB-AAAAAAAAAAAAAA, B is contained entirely inside A } if( firstRead.getCigarString().contains("I") || firstRead.getCigarString().contains("D") || secondRead.getCigarString().contains("I") || secondRead.getCigarString().contains("D") ) { return overlappingPair; // fragments contain indels so don't merge them } +/* // check for inconsistent start positions between uncliped/soft alignment starts + if (secondRead.getAlignmentStart() >= firstRead.getAlignmentStart() && secondRead.getUnclippedStart() < firstRead.getUnclippedStart()) + return overlappingPair; + if (secondRead.getAlignmentStart() <= firstRead.getAlignmentStart() && secondRead.getUnclippedStart() > firstRead.getUnclippedStart()) + return overlappingPair; + + if (secondRead.getUnclippedStart() < firstRead.getAlignmentEnd() && secondRead.getAlignmentStart() >= firstRead.getAlignmentEnd()) + return overlappingPair; + */ final Pair pair = ReadUtils.getReadCoordinateForReferenceCoordinate(firstRead, secondRead.getSoftStart()); final int firstReadStop = ( pair.getSecond() ? pair.getFirst() + 1 : pair.getFirst() ); diff --git a/public/java/src/org/broadinstitute/sting/utils/recalibration/AdaptiveContext.java b/public/java/src/org/broadinstitute/sting/utils/recalibration/AdaptiveContext.java new file mode 100644 index 000000000..083b8af64 --- /dev/null +++ b/public/java/src/org/broadinstitute/sting/utils/recalibration/AdaptiveContext.java @@ -0,0 +1,154 @@ +package org.broadinstitute.sting.utils.recalibration; + +import java.util.*; + +/** + * Functions for working with AdaptiveContexts + * + * User: depristo + * Date: 8/3/12 + * Time: 12:21 PM + * To change this template use File | Settings | File Templates. + */ +public class AdaptiveContext { + private AdaptiveContext() {} + + /** + * Return a freshly allocated tree filled in completely to fillDepth with + * all combinations of {A,C,G,T}^filldepth contexts. For nodes + * in the tree, they are simply copied. When the algorithm needs to + * generate new nodes (because they are missing) the subnodes inherit the + * observation and error counts of their parent. + * + * This algorithm produces data consistent with the standard output in a BQSR recal + * file for the Context covariate + * + * @param root + * @param fillDepth + * @return + */ + public static RecalDatumNode fillToDepth(final RecalDatumNode root, final int fillDepth) { + if ( root == null ) throw new IllegalArgumentException("root is null"); + if ( fillDepth < 0 ) throw new IllegalArgumentException("fillDepth is < 0"); + + return fillToDepthRec(root, fillDepth, 0); + } + + private static RecalDatumNode fillToDepthRec(final RecalDatumNode parent, + final int fillDepth, + final int currentDepth) { + // three cases: + // We are in the tree and so just recursively build + // We have reached our depth goal, so just return the parent since we are done + // We are outside of the tree, in which case we need to pointer to our parent node so we can + // we info (N, M) and we need a running context + if ( currentDepth < fillDepth ) { + // we need to create subnodes for each base, and propogate N and M down + final RecalDatumNode newParent = new RecalDatumNode(parent.getRecalDatum()); + + for ( final String base : Arrays.asList("A", "C", "G", "T")) { + ContextDatum subContext; + Set> subContexts; + + final RecalDatumNode subNode = findSubcontext(parent.getRecalDatum().context + base, parent); + if ( subNode != null ) { + // we have a subnode corresponding to the expected one, just copy and recurse + subContext = subNode.getRecalDatum(); + subContexts = subNode.getSubnodes(); + } else { + // have to create a new one + subContext = new ContextDatum(parent.getRecalDatum().context + base, + parent.getRecalDatum().getNumObservations(), parent.getRecalDatum().getNumMismatches()); + subContexts = Collections.emptySet(); + } + + newParent.addSubnode( + fillToDepthRec(new RecalDatumNode(subContext, subContexts), + fillDepth, currentDepth + 1)); + } + return newParent; + } else { + return parent; + } + } + + /** + * Go from a flat list of contexts to the tree implied by the contexts + * + * Implicit nodes are created as needed, and their observation and error counts are the sum of the + * all of their subnodes. + * + * Note this does not guarentee the tree is complete, as some contexts (e.g., AAT) may be missing + * from the tree because they are absent from the input list of contexts. + * + * For input AAG, AAT, AC, G would produce the following tree: + * + * - x [root] + * - A + * - A + * - T + * - G + * - C + * - G + * + * sets the fixed penalties in the resulting tree as well + * + * @param flatContexts list of flat contexts + * @return + */ + public static RecalDatumNode createTreeFromFlatContexts(final List flatContexts) { + if ( flatContexts == null || flatContexts.isEmpty() ) + throw new IllegalArgumentException("flatContexts cannot be empty or null"); + + final Queue> remaining = new LinkedList>(); + final Map> contextToNodes = new HashMap>(); + RecalDatumNode root = null; + + // initialize -- start with all of the contexts + for ( final ContextDatum cd : flatContexts ) + remaining.add(new RecalDatumNode(cd)); + + while ( remaining.peek() != null ) { + final RecalDatumNode add = remaining.poll(); + final ContextDatum cd = add.getRecalDatum(); + + final String parentContext = cd.getParentContext(); + RecalDatumNode parent = contextToNodes.get(parentContext); + if ( parent == null ) { + // haven't yet found parent, so make one, and enqueue it for processing + parent = new RecalDatumNode(new ContextDatum(parentContext, 0, 0)); + contextToNodes.put(parentContext, parent); + + if ( parentContext != ContextDatum.ROOT_CONTEXT ) + remaining.add(parent); + else + root = parent; + } + + parent.getRecalDatum().incrementNumObservations(cd.getNumObservations()); + parent.getRecalDatum().incrementNumMismatches(cd.getNumMismatches()); + parent.addSubnode(add); + } + + if ( root == null ) + throw new RuntimeException("root is unexpectedly null"); + + // set the fixed penalty everywhere in the tree, so that future modifications don't change the penalties + root.calcAndSetFixedPenalty(true); + + return root; + } + + /** + * Finds immediate subnode with contextToFind, or null if none exists + * + * @param tree whose subnodes should be searched + * @return + */ + public static RecalDatumNode findSubcontext(final String contextToFind, final RecalDatumNode tree) { + for ( final RecalDatumNode sub : tree.getSubnodes() ) + if ( sub.getRecalDatum().context.equals(contextToFind) ) + return sub; + return null; + } +} diff --git a/public/java/src/org/broadinstitute/sting/utils/recalibration/RecalDatumNode.java b/public/java/src/org/broadinstitute/sting/utils/recalibration/RecalDatumNode.java index 3af91be16..1409af7d0 100644 --- a/public/java/src/org/broadinstitute/sting/utils/recalibration/RecalDatumNode.java +++ b/public/java/src/org/broadinstitute/sting/utils/recalibration/RecalDatumNode.java @@ -4,10 +4,12 @@ import com.google.java.contract.Ensures; import com.google.java.contract.Requires; import org.apache.commons.math.stat.inference.ChiSquareTestImpl; import org.apache.log4j.Logger; -import org.broadinstitute.sting.utils.MathUtils; import org.broadinstitute.sting.utils.collections.Pair; +import org.broadinstitute.sting.utils.exceptions.ReviewedStingException; +import java.util.Collection; import java.util.HashSet; +import java.util.LinkedList; import java.util.Set; /** @@ -17,13 +19,18 @@ import java.util.Set; * @since 07/27/12 */ public class RecalDatumNode { - private final static boolean USE_CHI2 = true; protected static Logger logger = Logger.getLogger(RecalDatumNode.class); + + /** + * fixedPenalty is this value if it's considered fixed + */ private final static double UNINITIALIZED = Double.NEGATIVE_INFINITY; + private final T recalDatum; private double fixedPenalty = UNINITIALIZED; private final Set> subnodes; + @Requires({"recalDatum != null"}) public RecalDatumNode(final T recalDatum) { this(recalDatum, new HashSet>()); } @@ -33,28 +40,45 @@ public class RecalDatumNode { return recalDatum.toString(); } + @Requires({"recalDatum != null", "subnodes != null"}) public RecalDatumNode(final T recalDatum, final Set> subnodes) { this(recalDatum, UNINITIALIZED, subnodes); } + @Requires({"recalDatum != null"}) protected RecalDatumNode(final T recalDatum, final double fixedPenalty) { this(recalDatum, fixedPenalty, new HashSet>()); } + @Requires({"recalDatum != null", "subnodes != null"}) protected RecalDatumNode(final T recalDatum, final double fixedPenalty, final Set> subnodes) { this.recalDatum = recalDatum; this.fixedPenalty = fixedPenalty; this.subnodes = new HashSet>(subnodes); } + /** + * Get the recal data associated with this node + * @return + */ + @Ensures("result != null") public T getRecalDatum() { return recalDatum; } + /** + * The set of all subnodes of this tree. May be modified. + * @return + */ + @Ensures("result != null") public Set> getSubnodes() { return subnodes; } + /** + * Return the fixed penalty, if set, or else the the calculated penalty for this node + * @return + */ public double getPenalty() { if ( fixedPenalty != UNINITIALIZED ) return fixedPenalty; @@ -62,6 +86,17 @@ public class RecalDatumNode { return calcPenalty(); } + /** + * Set the fixed penalty for this node to a fresh calculation from calcPenalty + * + * This is important in the case where you want to compute the penalty from a full + * tree and then chop the tree up afterwards while considering the previous penalties. + * If you don't call this function then manipulating the tree may result in the + * penalty functions changing with changes in the tree. + * + * @param doEntireTree recurse into all subnodes? + * @return the fixed penalty for this node + */ public double calcAndSetFixedPenalty(final boolean doEntireTree) { fixedPenalty = calcPenalty(); if ( doEntireTree ) @@ -70,15 +105,41 @@ public class RecalDatumNode { return fixedPenalty; } + /** + * Add node to the set of subnodes of this node + * @param sub + */ + @Requires("sub != null") public void addSubnode(final RecalDatumNode sub) { subnodes.add(sub); } + /** + * Is this a leaf node (i.e., has no subnodes)? + * @return + */ public boolean isLeaf() { return subnodes.isEmpty(); } - public int getNumBranches() { + /** + * Is this node immediately above only leaf nodes? + * + * @return + */ + public boolean isAboveOnlyLeaves() { + for ( final RecalDatumNode sub : subnodes ) + if ( ! sub.isLeaf() ) + return false; + return true; + } + + /** + * What's the immediate number of subnodes from this node? + * @return + */ + @Ensures("result >= 0") + public int getNumSubnodes() { return subnodes.size(); } @@ -89,6 +150,8 @@ public class RecalDatumNode { * definition have 0 penalty unless they represent a pruned tree with underlying -- but now * pruned -- subtrees * + * TODO -- can we really just add together the chi2 values? + * * @return */ public double totalPenalty() { @@ -102,6 +165,10 @@ public class RecalDatumNode { } } + /** + * What's the longest branch from this node to any leaf? + * @return + */ public int maxDepth() { int subMax = 0; for ( final RecalDatumNode sub : subnodes ) @@ -109,6 +176,11 @@ public class RecalDatumNode { return subMax + 1; } + /** + * What's the shortest branch from this node to any leaf? Includes this node + * @return + */ + @Ensures("result > 0") public int minDepth() { if ( isLeaf() ) return 1; @@ -120,6 +192,11 @@ public class RecalDatumNode { } } + /** + * Return the number of nodes, including this one, reachable from this node + * @return + */ + @Ensures("result > 0") public int size() { int size = 1; for ( final RecalDatumNode sub : subnodes ) @@ -127,6 +204,12 @@ public class RecalDatumNode { return size; } + /** + * Count the number of leaf nodes reachable from this node + * + * @return + */ + @Ensures("result >= 0") public int numLeaves() { if ( isLeaf() ) return 1; @@ -138,44 +221,37 @@ public class RecalDatumNode { } } + /** + * Calculate the chi^2 penalty among subnodes of this node. The chi^2 value + * indicates the degree of independence of the implied error rates among the + * immediate subnodes + * + * @return the chi2 penalty, or 0.0 if it cannot be calculated + */ private double calcPenalty() { - if ( USE_CHI2 ) - return calcPenaltyChi2(); - else - return calcPenaltyLog10(getRecalDatum().getEmpiricalErrorRate()); - } - - private double calcPenaltyChi2() { if ( isLeaf() ) return 0.0; + else if ( subnodes.size() == 1 ) + // only one value, so its free to merge away + return 0.0; else { final long[][] counts = new long[subnodes.size()][2]; int i = 0; - for ( RecalDatumNode subnode : subnodes ) { - counts[i][0] = subnode.getRecalDatum().getNumMismatches(); - counts[i][1] = subnode.getRecalDatum().getNumObservations(); + for ( final RecalDatumNode subnode : subnodes ) { + // use the yates correction to help avoid all zeros => NaN + counts[i][0] = subnode.getRecalDatum().getNumMismatches() + 1; + counts[i][1] = subnode.getRecalDatum().getNumObservations() + 2; i++; } final double chi2 = new ChiSquareTestImpl().chiSquare(counts); -// StringBuilder x = new StringBuilder(); -// StringBuilder y = new StringBuilder(); -// for ( int k = 0; k < counts.length; k++) { -// if ( k != 0 ) { -// x.append(", "); -// y.append(", "); -// } -// x.append(counts[k][0]); -// y.append(counts[k][1]); -// } -// logger.info("x = c(" + x.toString() + ")"); -// logger.info("y = c(" + y.toString() + ")"); -// logger.info("chi2 = " + chi2); + // make sure things are reasonable and fail early if not + if (Double.isInfinite(chi2) || Double.isNaN(chi2)) + throw new ReviewedStingException("chi2 value is " + chi2 + " at " + getRecalDatum()); return chi2; - //return Math.log10(chi2); } } @@ -216,11 +292,17 @@ public class RecalDatumNode { } } + /** + * Return a freshly allocated tree prunes to have no more than maxDepth from the root to any leaf + * + * @param maxDepth + * @return + */ public RecalDatumNode pruneToDepth(final int maxDepth) { if ( maxDepth < 1 ) throw new IllegalArgumentException("maxDepth < 1"); else { - final Set> subPruned = new HashSet>(getNumBranches()); + final Set> subPruned = new HashSet>(getNumSubnodes()); if ( maxDepth > 1 ) for ( final RecalDatumNode sub : subnodes ) subPruned.add(sub.pruneToDepth(maxDepth - 1)); @@ -228,12 +310,21 @@ public class RecalDatumNode { } } + /** + * Return a freshly allocated tree with to no more than maxElements in order of penalty + * + * Note that nodes must have fixed penalties to this algorithm will fail. + * + * @param maxElements + * @return + */ public RecalDatumNode pruneByPenalty(final int maxElements) { RecalDatumNode root = this; while ( root.size() > maxElements ) { // remove the lowest penalty element, and continue root = root.removeLowestPenaltyNode(); + logger.debug("pruneByPenalty root size is now " + root.size() + " of max " + maxElements); } // our size is below the target, so we are good, return @@ -241,15 +332,15 @@ public class RecalDatumNode { } /** - * Find the lowest penalty node in the tree, and return a tree without it + * Find the lowest penalty above leaf node in the tree, and return a tree without it * * Note this excludes the current (root) node * * @return */ private RecalDatumNode removeLowestPenaltyNode() { - final Pair, Double> nodeToRemove = getMinPenaltyNode(); - logger.info("Removing " + nodeToRemove.getFirst() + " with penalty " + nodeToRemove.getSecond()); + final Pair, Double> nodeToRemove = getMinPenaltyAboveLeafNode(); + //logger.info("Removing " + nodeToRemove.getFirst() + " with penalty " + nodeToRemove.getSecond()); final Pair, Boolean> result = removeNode(nodeToRemove.getFirst()); @@ -262,20 +353,37 @@ public class RecalDatumNode { return oneRemoved; } - private Pair, Double> getMinPenaltyNode() { - final double myValue = isLeaf() ? Double.MAX_VALUE : getPenalty(); - Pair, Double> maxNode = new Pair, Double>(this, myValue); - - for ( final RecalDatumNode sub : subnodes ) { - final Pair, Double> subFind = sub.getMinPenaltyNode(); - if ( subFind.getSecond() < maxNode.getSecond() ) { - maxNode = subFind; + /** + * Finds in the tree the node with the lowest penalty whose subnodes are all leaves + * + * @return + */ + private Pair, Double> getMinPenaltyAboveLeafNode() { + if ( isLeaf() ) + // not allowed to remove leafs directly + return null; + if ( isAboveOnlyLeaves() ) + // we only consider removing nodes above all leaves + return new Pair, Double>(this, getPenalty()); + else { + // just recurse, taking the result with the min penalty of all subnodes + Pair, Double> minNode = null; + for ( final RecalDatumNode sub : subnodes ) { + final Pair, Double> subFind = sub.getMinPenaltyAboveLeafNode(); + if ( subFind != null && (minNode == null || subFind.getSecond() < minNode.getSecond()) ) { + minNode = subFind; + } } + return minNode; } - - return maxNode; } + /** + * Return a freshly allocated tree without the node nodeToRemove + * + * @param nodeToRemove + * @return + */ private Pair, Boolean> removeNode(final RecalDatumNode nodeToRemove) { if ( this == nodeToRemove ) { if ( isLeaf() ) @@ -288,7 +396,7 @@ public class RecalDatumNode { boolean removedSomething = false; // our sub nodes with the penalty node removed - final Set> sub = new HashSet>(getNumBranches()); + final Set> sub = new HashSet>(getNumSubnodes()); for ( final RecalDatumNode sub1 : subnodes ) { if ( removedSomething ) { @@ -306,4 +414,29 @@ public class RecalDatumNode { return new Pair, Boolean>(node, removedSomething); } } + + /** + * Return a collection of all of the data in the leaf nodes of this tree + * + * @return + */ + public Collection getAllLeaves() { + final LinkedList list = new LinkedList(); + getAllLeavesRec(list); + return list; + } + + /** + * Helpful recursive function for getAllLeaves() + * + * @param list the destination for the list of leaves + */ + private void getAllLeavesRec(final LinkedList list) { + if ( isLeaf() ) + list.add(getRecalDatum()); + else { + for ( final RecalDatumNode sub : subnodes ) + sub.getAllLeavesRec(list); + } + } } diff --git a/public/java/test/org/broadinstitute/sting/utils/recalibration/AdaptiveContextUnitTest.java b/public/java/test/org/broadinstitute/sting/utils/recalibration/AdaptiveContextUnitTest.java new file mode 100644 index 000000000..c07c084b8 --- /dev/null +++ b/public/java/test/org/broadinstitute/sting/utils/recalibration/AdaptiveContextUnitTest.java @@ -0,0 +1,64 @@ +package org.broadinstitute.sting.utils.recalibration; + +import org.broadinstitute.sting.BaseTest; +import org.testng.annotations.DataProvider; +import org.testng.annotations.Test; + +/** + * User: depristo + * Date: 8/3/12 + * Time: 12:26 PM + * To change this template use File | Settings | File Templates. + */ +public class AdaptiveContextUnitTest { + // TODO + // TODO actually need unit tests when we have validated the value of this approach + // TODO particularly before we attempt to optimize the algorithm + // TODO + + // -------------------------------------------------------------------------------- + // + // Provider + // + // -------------------------------------------------------------------------------- + + private class AdaptiveContextTestProvider extends BaseTest.TestDataProvider { + final RecalDatumNode pruned; + final RecalDatumNode full; + + private AdaptiveContextTestProvider(Class c, RecalDatumNode pruned, RecalDatumNode full) { + super(AdaptiveContextTestProvider.class); + this.pruned = pruned; + this.full = full; + } + } + + private RecalDatumNode makeTree(final String context, final int N, final int M, + final RecalDatumNode ... sub) { + final ContextDatum contextDatum = new ContextDatum(context, N, M); + final RecalDatumNode node = new RecalDatumNode(contextDatum); + for ( final RecalDatumNode sub1 : sub ) { + node.addSubnode(sub1); + } + return node; + } + + @DataProvider(name = "AdaptiveContextTestProvider") + public Object[][] makeRecalDatumTestProvider() { +// final RecalDatumNode prune1 = +// makeTree("A", 10, 1, +// makeTree("AA", 11, 2), +// makeTree("AC", 12, 3), +// makeTree("AG", 13, 4), +// makeTree("AT", 14, 5)); +// +// new AdaptiveContextTestProvider(pruned, full); + + return AdaptiveContextTestProvider.getTests(AdaptiveContextTestProvider.class); + } + + @Test(dataProvider = "AdaptiveContextTestProvider") + public void testAdaptiveContextFill(AdaptiveContextTestProvider cfg) { + + } +} diff --git a/public/java/test/org/broadinstitute/sting/utils/recalibration/QualQuantizerUnitTest.java b/public/java/test/org/broadinstitute/sting/utils/recalibration/QualQuantizerUnitTest.java new file mode 100644 index 000000000..0ff2eaf03 --- /dev/null +++ b/public/java/test/org/broadinstitute/sting/utils/recalibration/QualQuantizerUnitTest.java @@ -0,0 +1,169 @@ +/* + * Copyright (c) 2012, The Broad Institute + * + * Permission is hereby granted, free of charge, to any person + * obtaining a copy of this software and associated documentation + * files (the "Software"), to deal in the Software without + * restriction, including without limitation the rights to use, + * copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the + * Software is furnished to do so, subject to the following + * conditions: + * + * The above copyright notice and this permission notice shall be + * included in all copies or substantial portions of the Software. + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES + * OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND + * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT + * HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, + * WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR + * OTHER DEALINGS IN THE SOFTWARE. + */ + +// our package +package org.broadinstitute.sting.utils.recalibration; + + +// the imports for unit testing. + + +import org.broadinstitute.sting.BaseTest; +import org.broadinstitute.sting.utils.QualityUtils; +import org.broadinstitute.sting.utils.Utils; +import org.testng.Assert; +import org.testng.annotations.BeforeSuite; +import org.testng.annotations.DataProvider; +import org.testng.annotations.Test; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + + +public class QualQuantizerUnitTest extends BaseTest { + @BeforeSuite + public void before() { + + } + + // -------------------------------------------------------------------------------- + // + // merge case Provider + // + // -------------------------------------------------------------------------------- + + private class QualIntervalTestProvider extends TestDataProvider { + final QualQuantizer.QualInterval left, right; + int exError, exTotal, exQual; + double exErrorRate; + + private QualIntervalTestProvider(int leftE, int leftN, int rightE, int rightN, int exError, int exTotal) { + super(QualIntervalTestProvider.class); + + QualQuantizer qq = new QualQuantizer(0); + left = qq.new QualInterval(10, 10, leftN, leftE, 0); + right = qq.new QualInterval(11, 11, rightN, rightE, 0); + + this.exError = exError; + this.exTotal = exTotal; + this.exErrorRate = (leftE + rightE + 1) / (1.0 * (leftN + rightN + 1)); + this.exQual = QualityUtils.probToQual(1-this.exErrorRate, 0); + } + } + + @DataProvider(name = "QualIntervalTestProvider") + public Object[][] makeQualIntervalTestProvider() { + new QualIntervalTestProvider(10, 100, 10, 1000, 20, 1100); + new QualIntervalTestProvider(0, 100, 10, 900, 10, 1000); + new QualIntervalTestProvider(10, 900, 0, 100, 10, 1000); + new QualIntervalTestProvider(0, 0, 10, 100, 10, 100); + new QualIntervalTestProvider(1, 10, 9, 90, 10, 100); + new QualIntervalTestProvider(1, 10, 9, 100000, 10, 100010); + new QualIntervalTestProvider(1, 10, 9, 1000000, 10,1000010); + + return QualIntervalTestProvider.getTests(QualIntervalTestProvider.class); + } + + @Test(dataProvider = "QualIntervalTestProvider") + public void testQualInterval(QualIntervalTestProvider cfg) { + QualQuantizer.QualInterval merged = cfg.left.merge(cfg.right); + Assert.assertEquals(merged.nErrors, cfg.exError); + Assert.assertEquals(merged.nObservations, cfg.exTotal); + Assert.assertEquals(merged.getErrorRate(), cfg.exErrorRate); + Assert.assertEquals(merged.getQual(), cfg.exQual); + } + + @Test + public void testMinInterestingQual() { + for ( int q = 0; q < 15; q++ ) { + for ( int minQual = 0; minQual <= 10; minQual ++ ) { + QualQuantizer qq = new QualQuantizer(minQual); + QualQuantizer.QualInterval left = qq.new QualInterval(q, q, 100, 10, 0); + QualQuantizer.QualInterval right = qq.new QualInterval(q+1, q+1, 1000, 100, 0); + + QualQuantizer.QualInterval merged = left.merge(right); + boolean shouldBeFree = q+1 <= minQual; + if ( shouldBeFree ) + Assert.assertEquals(merged.getPenalty(), 0.0); + else + Assert.assertTrue(merged.getPenalty() > 0.0); + } + } + } + + + // -------------------------------------------------------------------------------- + // + // High-level case Provider + // + // -------------------------------------------------------------------------------- + + private class QuantizerTestProvider extends TestDataProvider { + final List nObservationsPerQual = new ArrayList(); + final int nLevels; + final List expectedMap; + + private QuantizerTestProvider(final List nObservationsPerQual, final int nLevels, final List expectedMap) { + super(QuantizerTestProvider.class); + + for ( int x : nObservationsPerQual ) + this.nObservationsPerQual.add((long)x); + this.nLevels = nLevels; + this.expectedMap = expectedMap; + } + + @Override + public String toString() { + return String.format("QQTest nLevels=%d nObs=[%s] map=[%s]", + nLevels, Utils.join(",", nObservationsPerQual), Utils.join(",", expectedMap)); + } + } + + @DataProvider(name = "QuantizerTestProvider") + public Object[][] makeQuantizerTestProvider() { + List allQ2 = Arrays.asList(0, 0, 1000, 0, 0); + + new QuantizerTestProvider(allQ2, 5, Arrays.asList(0, 1, 2, 3, 4)); + new QuantizerTestProvider(allQ2, 1, Arrays.asList(2, 2, 2, 2, 2)); + + new QuantizerTestProvider(Arrays.asList(0, 0, 1000, 0, 1000), 2, Arrays.asList(2, 2, 2, 2, 4)); + new QuantizerTestProvider(Arrays.asList(0, 0, 1000, 1, 1000), 2, Arrays.asList(2, 2, 2, 4, 4)); + new QuantizerTestProvider(Arrays.asList(0, 0, 1000, 10, 1000), 2, Arrays.asList(2, 2, 2, 2, 4)); + + return QuantizerTestProvider.getTests(QuantizerTestProvider.class); + } + + @Test(dataProvider = "QuantizerTestProvider", enabled = true) + public void testQuantizer(QuantizerTestProvider cfg) { + QualQuantizer qq = new QualQuantizer(cfg.nObservationsPerQual, cfg.nLevels, 0); + logger.warn("cfg: " + cfg); + for ( int i = 0; i < cfg.expectedMap.size(); i++) { + int expected = cfg.expectedMap.get(i); + int observed = qq.originalToQuantizedMap.get(i); + //logger.warn(String.format(" qq map: %s : %d => %d", i, expected, observed)); + Assert.assertEquals(observed, expected); + } + } +} \ No newline at end of file diff --git a/public/java/test/org/broadinstitute/sting/utils/recalibration/RecalDatumUnitTest.java b/public/java/test/org/broadinstitute/sting/utils/recalibration/RecalDatumUnitTest.java new file mode 100644 index 000000000..33985e0ac --- /dev/null +++ b/public/java/test/org/broadinstitute/sting/utils/recalibration/RecalDatumUnitTest.java @@ -0,0 +1,154 @@ +/* + * Copyright (c) 2012, The Broad Institute + * + * Permission is hereby granted, free of charge, to any person + * obtaining a copy of this software and associated documentation + * files (the "Software"), to deal in the Software without + * restriction, including without limitation the rights to use, + * copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the + * Software is furnished to do so, subject to the following + * conditions: + * + * The above copyright notice and this permission notice shall be + * included in all copies or substantial portions of the Software. + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES + * OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND + * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT + * HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, + * WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR + * OTHER DEALINGS IN THE SOFTWARE. + */ + +// our package +package org.broadinstitute.sting.utils.recalibration; + + +// the imports for unit testing. + + +import org.broadinstitute.sting.BaseTest; +import org.broadinstitute.sting.utils.QualityUtils; +import org.broadinstitute.sting.utils.Utils; +import org.testng.Assert; +import org.testng.annotations.BeforeSuite; +import org.testng.annotations.DataProvider; +import org.testng.annotations.Test; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + + +public class RecalDatumUnitTest extends BaseTest { + + // -------------------------------------------------------------------------------- + // + // merge case Provider + // + // -------------------------------------------------------------------------------- + + private class RecalDatumTestProvider extends TestDataProvider { + int exError, exTotal, reportedQual; + + private RecalDatumTestProvider(int E, int N, int reportedQual) { + super(RecalDatumTestProvider.class); + + this.exError = E; + this.exTotal = N; + this.reportedQual = reportedQual; + } + + public double getErrorRate() { + return (exError + 1) / (1.0 * (exTotal + 2)); + } + + public double getErrorRatePhredScaled() { + return QualityUtils.phredScaleErrorRate(getErrorRate()); + } + + public int getReportedQual() { + return reportedQual; + } + + public RecalDatum makeRecalDatum() { + return new RecalDatum(exTotal, exError, (byte)getReportedQual()); + } + + @Override + public String toString() { + return String.format("exError=%d, exTotal=%d, reportedQual=%d", exError, exTotal, reportedQual); + } + } + + @DataProvider(name = "RecalDatumTestProvider") + public Object[][] makeRecalDatumTestProvider() { + for ( int E : Arrays.asList(1, 10, 100, 1000, 10000) ) + for ( int N : Arrays.asList(10, 100, 1000, 10000, 100000, 1000000) ) + for ( int reportedQual : Arrays.asList(10, 20) ) + if ( E <= N ) + new RecalDatumTestProvider(E, N, reportedQual); + return RecalDatumTestProvider.getTests(RecalDatumTestProvider.class); + } + + @Test(dataProvider = "RecalDatumTestProvider") + public void testRecalDatumBasics(RecalDatumTestProvider cfg) { + final RecalDatum datum = cfg.makeRecalDatum(); + assertBasicFeaturesOfRecalDatum(datum, cfg); + } + + private static void assertBasicFeaturesOfRecalDatum(final RecalDatum datum, final RecalDatumTestProvider cfg) { + Assert.assertEquals(datum.getNumMismatches(), cfg.exError); + Assert.assertEquals(datum.getNumObservations(), cfg.exTotal); + if ( cfg.getReportedQual() != -1 ) + Assert.assertEquals(datum.getEstimatedQReportedAsByte(), cfg.getReportedQual()); + BaseTest.assertEqualsDoubleSmart(datum.getEmpiricalQuality(), cfg.getErrorRatePhredScaled()); + BaseTest.assertEqualsDoubleSmart(datum.getEmpiricalErrorRate(), cfg.getErrorRate()); + } + + @Test(dataProvider = "RecalDatumTestProvider") + public void testRecalDatumCopyAndCombine(RecalDatumTestProvider cfg) { + final RecalDatum datum = cfg.makeRecalDatum(); + final RecalDatum copy = new RecalDatum(datum); + assertBasicFeaturesOfRecalDatum(copy, cfg); + + RecalDatumTestProvider combinedCfg = new RecalDatumTestProvider(cfg.exError * 2, cfg.exTotal * 2, cfg.reportedQual); + copy.combine(datum); + assertBasicFeaturesOfRecalDatum(copy, combinedCfg); + } + + @Test(dataProvider = "RecalDatumTestProvider") + public void testRecalDatumModification(RecalDatumTestProvider cfg) { + RecalDatum datum = cfg.makeRecalDatum(); + datum.setEmpiricalQuality(10.1); + Assert.assertEquals(datum.getEmpiricalQuality(), 10.1); + + datum.setEstimatedQReported(10.1); + Assert.assertEquals(datum.getEstimatedQReported(), 10.1); + Assert.assertEquals(datum.getEstimatedQReportedAsByte(), 10); + + datum = cfg.makeRecalDatum(); + cfg.exTotal = 100000; + datum.setNumObservations(cfg.exTotal); + assertBasicFeaturesOfRecalDatum(datum, cfg); + + datum = cfg.makeRecalDatum(); + cfg.exError = 1000; + datum.setNumMismatches(cfg.exError); + assertBasicFeaturesOfRecalDatum(datum, cfg); + + datum = cfg.makeRecalDatum(); + datum.increment(true); + cfg.exError++; + cfg.exTotal++; + assertBasicFeaturesOfRecalDatum(datum, cfg); + + datum = cfg.makeRecalDatum(); + datum.increment(10, 5); + cfg.exError += 5; + cfg.exTotal += 10; + assertBasicFeaturesOfRecalDatum(datum, cfg); + } +} \ No newline at end of file