From 7e77875c810ac1bcb1a2ad1d6526333490498166 Mon Sep 17 00:00:00 2001 From: Phillip Dexheimer Date: Wed, 6 Aug 2014 22:17:22 -0400 Subject: [PATCH] Improvements to read-group filtering in PrintReads - Read groups that are excluded by sample_name, platform, or read_group arguments no longer appear in the header - The performance penalty associated with filtering by read group has been essentially eliminated - Partial fulfillment of PT 73075482 --- .../tools/walkers/readutils/PrintReads.java | 117 ++++++++++++------ .../readutils/PrintReadsIntegrationTest.java | 81 +++++++++--- 2 files changed, 145 insertions(+), 53 deletions(-) diff --git a/public/gatk-tools-public/src/main/java/org/broadinstitute/gatk/tools/walkers/readutils/PrintReads.java b/public/gatk-tools-public/src/main/java/org/broadinstitute/gatk/tools/walkers/readutils/PrintReads.java index f8d8b529b..f271fe900 100644 --- a/public/gatk-tools-public/src/main/java/org/broadinstitute/gatk/tools/walkers/readutils/PrintReads.java +++ b/public/gatk-tools-public/src/main/java/org/broadinstitute/gatk/tools/walkers/readutils/PrintReads.java @@ -25,6 +25,7 @@ package org.broadinstitute.gatk.tools.walkers.readutils; +import htsjdk.samtools.SAMFileHeader; import htsjdk.samtools.SAMFileWriter; import htsjdk.samtools.SAMReadGroupRecord; import org.broadinstitute.gatk.engine.walkers.*; @@ -41,6 +42,7 @@ import org.broadinstitute.gatk.engine.refdata.RefMetaDataTracker; import org.broadinstitute.gatk.utils.SampleUtils; import org.broadinstitute.gatk.utils.Utils; import org.broadinstitute.gatk.utils.baq.BAQ; +import org.broadinstitute.gatk.utils.exceptions.UserException; import org.broadinstitute.gatk.utils.help.DocumentedGATKFeature; import org.broadinstitute.gatk.utils.help.HelpConstants; import org.broadinstitute.gatk.utils.sam.GATKSAMRecord; @@ -128,13 +130,13 @@ public class PrintReads extends ReadWalker impleme * Only reads from samples listed in the provided file(s) will be included in the output. */ @Argument(fullName="sample_file", shortName="sf", doc="File containing a list of samples (one per line). Can be specified multiple times", required=false) - public Set sampleFile = new TreeSet(); + public Set sampleFile = new TreeSet<>(); /** * Only reads from the sample(s) will be included in the output. */ @Argument(fullName="sample_name", shortName="sn", doc="Sample name to be included in the analysis. Can be specified multiple times.", required=false) - public Set sampleNames = new TreeSet(); + public Set sampleNames = new TreeSet<>(); /** * Erase all extra attributes in the read but keep the read group information @@ -147,8 +149,7 @@ public class PrintReads extends ReadWalker impleme public boolean NO_PG_TAG = false; List readTransformers = Collections.emptyList(); - private TreeSet samplesToChoose = new TreeSet(); - private boolean SAMPLES_SPECIFIED = false; + private Set readGroupsToKeep = Collections.emptySet(); public static final String PROGRAM_RECORD_NAME = "GATK PrintReads"; // The name that will go in the @PG tag @@ -161,12 +162,16 @@ public class PrintReads extends ReadWalker impleme public void initialize() { final GenomeAnalysisEngine toolkit = getToolkit(); - if ( platform != null ) - platform = platform.toUpperCase(); - - if ( getToolkit() != null ) - readTransformers = getToolkit().getReadTransformers(); + if ( toolkit != null ) + readTransformers = toolkit.getReadTransformers(); + //Sample names are case-insensitive + final TreeSet samplesToChoose = new TreeSet<>(new Comparator() { + @Override + public int compare(String a, String b) { + return a.compareToIgnoreCase(b); + } + }); Collection samplesFromFile; if (!sampleFile.isEmpty()) { samplesFromFile = SampleUtils.getSamplesFromFiles(sampleFile); @@ -176,15 +181,24 @@ public class PrintReads extends ReadWalker impleme if (!sampleNames.isEmpty()) samplesToChoose.addAll(sampleNames); - if(!samplesToChoose.isEmpty()) { - SAMPLES_SPECIFIED = true; - } - random = GenomeAnalysisEngine.getRandomGenerator(); - final boolean preSorted = true; - if (getToolkit() != null && getToolkit().getArguments().BQSR_RECAL_FILE != null && !NO_PG_TAG ) { - Utils.setupWriter(out, toolkit, toolkit.getSAMFileHeader(), preSorted, this, PROGRAM_RECORD_NAME); + if (toolkit != null) { + final SAMFileHeader outputHeader = toolkit.getSAMFileHeader().clone(); + readGroupsToKeep = determineReadGroupsOfInterest(outputHeader, samplesToChoose); + + //If some read groups are to be excluded, remove them from the output header + pruneReadGroups(outputHeader); + + //Add the program record (if appropriate) and set up the writer + final boolean preSorted = true; + if (toolkit.getArguments().BQSR_RECAL_FILE != null && !NO_PG_TAG ) { + Utils.setupWriter(out, toolkit, outputHeader, preSorted, this, PROGRAM_RECORD_NAME); + } else { + out.writeHeader(outputHeader); + out.setPresorted(preSorted); + } + } } @@ -197,32 +211,13 @@ public class PrintReads extends ReadWalker impleme * @return true if the read passes the filter, false if it doesn't */ public boolean filter(ReferenceContext ref, GATKSAMRecord read) { - // check the read group - if ( readGroup != null ) { - SAMReadGroupRecord myReadGroup = read.getReadGroup(); - if ( myReadGroup == null || !readGroup.equals(myReadGroup.getReadGroupId()) ) + // check that the read belongs to an RG that we need to keep + if (!readGroupsToKeep.isEmpty()) { + final SAMReadGroupRecord readGroup = read.getReadGroup(); + if (!readGroupsToKeep.contains(readGroup.getReadGroupId())) return false; } - // check the platform - if ( platform != null ) { - SAMReadGroupRecord readGroup = read.getReadGroup(); - if ( readGroup == null ) - return false; - - Object readPlatformAttr = readGroup.getAttribute("PL"); - if ( readPlatformAttr == null || !readPlatformAttr.toString().toUpperCase().contains(platform)) - return false; - } - if (SAMPLES_SPECIFIED ) { - // user specified samples to select - // todo - should be case-agnostic but for simplicity and speed this is ignored. - // todo - can check at initialization intersection of requested samples and samples in BAM header to further speedup. - if (!samplesToChoose.contains(read.getReadGroup().getSample())) - return false; - } - - // check if we've reached the output limit if ( nReadsToPrint == 0 ) { return false; // n == 0 means we've printed all we needed. @@ -274,4 +269,48 @@ public class PrintReads extends ReadWalker impleme output.addAlignment(read); return output; } + + /** + * Determines the list of read groups that meet the user's criteria for inclusion (based on id, platform, or sample) + * @param header the merged header for all input files + * @param samplesToKeep the list of specific samples specified by the user + * @return a Set of read group IDs that meet the user's criteria, empty if all RGs should be included + */ + private Set determineReadGroupsOfInterest(final SAMFileHeader header, final Set samplesToKeep) { + //If no filter options that use read group information have been supplied, exit early + if (platform == null && readGroup == null && samplesToKeep.isEmpty()) + return Collections.emptySet(); + + if ( platform != null ) + platform = platform.toUpperCase(); + + final Set result = new HashSet<>(); + for (final SAMReadGroupRecord rg : header.getReadGroups()) { + // To be eligible for output, a read group must: + // NOT have an id that is blacklisted on the command line (note that String.equals(null) is false) + // AND NOT have a platform that contains the blacklisted platform from the command line + // AND have a sample that is whitelisted on the command line + if (!rg.getReadGroupId().equals(readGroup) && + (platform == null || !rg.getPlatform().toUpperCase().contains(platform)) && + (samplesToKeep.isEmpty() || samplesToKeep.contains(rg.getSample()))) + result.add(rg.getReadGroupId()); + } + + if (result.isEmpty()) + throw new UserException.BadArgumentValue("-sn/-sf/-platform/-readGroup", "No read groups remain after pruning based on the supplied parameters"); + + return result; + } + + private void pruneReadGroups(final SAMFileHeader header) { + if (readGroupsToKeep.isEmpty()) + return; + + final List readGroups = new ArrayList<>(); + for (final SAMReadGroupRecord rg : header.getReadGroups()) { + if (readGroupsToKeep.contains(rg.getReadGroupId())) + readGroups.add(rg); + } + header.setReadGroups(readGroups); + } } diff --git a/public/gatk-tools-public/src/test/java/org/broadinstitute/gatk/tools/walkers/readutils/PrintReadsIntegrationTest.java b/public/gatk-tools-public/src/test/java/org/broadinstitute/gatk/tools/walkers/readutils/PrintReadsIntegrationTest.java index 5f8be81a1..021394033 100644 --- a/public/gatk-tools-public/src/test/java/org/broadinstitute/gatk/tools/walkers/readutils/PrintReadsIntegrationTest.java +++ b/public/gatk-tools-public/src/test/java/org/broadinstitute/gatk/tools/walkers/readutils/PrintReadsIntegrationTest.java @@ -26,23 +26,27 @@ package org.broadinstitute.gatk.tools.walkers.readutils; import org.broadinstitute.gatk.engine.walkers.WalkerTest; +import org.broadinstitute.gatk.utils.exceptions.UserException; import org.testng.annotations.DataProvider; import org.testng.annotations.Test; +import java.util.ArrayList; import java.util.Arrays; +import java.util.List; public class PrintReadsIntegrationTest extends WalkerTest { private static class PRTest { final String reference; - final String bam; + final List bam; final String args; final String md5; - private PRTest(String reference, String bam, String args, String md5) { + private PRTest(String reference, String[] bam, String args, String md5) { this.reference = reference; - this.bam = bam; + this.bam = new ArrayList<>(); this.args = args; this.md5 = md5; + this.bam.addAll(Arrays.asList(bam)); } @Override @@ -54,30 +58,79 @@ public class PrintReadsIntegrationTest extends WalkerTest { @DataProvider(name = "PRTest") public Object[][] createPrintReadsTestData() { return new Object[][]{ - {new PRTest(hg18Reference, "HiSeq.1mb.bam", "", "fa9c66f66299fe5405512ac36ec9d0f2")}, - {new PRTest(hg18Reference, "HiSeq.1mb.bam", " -compress 0", "488eb22abc31c6af7cbb1a3d41da1507")}, - {new PRTest(hg18Reference, "HiSeq.1mb.bam", " -simplifyBAM", "1510dc4429f3ed49caf96da41e8ed396")}, - {new PRTest(hg18Reference, "HiSeq.1mb.bam", " -n 10", "0e3d1748ad1cb523e3295cab9d09d8fc")}, + {new PRTest(hg18Reference, new String[]{"HiSeq.1mb.bam"}, "", "fa9c66f66299fe5405512ac36ec9d0f2")}, + {new PRTest(hg18Reference, new String[]{"HiSeq.1mb.bam"}, " -compress 0", "488eb22abc31c6af7cbb1a3d41da1507")}, + {new PRTest(hg18Reference, new String[]{"HiSeq.1mb.bam"}, " -simplifyBAM", "1510dc4429f3ed49caf96da41e8ed396")}, + {new PRTest(hg18Reference, new String[]{"HiSeq.1mb.bam"}, " -n 10", "0e3d1748ad1cb523e3295cab9d09d8fc")}, // See: GATKBAMIndex.getStartOfLastLinearBin(), BAMScheduler.advance(), IntervalOverlapFilteringIterator.advance() - {new PRTest(b37KGReference, "unmappedFlagReadsInLastLinearBin.bam", "", "d7f23fd77d7dc7cb50d3397f644c6d8a")}, - {new PRTest(b37KGReference, "unmappedFlagReadsInLastLinearBin.bam", " -L 1", "c601db95b20248d012b0085347fcb6d1")}, - {new PRTest(b37KGReference, "unmappedFlagReadsInLastLinearBin.bam", " -L unmapped", "2d32440e47e8d9d329902fe573ad94ce")}, - {new PRTest(b37KGReference, "unmappedFlagReadsInLastLinearBin.bam", " -L 1 -L unmapped", "c601db95b20248d012b0085347fcb6d1")}, - {new PRTest(b37KGReference, "oneReadAllInsertion.bam", "", "349650b6aa9e574b48a2a62627f37c7d")}, - {new PRTest(b37KGReference, "NA12878.1_10mb_2_10mb.bam", "", "0c1cbe67296637a85e80e7a182f828ab")} + {new PRTest(b37KGReference, new String[]{"unmappedFlagReadsInLastLinearBin.bam"}, "", "d7f23fd77d7dc7cb50d3397f644c6d8a")}, + {new PRTest(b37KGReference, new String[]{"unmappedFlagReadsInLastLinearBin.bam"}, " -L 1", "c601db95b20248d012b0085347fcb6d1")}, + {new PRTest(b37KGReference, new String[]{"unmappedFlagReadsInLastLinearBin.bam"}, " -L unmapped", "2d32440e47e8d9d329902fe573ad94ce")}, + {new PRTest(b37KGReference, new String[]{"unmappedFlagReadsInLastLinearBin.bam"}, " -L 1 -L unmapped", "c601db95b20248d012b0085347fcb6d1")}, + {new PRTest(b37KGReference, new String[]{"oneReadAllInsertion.bam"}, "", "349650b6aa9e574b48a2a62627f37c7d")}, + {new PRTest(b37KGReference, new String[]{"NA12878.1_10mb_2_10mb.bam"}, "", "0c1cbe67296637a85e80e7a182f828ab")}, + // Tests for filtering options + {new PRTest(b37KGReference, new String[]{"NA12878.1_10mb_2_10mb.bam", "NA20313.highCoverageRegion.bam"}, + "", "b3ae15c8af33fd5badc1a29e089bdaac")}, + {new PRTest(b37KGReference, new String[]{"NA12878.1_10mb_2_10mb.bam", "NA20313.highCoverageRegion.bam"}, + " -readGroup SRR359098", "8bd867b30539524daa7181efd9835a8f")}, + {new PRTest(b37KGReference, new String[]{"NA12878.1_10mb_2_10mb.bam", "NA20313.highCoverageRegion.bam"}, + " -readGroup 20FUK.3 -sn NA12878", "93a7bc1b2b1cd27815ed1666cbb4d0cb")}, + {new PRTest(b37KGReference, new String[]{"NA12878.1_10mb_2_10mb.bam", "NA20313.highCoverageRegion.bam"}, + " -sn na12878", "52e99cfcf03ff46285d1ba302f8df964")}, }; } @Test(dataProvider = "PRTest") public void testPrintReads(PRTest params) { + + StringBuilder inputs = new StringBuilder(); + for (String bam : params.bam) { + inputs.append(" -I "); + inputs.append(privateTestDir); + inputs.append(bam); + } + WalkerTestSpec spec = new WalkerTestSpec( "-T PrintReads" + " -R " + params.reference + - " -I " + privateTestDir + params.bam + + inputs.toString() + params.args + " --no_pg_tag" + " -o %s", Arrays.asList(params.md5)); executeTest("testPrintReads-"+params.args, spec).getFirst(); } + + @DataProvider(name = "PRExceptionTest") + public Object[][] createPrintReadsExceptionTestData() { + return new Object[][]{ + {new PRTest(b37KGReference, new String[]{"NA12878.1_10mb_2_10mb.bam", "NA20313.highCoverageRegion.bam"}, + "-platform illum", "")}, + {new PRTest(b37KGReference, new String[]{"NA12878.1_10mb_2_10mb.bam", "NA20313.highCoverageRegion.bam"}, + " -sn NotASample", "")}, + }; + } + + @Test(dataProvider = "PRExceptionTest") + public void testPrintReadsException(PRTest params) { + + StringBuilder inputs = new StringBuilder(); + for (String bam : params.bam) { + inputs.append(" -I "); + inputs.append(privateTestDir); + inputs.append(bam); + } + + WalkerTestSpec spec = new WalkerTestSpec( + "-T PrintReads" + + " -R " + params.reference + + inputs.toString() + + params.args + + " --no_pg_tag" + + " -o %s", + 1, UserException.class); + executeTest("testPrintReadsException-"+params.args, spec); + } + }