Initial check in of refactored Recalibrator. The new walkers are called CountCovariatesRefactored and TableRecalibrationRefactored. More work is needed to finish up the sequential calculation and to document the code sufficiently. These files are not ready to be used by other people quite yet.

git-svn-id: file:///humgen/gsa-scr1/gsa-engineering/svn_contents/trunk@1982 348d0f76-0448-11de-a6fe-93d51630548a
This commit is contained in:
rpoplin 2009-11-06 22:33:55 +00:00
parent 6fdfc97db6
commit 66d4a995e6
12 changed files with 971 additions and 0 deletions

View File

@ -0,0 +1,13 @@
package org.broadinstitute.sting.playground.gatk.walkers.Recalibration;
import net.sf.samtools.SAMRecord;
/**
* Created by IntelliJ IDEA.
* User: rpoplin
* Date: Oct 30, 2009
*/
public interface Covariate {
public Comparable<?> getValue(SAMRecord read, int offset, char[] refBases); // used to pick out the value from the read and etc
public Comparable<?> getValue(String str); // used to get value from input file
}

View File

@ -0,0 +1,202 @@
package org.broadinstitute.sting.playground.gatk.walkers.Recalibration;
import org.broadinstitute.sting.gatk.walkers.*;
import org.broadinstitute.sting.gatk.refdata.RefMetaDataTracker;
import org.broadinstitute.sting.gatk.refdata.rodDbSNP;
import org.broadinstitute.sting.gatk.contexts.ReferenceContext;
import org.broadinstitute.sting.gatk.contexts.AlignmentContext;
import org.broadinstitute.sting.utils.cmdLine.Argument;
import org.broadinstitute.sting.utils.PackageUtils;
import org.broadinstitute.sting.utils.StingException;
import org.broadinstitute.sting.utils.BaseUtils;
import java.io.PrintStream;
import java.io.FileNotFoundException;
import java.util.*;
import net.sf.samtools.SAMRecord;
/**
* Created by IntelliJ IDEA.
* User: rpoplin
* Date: Nov 3, 2009
*/
@WalkerName("CountCovariatesRefactored")
public class CovariateCounterWalker extends LocusWalker<Integer, PrintStream> {
@Argument(fullName="list", shortName="ls", doc="List the available covariates and exit", required=false)
protected Boolean LIST_ONLY = false;
@Argument(fullName="covariate", shortName="cov", doc="Covariates to be used in the recalibration. Each covariate is given as a separate cov parameter. ReadGroup and ReportedQuality are already added for you.", required=false)
protected String[] COVARIATES = null;
@Argument(fullName="min_mapping_quality", shortName="minmap", required=false, doc="Only use reads with at least this mapping quality score")
public int MIN_MAPPING_QUALITY = 1;
@Argument(fullName = "use_original_quals", shortName="OQ", doc="If provided, we will use use the quals from the original qualities OQ attribute field instead of the quals in the regular QUALS field", required=false)
public boolean USE_ORIGINAL_QUALS = false;
@Argument(fullName="recal_file", shortName="rf", required=false, doc="Filename for the outputted covariates table recalibration file")
public String RECAL_FILE = "output.recal_data.csv";
protected static RecalDataManager dataManager;
protected static ArrayList<Covariate> requestedCovariates;
public void initialize() {
dataManager = new RecalDataManager();
// Get a list of all available covariates
List<Class<? extends Covariate>> classes = PackageUtils.getClassesImplementingInterface(Covariate.class);
// Print and exit if that's what was requested
if ( LIST_ONLY ) {
out.println( "Available covariates:" );
for( Class<?> covClass : classes ) {
out.println( covClass.getSimpleName() );
}
out.println();
System.exit( 0 ); // early exit here because user requested it
}
// Initialize the requested covariates
requestedCovariates = new ArrayList<Covariate>();
requestedCovariates.add( new ReadGroupCovariate() ); // Read Group Covariate is a required covariate for the recalibration calculation
requestedCovariates.add( new QualityScoreCovariate() ); // Quality Score Covariate is a required covariate for the recalibration calculation
if( COVARIATES != null ) {
for( String requestedCovariateString : COVARIATES ) {
boolean foundClass = false;
for( Class<?> covClass : classes ) {
if( requestedCovariateString.equalsIgnoreCase( covClass.getSimpleName() ) ) {
foundClass = true;
try {
Covariate covariate = (Covariate)covClass.newInstance();
requestedCovariates.add( covariate );
if (covariate instanceof ReadGroupCovariate || covariate instanceof QualityScoreCovariate) {
throw new StingException( "ReadGroupCovariate and QualityScoreCovariate are required covariates and are therefore added for you. Please remove them from the -cov list" );
}
} catch ( InstantiationException e ) {
throw new StingException( String.format("Can not instantiate covariate class '%s': must be concrete class.", covClass.getSimpleName()) );
} catch ( IllegalAccessException e ) {
throw new StingException( String.format("Can not instantiate covariate class '%s': must have no-arg constructor.", covClass.getSimpleName()) );
}
}
}
if( !foundClass ) {
throw new StingException( "The requested covariate type (" + requestedCovariateString + ") isn't a valid covariate option. Use --list to see possible covariates." );
}
}
}
logger.info( "The covariates being used here: " );
logger.info( requestedCovariates );
}
public Integer map( RefMetaDataTracker tracker, ReferenceContext ref, AlignmentContext context ) {
rodDbSNP dbsnp = rodDbSNP.getFirstRealSNP(tracker.getTrackData("dbsnp", null));
// Only use data from non-dbsnp sites
// Assume every mismatch at a non-dbsnp site is indicitive of poor quality
if( dbsnp == null ) {
List<SAMRecord> reads = context.getReads();
List<Integer> offsets = context.getOffsets();
SAMRecord read; // preallocate for use in for loop below
int offset; // preallocate for use in for loop below
for( int iii = 0; iii < reads.size(); iii++ ) {
read = reads.get(iii);
offset = offsets.get(iii);
// Only use data from reads with mapping quality above given quality value and base quality greater than zero
byte[] quals = read.getBaseQualities();
if( read.getMappingQuality() >= MIN_MAPPING_QUALITY && quals[offset] > 0)
{
if( offset > 0 && offset < (read.getReadLength() - 1) ) { // skip first and last bases because they don't have a dinuc count
updateDataFromRead(read, offset, ref);
}
}
}
}
return 1;
}
private void updateDataFromRead(SAMRecord read, int offset, ReferenceContext ref) {
ArrayList<Comparable<?>> key = new ArrayList<Comparable<?>>();
Comparable<?> keyElement; // preallocate for use in for loop below
boolean badKey = false;
for( Covariate covariate : requestedCovariates ) {
keyElement = covariate.getValue( read, offset, ref.getBases() );
if( keyElement != null ) {
key.add( keyElement );
} else {
badKey = true;
}
}
RecalDatum datum = null;
if( !badKey ) {
datum = dataManager.data.get( key );
if( datum == null ) { // key doesn't exist yet in the map so make a new bucket and add it
datum = new RecalDatum();
dataManager.data.put( key, datum );
}
}
byte[] bases = read.getReadBases();
char base = (char)bases[offset];
char refBase = ref.getBase();
if ( read.getReadNegativeStrandFlag() ) {
refBase = BaseUtils.simpleComplement( refBase );
base = BaseUtils.simpleComplement( base );
}
if( datum != null ) {
datum.increment( base, refBase );
}
}
public PrintStream reduceInit() {
try {
return new PrintStream( RECAL_FILE );
} catch ( FileNotFoundException e ) {
throw new RuntimeException( "Couldn't open output file: ", e );
}
}
public PrintStream reduce( Integer value, PrintStream recalTableStream ) {
return recalTableStream; // nothing to do here
}
public void onTraversalDone( PrintStream recalTableStream ) {
out.print( "Writing raw recalibration data..." );
for( Covariate cov : requestedCovariates ) {
recalTableStream.println( "@!" + cov.getClass().getSimpleName() ); // The "@!" is a code for TableRecalibrationWalker to recognize this line as a Covariate class name
}
outputToCSV( recalTableStream );
out.println( "...done!" );
recalTableStream.close();
}
private void outputToCSV( PrintStream recalTableStream ) {
for( Map.Entry<List<? extends Comparable<?>>, RecalDatum> entry : dataManager.data.entrySet() ) {
//for( int iii = 0; iii < entry.getKey().size(); iii++ ) {
//index = Integer.parseInt( (entry.getKey().get( iii )).toString());
//recalTableStream.print( requestedCovariates.get( iii ).getValue( index ) + "," );
// recalTableStream.print( entry.getKey().get(iii) )
//}
for( Comparable<?> comp : entry.getKey() ) {
recalTableStream.print( comp + "," );
}
recalTableStream.println( entry.getValue().outputToCSV() );
}
}
}

