Serialized GMM no longer depends on command line annotation order (#1632)
Order annotations by the order in the model report.
This commit is contained in:
parent
09b4cf70f6
commit
a2f45944f3
|
|
@ -110,7 +110,15 @@ public class VariantDataManager {
|
|||
return data;
|
||||
}
|
||||
|
||||
public void normalizeData(final boolean calculateMeans) {
|
||||
/**
|
||||
* Normalize annotations to mean 0 and standard deviation 1.
|
||||
* Order the variant annotations by the provided list {@code theOrder} or standard deviation.
|
||||
*
|
||||
* @param calculateMeans Boolean indicating whether or not to calculate the means
|
||||
* @param theOrder a list of integers specifying the desired annotation order. If this is null
|
||||
* annotations will get sorted in decreasing size of their standard deviations.
|
||||
*/
|
||||
public void normalizeData(final boolean calculateMeans, List<Integer> theOrder) {
|
||||
boolean foundZeroVarianceAnnotation = false;
|
||||
for( int iii = 0; iii < meanVector.length; iii++ ) {
|
||||
final double theMean, theSTD;
|
||||
|
|
@ -150,7 +158,10 @@ public class VariantDataManager {
|
|||
|
||||
// re-order the data by increasing standard deviation so that the results don't depend on the order things were specified on the command line
|
||||
// standard deviation over the training points is used as a simple proxy for information content, perhaps there is a better thing to use here
|
||||
final List<Integer> theOrder = calculateSortOrder(meanVector);
|
||||
// or use the serialized report's annotation order via the argument theOrder
|
||||
if (theOrder == null){
|
||||
theOrder = calculateSortOrder(meanVector);
|
||||
}
|
||||
annotationKeys = reorderList(annotationKeys, theOrder);
|
||||
varianceVector = ArrayUtils.toPrimitive(reorderArray(ArrayUtils.toObject(varianceVector), theOrder));
|
||||
meanVector = ArrayUtils.toPrimitive(reorderArray(ArrayUtils.toObject(meanVector), theOrder));
|
||||
|
|
@ -158,7 +169,8 @@ public class VariantDataManager {
|
|||
datum.annotations = ArrayUtils.toPrimitive(reorderArray(ArrayUtils.toObject(datum.annotations), theOrder));
|
||||
datum.isNull = ArrayUtils.toPrimitive(reorderArray(ArrayUtils.toObject(datum.isNull), theOrder));
|
||||
}
|
||||
logger.info("Annotations are now ordered by their information content: " + annotationKeys.toString());
|
||||
logger.info("Annotation order is: " + annotationKeys.toString());
|
||||
|
||||
}
|
||||
|
||||
public double[] getMeanVector() {
|
||||
|
|
|
|||
|
|
@ -51,6 +51,7 @@
|
|||
|
||||
package org.broadinstitute.gatk.tools.walkers.variantrecalibration;
|
||||
|
||||
import com.google.common.annotations.VisibleForTesting;
|
||||
import htsjdk.variant.variantcontext.Allele;
|
||||
import org.broadinstitute.gatk.utils.commandline.*;
|
||||
import org.broadinstitute.gatk.engine.CommandLineGATK;
|
||||
|
|
@ -312,6 +313,9 @@ public class VariantRecalibrator extends RodWalker<ExpandingArrayList<VariantDat
|
|||
@Argument(fullName = "trustAllPolymorphic", shortName = "allPoly", doc = "Trust that all the input training sets' unfiltered records contain only polymorphic sites to drastically speed up the computation.", required = false)
|
||||
protected Boolean TRUST_ALL_POLYMORPHIC = false;
|
||||
|
||||
@VisibleForTesting
|
||||
protected List<Integer> annotationOrder = null;
|
||||
|
||||
/////////////////////////////
|
||||
// Private Member Variables
|
||||
/////////////////////////////
|
||||
|
|
@ -372,18 +376,15 @@ public class VariantRecalibrator extends RodWalker<ExpandingArrayList<VariantDat
|
|||
final GATKReportTable pPMixTable = reportIn.getTable("GoodGaussianPMix");
|
||||
final GATKReportTable anMeansTable = reportIn.getTable("AnnotationMeans");
|
||||
final GATKReportTable anStDevsTable = reportIn.getTable("AnnotationStdevs");
|
||||
final int numAnnotations = dataManager.annotationKeys.size();
|
||||
|
||||
if( numAnnotations != pmmTable.getNumColumns()-1 || numAnnotations != nmmTable.getNumColumns()-1 ) { // -1 because the first column is the gaussian number.
|
||||
throw new UserException.CommandLineException( "Annotations specified on the command line do not match annotations in the model report." );
|
||||
}
|
||||
orderAndValidateAnnotations(anMeansTable, dataManager.annotationKeys);
|
||||
|
||||
final Map<String, Double> anMeans = getMapFromVectorTable(anMeansTable);
|
||||
final Map<String, Double> anStdDevs = getMapFromVectorTable(anStDevsTable);
|
||||
dataManager.setNormalization(anMeans, anStdDevs);
|
||||
|
||||
goodModel = GMMFromTables(pmmTable, pmcTable, pPMixTable, numAnnotations);
|
||||
badModel = GMMFromTables(nmmTable, nmcTable, nPMixTable, numAnnotations);
|
||||
goodModel = GMMFromTables(pmmTable, pmcTable, pPMixTable, annotationOrder.size());
|
||||
badModel = GMMFromTables(nmmTable, nmcTable, nPMixTable, annotationOrder.size());
|
||||
}
|
||||
|
||||
final Set<VCFHeaderLine> hInfo = new HashSet<>();
|
||||
|
|
@ -401,6 +402,32 @@ public class VariantRecalibrator extends RodWalker<ExpandingArrayList<VariantDat
|
|||
|
||||
}
|
||||
|
||||
/**
|
||||
* Order and validate annotations according to the annotations in the serialized model
|
||||
* Annotations on the command line must be the same as those in the model report or this will throw an exception.
|
||||
* Sets the {@code annotationOrder} list to map from command line order to the model report's order.
|
||||
* n^2 because we typically use 7 or less annotations.
|
||||
* @param annotationTable GATKReportTable of annotations read from the serialized model file
|
||||
*/
|
||||
protected void orderAndValidateAnnotations(final GATKReportTable annotationTable, final List<String> annotationKeys){
|
||||
annotationOrder = new ArrayList<Integer>(annotationKeys.size());
|
||||
|
||||
for (int i = 0; i < annotationTable.getNumRows(); i++){
|
||||
String serialAnno = (String)annotationTable.get(i, "Annotation");
|
||||
for (int j = 0; j < annotationKeys.size(); j++) {
|
||||
if (serialAnno.equals( annotationKeys.get(j) )){
|
||||
annotationOrder.add(j);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if(annotationOrder.size() != annotationTable.getNumRows() || annotationOrder.size() != annotationKeys.size()) {
|
||||
final String errorMsg = "Annotations specified on the command line:"+annotationKeys.toString() +" do not match annotations in the model report:"+inputModel;
|
||||
throw new UserException.CommandLineException(errorMsg);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
|
||||
//---------------------------------------------------------------------------------------------------------------
|
||||
//
|
||||
|
|
@ -518,7 +545,7 @@ public class VariantRecalibrator extends RodWalker<ExpandingArrayList<VariantDat
|
|||
for (int i = 1; i <= max_attempts; i++) {
|
||||
try {
|
||||
dataManager.setData(reduceSum);
|
||||
dataManager.normalizeData(inputModel.isEmpty()); // Each data point is now (x - mean) / standard deviation
|
||||
dataManager.normalizeData(inputModel.isEmpty(), annotationOrder); // Each data point is now (x - mean) / standard deviation
|
||||
|
||||
final List<VariantDatum> positiveTrainingData = dataManager.getTrainingData();
|
||||
final List<VariantDatum> negativeTrainingData;
|
||||
|
|
|
|||
|
|
@ -51,6 +51,8 @@
|
|||
|
||||
package org.broadinstitute.gatk.tools.walkers.variantrecalibration;
|
||||
|
||||
import org.broadinstitute.gatk.utils.exceptions.UserException;
|
||||
import org.broadinstitute.gatk.utils.exceptions.UserException.CommandLineException;
|
||||
import org.broadinstitute.gatk.utils.variant.VCIterable;
|
||||
import org.broadinstitute.gatk.engine.walkers.WalkerTest;
|
||||
import htsjdk.variant.variantcontext.VariantContext;
|
||||
|
|
@ -60,6 +62,7 @@ import org.testng.annotations.DataProvider;
|
|||
import org.testng.annotations.Test;
|
||||
|
||||
import java.io.File;
|
||||
import java.io.IOException;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
|
||||
|
|
@ -390,5 +393,73 @@ public class VariantRecalibrationWalkersIntegrationTest extends WalkerTest {
|
|||
new File(outputFile.getAbsolutePath() + ".pdf").deleteOnExit();
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testVQSRAnnotationOrder() throws IOException {
|
||||
final String inputFile = privateTestDir + "oneSNP.vcf";
|
||||
final String exacModelReportFilename = privateTestDir + "subsetExAC.snps_model.report";
|
||||
final String annoOrderRecal = privateTestDir + "anno_order.recal";
|
||||
final String annoOrderTranches = privateTestDir + "anno_order.tranches";
|
||||
final String goodMd5 = "d41d8cd98f00b204e9800998ecf8427e";
|
||||
final String base = "-R " + b37KGReference +
|
||||
" -T VariantRecalibrator" +
|
||||
" -input " + inputFile +
|
||||
" -L 1:110201699" +
|
||||
" -resource:truth=true,training=true,prior=15.0 " + inputFile +
|
||||
" -an FS -an ReadPosRankSum -an MQ -an MQRankSum -an QD -an SOR"+
|
||||
" --recal_file " + annoOrderRecal +
|
||||
" -tranchesFile " + annoOrderTranches +
|
||||
" --input_model " + exacModelReportFilename +
|
||||
" -ignoreAllFilters -mode SNP" +
|
||||
" --no_cmdline_in_header" ;
|
||||
|
||||
final WalkerTestSpec spec = new WalkerTestSpec(base, 1, Arrays.asList(goodMd5));
|
||||
spec.disableShadowBCF(); // TODO -- enable when we support symbolic alleles
|
||||
|
||||
List<File> outputFiles = executeTest("testVQSRAnnotationOrder", spec).getFirst();
|
||||
setPDFsForDeletion(outputFiles);
|
||||
|
||||
|
||||
final String base2 = "-R " + b37KGReference +
|
||||
" -T VariantRecalibrator" +
|
||||
" -input " + inputFile +
|
||||
" -L 1:110201699" +
|
||||
" -resource:truth=true,training=true,prior=15.0 " + inputFile +
|
||||
" -an ReadPosRankSum -an MQ -an MQRankSum -an QD -an SOR -an FS "+
|
||||
" --recal_file " + annoOrderRecal +
|
||||
" -tranchesFile " + annoOrderTranches +
|
||||
" --input_model " + exacModelReportFilename +
|
||||
" -ignoreAllFilters -mode SNP" +
|
||||
" --no_cmdline_in_header" ;
|
||||
|
||||
final WalkerTestSpec spec2 = new WalkerTestSpec(base2, 1, Arrays.asList(goodMd5));
|
||||
spec2.disableShadowBCF(); // TODO -- enable when we support symbolic alleles
|
||||
outputFiles = executeTest("testVQSRAnnotationOrder2", spec2).getFirst();
|
||||
setPDFsForDeletion(outputFiles);
|
||||
}
|
||||
|
||||
@Test(expectedExceptions={RuntimeException.class, CommandLineException.class})
|
||||
public void testVQSRAnnotationMismatch() throws IOException {
|
||||
final String inputFile = privateTestDir + "oneSNP.vcf";
|
||||
final String exacModelReportFilename = privateTestDir + "subsetExAC.snps_model.report";
|
||||
final String annoOrderRecal = privateTestDir + "anno_order.recal";
|
||||
final String annoOrderTranches = privateTestDir + "anno_order.tranches";
|
||||
final String goodMd5 = "d41d8cd98f00b204e9800998ecf8427e";
|
||||
final String base = "-R " + b37KGReference +
|
||||
" -T VariantRecalibrator" +
|
||||
" -input " + inputFile +
|
||||
" -L 1:110201699" +
|
||||
" -resource:truth=true,training=true,prior=15.0 " + inputFile +
|
||||
" -an FS -an ReadPosRankSum -an MQ -an MQRankSum -an QD -an SOR -an BaseQRankSum"+
|
||||
" --recal_file " + annoOrderRecal +
|
||||
" -tranchesFile " + annoOrderTranches +
|
||||
" --input_model " + exacModelReportFilename +
|
||||
" -ignoreAllFilters -mode SNP" +
|
||||
" --no_cmdline_in_header" ;
|
||||
|
||||
final WalkerTestSpec spec = new WalkerTestSpec(base, 1, Arrays.asList(goodMd5));
|
||||
spec.disableShadowBCF(); // TODO -- enable when we support symbolic alleles
|
||||
executeTest("testVQSRAnnotationMismatch", spec).getFirst();
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -276,4 +276,52 @@ public class VariantRecalibratorModelOutputUnitTest extends BaseTest {
|
|||
return new GaussianMixtureModel(badGaussianList, shrinkage, dirichlet, priorCounts);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testAnnotationOrderAndValidate() {
|
||||
final VariantRecalibrator vqsr = new VariantRecalibrator();
|
||||
final List<String> annotationList = new ArrayList<>();
|
||||
annotationList.add("QD");
|
||||
annotationList.add("FS");
|
||||
annotationList.add("ReadPosRankSum");
|
||||
annotationList.add("MQ");
|
||||
annotationList.add("MQRankSum");
|
||||
annotationList.add("SOR");
|
||||
|
||||
double[] meanVector = {16.13, 2.45, 0.37, 59.08, 0.14, 0.91};
|
||||
final String columnName = "Mean";
|
||||
final String formatString = "%.3f";
|
||||
GATKReportTable annotationTable = vqsr.makeVectorTable("AnnotationMeans", "Mean for each annotation, used to normalize data", annotationList, meanVector, columnName, formatString);
|
||||
vqsr.orderAndValidateAnnotations(annotationTable, annotationList);
|
||||
|
||||
for (int i = 0; i < vqsr.annotationOrder.size(); i++){
|
||||
Assert.assertEquals(i, (int)vqsr.annotationOrder.get(i));
|
||||
}
|
||||
|
||||
annotationList.remove(0);
|
||||
annotationList.add("QD");
|
||||
vqsr.orderAndValidateAnnotations(annotationTable, annotationList);
|
||||
for (int i = 0; i < vqsr.annotationOrder.size(); i++) {
|
||||
if (i == 0) {
|
||||
Assert.assertEquals(annotationList.size()-1, (int)vqsr.annotationOrder.get(i));
|
||||
} else {
|
||||
Assert.assertEquals(i - 1, (int)vqsr.annotationOrder.get(i));
|
||||
}
|
||||
}
|
||||
|
||||
final List<String> annotationList2 = new ArrayList<>();
|
||||
annotationList2.add("ReadPosRankSum");
|
||||
annotationList2.add("MQRankSum");
|
||||
annotationList2.add("MQ");
|
||||
annotationList2.add("SOR");
|
||||
annotationList2.add("QD");
|
||||
annotationList2.add("FS");
|
||||
|
||||
final VariantRecalibrator vqsr2 = new VariantRecalibrator();
|
||||
vqsr2.orderAndValidateAnnotations(annotationTable, annotationList2);
|
||||
for (int i = 0; i < vqsr2.annotationOrder.size(); i++){
|
||||
Assert.assertEquals(annotationList.get(vqsr.annotationOrder.get(i)), annotationList2.get(vqsr2.annotationOrder.get(i)));
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
Loading…
Reference in New Issue