Serialized GMM no longer depends on command line annotation order (#1632)
Order annotations by the order in the model report.
This commit is contained in:
parent
09b4cf70f6
commit
a2f45944f3
|
|
@ -110,7 +110,15 @@ public class VariantDataManager {
|
||||||
return data;
|
return data;
|
||||||
}
|
}
|
||||||
|
|
||||||
public void normalizeData(final boolean calculateMeans) {
|
/**
|
||||||
|
* Normalize annotations to mean 0 and standard deviation 1.
|
||||||
|
* Order the variant annotations by the provided list {@code theOrder} or standard deviation.
|
||||||
|
*
|
||||||
|
* @param calculateMeans Boolean indicating whether or not to calculate the means
|
||||||
|
* @param theOrder a list of integers specifying the desired annotation order. If this is null
|
||||||
|
* annotations will get sorted in decreasing size of their standard deviations.
|
||||||
|
*/
|
||||||
|
public void normalizeData(final boolean calculateMeans, List<Integer> theOrder) {
|
||||||
boolean foundZeroVarianceAnnotation = false;
|
boolean foundZeroVarianceAnnotation = false;
|
||||||
for( int iii = 0; iii < meanVector.length; iii++ ) {
|
for( int iii = 0; iii < meanVector.length; iii++ ) {
|
||||||
final double theMean, theSTD;
|
final double theMean, theSTD;
|
||||||
|
|
@ -150,7 +158,10 @@ public class VariantDataManager {
|
||||||
|
|
||||||
// re-order the data by increasing standard deviation so that the results don't depend on the order things were specified on the command line
|
// re-order the data by increasing standard deviation so that the results don't depend on the order things were specified on the command line
|
||||||
// standard deviation over the training points is used as a simple proxy for information content, perhaps there is a better thing to use here
|
// standard deviation over the training points is used as a simple proxy for information content, perhaps there is a better thing to use here
|
||||||
final List<Integer> theOrder = calculateSortOrder(meanVector);
|
// or use the serialized report's annotation order via the argument theOrder
|
||||||
|
if (theOrder == null){
|
||||||
|
theOrder = calculateSortOrder(meanVector);
|
||||||
|
}
|
||||||
annotationKeys = reorderList(annotationKeys, theOrder);
|
annotationKeys = reorderList(annotationKeys, theOrder);
|
||||||
varianceVector = ArrayUtils.toPrimitive(reorderArray(ArrayUtils.toObject(varianceVector), theOrder));
|
varianceVector = ArrayUtils.toPrimitive(reorderArray(ArrayUtils.toObject(varianceVector), theOrder));
|
||||||
meanVector = ArrayUtils.toPrimitive(reorderArray(ArrayUtils.toObject(meanVector), theOrder));
|
meanVector = ArrayUtils.toPrimitive(reorderArray(ArrayUtils.toObject(meanVector), theOrder));
|
||||||
|
|
@ -158,7 +169,8 @@ public class VariantDataManager {
|
||||||
datum.annotations = ArrayUtils.toPrimitive(reorderArray(ArrayUtils.toObject(datum.annotations), theOrder));
|
datum.annotations = ArrayUtils.toPrimitive(reorderArray(ArrayUtils.toObject(datum.annotations), theOrder));
|
||||||
datum.isNull = ArrayUtils.toPrimitive(reorderArray(ArrayUtils.toObject(datum.isNull), theOrder));
|
datum.isNull = ArrayUtils.toPrimitive(reorderArray(ArrayUtils.toObject(datum.isNull), theOrder));
|
||||||
}
|
}
|
||||||
logger.info("Annotations are now ordered by their information content: " + annotationKeys.toString());
|
logger.info("Annotation order is: " + annotationKeys.toString());
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public double[] getMeanVector() {
|
public double[] getMeanVector() {
|
||||||
|
|
|
||||||
|
|
@ -51,6 +51,7 @@
|
||||||
|
|
||||||
package org.broadinstitute.gatk.tools.walkers.variantrecalibration;
|
package org.broadinstitute.gatk.tools.walkers.variantrecalibration;
|
||||||
|
|
||||||
|
import com.google.common.annotations.VisibleForTesting;
|
||||||
import htsjdk.variant.variantcontext.Allele;
|
import htsjdk.variant.variantcontext.Allele;
|
||||||
import org.broadinstitute.gatk.utils.commandline.*;
|
import org.broadinstitute.gatk.utils.commandline.*;
|
||||||
import org.broadinstitute.gatk.engine.CommandLineGATK;
|
import org.broadinstitute.gatk.engine.CommandLineGATK;
|
||||||
|
|
@ -312,6 +313,9 @@ public class VariantRecalibrator extends RodWalker<ExpandingArrayList<VariantDat
|
||||||
@Argument(fullName = "trustAllPolymorphic", shortName = "allPoly", doc = "Trust that all the input training sets' unfiltered records contain only polymorphic sites to drastically speed up the computation.", required = false)
|
@Argument(fullName = "trustAllPolymorphic", shortName = "allPoly", doc = "Trust that all the input training sets' unfiltered records contain only polymorphic sites to drastically speed up the computation.", required = false)
|
||||||
protected Boolean TRUST_ALL_POLYMORPHIC = false;
|
protected Boolean TRUST_ALL_POLYMORPHIC = false;
|
||||||
|
|
||||||
|
@VisibleForTesting
|
||||||
|
protected List<Integer> annotationOrder = null;
|
||||||
|
|
||||||
/////////////////////////////
|
/////////////////////////////
|
||||||
// Private Member Variables
|
// Private Member Variables
|
||||||
/////////////////////////////
|
/////////////////////////////
|
||||||
|
|
@ -372,18 +376,15 @@ public class VariantRecalibrator extends RodWalker<ExpandingArrayList<VariantDat
|
||||||
final GATKReportTable pPMixTable = reportIn.getTable("GoodGaussianPMix");
|
final GATKReportTable pPMixTable = reportIn.getTable("GoodGaussianPMix");
|
||||||
final GATKReportTable anMeansTable = reportIn.getTable("AnnotationMeans");
|
final GATKReportTable anMeansTable = reportIn.getTable("AnnotationMeans");
|
||||||
final GATKReportTable anStDevsTable = reportIn.getTable("AnnotationStdevs");
|
final GATKReportTable anStDevsTable = reportIn.getTable("AnnotationStdevs");
|
||||||
final int numAnnotations = dataManager.annotationKeys.size();
|
|
||||||
|
|
||||||
if( numAnnotations != pmmTable.getNumColumns()-1 || numAnnotations != nmmTable.getNumColumns()-1 ) { // -1 because the first column is the gaussian number.
|
orderAndValidateAnnotations(anMeansTable, dataManager.annotationKeys);
|
||||||
throw new UserException.CommandLineException( "Annotations specified on the command line do not match annotations in the model report." );
|
|
||||||
}
|
|
||||||
|
|
||||||
final Map<String, Double> anMeans = getMapFromVectorTable(anMeansTable);
|
final Map<String, Double> anMeans = getMapFromVectorTable(anMeansTable);
|
||||||
final Map<String, Double> anStdDevs = getMapFromVectorTable(anStDevsTable);
|
final Map<String, Double> anStdDevs = getMapFromVectorTable(anStDevsTable);
|
||||||
dataManager.setNormalization(anMeans, anStdDevs);
|
dataManager.setNormalization(anMeans, anStdDevs);
|
||||||
|
|
||||||
goodModel = GMMFromTables(pmmTable, pmcTable, pPMixTable, numAnnotations);
|
goodModel = GMMFromTables(pmmTable, pmcTable, pPMixTable, annotationOrder.size());
|
||||||
badModel = GMMFromTables(nmmTable, nmcTable, nPMixTable, numAnnotations);
|
badModel = GMMFromTables(nmmTable, nmcTable, nPMixTable, annotationOrder.size());
|
||||||
}
|
}
|
||||||
|
|
||||||
final Set<VCFHeaderLine> hInfo = new HashSet<>();
|
final Set<VCFHeaderLine> hInfo = new HashSet<>();
|
||||||
|
|
@ -401,6 +402,32 @@ public class VariantRecalibrator extends RodWalker<ExpandingArrayList<VariantDat
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Order and validate annotations according to the annotations in the serialized model
|
||||||
|
* Annotations on the command line must be the same as those in the model report or this will throw an exception.
|
||||||
|
* Sets the {@code annotationOrder} list to map from command line order to the model report's order.
|
||||||
|
* n^2 because we typically use 7 or less annotations.
|
||||||
|
* @param annotationTable GATKReportTable of annotations read from the serialized model file
|
||||||
|
*/
|
||||||
|
protected void orderAndValidateAnnotations(final GATKReportTable annotationTable, final List<String> annotationKeys){
|
||||||
|
annotationOrder = new ArrayList<Integer>(annotationKeys.size());
|
||||||
|
|
||||||
|
for (int i = 0; i < annotationTable.getNumRows(); i++){
|
||||||
|
String serialAnno = (String)annotationTable.get(i, "Annotation");
|
||||||
|
for (int j = 0; j < annotationKeys.size(); j++) {
|
||||||
|
if (serialAnno.equals( annotationKeys.get(j) )){
|
||||||
|
annotationOrder.add(j);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if(annotationOrder.size() != annotationTable.getNumRows() || annotationOrder.size() != annotationKeys.size()) {
|
||||||
|
final String errorMsg = "Annotations specified on the command line:"+annotationKeys.toString() +" do not match annotations in the model report:"+inputModel;
|
||||||
|
throw new UserException.CommandLineException(errorMsg);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
//---------------------------------------------------------------------------------------------------------------
|
//---------------------------------------------------------------------------------------------------------------
|
||||||
//
|
//
|
||||||
|
|
@ -518,7 +545,7 @@ public class VariantRecalibrator extends RodWalker<ExpandingArrayList<VariantDat
|
||||||
for (int i = 1; i <= max_attempts; i++) {
|
for (int i = 1; i <= max_attempts; i++) {
|
||||||
try {
|
try {
|
||||||
dataManager.setData(reduceSum);
|
dataManager.setData(reduceSum);
|
||||||
dataManager.normalizeData(inputModel.isEmpty()); // Each data point is now (x - mean) / standard deviation
|
dataManager.normalizeData(inputModel.isEmpty(), annotationOrder); // Each data point is now (x - mean) / standard deviation
|
||||||
|
|
||||||
final List<VariantDatum> positiveTrainingData = dataManager.getTrainingData();
|
final List<VariantDatum> positiveTrainingData = dataManager.getTrainingData();
|
||||||
final List<VariantDatum> negativeTrainingData;
|
final List<VariantDatum> negativeTrainingData;
|
||||||
|
|
|
||||||
|
|
@ -51,6 +51,8 @@
|
||||||
|
|
||||||
package org.broadinstitute.gatk.tools.walkers.variantrecalibration;
|
package org.broadinstitute.gatk.tools.walkers.variantrecalibration;
|
||||||
|
|
||||||
|
import org.broadinstitute.gatk.utils.exceptions.UserException;
|
||||||
|
import org.broadinstitute.gatk.utils.exceptions.UserException.CommandLineException;
|
||||||
import org.broadinstitute.gatk.utils.variant.VCIterable;
|
import org.broadinstitute.gatk.utils.variant.VCIterable;
|
||||||
import org.broadinstitute.gatk.engine.walkers.WalkerTest;
|
import org.broadinstitute.gatk.engine.walkers.WalkerTest;
|
||||||
import htsjdk.variant.variantcontext.VariantContext;
|
import htsjdk.variant.variantcontext.VariantContext;
|
||||||
|
|
@ -60,6 +62,7 @@ import org.testng.annotations.DataProvider;
|
||||||
import org.testng.annotations.Test;
|
import org.testng.annotations.Test;
|
||||||
|
|
||||||
import java.io.File;
|
import java.io.File;
|
||||||
|
import java.io.IOException;
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
|
|
@ -390,5 +393,73 @@ public class VariantRecalibrationWalkersIntegrationTest extends WalkerTest {
|
||||||
new File(outputFile.getAbsolutePath() + ".pdf").deleteOnExit();
|
new File(outputFile.getAbsolutePath() + ".pdf").deleteOnExit();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testVQSRAnnotationOrder() throws IOException {
|
||||||
|
final String inputFile = privateTestDir + "oneSNP.vcf";
|
||||||
|
final String exacModelReportFilename = privateTestDir + "subsetExAC.snps_model.report";
|
||||||
|
final String annoOrderRecal = privateTestDir + "anno_order.recal";
|
||||||
|
final String annoOrderTranches = privateTestDir + "anno_order.tranches";
|
||||||
|
final String goodMd5 = "d41d8cd98f00b204e9800998ecf8427e";
|
||||||
|
final String base = "-R " + b37KGReference +
|
||||||
|
" -T VariantRecalibrator" +
|
||||||
|
" -input " + inputFile +
|
||||||
|
" -L 1:110201699" +
|
||||||
|
" -resource:truth=true,training=true,prior=15.0 " + inputFile +
|
||||||
|
" -an FS -an ReadPosRankSum -an MQ -an MQRankSum -an QD -an SOR"+
|
||||||
|
" --recal_file " + annoOrderRecal +
|
||||||
|
" -tranchesFile " + annoOrderTranches +
|
||||||
|
" --input_model " + exacModelReportFilename +
|
||||||
|
" -ignoreAllFilters -mode SNP" +
|
||||||
|
" --no_cmdline_in_header" ;
|
||||||
|
|
||||||
|
final WalkerTestSpec spec = new WalkerTestSpec(base, 1, Arrays.asList(goodMd5));
|
||||||
|
spec.disableShadowBCF(); // TODO -- enable when we support symbolic alleles
|
||||||
|
|
||||||
|
List<File> outputFiles = executeTest("testVQSRAnnotationOrder", spec).getFirst();
|
||||||
|
setPDFsForDeletion(outputFiles);
|
||||||
|
|
||||||
|
|
||||||
|
final String base2 = "-R " + b37KGReference +
|
||||||
|
" -T VariantRecalibrator" +
|
||||||
|
" -input " + inputFile +
|
||||||
|
" -L 1:110201699" +
|
||||||
|
" -resource:truth=true,training=true,prior=15.0 " + inputFile +
|
||||||
|
" -an ReadPosRankSum -an MQ -an MQRankSum -an QD -an SOR -an FS "+
|
||||||
|
" --recal_file " + annoOrderRecal +
|
||||||
|
" -tranchesFile " + annoOrderTranches +
|
||||||
|
" --input_model " + exacModelReportFilename +
|
||||||
|
" -ignoreAllFilters -mode SNP" +
|
||||||
|
" --no_cmdline_in_header" ;
|
||||||
|
|
||||||
|
final WalkerTestSpec spec2 = new WalkerTestSpec(base2, 1, Arrays.asList(goodMd5));
|
||||||
|
spec2.disableShadowBCF(); // TODO -- enable when we support symbolic alleles
|
||||||
|
outputFiles = executeTest("testVQSRAnnotationOrder2", spec2).getFirst();
|
||||||
|
setPDFsForDeletion(outputFiles);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test(expectedExceptions={RuntimeException.class, CommandLineException.class})
|
||||||
|
public void testVQSRAnnotationMismatch() throws IOException {
|
||||||
|
final String inputFile = privateTestDir + "oneSNP.vcf";
|
||||||
|
final String exacModelReportFilename = privateTestDir + "subsetExAC.snps_model.report";
|
||||||
|
final String annoOrderRecal = privateTestDir + "anno_order.recal";
|
||||||
|
final String annoOrderTranches = privateTestDir + "anno_order.tranches";
|
||||||
|
final String goodMd5 = "d41d8cd98f00b204e9800998ecf8427e";
|
||||||
|
final String base = "-R " + b37KGReference +
|
||||||
|
" -T VariantRecalibrator" +
|
||||||
|
" -input " + inputFile +
|
||||||
|
" -L 1:110201699" +
|
||||||
|
" -resource:truth=true,training=true,prior=15.0 " + inputFile +
|
||||||
|
" -an FS -an ReadPosRankSum -an MQ -an MQRankSum -an QD -an SOR -an BaseQRankSum"+
|
||||||
|
" --recal_file " + annoOrderRecal +
|
||||||
|
" -tranchesFile " + annoOrderTranches +
|
||||||
|
" --input_model " + exacModelReportFilename +
|
||||||
|
" -ignoreAllFilters -mode SNP" +
|
||||||
|
" --no_cmdline_in_header" ;
|
||||||
|
|
||||||
|
final WalkerTestSpec spec = new WalkerTestSpec(base, 1, Arrays.asList(goodMd5));
|
||||||
|
spec.disableShadowBCF(); // TODO -- enable when we support symbolic alleles
|
||||||
|
executeTest("testVQSRAnnotationMismatch", spec).getFirst();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -276,4 +276,52 @@ public class VariantRecalibratorModelOutputUnitTest extends BaseTest {
|
||||||
return new GaussianMixtureModel(badGaussianList, shrinkage, dirichlet, priorCounts);
|
return new GaussianMixtureModel(badGaussianList, shrinkage, dirichlet, priorCounts);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testAnnotationOrderAndValidate() {
|
||||||
|
final VariantRecalibrator vqsr = new VariantRecalibrator();
|
||||||
|
final List<String> annotationList = new ArrayList<>();
|
||||||
|
annotationList.add("QD");
|
||||||
|
annotationList.add("FS");
|
||||||
|
annotationList.add("ReadPosRankSum");
|
||||||
|
annotationList.add("MQ");
|
||||||
|
annotationList.add("MQRankSum");
|
||||||
|
annotationList.add("SOR");
|
||||||
|
|
||||||
|
double[] meanVector = {16.13, 2.45, 0.37, 59.08, 0.14, 0.91};
|
||||||
|
final String columnName = "Mean";
|
||||||
|
final String formatString = "%.3f";
|
||||||
|
GATKReportTable annotationTable = vqsr.makeVectorTable("AnnotationMeans", "Mean for each annotation, used to normalize data", annotationList, meanVector, columnName, formatString);
|
||||||
|
vqsr.orderAndValidateAnnotations(annotationTable, annotationList);
|
||||||
|
|
||||||
|
for (int i = 0; i < vqsr.annotationOrder.size(); i++){
|
||||||
|
Assert.assertEquals(i, (int)vqsr.annotationOrder.get(i));
|
||||||
|
}
|
||||||
|
|
||||||
|
annotationList.remove(0);
|
||||||
|
annotationList.add("QD");
|
||||||
|
vqsr.orderAndValidateAnnotations(annotationTable, annotationList);
|
||||||
|
for (int i = 0; i < vqsr.annotationOrder.size(); i++) {
|
||||||
|
if (i == 0) {
|
||||||
|
Assert.assertEquals(annotationList.size()-1, (int)vqsr.annotationOrder.get(i));
|
||||||
|
} else {
|
||||||
|
Assert.assertEquals(i - 1, (int)vqsr.annotationOrder.get(i));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
final List<String> annotationList2 = new ArrayList<>();
|
||||||
|
annotationList2.add("ReadPosRankSum");
|
||||||
|
annotationList2.add("MQRankSum");
|
||||||
|
annotationList2.add("MQ");
|
||||||
|
annotationList2.add("SOR");
|
||||||
|
annotationList2.add("QD");
|
||||||
|
annotationList2.add("FS");
|
||||||
|
|
||||||
|
final VariantRecalibrator vqsr2 = new VariantRecalibrator();
|
||||||
|
vqsr2.orderAndValidateAnnotations(annotationTable, annotationList2);
|
||||||
|
for (int i = 0; i < vqsr2.annotationOrder.size(); i++){
|
||||||
|
Assert.assertEquals(annotationList.get(vqsr.annotationOrder.get(i)), annotationList2.get(vqsr2.annotationOrder.get(i)));
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
Loading…
Reference in New Issue