View File

@ -0,0 +1,38 @@
package org.broadinstitute.sting.playground.gatk.walkers.Recalibration;
import net.sf.samtools.SAMRecord;
/**
* Created by IntelliJ IDEA.
* User: rpoplin
* Date: Oct 30, 2009
*/
public class CycleCovariate implements Covariate {
public String platform;
public CycleCovariate() { // empty constructor is required by CovariateCounterWalker
platform = null;
}
public CycleCovariate(String _platform) {
platform = _platform;
}
public Comparable<?> getValue(SAMRecord read, int offset, char[] refBases) {
//BUGBUG: assumes Solexia platform
int cycle = offset;
if( read.getReadNegativeStrandFlag() ) {
cycle = read.getReadLength() - (offset + 1);
}
return cycle;
}
public Comparable<?> getValue(String str) {
return Integer.parseInt( str );
}
public String toString() {
return "Cycle";
}
}

View File

@ -0,0 +1,52 @@
package org.broadinstitute.sting.playground.gatk.walkers.Recalibration;
import net.sf.samtools.SAMRecord;
import org.broadinstitute.sting.utils.BaseUtils;
import java.util.*;
/**
* Created by IntelliJ IDEA.
* User: rpoplin
* Date: Nov 3, 2009
*/
public class DinucCovariate implements Covariate {
public static ArrayList<String> BASES;
public DinucCovariate() { // empty constructor is required by CovariateCounterWalker
BASES = new ArrayList<String>();
BASES.add("A");
BASES.add("G");
BASES.add("C");
BASES.add("T");
BASES.add("a");
BASES.add("g");
BASES.add("c");
BASES.add("t");
}
public Comparable<?> getValue(SAMRecord read, int offset, char[] refBases) {
byte[] bases = read.getReadBases();
char base = (char)bases[offset];
char prevBase = (char)bases[offset - 1];
if( read.getReadNegativeStrandFlag() ) {
base = BaseUtils.simpleComplement(base);
prevBase = BaseUtils.simpleComplement( (char)bases[offset + 1] );
}
// Check if bad base, probably an 'N'
if( !BASES.contains( String.format( "%c", prevBase ) ) || !BASES.contains( String.format( "%c", base) ) ) {
return null; // CovariateCounterWalker and TableRecalibrationWalker will recognize that null means skip this particular location in the read
} else {
return String.format("%c%c", prevBase, base);
}
}
public Comparable<?> getValue(String str) {
return str;
}
public String toString() {
return "Dinucleotide";
}
}

