diff --git a/java/src/org/broadinstitute/sting/playground/gatk/walkers/CovariateCounterWalker.java b/java/src/org/broadinstitute/sting/playground/gatk/walkers/CovariateCounterWalker.java index 93d4b2afa..841eaa7b7 100644 --- a/java/src/org/broadinstitute/sting/playground/gatk/walkers/CovariateCounterWalker.java +++ b/java/src/org/broadinstitute/sting/playground/gatk/walkers/CovariateCounterWalker.java @@ -7,9 +7,12 @@ import org.broadinstitute.sting.gatk.refdata.RefMetaDataTracker; import org.broadinstitute.sting.gatk.refdata.rodDbSNP; import org.broadinstitute.sting.gatk.walkers.LocusWalker; import org.broadinstitute.sting.gatk.walkers.WalkerName; +import org.broadinstitute.sting.playground.gatk.walkers.RecalData; + import org.broadinstitute.sting.utils.cmdLine.Argument; import org.broadinstitute.sting.utils.QualityUtils; import org.broadinstitute.sting.utils.Utils; +import org.broadinstitute.sting.utils.BaseUtils; import java.util.ArrayList; import java.util.List; @@ -21,11 +24,11 @@ import java.io.FileNotFoundException; @WalkerName("CountCovariates") public class CovariateCounterWalker extends LocusWalker { - @Argument(fullName="MAX_READ_LENGTH", shortName="mrl", doc="max read length", required=false) - public int MAX_READ_LENGTH = 101; + @Argument(fullName="maxReadLen", shortName="mrl", doc="max read length", required=false) + public int maxReadLen = 101; - @Argument(fullName="MAX_QUAL_SCORE", shortName="mqs", doc="max quality score", required=false) - public int MAX_QUAL_SCORE = 63; + //@Argument(fullName="MAX_QUAL_SCORE", shortName="mqs", doc="max quality score", required=false) + //public int MAX_QUAL_SCORE = 63; @Argument(fullName="OUTPUT_FILEROOT", shortName="outroot", required=false, doc="Filename root for the outputted logistic regression training files") public String OUTPUT_FILEROOT = "output"; @@ -48,56 +51,21 @@ public class CovariateCounterWalker extends LocusWalker { @Argument(fullName="PLATFORM", shortName="pl", required=false, doc="Only calibrate read groups generated from the given platform (default = Illumina)") public List platforms = Collections.singletonList("ILLUMINA"); - int NDINUCS = 16; - ArrayList flattenData = new ArrayList(); - HashMap data = new HashMap(); - //RecalData[][][] data; + @Argument(fullName="rawData", shortName="raw", required=false, doc="If true, raw mismatch observations will be output to a file") + public boolean outputRawData = true; - static int nuc2num[]; - static char num2nuc[]; + int NDINUCS = 16; + //ArrayList flattenData = new ArrayList(); + //HashMap data = new HashMap(); + HashMap data = new HashMap(); + //RecalData[][][] data; + boolean trackPos = true; + boolean trackDinuc = true; long counted_sites = 0; // number of sites used to count covariates + long counted_bases = 0; // number of bases used to count covariates long skipped_sites = 0; // number of sites skipped because of a dbSNP entry - private class RecalData { - long N; - long B; - int pos; - int qual; - String readGroup; - String dinuc; - - public RecalData(int pos, int qual, String readGroup, String dinuc ) { - this.pos = pos; - this.qual = qual; - this.readGroup = readGroup; - this.dinuc = dinuc; - } - - public void inc(long incN, long incB) { - N += incN; - B += incB; - } - - - public void inc(char curBase, char ref) { - inc(1, nuc2num[curBase] == nuc2num[ref] ? 0 : 1); - //out.printf("%s %s\n", curBase, ref); - } - - public String headerString() { - return ("pos, rg, dinuc, qual, emp_qual, qual_diff, n, b"); - } - - public String toString() { - double empiricalQual = -10 * Math.log10((double)B / N); - - if (empiricalQual > MAX_QUAL_SCORE) empiricalQual = MAX_QUAL_SCORE; - return String.format("%3d,%s,%s,%3d,%5.1f,%5.1f,%6d,%6d", pos, readGroup, dinuc, qual, empiricalQual, qual-empiricalQual, N, B); - //return String.format("%d\t%s\t%d\t%.1f\t%d\t%6d", pos, dinuc, qual, empiricalQual, N, B); - } - } - public void initialize() { if( getToolkit().getEngine().getSAMHeader().getReadGroups().size() > MAX_READ_GROUPS ) Utils.scareUser("Number of read groups in the specified file exceeds the number that can be processed in a reasonable amount of memory." + @@ -108,69 +76,90 @@ public class CovariateCounterWalker extends LocusWalker { Utils.warnUser(String.format("PL attribute for read group %s is unset; assuming all reads are supported",readGroup.getReadGroupId())); if( !isSupportedReadGroup(readGroup) ) continue; - data.put(readGroup.getReadGroupId(), new RecalData[MAX_READ_LENGTH+1][MAX_QUAL_SCORE+1][NDINUCS]); - for ( int i = 0; i < MAX_READ_LENGTH+1; i++) { - for ( int j = 0; j < MAX_QUAL_SCORE+1; j++) { - for ( int k = 0; k < NDINUCS; k++) { - String dinuc = dinucIndex2bases(k); - RecalData datum = new RecalData(i, j, readGroup.getReadGroupId(), dinuc); - data.get(readGroup.getReadGroupId())[i][j][k] = datum; - flattenData.add(datum); - } - } - } + String rg = readGroup.getReadGroupId(); + RecalDataManager manager = new RecalDataManager(rg, maxReadLen, QualityUtils.MAX_QUAL_SCORE+1, NDINUCS, trackPos, trackDinuc ); + //data.put(rg, new RecalData[maxReadLen+1][QualityUtils.MAX_QUAL_SCORE+1][NDINUCS]); + data.put(rg, manager); } } + private RecalData getRecalData(String readGroup, int pos, int qual, int dinuc_index) { + return data.get(readGroup).expandingGetRecalData(pos, qual, dinuc_index, true); + } + + private List getRecalData(String readGroup) { + return data.get(readGroup).getAll(); + } + public Integer map(RefMetaDataTracker tracker, char ref, LocusContext context) { + //System.out.printf("%s %c%n", context.getLocation(), ref); rodDbSNP dbsnp = (rodDbSNP)tracker.lookup("dbSNP", null); if ( dbsnp == null || !dbsnp.isSNP() ) { List reads = context.getReads(); List offsets = context.getOffsets(); for (int i =0; i < reads.size(); i++ ) { SAMRecord read = reads.get(i); + + if ( read.getReadLength() > maxReadLen ) { + throw new RuntimeException("Expectedly long read, please increase maxium read len with maxReadLen parameter: " + read.format()); + } + SAMReadGroupRecord readGroup = read.getHeader().getReadGroup((String)read.getAttribute("RG")); if ( isSupportedReadGroup(readGroup) && - !read.getReadNegativeStrandFlag() && + //!read.getReadNegativeStrandFlag() && (READ_GROUP.equals("none") || read.getAttribute("RG") != null && read.getAttribute("RG").equals(READ_GROUP)) && (read.getMappingQuality() >= MIN_MAPPING_QUALITY)) { - //(random_genrator.nextFloat() <= DOWNSAMPLE_FRACTION) int offset = offsets.get(i); int numBases = read.getReadLength(); if ( offset > 0 && offset < (numBases-1) ) { // skip first and last bases because they suck and they don't have a dinuc count - int qual = (int)read.getBaseQualities()[offset]; - if (qual > 0 && qual <= MAX_QUAL_SCORE) { - // previous base is the next base in terms of machine chemistry if this is a negative strand - char base = (char)read.getReadBases()[offset]; - char prevBase = (char)read.getReadBases()[offset -1]; - int dinuc_index = bases2dinucIndex(prevBase, base, false); - //char prevBase = (char)read.getReadBases()[offset + (read.getReadNegativeStrandFlag() ? 1 : -1)]; - //int dinuc_index = bases2dinucIndex(prevBase, base, read.getReadNegativeStrandFlag()); - - // Convert offset into cycle position which means reversing the position of reads on the negative strand - //int cycle = read.getReadNegativeStrandFlag() ? numBases - offset - 1 : offset; - //data[cycle][qual][dinuc_index].inc(base,ref); - data.get(readGroup.getReadGroupId())[offset][qual][dinuc_index].inc(base,ref); - } + counted_bases += updateDataFromRead(readGroup.getReadGroupId(), read, offset, ref); } } } counted_sites += 1; - }else{ + } else { skipped_sites += 1; //System.out.println(dbsnp.toSimpleString()+" "+new ReadBackedPileup(ref, context).getPileupString()); } return 1; } + private int updateDataFromRead( String rg, SAMRecord read, int offset, char ref ) { + int cycle = offset; + byte[] bases = read.getReadBases(); + byte[] quals = read.getBaseQualities(); + + char base = (char)bases[offset]; + char prevBase = (char)bases[offset - 1]; + + if (read.getReadNegativeStrandFlag()) { + ref = (char)BaseUtils.simpleComplement(ref); + base = (char)BaseUtils.simpleComplement(base); + prevBase = (char)BaseUtils.simpleComplement((char)bases[offset+1]); + cycle = read.getReadLength() - (offset + 1); + } + + int qual = quals[offset]; + if ( qual > 0 && qual <= QualityUtils.MAX_QUAL_SCORE ) { + // previous base is the next base in terms of machine chemistry if this is a negative strand + int dinuc_index = RecalData.bases2dinucIndex(prevBase, base, false); + //System.out.printf("Adding b_offset=%c offset=%d cycle=%d qual=%d dinuc=%c%c ref_match=%c comp=%c%n", (char)read.getReadBases()[offset], offset, cycle, qual, prevBase, base, ref, (char)BaseUtils.simpleComplement(ref)); + getRecalData(rg, cycle, qual, dinuc_index).inc(base,ref); + return 1; + } else { + return 0; + } + } + public void onTraversalDone(Integer result) { PrintStream covars_out; try { covars_out = new PrintStream(OUTPUT_FILEROOT+".covars.out"); - if (flattenData.size() > 0) - covars_out.println(flattenData.get(0).headerString()); - for ( RecalData datum : flattenData ) { - covars_out.println(datum); + covars_out.println(RecalData.headerString()); + for (SAMReadGroupRecord readGroup : this.getToolkit().getEngine().getSAMHeader().getReadGroups()) { + for ( RecalData datum : getRecalData(readGroup.getReadGroupId()) ) { + covars_out.println(datum); + } } } catch (FileNotFoundException e) { System.err.println("FileNotFoundException: " + e.getMessage()); @@ -181,6 +170,7 @@ public class CovariateCounterWalker extends LocusWalker { qualityDiffVsDinucleotide(); out.printf("Counted sites: %d%n", counted_sites); + out.printf("Counted bases: %d%n", counted_bases); out.printf("Skipped sites: %d%n", skipped_sites); out.printf("Fraction skipped: 1/%.0f%n", (double)counted_sites / skipped_sites); @@ -189,29 +179,40 @@ public class CovariateCounterWalker extends LocusWalker { void writeTrainingData() { PrintStream dinuc_out = null; + PrintStream table_out = null; try { dinuc_out = new PrintStream( OUTPUT_FILEROOT+".covariate_counts.csv"); - dinuc_out.println("rg,dn,logitQ,pos,indicator,count"); + dinuc_out.println("rg,dn,logitQ,pos,indicator,count"); for (SAMReadGroupRecord readGroup : this.getToolkit().getEngine().getSAMHeader().getReadGroups()) { for ( int dinuc_index=0; dinuc_index 0) - dinuc_out.format("%s,%s,%d,%d,%d,%d%n", readGroup.getReadGroupId(), dinucIndex2bases(dinuc_index), datum.qual, datum.pos, 0, datum.N - datum.B); + dinuc_out.format("%s,%s,%d,%d,%d,%d%n", readGroup.getReadGroupId(), RecalData.dinucIndex2bases(dinuc_index), datum.qual, datum.pos, 0, datum.N - datum.B); if (datum.B > 0) - dinuc_out.format("%s,%s,%d,%d,%d,%d%n", readGroup.getReadGroupId(), dinucIndex2bases(dinuc_index), datum.qual, datum.pos, 1, datum.B); + dinuc_out.format("%s,%s,%d,%d,%d,%d%n", readGroup.getReadGroupId(), RecalData.dinucIndex2bases(dinuc_index), datum.qual, datum.pos, 1, datum.B); } } } } + + if ( outputRawData ) { + table_out = new PrintStream( OUTPUT_FILEROOT+".raw_data.csv"); + for (SAMReadGroupRecord readGroup : this.getToolkit().getEngine().getSAMHeader().getReadGroups()) { + for ( RecalData datum: getRecalData(readGroup.getReadGroupId()) ) { + if ( datum.N > 0 ) + table_out.format("%s%n", datum.toCSVString()); + } + } + } } catch (FileNotFoundException e) { System.err.println("FileNotFoundException: " + e.getMessage()); return; } finally { - if (dinuc_out != null) - dinuc_out.close(); + if (dinuc_out != null) dinuc_out.close(); + if (table_out != null) table_out.close(); } } @@ -248,21 +249,19 @@ public class CovariateCounterWalker extends LocusWalker { ByCycleFile.printf("cycle,Qemp-obs,Qemp,Qobs,B,N%n"); RecalData All = new RecalData(0,0,readGroup.getReadGroupId(),""); MeanReportedQuality AllReported = new MeanReportedQuality(); - for (int c=0; c < MAX_READ_LENGTH+1; c++) { + for (int c=0; c < maxReadLen+1; c++) { ByCycle.add(new RecalData(c, -1,readGroup.getReadGroupId(),"-")); ByCycleReportedQ.add(new MeanReportedQuality()); } - for ( RecalData datum: flattenData ) { - if( !datum.readGroup.equals(readGroup.getReadGroupId()) ) - continue; + for ( RecalData datum: getRecalData(readGroup.getReadGroupId()) ) { ByCycle.get(datum.pos).inc(datum.N, datum.B); ByCycleReportedQ.get(datum.pos).inc(datum.qual, datum.N); All.inc(datum.N, datum.B); AllReported.inc(datum.qual, datum.N); } - for (int c=0; c < MAX_READ_LENGTH+1; c++) { + for (int c=0; c < maxReadLen+1; c++) { double empiricalQual = -10 * Math.log10((double)ByCycle.get(c).B / ByCycle.get(c).N); double reportedQual = ByCycleReportedQ.get(c).result(); ByCycleFile.printf("%d, %f, %f, %f, %d, %d%n", c, empiricalQual-reportedQual, empiricalQual, reportedQual, ByCycle.get(c).B, ByCycle.get(c).N); @@ -287,14 +286,12 @@ public class CovariateCounterWalker extends LocusWalker { RecalData All = new RecalData(0,0,readGroup.getReadGroupId(),""); MeanReportedQuality AllReported = new MeanReportedQuality(); for (int c=0; c < NDINUCS; c++) { - ByCycle.add(new RecalData(-1, -1,readGroup.getReadGroupId(),dinucIndex2bases(c))); + ByCycle.add(new RecalData(-1, -1,readGroup.getReadGroupId(),RecalData.dinucIndex2bases(c))); ByCycleReportedQ.add(new MeanReportedQuality()); } - for ( RecalData datum: flattenData ) { - if( !datum.readGroup.equals(readGroup.getReadGroupId()) ) - continue; - int dinucIndex = string2dinucIndex(datum.dinuc); //bases2dinucIndex(datum.dinuc.charAt(0), datum.dinuc.charAt(1), false); + for ( RecalData datum: getRecalData(readGroup.getReadGroupId()) ) { + int dinucIndex = RecalData.string2dinucIndex(datum.dinuc); //bases2dinucIndex(datum.dinuc.charAt(0), datum.dinuc.charAt(1), false); ByCycle.get(dinucIndex).inc(datum.N, datum.B); ByCycleReportedQ.get(dinucIndex).inc(datum.qual, datum.N); All.inc(datum.N, datum.B); @@ -325,14 +322,12 @@ public class CovariateCounterWalker extends LocusWalker { ByQualFile.printf("Qrep,Qemp,Qrep_avg,B,N%n"); RecalData All = new RecalData(0,0,readGroup.getReadGroupId(),""); MeanReportedQuality AllReported = new MeanReportedQuality(); - for (int q=0; q { //out.printf("%2d%6d%3d %2d %s%n", datum.qual, datum.N, datum.pos, datum.qual, datum.dinuc); } - for (int q=0; q { return 0; } - public int bases2dinucIndex(char prevBase, char base, boolean Complement) { - if (!Complement) { - return nuc2num[prevBase] * 4 + nuc2num[base]; - }else{ - return (3 - nuc2num[prevBase]) * 4 + (3 - nuc2num[base]); - } - } - - public String dinucIndex2bases(int index) { - char data[] = {num2nuc[index / 4], num2nuc[index % 4]}; - return new String( data ); - } - - public int string2dinucIndex(String s) { - return bases2dinucIndex(s.charAt(0), s.charAt(1), false); - } - - static { - nuc2num = new int[128]; - nuc2num['A'] = 0; - nuc2num['C'] = 1; - nuc2num['G'] = 2; - nuc2num['T'] = 3; - nuc2num['a'] = 0; - nuc2num['c'] = 1; - nuc2num['g'] = 2; - nuc2num['t'] = 3; - - num2nuc = new char[4]; - num2nuc[0] = 'A'; - num2nuc[1] = 'C'; - num2nuc[2] = 'G'; - num2nuc[3] = 'T'; - } Random random_genrator; // Print out data for regression public CovariateCounterWalker() throws FileNotFoundException { diff --git a/java/src/org/broadinstitute/sting/playground/gatk/walkers/LogisticRecalibrationWalker.java b/java/src/org/broadinstitute/sting/playground/gatk/walkers/LogisticRecalibrationWalker.java index ebac1f1ad..7835eaecf 100755 --- a/java/src/org/broadinstitute/sting/playground/gatk/walkers/LogisticRecalibrationWalker.java +++ b/java/src/org/broadinstitute/sting/playground/gatk/walkers/LogisticRecalibrationWalker.java @@ -139,6 +139,9 @@ public class LogisticRecalibrationWalker extends ReadWalker %d",