diff --git a/java/src/org/broadinstitute/sting/playground/gatk/walkers/recalibration/RecalData.java b/java/src/org/broadinstitute/sting/playground/gatk/walkers/recalibration/RecalData.java new file mode 100755 index 000000000..eedea5c74 --- /dev/null +++ b/java/src/org/broadinstitute/sting/playground/gatk/walkers/recalibration/RecalData.java @@ -0,0 +1,112 @@ +package org.broadinstitute.sting.playground.gatk.walkers; + +import org.broadinstitute.sting.utils.QualityUtils; +import org.broadinstitute.sting.utils.Utils; + +public class RecalData { + long N; + long B; + int pos; + int qual; + String readGroup; + String dinuc; + + public RecalData(int pos, int qual, String readGroup, String dinuc ) { + this.pos = pos; + this.qual = qual; + this.readGroup = readGroup; + this.dinuc = dinuc; + } + + public void inc(long incN, long incB) { + N += incN; + B += incB; + } + + public int getDinucIndex() { + return string2dinucIndex(this.dinuc); + } + + public void inc(char curBase, char ref) { + inc(1, nuc2num[curBase] == nuc2num[ref] ? 0 : 1); + //out.printf("%s %s\n", curBase, ref); + } + + public static String headerString() { + return ("pos, rg, dinuc, qual, emp_qual, qual_diff, n, b"); + } + + public double empiricalQualDouble() { + double empiricalQual = -10 * Math.log10((double)B / N); + if (empiricalQual > QualityUtils.MAX_QUAL_SCORE) empiricalQual = QualityUtils.MAX_QUAL_SCORE; + return empiricalQual; + } + + public byte empiricalQualByte() { + return QualityUtils.probToQual(1.0 - (double)B / N); + } + + public String toString() { + double empiricalQual = empiricalQualDouble(); + return String.format("%3d,%s,%s,%3d,%5.1f,%5.1f,%6d,%6d", pos, readGroup, dinuc, qual, empiricalQual, qual-empiricalQual, N, B); + } + + public String toCSVString() { + return String.format("%s,%s,%d,%d,%d,%d", readGroup, dinuc, qual, pos, N, B); + } + + public static RecalData fromCSVString(String s) { + String[] vals = s.split(","); + String rg = vals[0]; + String dinuc = vals[1]; + int qual = Integer.parseInt(vals[2]); + int pos = Integer.parseInt(vals[3]); + int N = Integer.parseInt(vals[4]); + int B = Integer.parseInt(vals[5]); + RecalData datum = new RecalData(pos, qual, rg, dinuc); + datum.B = B; + datum.N = N; + + //if ( datum.N > 0 ) System.out.printf("Parsing line [%s] => [%s]%n", s, datum); + + return datum; + } + + public static int bases2dinucIndex(char prevBase, char base, boolean Complement) { + if (!Complement) { + return nuc2num[prevBase] * 4 + nuc2num[base]; + }else{ + return (3 - nuc2num[prevBase]) * 4 + (3 - nuc2num[base]); + } + } + + public static String dinucIndex2bases(int index) { + char data[] = {num2nuc[index / 4], num2nuc[index % 4]}; + return new String( data ); + } + + public static int string2dinucIndex(String s) { + return bases2dinucIndex(s.charAt(0), s.charAt(1), false); + } + + private static int nuc2num[]; + private static char num2nuc[]; + + static { + nuc2num = new int[128]; + nuc2num['A'] = 0; + nuc2num['C'] = 1; + nuc2num['G'] = 2; + nuc2num['T'] = 3; + nuc2num['a'] = 0; + nuc2num['c'] = 1; + nuc2num['g'] = 2; + nuc2num['t'] = 3; + + num2nuc = new char[4]; + num2nuc[0] = 'A'; + num2nuc[1] = 'C'; + num2nuc[2] = 'G'; + num2nuc[3] = 'T'; + } +} \ No newline at end of file diff --git a/java/src/org/broadinstitute/sting/playground/gatk/walkers/recalibration/RecalDataManager.java b/java/src/org/broadinstitute/sting/playground/gatk/walkers/recalibration/RecalDataManager.java new file mode 100755 index 000000000..1d9a6d909 --- /dev/null +++ b/java/src/org/broadinstitute/sting/playground/gatk/walkers/recalibration/RecalDataManager.java @@ -0,0 +1,150 @@ +package org.broadinstitute.sting.playground.gatk.walkers; + +import org.broadinstitute.sting.gatk.walkers.WalkerName; +import org.broadinstitute.sting.gatk.walkers.LocusWalker; +import org.broadinstitute.sting.gatk.refdata.RefMetaDataTracker; +import org.broadinstitute.sting.gatk.refdata.rodDbSNP; +import org.broadinstitute.sting.gatk.LocusContext; +import org.broadinstitute.sting.utils.cmdLine.Argument; +import org.broadinstitute.sting.utils.Utils; +import org.broadinstitute.sting.utils.QualityUtils; +import org.broadinstitute.sting.utils.BaseUtils; + +import java.util.*; +import java.io.PrintStream; +import java.io.FileNotFoundException; + +import net.sf.samtools.SAMReadGroupRecord; +import net.sf.samtools.SAMRecord; + +/** + * Created by IntelliJ IDEA. + * User: mdepristo + * Date: Jun 16, 2009 + * Time: 9:55:10 PM + * To change this template use File | Settings | File Templates. + */ +public class RecalDataManager { + ArrayList flattenData = new ArrayList(); + RecalData[][][] data = null; + boolean trackPos, trackDinuc; + String readGroup; + int nDinucs, maxReadLen; + + public RecalDataManager(String readGroup, + int maxReadLen, int maxQual, int nDinucs, + boolean trackPos, boolean trackDinuc) { + data = new RecalData[maxReadLen+1][QualityUtils.MAX_QUAL_SCORE+1][nDinucs]; + this.readGroup = readGroup; + this.trackPos = trackPos; + this.trackDinuc = trackDinuc; + this.maxReadLen = maxReadLen; + this.nDinucs = nDinucs; + } + + public int getPosIndex(int pos) { + return trackPos ? pos : 0; + } + + public int getDinucIndex(int dinuc) { + return trackDinuc ? dinuc : 0; + } + + public void addDatum(RecalData datum) { + if ( ! datum.readGroup.equals(this.readGroup) ) { + throw new RuntimeException(String.format("BUG: adding incorrect read group datum %s to RecalDataManager for %s", datum.readGroup, this.readGroup)); + } + + if ( getRecalData(datum.pos, datum.qual, datum.getDinucIndex()) != null ) + throw new RuntimeException(String.format("Duplicate entry discovered: %s vs. %s", getRecalData(datum.pos, datum.qual, datum.getDinucIndex()), datum)); + + int posIndex = getPosIndex(datum.pos); + int internalDinucIndex = getDinucIndex(datum.getDinucIndex()); + data[posIndex][datum.qual][internalDinucIndex] = datum; + flattenData.add(datum); + } + + public RecalData getRecalData(int pos, int qual, int dinuc_index) { + return expandingGetRecalData(pos, qual, dinuc_index, false); + } + + public RecalData expandingGetRecalData(int pos, int qual, int dinuc_index, boolean expandP) { + int posIndex = getPosIndex(pos); + int internalDinucIndex = getDinucIndex(dinuc_index); + + RecalData datum = data[posIndex][qual][internalDinucIndex]; + if ( datum == null && expandP ) { + //System.out.printf("Allocating %s %d %d %d%n", readGroup, pos, qual, dinuc_index); + datum = new RecalData(posIndex, qual, readGroup, RecalData.dinucIndex2bases(dinuc_index)); + data[posIndex][qual][internalDinucIndex] = datum; + flattenData.add(datum); + } + + return datum; + } + + public List select(int pos, int qual, int dinuc_index ) { + List l = new LinkedList(); + for ( int i = 0; i < data.length; i++ ) { + if ( i == pos || pos == -1 || ! trackPos ) { + for ( int j = 0; j < data[i].length; j++ ) { + if ( j == qual || qual == -1 ) { + for ( int k = 0; k < data[i][j].length; k++ ) { + if ( k == dinuc_index|| dinuc_index == -1 || ! trackDinuc ) { + l.add(data[i][j][k]); + } + } + } + } + } + } + + return l; + } + + public List getDataByPos() { + List l = new ArrayList(data.length); + for ( int pos = 0; pos < maxReadLen; pos++ ) { + for ( int qual = 0; qual < QualityUtils.MAX_QUAL_SCORE+1; qual++ ) { + RecalData datum = new RecalData(pos, qual, readGroup, "**"); + for ( int dinucIndex = 0; dinucIndex < nDinucs; dinucIndex++ ) { + RecalData datum2 = getRecalData(pos, qual, dinucIndex); + if ( datum2 != null ) + datum.inc(data[pos][qual][dinucIndex].N, data[pos][qual][dinucIndex].B); + } + if ( datum.N > 0 ) l.add(datum); + } + } + + System.out.printf("getDataByPos => %d%n", l.size()); + return l; + } + + public List getDataByDinuc() { + List l = new ArrayList(nDinucs); + + for ( int dinucIndex = 0; dinucIndex < nDinucs; dinucIndex++ ) { + for ( int qual = 0; qual < QualityUtils.MAX_QUAL_SCORE+1; qual++ ) { + RecalData datum = new RecalData(-1, qual, readGroup, RecalData.dinucIndex2bases(dinucIndex)); + System.out.printf("Aggregating [%s]:%n", datum); + for ( int pos = 0; pos < data.length; pos++ ) { + RecalData datum2 = getRecalData(pos, qual, dinucIndex); + if ( datum2 != null ) { + System.out.printf(" + [%s]:%n", datum2); + datum.inc(data[pos][qual][dinucIndex].N, data[pos][qual][dinucIndex].B); + } + } + if ( datum.N > 0 ) l.add(datum); + System.out.printf(" %s [%s]:%n", datum.N > 0 ? "=>" : "<>", datum); + } + } + + System.out.printf("getDataByDinuc => %d%n", l.size()); + return l; + } + + public List getAll() { + return flattenData; + } +} + diff --git a/java/src/org/broadinstitute/sting/playground/gatk/walkers/recalibration/TableRecalibrationWalker.java b/java/src/org/broadinstitute/sting/playground/gatk/walkers/recalibration/TableRecalibrationWalker.java new file mode 100755 index 000000000..aa773a1c0 --- /dev/null +++ b/java/src/org/broadinstitute/sting/playground/gatk/walkers/recalibration/TableRecalibrationWalker.java @@ -0,0 +1,250 @@ +package org.broadinstitute.sting.playground.gatk.walkers; + +import net.sf.samtools.*; +import org.broadinstitute.sting.gatk.walkers.WalkerName; +import org.broadinstitute.sting.gatk.walkers.ReadWalker; +import org.broadinstitute.sting.utils.cmdLine.Argument; +import org.broadinstitute.sting.utils.*; +import org.apache.log4j.Logger; + +import java.util.*; +import java.io.File; +import java.io.FileNotFoundException; + +@WalkerName("TableRecalibration") +public class TableRecalibrationWalker extends ReadWalker { + @Argument(shortName="params", doc="CountCovariates params file", required=true) + public String paramsFile; + + @Argument(shortName="outputBAM", doc="output BAM file", required=false) + public String outputBamFile = null; + + private static Logger logger = Logger.getLogger(TableRecalibrationWalker.class); + + private final static boolean DEBUG = false; + + // maps from [readGroup] -> [prevBase x base -> [cycle, qual, new qual]] + HashMap cache = new HashMap(); + + @Argument(shortName="serial", doc="", required=false) + boolean serialRecalibration = false; + + public void initialize() { + try { + System.out.printf("Reading data...%n"); + List data = new ArrayList(); + List lines = new xReadLines(new File(paramsFile)).readLines(); + for ( String line : lines ) { + // rg,dn,logitQ,pos,indicator,count + // SRR007069,AA,28,1,0,2 + data.add(RecalData.fromCSVString(line)); + } + initializeCache(data); + } catch ( FileNotFoundException e ) { + Utils.scareUser("Cannot read/find parameters file " + paramsFile); + } + } + + private void initializeCache(List data) { + Set readGroups = new HashSet(); + Set dinucs = new HashSet(); + int maxPos = -1; + int maxQReported = -1; + + logger.info(String.format("No. params : %d", data.size())); + // get the maximum data from the file + for ( RecalData datum : data ) { + readGroups.add(datum.readGroup); + dinucs.add(datum.dinuc); + maxPos = Math.max(maxPos, datum.pos); + maxQReported = Math.max(maxQReported, datum.qual); + } + logger.info(String.format("Read groups : %d %s", readGroups.size(), readGroups.toString())); + logger.info(String.format("Dinucs : %d %s", dinucs.size(), dinucs.toString())); + logger.info(String.format("Max pos : %d", maxPos)); + logger.info(String.format("Max Q reported : %d", maxQReported)); + + // initialize the data structure + HashMap managers = new HashMap(); + for ( String readGroup : readGroups ) { + RecalDataManager manager = new RecalDataManager(readGroup, maxPos, maxQReported, dinucs.size(), true, true); + managers.put(readGroup, manager); + } + + // fill in the manager structure + for ( RecalData datum : data ) { + managers.get(datum.readGroup).addDatum(datum); + } + + // fill in the table with mapping objects + for ( String readGroup : readGroups ) { + RecalDataManager manager = managers.get(readGroup); + RecalMapping mapper = null; + if ( serialRecalibration ) + mapper = new SerialRecalMapping(manager, dinucs, maxPos, maxQReported); + else + mapper = new CombinatorialRecalMapping(manager, dinucs, maxPos, maxQReported); + cache.put(readGroup, mapper); + } + } + + public SAMRecord map(char[] ref, SAMRecord read) { + //if ( read.getReadLength() > maxReadLen ) { + // throw new RuntimeException("Expectedly long read, please increase maxium read len with maxReadLen parameter: " + read.format()); + //} + + byte[] bases = read.getReadBases(); + byte[] quals = read.getBaseQualities(); + byte[] recalQuals = new byte[quals.length]; + + // Since we want machine direction reads not corrected positive strand reads, rev comp any negative strand reads + if (read.getReadNegativeStrandFlag()) { + bases = BaseUtils.simpleReverseComplement(bases); + quals = BaseUtils.reverse(quals); + } + + String readGroup = read.getAttribute("RG").toString(); + + RecalMapping mapper = cache.get(readGroup); + + int numBases = read.getReadLength(); + recalQuals[0] = quals[0]; // can't change the first -- no dinuc + + for ( int cycle = 1; cycle < numBases; cycle++ ) { // skip first and last base, qual already set because no dinuc + // Take into account that previous base is the next base in terms of machine chemistry if + // this is a negative strand + byte qual = quals[cycle]; + byte newQual = mapper.getNewQual(readGroup, bases[cycle - 1], bases[cycle], cycle, qual); + recalQuals[cycle] = newQual; + //System.out.printf("Mapping %d => %d%n", qual, newQual); + } + + if (read.getReadNegativeStrandFlag()) + recalQuals = BaseUtils.reverse(quals); + //System.out.printf("OLD: %s%n", read.format()); + read.setBaseQualities(recalQuals); + //System.out.printf("NEW: %s%n", read.format()); + return read; + } + + public void onTraversalDone(SAMFileWriter output) { + if ( output != null ) { + output.close(); + } + } + + public SAMFileWriter reduceInit() { + if ( outputBamFile != null ) { // ! outputBamFile.equals("") ) { + SAMFileHeader header = this.getToolkit().getEngine().getSAMHeader(); + return Utils.createSAMFileWriterWithCompression(header, true, outputBamFile, getToolkit().getBAMCompression()); + } + else { + return null; + } + } + + /** + * Summarize the error rate data. + * + */ + public SAMFileWriter reduce(SAMRecord read, SAMFileWriter output) { + if ( output != null ) { + output.addAlignment(read); + } else { + out.println(read.format()); + } + + return output; + } +} + +interface RecalMapping { + public byte getNewQual(final String readGroup, byte prevBase, byte base, int cycle, byte qual); +} + +class CombinatorialRecalMapping implements RecalMapping { + HashMap cache = new HashMap(); + + public CombinatorialRecalMapping(RecalDataManager manager, Set dinucs, int maxPos, int maxQReported ) { + // initialize the data structure + for ( String dinuc : dinucs ) { + byte[][] table = new byte[maxPos+1][maxQReported+1]; + cache.put(dinuc, table); + } + + for ( RecalData datum : manager.getAll() ) { + //System.out.printf("Adding datum %s%n", datum); + byte [][] table = cache.get(datum.dinuc); + if ( table[datum.pos][datum.qual] != 0 ) + throw new RuntimeException(String.format("Duplicate entry discovered: %s", datum)); + //table[datum.pos][datum.qual] = (byte)(1 + datum.empiricalQualByte()); + table[datum.pos][datum.qual] = datum.empiricalQualByte(); + } + } + + public byte getNewQual(final String readGroup, byte prevBase, byte base, int cycle, byte qual) { + //System.out.printf("Lookup RG=%s prevBase=%c base=%c cycle=%d qual=%d%n", readGroup, prevBase, base, cycle, qual); + //String dinuc = String.format("%c%c", (char)prevBase, (char)base); + byte[] bp = {prevBase, base}; + String dinuc = new String(bp); + byte[][] dataTable = cache.get(dinuc); + + if ( dataTable == null && prevBase != 'N' && base != 'N' ) + throw new RuntimeException(String.format("Unmapped data table at %s %s", readGroup, dinuc)); + + return dataTable != null && cycle < dataTable.length ? dataTable[cycle][qual] : qual; + } +} + +class SerialRecalMapping implements RecalMapping { + // mapping from dinuc x Q => new Q + HashMap mappingByDinuc; + + // mapping from pos x Q => new Q + byte[][] mappingByPos; + + public SerialRecalMapping(RecalDataManager manager, Set dinucs, int maxPos, int maxQReported ) { + mappingByDinuc = new HashMap(); + for ( String dinuc : dinucs ) { + byte[] table = new byte[maxQReported+1]; + mappingByDinuc.put(dinuc, table); + } + for ( RecalData datum : manager.getDataByDinuc() ) { + //System.out.printf("Adding datum %s%n", datum); + if ( mappingByDinuc.get(datum.dinuc).length <= datum.qual ) { + throw new RuntimeException(String.format("Unexpectedly massive Q score of %d found, calculated max was %d%n", maxQReported, datum.qual)); + } + mappingByDinuc.get(datum.dinuc)[datum.qual] = datum.empiricalQualByte(); + } + + // initialize the mapping by position + mappingByPos = new byte[maxPos+1][maxQReported+1]; + for ( RecalData datum : manager.getDataByPos() ) { + //System.out.printf("Adding datum %s%n", datum); + mappingByPos[datum.pos][datum.qual] = datum.empiricalQualByte(); + } + } + + public byte getNewQual(final String readGroup, byte prevBase, byte base, int cycle, byte qual) { + //System.out.printf("Lookup RG=%s prevBase=%c base=%c cycle=%d qual=%d%n", readGroup, prevBase, base, cycle, qual); + //String dinuc = String.format("%c%c", (char)prevBase, (char)base); + byte[] bp = {prevBase, base}; + String dinuc = new String(bp); + + byte newQualFromDinuc = 0; + byte newQualFromPos = cycle > 0 && cycle < mappingByPos.length ? mappingByPos[cycle][qual] : qual; + byte newQual = newQualFromPos; + if ( prevBase != 'N' && base != 'N' ) { + // if the qual got mapped too high, assume it's the best we've seen for recalibration purposes + int newQualIndex = newQual < mappingByDinuc.get(dinuc).length ? newQual : mappingByDinuc.get(dinuc).length - 1; + newQualFromDinuc = mappingByDinuc.get(dinuc)[newQualIndex]; + newQual = newQualFromDinuc; + } + + //System.out.printf("Lookup RG=%s prevBase=%c base=%c cycle=%d qual=%d => %d => %d => %d%n", + // readGroup, prevBase, base, cycle, qual, newQualFromPos, newQualFromDinuc, newQual); + + + return newQual; + } +} \ No newline at end of file