View File

@ -0,0 +1,26 @@
package org.broadinstitute.sting.playground.gatk.walkers.Recalibration;
import net.sf.samtools.SAMRecord;
/**
* Created by IntelliJ IDEA.
* User: rpoplin
* Date: Nov 4, 2009
*/
public class MappingQualityCovariate implements Covariate {
public MappingQualityCovariate() { // empty constructor is required by CovariateCounterWalker
}
public Comparable<?> getValue(SAMRecord read, int offset, char[] refBases) {
return read.getMappingQuality();
}
public Comparable<?> getValue(String str) {
return Integer.parseInt( str );
}
public String toString() {
return "Mapping Quality Score";
}
}

View File

@ -0,0 +1,51 @@
package org.broadinstitute.sting.playground.gatk.walkers.Recalibration;
import net.sf.samtools.SAMRecord;
import org.broadinstitute.sting.utils.QualityUtils;
/**
* Created by IntelliJ IDEA.
* User: rpoplin
* Date: Nov 4, 2009
*/
public class MinimumNQSCovariate implements Covariate {
public final static String ORIGINAL_QUAL_ATTRIBUTE_TAG = "OQ";
protected boolean USE_ORIGINAL_QUALS;
public MinimumNQSCovariate() { // empty constructor is required by CovariateCounterWalker
USE_ORIGINAL_QUALS = false;
}
public MinimumNQSCovariate(boolean originalQuals) {
USE_ORIGINAL_QUALS = originalQuals;
}
public Comparable<?> getValue(SAMRecord read, int offset, char[] refBases) {
byte[] quals = read.getBaseQualities();
if ( USE_ORIGINAL_QUALS && read.getAttribute(ORIGINAL_QUAL_ATTRIBUTE_TAG) != null ) {
Object obj = read.getAttribute(ORIGINAL_QUAL_ATTRIBUTE_TAG);
if ( obj instanceof String )
quals = QualityUtils.fastqToPhred((String)obj);
else {
throw new RuntimeException(String.format("Value encoded by %s in %s isn't a string!", ORIGINAL_QUAL_ATTRIBUTE_TAG, read.getReadName()));
}
}
int minQual = quals[0];
for ( int qual : quals ) {
if( qual < minQual ) {
minQual = qual;
}
}
return minQual;
}
public Comparable<?> getValue(String str) {
return Integer.parseInt( str );
}
public String toString() {
return "Minimum Neighborhood Quality Score";
}
}

View File

