BTTJ - Code refactoring (major) - passes integration test

VariantEvalWalker - whoops, wrote PooledGenotypeAnalysis rather than PooledAnalysis, now passes tests again

- PooledFrequencyAnalysis - don't bother initializing matrices if this isn't a pool




git-svn-id: file:///humgen/gsa-scr1/gsa-engineering/svn_contents/trunk@1895 348d0f76-0448-11de-a6fe-93d51630548a
This commit is contained in:
chartl 2009-10-21 19:04:51 +00:00
parent 15a1849758
commit 8e3f72ced9
3 changed files with 152 additions and 194 deletions

View File

@ -1,8 +1,6 @@
package org.broadinstitute.sting.playground.gatk.walkers;
import org.broadinstitute.sting.gatk.walkers.LocusWalker;
import org.broadinstitute.sting.gatk.walkers.By;
import org.broadinstitute.sting.gatk.walkers.DataSource;
import org.broadinstitute.sting.gatk.walkers.*;
import org.broadinstitute.sting.gatk.walkers.genotyper.*;
import org.broadinstitute.sting.gatk.refdata.RefMetaDataTracker;
import org.broadinstitute.sting.gatk.contexts.ReferenceContext;
@ -25,7 +23,8 @@ import net.sf.samtools.SAMRecord;
* To change this template use File | Settings | File Templates.
*/
@By(DataSource.REFERENCE)
public class BaseTransitionTableCalculatorJavaWalker extends LocusWalker<ReferenceContextWindow,Integer>{
@Reference(window=@Window(start=-3,stop=3))
public class BaseTransitionTableCalculatorJavaWalker extends LocusWalker<Set<BaseTransitionTable>,Set<BaseTransitionTable>> {
@Argument(fullName="usePreviousBases", doc="Use previous bases of the reference as part of the calculation, uses the specified number, defaults to 0", required=false)
int nPreviousBases = 0;
@Argument(fullName="useSecondaryBase",doc="Use the secondary base of a read as part of the calculation", required=false)
@ -46,81 +45,121 @@ public class BaseTransitionTableCalculatorJavaWalker extends LocusWalker<Referen
boolean useReadGroup = false;
private UnifiedGenotyper ug;
private ReferenceContextWindow refWindow;
private Set<BaseTransitionTable> conditionalTables;
// private ReferenceContextWindow refWindow;
// private Set<BaseTransitionTable> conditionalTables;
private List<Boolean> usePreviousBases;
private List<GenomeLoc> previousBaseLoci;
public void initialize() {
if ( nPreviousBases > 3 ) {
throw new StingException("You have opted to use a number of previous bases in excess of 3. In order to do this you must change the reference window size in the walker itself.");
}
ug = new UnifiedGenotyper();
ug.initialize();
refWindow = new ReferenceContextWindow(nPreviousBases);
conditionalTables = new TreeSet<BaseTransitionTable>();
// refWindow = new ReferenceContextWindow(nPreviousBases);
usePreviousBases = new ArrayList<Boolean>();
previousBaseLoci = new ArrayList<GenomeLoc>();
}
public Integer reduceInit() {
return 0;
public Set<BaseTransitionTable> reduceInit() {
return new TreeSet<BaseTransitionTable>();
}
// todo -- emit table from map and reduce just sums
public ReferenceContextWindow map( RefMetaDataTracker tracker, ReferenceContext ref, AlignmentContext context ) {
// todo -- change to use windowed reference itself
// todo -- move up calculations into map not reduce
public Set<BaseTransitionTable> map( RefMetaDataTracker tracker, ReferenceContext ref, AlignmentContext context ) {
ReadBackedPileup pileup = new ReadBackedPileup(ref.getBase(),context);
refWindow.update(ref,pileup,baseIsUsable(tracker,ref,pileup,context));
Set<BaseTransitionTable> newCounts = null;
//System.out.println(pileup.getBases());
if ( baseIsUsable(tracker, ref, pileup, context) ) {
//System.out.println("Pileup will be used");
if ( previousLociCanBeUsed(usePreviousBases,previousBaseLoci,context.getLocation()) ) {
for ( int r = 0; r < pileup.getReads().size(); r ++ ) {
if ( useRead ( pileup.getReads().get(r), pileup.getOffsets().get(r), ref ) ) {
newCounts = updateTables( newCounts, pileup.getReads().get(r), pileup.getOffsets().get(r), ref, pileup );
}
}
} else {
updatePreviousBases(usePreviousBases,true,previousBaseLoci,context.getLocation() );
}
} else {
updatePreviousBases( usePreviousBases,false,previousBaseLoci,context.getLocation() );
}
return refWindow;
return newCounts;
}
public Integer reduce ( ReferenceContextWindow map, Integer prevReduce ) {
if ( map.isValidWindow() ) {
prevReduce++;
List<SAMRecord> reads = map.getPileup().getReads();
List<Integer> offsets = map.getPileup().getOffsets();
ReferenceContext ref = map.getMiddleReferenceContext();
// ReadBackedPileup pileup = splitPileupNonref(map.getPileup());
// List<SAMRecord> reads = pileup.getReads();
// List<Integer> offsets = pileup.getOffsets();
// System.out.println("Base and read are usable:");
// System.out.println("Num Mismatches: "+countMismatches(map.getPileup()));
// System.out.println("Ref: "+map.getMiddleReferenceContext().getBase()+" Pileup ref: "+map.getPileup().getRef());
// System.out.println("Pileup: "+map.getPileup().getBases());
for ( int r = 0; r < reads.size(); r ++ ) {
if ( Character.toUpperCase(reads.get(r).getReadBases()[offsets.get(r)]) != ref.getBase() ) {
// System.out.println("Examining read. Mapping quality is: " + reads.get(r).getMappingQuality());
// System.out.println("Base quality is: "+reads.get(r).getBaseQualities()[offsets.get(r)]);
// System.out.println("Read base is: "+ (char) reads.get(r).getReadBases()[offsets.get(r)]);
// System.out.println("Pileup is: " + map.getPileup().getBases());
// if ( ! useRead(reads.get(r),offsets.get(r),ref) ) {
// System.out.println("Read will not be used.");
//}
public Set<BaseTransitionTable> reduce ( Set<BaseTransitionTable> map, Set<BaseTransitionTable> reduce ) {
if ( map != null && ! map.isEmpty() ) {
for ( BaseTransitionTable t : map ) {
boolean add = true;
for ( BaseTransitionTable r : reduce ) {
if ( r.conditionsMatch(t) ) {
r.incorporateTable(t);
add = false;
break;
}
}
if ( useRead( reads.get(r), offsets.get(r), ref ) ) {
updateTables( reads.get(r), offsets.get(r), map );
// prevReduce++;
if ( add ) {
reduce.add(t);
}
}
}
return prevReduce;
// System.out.println("Reduce: size of TransitionTable set is " + reduce.size() + " -- size of Map: " + (map != null ? map.size() : "null"));
return reduce;
}
public void onTraversalDone( Integer numValidObservedMismatchingReads ) {
logger.info(numValidObservedMismatchingReads);
public void onTraversalDone( Set<BaseTransitionTable> conditionalTables ) {
out.print(createHeaderFromConditions());
for ( BaseTransitionTable t : conditionalTables )
t.print(out);
for ( BaseTransitionTable t : conditionalTables )
t.print(out);
}
public void updateTables ( SAMRecord read, int offset, ReferenceContextWindow map ) {
List<Comparable> readConditions = buildConditions(read,offset,map);
public void updatePreviousBases(List<Boolean> usage, boolean canUse, List<GenomeLoc> loci, GenomeLoc locus) {
// early return
if ( nPreviousBases < 1 ) {
return;
}
if ( usage.size() <= nPreviousBases ) {
usage.add(canUse);
loci.add(locus);
} else {
usage.remove(0);
usage.add(canUse);
loci.remove(0);
loci.add(locus);
}
}
public boolean previousLociCanBeUsed( List<Boolean> canUse, List<GenomeLoc> loci, GenomeLoc locus ) {
if ( nPreviousBases < 1 ) {
return true;
}
boolean use = true;
for ( boolean b : canUse ) {
use = use && b;
}
if ( use ) {
use = use && ( loci.get(0).distance(locus) == 1 ); // truly is PREVIOUS base
}
return use;
}
public Set<BaseTransitionTable> updateTables ( Set<BaseTransitionTable> tables, SAMRecord read, int offset, ReferenceContext ref, ReadBackedPileup pileup ) {
List<Comparable> readConditions = buildConditions(read,offset,ref, pileup);
if ( tables == null ) {
tables = new TreeSet<BaseTransitionTable>();
}
boolean createNewTable = true;
for ( BaseTransitionTable t : conditionalTables ) {
for ( BaseTransitionTable t : tables ) {
if ( t.conditionsMatch(readConditions) ) {
updateTable(t,read,offset,map);
updateTable(t,read,offset,ref);
createNewTable = false;
break;
}
@ -128,16 +167,19 @@ public class BaseTransitionTableCalculatorJavaWalker extends LocusWalker<Referen
if ( createNewTable ) {
BaseTransitionTable t = new BaseTransitionTable(readConditions);
updateTable(t,read,offset,map);
conditionalTables.add(t);
updateTable(t,read,offset,ref);
tables.add(t);
}
return tables;
}
public void updateTable(BaseTransitionTable t, SAMRecord r, int o, ReferenceContextWindow map) {
public void updateTable(BaseTransitionTable t, SAMRecord r, int o, ReferenceContext ref) {
// System.out.println("Update Table");
if ( r.getReadNegativeStrandFlag() ) {
t.update(BaseUtils.simpleComplement((char) r.getReadBases()[o]),BaseUtils.simpleComplement(map.getMiddleReferenceContext().getBase()));
t.update(BaseUtils.simpleComplement((char) r.getReadBases()[o]),BaseUtils.simpleComplement(ref.getBase()));
} else {
t.update(r.getReadBases()[o], map.getMiddleReferenceContext().getBase());
t.update(r.getReadBases()[o], ref.getBase());
}
}
@ -158,14 +200,12 @@ public class BaseTransitionTableCalculatorJavaWalker extends LocusWalker<Referen
}
}
public List<Comparable> buildConditions( SAMRecord read, int offset, ReferenceContextWindow map ) {
public List<Comparable> buildConditions( SAMRecord read, int offset, ReferenceContext ref, ReadBackedPileup pileup ) {
ArrayList<Comparable> conditions = new ArrayList<Comparable>();
if ( nPreviousBases > 0 ) {
if ( ! read.getReadNegativeStrandFlag() )
conditions.add(map.getForwardRefString());
else
conditions.add(map.getReverseRefString());
conditions.add(buildRefString(ref,nPreviousBases, ! read.getReadNegativeStrandFlag()));
}
if ( useSecondaryBase ) {
@ -177,7 +217,7 @@ public class BaseTransitionTableCalculatorJavaWalker extends LocusWalker<Referen
}
if ( usePileupMismatches ) {
conditions.add(countMismatches(map.getPileup()));
conditions.add(countMismatches(pileup));
}
if ( useReadGroup ) {
@ -187,6 +227,14 @@ public class BaseTransitionTableCalculatorJavaWalker extends LocusWalker<Referen
return conditions;
}
public String buildRefString(ReferenceContext ref, int bases, boolean forwardRead) {
if ( forwardRead ) {
return ( new String(ref.getBases()) ).substring(0,nPreviousBases-1);
} else {
return BaseUtils.simpleReverseComplement( ( new String(ref.getBases()) ).substring(nPreviousBases+1) );
}
}
public String createHeaderFromConditions() {
String header = "True_base\tObserved_base";
@ -264,8 +312,12 @@ public class BaseTransitionTableCalculatorJavaWalker extends LocusWalker<Referen
}
class BaseTransitionTable implements Comparable {
/*
* no direct manipulation of these objects ever
*/
private int[][] table;
private List<Comparable> conditions;
@ -283,7 +335,10 @@ class BaseTransitionTable implements Comparable {
public boolean conditionsMatch(Object obj) {
if ( obj == null ) {
return false;
} else if ( obj instanceof BaseTransitionTable ) {
return ((BaseTransitionTable) obj).conditionsMatch(conditions);
} else if ( ! (obj instanceof List) ) {
return false;
} else if ( this.numConditions() != ((List)obj).size() ){
return false;
@ -300,6 +355,7 @@ class BaseTransitionTable implements Comparable {
}
}
public int compareTo(Object obj) {
if ( ! ( obj instanceof BaseTransitionTable ) ) {
return -1;
@ -311,11 +367,13 @@ class BaseTransitionTable implements Comparable {
if ( this.numConditions() == t.numConditions() ) {
ListIterator<Comparable> thisIter = this.conditions.listIterator();
ListIterator<Comparable> thatIter = t.conditions.listIterator();
while ( thisIter.next() == thatIter.next() ) {
// todo -- compareTo
// do nothing
}
return thisIter.previous().compareTo(thatIter.previous());
int g = 0;
do {
g = thisIter.next().compareTo(thatIter.next());
} while ( g == 0 );
return g;
} else {
return (this.numConditions() > t.numConditions() ) ? 1 : -1;
}
@ -325,18 +383,18 @@ class BaseTransitionTable implements Comparable {
}
public void print( PrintStream out ) {
StringBuilder s = new StringBuilder();
for ( char observedBase : BaseUtils.BASES ) {
for ( char refBase : BaseUtils.BASES ) {
// todo -- String.format please
// todo -- in these situations use StringBuilder
String outString = observedBase+"\t"+refBase;
s.append(String.format("%s\t%s",observedBase,refBase));
for ( Comparable c : conditions ) {
outString = outString+"\t"+c.toString();
s.append(String.format("\t%s",c.toString()));
}
out.printf("%s\t%d%n",outString,table[BaseUtils.simpleBaseToBaseIndex(observedBase)][BaseUtils.simpleBaseToBaseIndex(refBase)]);
s.append(String.format("\t%d%n", table[BaseUtils.simpleBaseToBaseIndex(observedBase)][BaseUtils.simpleBaseToBaseIndex(refBase)]));
}
}
out.print(s.toString());
}
public void update(char observedBase, char refBase ) {
@ -362,118 +420,16 @@ class BaseTransitionTable implements Comparable {
return conditions.listIterator();
}
}
class ReferenceContextWindow {
protected int windowSize;
protected int nPrevBases;
protected LinkedList<ReadBackedPileup> prevAlignments;
protected LinkedList<ReferenceContext> prevRefs;
protected LinkedList<Boolean> usePrevious;
protected boolean initialized;
public ReferenceContextWindow( int nPrevBases ) {
windowSize = 2*nPrevBases + 1;
this.nPrevBases = nPrevBases;
prevAlignments = new LinkedList<ReadBackedPileup>();
prevRefs = new LinkedList<ReferenceContext>();
usePrevious = new LinkedList<Boolean>();
initialized = false;
}
public void update( ReferenceContext ref, ReadBackedPileup pileup, boolean useLocus ) {
if ( ! initialized ) {
prevAlignments.add(pileup);
prevRefs.add(ref);
usePrevious.add(useLocus);
if ( prevAlignments.size() == windowSize ) {
initialized = true;
}
} else {
prevAlignments.removeFirst();
prevRefs.removeFirst();
usePrevious.removeFirst();
prevAlignments.add(pileup);
prevRefs.add(ref);
usePrevious.add(useLocus);
}
}
public String getReferenceString() {
String ref = "";
for ( ReferenceContext c : prevRefs ) {
ref = ref + c.getBase();
}
return ref;
}
public String getForwardRefString() {
String ref = "";
for ( ReferenceContext c : prevRefs.subList(0,nPrevBases) ) {
ref = ref + c.getBase();
}
return ref;
}
public String getReverseRefString() { // todo -- make sure we want to flip this done (yes we do)
String ref = "";
for ( int base = prevRefs.size()-1; base > nPrevBases; base -- ) {
ref = ref + prevRefs.get(base).getBase();
}
return BaseUtils.simpleComplement(ref);
}
public ReadBackedPileup getPileup() {
// because lists are 0-indexed, this returns the alignments
// to the middle base in the window.
return prevAlignments.get(nPrevBases);
}
public ReferenceContext getMiddleReferenceContext() {
return prevRefs.get(nPrevBases);
}
public boolean isValidWindow() {
boolean valid;
if ( ! initialized ) {
valid = false;
} else {
valid = true;
// check if everything is confident ref
for ( Boolean b : usePrevious ) {
if ( !b ) {
valid = false;
break;
}
}
// if still valid, check distances
if ( valid ) {
ListIterator<ReferenceContext> iter = prevRefs.listIterator();
ReferenceContext prev = iter.next();
while ( iter.hasNext() ) {
ReferenceContext cur = iter.next();
if ( cur.getLocus().distance(prev.getLocus()) > 1 ) {
valid = false;
break;
}
prev = cur;
}
public void incorporateTable(BaseTransitionTable t) {
for ( int i = 0; i < BaseUtils.BASES.length; i ++ ) {
for ( int j = 0; j < BaseUtils.BASES.length; j ++ ) {
table[i][j] += t.observationsOf(i,j);
}
}
return valid;
}
public int getWindowSize() {
return windowSize;
public int observationsOf( int observedBaseIndex, int referenceBaseIndex ) {
return table[observedBaseIndex][referenceBaseIndex];
}
}
}

View File

@ -26,15 +26,17 @@ public class PooledFrequencyAnalysis extends BasicPoolVariantAnalysis implements
public PooledFrequencyAnalysis(int poolSize, String knownDBSNPName ) {
super("Pooled_Frequency_Analysis",poolSize);
coverageAnalysisByFrequency = new VariantDBCoverage[getNumberOfAllelesInPool()+1];
variantCounterByFrequency = new VariantCounter[getNumberOfAllelesInPool()+1];
transitionTransversionByFrequency = new TransitionTranversionAnalysis[getNumberOfAllelesInPool()+1];
for ( int j = 0; j < getNumberOfAllelesInPool()+1; j ++ ) {
coverageAnalysisByFrequency[j] = new VariantDBCoverage(knownDBSNPName);
variantCounterByFrequency[j] = new VariantCounter();
transitionTransversionByFrequency[j] = new TransitionTranversionAnalysis();
if ( poolSize > 0 ) {
coverageAnalysisByFrequency = new VariantDBCoverage[getNumberOfAllelesInPool()+1];
variantCounterByFrequency = new VariantCounter[getNumberOfAllelesInPool()+1];
transitionTransversionByFrequency = new TransitionTranversionAnalysis[getNumberOfAllelesInPool()+1];
for ( int j = 0; j < getNumberOfAllelesInPool()+1; j ++ ) {
coverageAnalysisByFrequency[j] = new VariantDBCoverage(knownDBSNPName);
variantCounterByFrequency[j] = new VariantCounter();
transitionTransversionByFrequency[j] = new TransitionTranversionAnalysis();
}
}
}
}
public void initialize(VariantEvalWalker master, PrintStream out1, PrintStream out2, String name) {
super.initialize(master,out1,out2,name);

View File

@ -149,7 +149,7 @@ public class VariantEvalWalker extends RodWalker<Integer, Integer> {
VariantAnalysis analysis = iter.next();
boolean disableForGenotyping = evalContainsGenotypes && !(analysis instanceof GenotypeAnalysis);
boolean disableForPopulation = !evalContainsGenotypes && !(analysis instanceof PopulationAnalysis);
boolean disableForPools = (pathToHapmapPoolFile == null && analysis instanceof PooledGenotypeConcordance) || (numPeopleInPool < 1 && analysis instanceof PooledGenotypeConcordance);
boolean disableForPools = (pathToHapmapPoolFile == null && analysis instanceof PooledGenotypeConcordance) || (numPeopleInPool < 1 && analysis instanceof PoolAnalysis);
boolean disable = disableForGenotyping | disableForPopulation | disableForPools;
String causeName = disableForGenotyping ? "population" : (disableForPopulation ? "genotype" : (disableForPools ? "pool" : null));
if (disable) {