@ -0,0 +1,29 @@
package org.broadinstitute.sting.playground.gatk.walkers.Recalibration;
import java.util.*;
/**
* Created by IntelliJ IDEA.
* User: rpoplin
* Date: Oct 30, 2009
*/
public class NHashMap<T> extends HashMap<List<? extends Comparable<?>>, T> {
private static final long serialVersionUID = 1L; //BUGBUG: what should I do here?
public static <T extends Comparable<?>> List<T> makeList(T... args) {
List<T> list = new ArrayList<T>();
for (T arg : args)
{
list.add(arg);
}
return list;
}
}

View File

@ -0,0 +1,45 @@
package org.broadinstitute.sting.playground.gatk.walkers.Recalibration;
import net.sf.samtools.SAMRecord;
import org.broadinstitute.sting.utils.QualityUtils;
/**
* Created by IntelliJ IDEA.
* User: rpoplin
* Date: Nov 3, 2009
*/
public class QualityScoreCovariate implements Covariate {
public final static String ORIGINAL_QUAL_ATTRIBUTE_TAG = "OQ";
protected boolean USE_ORIGINAL_QUALS;
public QualityScoreCovariate() { // empty constructor is required by CovariateCounterWalker
USE_ORIGINAL_QUALS = false;
}
public QualityScoreCovariate(boolean originalQuals) {
USE_ORIGINAL_QUALS = originalQuals;
}
public Comparable<?> getValue(SAMRecord read, int offset, char[] refBases) {
byte[] quals = read.getBaseQualities();
if ( USE_ORIGINAL_QUALS && read.getAttribute(ORIGINAL_QUAL_ATTRIBUTE_TAG) != null ) {
Object obj = read.getAttribute(ORIGINAL_QUAL_ATTRIBUTE_TAG);
if ( obj instanceof String )
quals = QualityUtils.fastqToPhred((String)obj);
else {
throw new RuntimeException(String.format("Value encoded by %s in %s isn't a string!", ORIGINAL_QUAL_ATTRIBUTE_TAG, read.getReadName()));
}
}
return quals[offset];
}
public Comparable<?> getValue(String str) {
return Integer.parseInt( str );
}
public String toString() {
return "Reported Quality Score";
}
}

View File

@ -0,0 +1,28 @@
package org.broadinstitute.sting.playground.gatk.walkers.Recalibration;
import net.sf.samtools.SAMRecord;
/**
* Created by IntelliJ IDEA.
* User: rpoplin
* Date: Oct 30, 2009
*/
public class ReadGroupCovariate implements Covariate{
public ReadGroupCovariate() { // empty constructor is required by CovariateCounterWalker
}
public Comparable<?> getValue(SAMRecord read, int offset, char[] refBases) {
return read.getReadGroup().getReadGroupId();
}
public Comparable<?> getValue(String str) {
return str;
}
public String toString() {
return "Read Group";
}
}

View File

@ -0,0 +1,114 @@
package org.broadinstitute.sting.playground.gatk.walkers.Recalibration;
import org.broadinstitute.sting.utils.StingException;
import org.broadinstitute.sting.utils.QualityUtils;
import java.util.*;
/**
* Created by IntelliJ IDEA.
* User: rpoplin
* Date: Nov 6, 2009
*/
public class RecalDataManager {
public NHashMap<RecalDatum> data; // the full dataset
public NHashMap<RecalDatum> dataCollapsedReadGroup; // table where everything except read group has been collapsed
public NHashMap<RecalDatum> dataCollapsedQualityScore; // table where everything except read group and quality score has been collapsed
public ArrayList<NHashMap<RecalDatum>> dataCollapsedByCovariate; // tables where everything except read group, quality score, and given covariate has been collapsed
public boolean collapsedTablesCreated;
public NHashMap<Double> dataSumExpectedErrors;
RecalDataManager() {
data = new NHashMap<RecalDatum>();
collapsedTablesCreated = false;
}
// BUGBUG: A lot going on in this method, doing a lot of pre-calculations for use in the sequential mode calculation later
public void createCollapsedTables( int numCovariates ) {
dataCollapsedReadGroup = new NHashMap<RecalDatum>();
dataCollapsedQualityScore = new NHashMap<RecalDatum>();
dataCollapsedByCovariate = new ArrayList<NHashMap<RecalDatum>>();
for( int iii = 0; iii < numCovariates - 2; iii++ ) { // readGroup and QualityScore aren't counted
dataCollapsedByCovariate.add( new NHashMap<RecalDatum>() );
}
dataSumExpectedErrors = new NHashMap<Double>();
// preallocate for use in for loops below
RecalDatum thisDatum;
RecalDatum collapsedDatum;
List<? extends Comparable<?>> key;
ArrayList<Comparable<?>> newKey;
Double sumExpectedErrors;
// for every data point in the map
for( Map.Entry<List<? extends Comparable<?>>,RecalDatum> entry : data.entrySet() ) {
thisDatum = entry.getValue();
key = entry.getKey();
// create dataCollapsedReadGroup, the table where everything except read group has been collapsed
newKey = new ArrayList<Comparable<?>>();
newKey.add( key.get(0) ); // make a new key with just the read group
collapsedDatum = dataCollapsedReadGroup.get( newKey );
if( collapsedDatum == null ) {
dataCollapsedReadGroup.put( newKey, new RecalDatum( thisDatum ) );
//System.out.println("Added key: " + newKey + " to the dataCollapsedReadGroup");
} else {
collapsedDatum.increment( thisDatum );
}
newKey = new ArrayList<Comparable<?>>();
newKey.add( key.get(0) ); // make a new key with just the read group
sumExpectedErrors = dataSumExpectedErrors.get( newKey );
if( sumExpectedErrors == null ) {
dataSumExpectedErrors.put( newKey, 0.0 );
} else {
//System.out.println("updated += " + QualityUtils.qualToErrorProb(Byte.parseByte(key.get(1).toString())) * thisDatum.getNumObservations());
dataSumExpectedErrors.remove( newKey );
sumExpectedErrors += QualityUtils.qualToErrorProb(Byte.parseByte(key.get(1).toString())) * thisDatum.getNumObservations();
dataSumExpectedErrors.put( newKey, sumExpectedErrors );
}
newKey = new ArrayList<Comparable<?>>();
// create dataCollapsedQuality, the table where everything except read group and quality score has been collapsed
newKey.add( key.get(0) ); // make a new key with the read group ...
newKey.add( key.get(1) ); // and quality score
collapsedDatum = dataCollapsedQualityScore.get( newKey );
if( collapsedDatum == null ) {
dataCollapsedQualityScore.put( newKey, new RecalDatum( thisDatum ) );
} else {
collapsedDatum.increment( thisDatum );
}
// create dataCollapsedByCovariate's, the tables where everything except read group, quality score, and given covariate has been collapsed
for( int iii = 0; iii < numCovariates - 2; iii++ ) { // readGroup and QualityScore aren't counted
newKey = new ArrayList<Comparable<?>>();
newKey.add( key.get(0) ); // make a new key with the read group ...
newKey.add( key.get(1) ); // and quality score ...
newKey.add( key.get(iii) ); // and the given covariate
collapsedDatum = dataCollapsedByCovariate.get(iii).get( newKey );
if( collapsedDatum == null ) {
dataCollapsedByCovariate.get(iii).put( newKey, new RecalDatum( thisDatum ) );
} else {
collapsedDatum.increment( thisDatum );
}
}
}
collapsedTablesCreated = true;
}
public NHashMap<RecalDatum> getCollapsedTable( int covariate ) {
if( !collapsedTablesCreated ) {
throw new StingException("Trying to get collapsed tables before they have been populated.");
}
if( covariate == 0) {
return dataCollapsedReadGroup; // table where everything except read group has been collapsed
} else if( covariate == 1 ) {
return dataCollapsedQualityScore; // table where everything except read group and quality score has been collapsed
} else {
return dataCollapsedByCovariate.get( covariate - 2 ); // table where everything except read group, quality score, and given covariate has been collapsed
}
}
}

View File

@ -0,0 +1,107 @@
package org.broadinstitute.sting.playground.gatk.walkers.Recalibration;
import org.broadinstitute.sting.utils.BaseUtils;
import org.broadinstitute.sting.utils.QualityUtils;
import java.util.*;
/**
* Created by IntelliJ IDEA.
* User: rpoplin
* Date: Nov 3, 2009
*/
public class RecalDatum {
long numObservations; // number of bases seen in total
long numMismatches; // number of bases seen that didn't match the reference
//---------------------------------------------------------------------------------------------------------------
//
// constructors
//
//---------------------------------------------------------------------------------------------------------------
public RecalDatum() {
numObservations = 0L;
numMismatches = 0L;
}
public RecalDatum( long _numObservations, long _numMismatches ) {
numObservations = _numObservations;
numMismatches = _numMismatches;
}
public RecalDatum( RecalDatum copy ) {
this.numObservations = copy.numObservations;
this.numMismatches = copy.numMismatches;
}
//---------------------------------------------------------------------------------------------------------------
//
// increment methods
//
//---------------------------------------------------------------------------------------------------------------
public RecalDatum increment( long incObservations, long incMismatches ) {
numObservations += incObservations;
numMismatches += incMismatches;
return this;
}
public RecalDatum increment( RecalDatum other ) {
return increment( other.numObservations, other.numMismatches );
}
public RecalDatum increment( List<RecalDatum> data ) {
for ( RecalDatum other : data ) {
this.increment( other );
}
return this;
}
public RecalDatum increment( char curBase, char ref ) {
return increment( 1, BaseUtils.simpleBaseToBaseIndex(curBase) == BaseUtils.simpleBaseToBaseIndex(ref) ? 0 : 1 ); // inc takes num observations, then num mismatches
}
//---------------------------------------------------------------------------------------------------------------
//
// methods to derive empirical quality score
//
//---------------------------------------------------------------------------------------------------------------
public double empiricalQualDouble( int smoothing ) {
double doubleMismatches = (double) ( numMismatches + smoothing );
double doubleObservations = (double) ( numObservations + smoothing );
double empiricalQual = -10 * Math.log10(doubleMismatches / doubleObservations);
if (empiricalQual > QualityUtils.MAX_REASONABLE_Q_SCORE) empiricalQual = QualityUtils.MAX_REASONABLE_Q_SCORE;
return empiricalQual;
}
public double empiricalQualDouble() { return empiricalQualDouble( 0 ); } // 'default' behavior is to use smoothing value of zero
public byte empiricalQualByte( int smoothing ) {
double doubleMismatches = (double) ( numMismatches + smoothing );
double doubleObservations = (double) ( numObservations + smoothing );
return QualityUtils.probToQual( 1.0 - doubleMismatches / doubleObservations );
}
public byte empiricalQualByte() { return empiricalQualByte( 0 ); } // 'default' behavior is to use smoothing value of zero
//---------------------------------------------------------------------------------------------------------------
//
// misc. methods
//
//---------------------------------------------------------------------------------------------------------------
public String outputToCSV( ) {
return String.format( "%d,%d,%d", numObservations, numMismatches, (int)empiricalQualByte() );
}
public String outputToCSV( int smoothing ) {
return String.format( "%d,%d,%d", numObservations, numMismatches, (int)empiricalQualByte( smoothing ) );
}
public Long getNumObservations() {
return numObservations;
}
}

View File

@ -0,0 +1,266 @@
package org.broadinstitute.sting.playground.gatk.walkers.Recalibration;
import net.sf.samtools.SAMRecord;
import net.sf.samtools.SAMFileWriter;
import org.broadinstitute.sting.gatk.walkers.ReadWalker;
import org.broadinstitute.sting.gatk.walkers.WalkerName;
import org.broadinstitute.sting.utils.cmdLine.Argument;
import org.broadinstitute.sting.utils.*;
import java.util.ArrayList;
import java.util.List;
import java.util.regex.Pattern;
import java.io.File;
import java.io.FileNotFoundException;
/**
* Created by IntelliJ IDEA.
* User: rpoplin
* Date: Nov 3, 2009
*/
@WalkerName("TableRecalibrationRefactored")
public class TableRecalibrationWalker extends ReadWalker<SAMRecord, SAMFileWriter> {
@Argument(fullName="recal_file", shortName="rf", doc="Input recalibration table file generated by CountCovariates", required=true)
public String RECAL_FILE;
@Argument(fullName="outputBam", shortName="outputBam", doc="output BAM file", required=false)
public SAMFileWriter OUTPUT_BAM = null;
@Argument(fullName="preserve_qscores_less_than", shortName="pQ", doc="If provided, bases with quality scores less than this threshold won't be recalibrated. In general its unsafe to change qualities scores below < 5, since base callers use these values to indicate random or bad bases", required=false)
public int PRESERVE_QSCORES_LESS_THAN = 5;
@Argument(fullName = "use_original_quals", shortName="OQ", doc="If provided, we will use use the quals from the original qualities OQ attribute field instead of the quals in the regular QUALS field", required=false)
public boolean USE_ORIGINAL_QUALS = false;
@Argument(fullName="smoothing", shortName="sm", required = false, doc="Number of imaginary counts to add to each bin in order to smooth out bins with few data points")
public int SMOOTHING = 1;
public enum RecalibrationMode {
COMBINATORIAL,
SEQUENTIAL,
ERROR
}
@Argument(fullName="RecalibrationMode", shortName="mode", doc="which type of calculation to use when recalibrating, default is SEQUENTIAL", required=false)
public String MODE_STRING = RecalibrationMode.SEQUENTIAL.toString();
public RecalibrationMode MODE = RecalibrationMode.SEQUENTIAL; //BUGBUG: need some code here to set this properly
protected static RecalDataManager dataManager;
protected static ArrayList<Covariate> requestedCovariates;
private static Pattern COVARIATE_PATTERN = Pattern.compile("^@!.*");
public final static String ORIGINAL_QUAL_ATTRIBUTE_TAG = "OQ";
public void initialize() {
// Get a list of all available covariates
List<Class<? extends Covariate>> classes = PackageUtils.getClassesImplementingInterface(Covariate.class);
int lineNumber = 0;
boolean foundAllCovariates = false;
dataManager = new RecalDataManager();
// Read in the covariates that were used from the input file
requestedCovariates = new ArrayList<Covariate>();
// Read in the data from the csv file and populate the map
out.print( "Reading in the data from input file..." );
try {
for ( String line : new xReadLines(new File( RECAL_FILE )) ) {
lineNumber++;
if ( COVARIATE_PATTERN.matcher(line).matches() ) { // the line string is either specifying a covariate or is giving csv data
if( foundAllCovariates ) {
throw new StingException( "Malformed input recalibration file. Found covariate names intermingled with data. " + RECAL_FILE );
} else { // found another covariate in input file
boolean foundClass = false;
for( Class<?> covClass : classes ) {
if( line.equalsIgnoreCase( "@!" + covClass.getSimpleName() ) ) { // the "@!" was added by CovariateCounterWalker as a code to recognize covariate class names
foundClass = true;
try {
Covariate covariate = (Covariate)covClass.newInstance();
requestedCovariates.add( covariate );
} catch ( InstantiationException e ) {
throw new StingException( String.format("Can not instantiate covariate class '%s': must be concrete class.", covClass.getSimpleName()) );
} catch ( IllegalAccessException e ) {
throw new StingException( String.format("Can not instantiate covariate class '%s': must have no-arg constructor.", covClass.getSimpleName()) );
}
}
}
if( !foundClass ) {
throw new StingException( "Malformed input recalibration file. The requested covariate type (" + line + ") isn't a valid covariate option." );
}
}
}
else { // found some data
if( !foundAllCovariates ) {
foundAllCovariates = true;
logger.info( "The covariates being used here: " );
logger.info( requestedCovariates );
}
addCSVData(line);
}
}
} catch ( FileNotFoundException e ) {
Utils.scareUser("Can not find input file: " + RECAL_FILE);
} catch ( NumberFormatException e ) {
throw new RuntimeException("Error parsing recalibration data at line " + lineNumber, e);
}
out.println( "...done!" );
if( MODE == RecalibrationMode.SEQUENTIAL ) {
out.print( "Creating collapsed tables for use in sequential calculation..." );
dataManager.createCollapsedTables( requestedCovariates.size() );
out.println( "...done!" );
}
}
private void addCSVData(String line) {
String[] vals = line.split(",");
List<Comparable<?>> key = new ArrayList<Comparable<?>>();
Covariate cov; // preallocated for use in for loop below
int iii;
for( iii = 0; iii < requestedCovariates.size(); iii++ ) {
cov = requestedCovariates.get( iii );
key.add( cov.getValue( vals[iii] ) );
}
RecalDatum datum = new RecalDatum( Long.parseLong( vals[iii] ), Long.parseLong( vals[iii + 1] ) );
dataManager.data.put( key, datum );
}
public SAMRecord map( char[] refBases, SAMRecord read ) {
byte[] originalQuals = read.getBaseQualities();
// Check if we need to use the original quality scores instead
if ( USE_ORIGINAL_QUALS && read.getAttribute(ORIGINAL_QUAL_ATTRIBUTE_TAG) != null ) {
Object obj = read.getAttribute(ORIGINAL_QUAL_ATTRIBUTE_TAG);
if ( obj instanceof String )
originalQuals = QualityUtils.fastqToPhred((String)obj);
else {
throw new RuntimeException(String.format("Value encoded by %s in %s isn't a string!", ORIGINAL_QUAL_ATTRIBUTE_TAG, read.getReadName()));
}
}
byte[] recalQuals = originalQuals.clone();
// For each base in the read
Comparable<?> keyElement; // preallocate for use in for loops below
for( int iii = 1; iii < read.getReadLength() - 1; iii++ ) { // skip first and last base because there is no dinuc
ArrayList<Comparable<?>> key = new ArrayList<Comparable<?>>();
boolean badKey = false;
for( Covariate covariate : requestedCovariates ) {
keyElement = covariate.getValue( read, iii, refBases );
if ( keyElement != null ) {
key.add( keyElement );
} else {
badKey = true;
}
}
if( !badKey ) {
if( MODE == RecalibrationMode.COMBINATORIAL ) {
RecalDatum datum = dataManager.data.get( key );
if( datum != null ) { // if we have data for this combination of covariates then recalibrate the quality score otherwise do nothing
recalQuals[iii] = datum.empiricalQualByte( SMOOTHING );
}
} else if( MODE == RecalibrationMode.SEQUENTIAL ) {
recalQuals[iii] = performSequentialQualityCalculation( key );
} else {
throw new StingException( "Specified RecalibrationMode is not supported: " + MODE );
}
// Do some error checking on the new quality score
if ( recalQuals[iii] <= 0 || recalQuals[iii] > QualityUtils.MAX_REASONABLE_Q_SCORE ) {
throw new StingException( "Assigning bad quality score " + key + " => " + recalQuals[iii] );
}
}
}
preserveQScores( originalQuals, recalQuals ); // overwrite the work done if original quality score is too low
read.setBaseQualities(recalQuals); // overwrite old qualities with new recalibrated qualities
if ( read.getAttribute(ORIGINAL_QUAL_ATTRIBUTE_TAG) == null ) { // save the old qualities if there is room in the read
read.setAttribute(ORIGINAL_QUAL_ATTRIBUTE_TAG, QualityUtils.phredToFastq(originalQuals));
}
return read;
}
private byte performSequentialQualityCalculation( ArrayList<? extends Comparable<?>> key ) {
byte qualFromRead = Byte.parseByte(key.get(1).toString());
ArrayList<Comparable<?>> newKey;
newKey = new ArrayList<Comparable<?>>();
newKey.add( key.get(0) ); // read group
RecalDatum globalDeltaQDatum = dataManager.getCollapsedTable(0).get( newKey );
double globalDeltaQ = 0.0;
if( globalDeltaQDatum != null ) {
globalDeltaQ = globalDeltaQDatum.empiricalQualDouble( SMOOTHING ) - ( dataManager.dataSumExpectedErrors.get( newKey ) / ((double) globalDeltaQDatum.getNumObservations()) );
}
//System.out.printf("Global quality score shift is %.2f - %.2f = %.2f%n",
// globalDeltaQDatum.empiricalQualDouble( SMOOTHING ), ( dataManager.dataSumExpectedErrors.get( newKey ) / ((double) globalDeltaQDatum.getNumObservations())), globalDeltaQ);
newKey = new ArrayList<Comparable<?>>();
newKey.add( key.get(0) ); // read group
newKey.add( key.get(1) ); // quality score
RecalDatum deltaQReportedDatum = dataManager.getCollapsedTable(1).get( newKey );
double deltaQReported = 0.0;
if( deltaQReportedDatum != null ) {
deltaQReported = deltaQReportedDatum.empiricalQualDouble( SMOOTHING ) - qualFromRead - globalDeltaQ;
}
double deltaQCovariates = 0.0;
RecalDatum deltaQCovariateDatum;
for( int iii = 2; iii < key.size(); iii++ ) {
newKey = new ArrayList<Comparable<?>>();
newKey.add( key.get(0) ); // read group
newKey.add( key.get(1) ); // quality score
newKey.add( key.get(iii) ); // given covariate
deltaQCovariateDatum = dataManager.getCollapsedTable(iii).get( newKey );
if( deltaQCovariateDatum != null ) {
deltaQCovariates += ( deltaQCovariateDatum.empiricalQualDouble( SMOOTHING ) - qualFromRead - (globalDeltaQ + deltaQReported) );
}
}
double newQuality = qualFromRead + globalDeltaQ + deltaQReported + deltaQCovariates;
byte newQualityByte = QualityUtils.boundQual( (int)Math.round(newQuality), QualityUtils.MAX_REASONABLE_Q_SCORE );
//System.out.println( "base quality score calculated: " + key +
// String.format( " => %d + %.2f + %.2f + %.2f = %d", qualFromRead, globalDeltaQ, deltaQReported, deltaQCovariates, newQualityByte ) );
if( newQualityByte <= 0 && newQualityByte >= QualityUtils.MAX_REASONABLE_Q_SCORE ) {
throw new StingException( "Illegal base quality score calculated: " + key +
String.format( " => %d + %.2f + %.2f + %.2f = %d", qualFromRead, globalDeltaQ, deltaQReported, deltaQCovariates, newQualityByte ) );
}
return newQualityByte;
}
private void preserveQScores( byte[] originalQuals, byte[] recalQuals ) {
for( int iii = 0; iii < recalQuals.length; iii++ ) {
if ( originalQuals[iii] < PRESERVE_QSCORES_LESS_THAN ) {
recalQuals[iii] = originalQuals[iii];
}
}
}
public SAMFileWriter reduceInit() {
return OUTPUT_BAM;
}
public SAMFileWriter reduce( SAMRecord read, SAMFileWriter output ) {
if ( output != null ) {
output.addAlignment(read);
} else {
out.println(read.format());
}
return output;
}
public void onTraversalDone( SAMFileWriter reduceResult ) {
}
}