diff --git a/java/src/org/broadinstitute/sting/playground/gatk/walkers/CovariateCounterWalker.java b/java/src/org/broadinstitute/sting/playground/gatk/walkers/CovariateCounterWalker.java index 24fd165b4..e5b3649a6 100644 --- a/java/src/org/broadinstitute/sting/playground/gatk/walkers/CovariateCounterWalker.java +++ b/java/src/org/broadinstitute/sting/playground/gatk/walkers/CovariateCounterWalker.java @@ -56,11 +56,13 @@ public class CovariateCounterWalker extends LocusWalker { long B; int pos; int qual; + String readGroup; String dinuc; - public RecalData(int pos, int qual, String dinuc ) { + public RecalData(int pos, int qual, String readGroup, String dinuc ) { this.pos = pos; this.qual = qual; + this.readGroup = readGroup; this.dinuc = dinuc; } @@ -95,7 +97,7 @@ public class CovariateCounterWalker extends LocusWalker { for ( int j = 0; j < MAX_QUAL_SCORE+1; j++) { for ( int k = 0; k < NDINUCS; k++) { String dinuc = dinucIndex2bases(k); - RecalData datum = new RecalData(i, j, dinuc); + RecalData datum = new RecalData(i, j, readGroup.getReadGroupId(), dinuc); data.get(readGroup.getReadGroupId())[i][j][k] = datum; flattenData.add(datum); } @@ -115,8 +117,7 @@ public class CovariateCounterWalker extends LocusWalker { if ( "ILLUMINA".equalsIgnoreCase(readGroup.getAttribute("PL").toString()) && !read.getReadNegativeStrandFlag() && (READ_GROUP.equals("none") || read.getAttribute("RG") != null && read.getAttribute("RG").equals(READ_GROUP)) && - (read.getMappingQuality() >= MIN_MAPPING_QUALITY) && - (DOWNSAMPLE_FRACTION == 1.0 || random_genrator.nextFloat() < DOWNSAMPLE_FRACTION)) { + (read.getMappingQuality() >= MIN_MAPPING_QUALITY)) { //(random_genrator.nextFloat() <= DOWNSAMPLE_FRACTION) int offset = offsets.get(i); int numBases = read.getReadLength(); @@ -178,7 +179,7 @@ public class CovariateCounterWalker extends LocusWalker { for (SAMReadGroupRecord readGroup : this.getToolkit().getEngine().getSAMHeader().getReadGroups()) { for ( int dinuc_index=0; dinuc_index 0) dinuc_out.format("%s,%s,%d,%d,%d,%d%n", readGroup.getReadGroupId(), dinucIndex2bases(dinuc_index), datum.qual, datum.pos, 0, datum.N - datum.B); if (datum.B > 0) @@ -229,10 +230,10 @@ public class CovariateCounterWalker extends LocusWalker { ArrayList ByCycle = new ArrayList(); ArrayList ByCycleReportedQ = new ArrayList(); ByCycleFile.printf("cycle,Qemp-obs,Qemp,Qobs,B,N%n"); - RecalData All = new RecalData(0,0,""); + RecalData All = new RecalData(0,0,readGroup.getReadGroupId(),""); MeanReportedQuality AllReported = new MeanReportedQuality(); for (int c=0; c < MAX_READ_LENGTH+1; c++) { - ByCycle.add(new RecalData(c, -1, "-")); + ByCycle.add(new RecalData(c, -1,readGroup.getReadGroupId(),"-")); ByCycleReportedQ.add(new MeanReportedQuality()); } @@ -265,10 +266,10 @@ public class CovariateCounterWalker extends LocusWalker { ArrayList ByCycle = new ArrayList(); ArrayList ByCycleReportedQ = new ArrayList(); ByDinucFile.printf("dinuc,Qemp-obs,Qemp,Qobs,B,N%n"); - RecalData All = new RecalData(0,0,""); + RecalData All = new RecalData(0,0,readGroup.getReadGroupId(),""); MeanReportedQuality AllReported = new MeanReportedQuality(); for (int c=0; c < NDINUCS; c++) { - ByCycle.add(new RecalData(-1, -1, dinucIndex2bases(c))); + ByCycle.add(new RecalData(-1, -1,readGroup.getReadGroupId(),dinucIndex2bases(c))); ByCycleReportedQ.add(new MeanReportedQuality()); } @@ -302,10 +303,10 @@ public class CovariateCounterWalker extends LocusWalker { ArrayList ByQ = new ArrayList(); ArrayList ByQReportedQ = new ArrayList(); ByQualFile.printf("Qrep,Qemp,Qrep_avg,B,N%n"); - RecalData All = new RecalData(0,0,""); + RecalData All = new RecalData(0,0,readGroup.getReadGroupId(),""); MeanReportedQuality AllReported = new MeanReportedQuality(); for (int q=0; q regressors = new HashMap(); + Map, LogisticRegressor> regressors = new HashMap, LogisticRegressor>(); private static Logger logger = Logger.getLogger(LogisticRecalibrationWalker.class); public void initialize() { @@ -29,26 +29,28 @@ public class LogisticRecalibrationWalker extends ReadWalker(readGroup,dinuc), regressor); System.out.printf("Vals = %s%n", Utils.join(",", vals)); - for ( int i = 1; i < vals.length; i++ ) { + for ( int i = 2; i <= (vals.length-2); i++ ) { Pair ij = mapping.get(i-1); try { double c = Double.parseDouble(vals[i]); regressor.setCoefficient(ij.first, ij.second, c); - System.out.printf("Setting coefficient %s => %s = %f%n", dinuc, ij, c); + System.out.printf("Setting coefficient %s,%s => %s = %f%n", readGroup, dinuc, ij, c); } catch ( NumberFormatException e ) { Utils.scareUser("Badly formed logistic regression header at " + vals[i] + " line: " + line ); } } } - for ( Map.Entry e : regressors.entrySet() ) { - String dinuc = e.getKey(); + for ( Map.Entry, LogisticRegressor> e : regressors.entrySet() ) { + String readGroup = e.getKey().first; + String dinuc = e.getKey().second; LogisticRegressor regressor = e.getValue(); - logger.debug(String.format("Regressor: %s => %s", dinuc, regressor)); + logger.debug(String.format("Regressor: %s,%s => %s", readGroup, dinuc, regressor)); } //System.exit(1); @@ -62,10 +64,12 @@ public class LogisticRecalibrationWalker extends ReadWalker> mapping = new ArrayList>(); String[] elts = headerLine.split("\\s+"); - if ( ! elts[0].toLowerCase().startsWith("dinuc") ) // checking only start of first field because dinuc will be followed by a version number to be checekde later - Utils.scareUser("Badly formatted Logistic regression header, upper left keyword must be dinuc: " + elts[0] + " line: " + headerLine); + if ( ! "rg".equalsIgnoreCase(elts[0]) ) + Utils.scareUser("Badly formatted Logistic regression header, upper left keyword must be rg: " + elts[0] + " line: " + headerLine); + if ( ! elts[1].toLowerCase().startsWith("dinuc") ) // checking only start of first field because dinuc will be followed by a version number to be checekde later + Utils.scareUser("Badly formatted Logistic regression header, second left keyword must be dinuc: " + elts[1] + " line: " + headerLine); - for ( int k = 1; k < elts.length; k++ ) { + for ( int k = 2; k < elts.length; k++ ) { String paramStr = elts[k]; String[] ij = paramStr.split(","); if ( ij.length != 2 ) { @@ -87,6 +91,7 @@ public class LogisticRecalibrationWalker extends ReadWalker(readGroup,dinuc)); byte newQual; if ( regressor != null ) { // no N or some other unexpected bp in the stream diff --git a/java/src/org/broadinstitute/sting/utils/Pair.java b/java/src/org/broadinstitute/sting/utils/Pair.java index cd0b09671..3f358a735 100644 --- a/java/src/org/broadinstitute/sting/utils/Pair.java +++ b/java/src/org/broadinstitute/sting/utils/Pair.java @@ -20,6 +20,48 @@ public class Pair { */ public Y getSecond() { return second; } + /** + * Calculate whether this pair object is equal to another object. + * @param o The other object (hopefully a pair). + * @return True if the two are equal; false otherwise. + */ + @Override + public boolean equals( Object o ) { + if( o == null ) + return false; + if( !(o instanceof Pair) ) + return false; + + Pair other = (Pair)o; + + // Check to see whether one is null but not the other. + if( this.first == null && other.first != null ) return false; + if( this.second == null && other.second != null ) return false; + + // Check to see whether the values are equal. + // If the param of equals is null, it should by contract return false. + if( this.first != null && !this.first.equals(other.first) ) return false; + if( this.second != null && !this.second.equals(other.second) ) return false; + + return true; + } + + /** + * Basic hashcode function. Assume hashcodes of first and second are + * randomly distributed and return the XOR of the two. + * @return Randomly distributed hashcode of the pair. + */ + @Override + public int hashCode() { + if( second == null && first == null ) + return 0; + if( second == null ) + return first.hashCode(); + if( first == null ) + return second.hashCode(); + return first.hashCode() ^ second.hashCode(); + } + public String toString() { return first+","+second; } diff --git a/python/LogisticRegressionByReadGroup.py b/python/LogisticRegressionByReadGroup.py new file mode 100755 index 000000000..45dd6e9a9 --- /dev/null +++ b/python/LogisticRegressionByReadGroup.py @@ -0,0 +1,97 @@ +#!/usr/bin/env python + +import os,sys + +R_exe="/broad/tools/apps/R-2.6.0/bin/Rscript" +logistic_regression_script="/humgen/gsa-scr1/hanna/src/StingWorking/R/logistic_regression.R" + +def exit(msg,errorlevel): + "Exit the program with the specified message and error code." + print msg + sys.exit(errorlevel) + +def open_source_file(source_filename): + "Open the source file with the given name. Make sure it's readable." + if not os.access(source_filename,os.R_OK): + exit("Unable to read covariate counts file '" + sys.argv[1] + "'",1) + if not source_filename.endswith('.csv'): + exit("Source file is in incorrect format. Must be csv.") + return open(source_filename,'r') + +def read_header(source_file): + "Read the header from the given source file. Do basic validation." + header = source_file.readline().split(',') + if header[0] != 'rg' or header[1] != 'dn': + exit("Input file is in invalid format. First two columns should be ,",1) + return header + +def create_intermediate_file(source_file,read_group,dinuc): + "Create an intermediate file for a particular read group / dinuc from a given source file" + base = source_file.name[:source_file.name.rfind('.csv')] + intermediate_filename = "%s.%s.%s.csv" % (base,read_group,dinuc) + intermediate_file = open(intermediate_filename,"w") + intermediate_file.write(','.join(header[2:])) + return intermediate_file + +def open_target_file(target_filename): + "Open a target file and write out the header." + target_file = open(target_filename,'w') + target_file.write("rg\tdinuc\t") + for p in range(5): + for q in range(5): + target_file.write("%d,%d\t" % (p,q)) + target_file.write("\n") + return target_file + +def process_file(source_file,read_group,dinuc,target_file): + base = source_file.name[:source_file.name.rfind('.csv')] + '.' + read_group + regression_command = ' '.join((R_exe,logistic_regression_script,base,base,dinuc)) + print "Running " + regression_command + os.system(regression_command) + parameters_filename = '.'.join((base,dinuc,'parameters')) + if not os.access(parameters_filename,os.R_OK): + exit("Unable to read output of R from file " + parameters_filename) + parameters_file = open(parameters_filename,'r') + parameters = ' '.join([line.rstrip('\n') for line in parameters_file]).split(' ') + target_file.write('\t'.join([read_group,dinuc]+parameters)+'\n') + os.remove('.'.join((base,dinuc,'csv'))) + os.remove('.'.join((base,dinuc,'parameters'))) + +class LogisticRegressionRunner: + def __init__(self,source_filename,read_group,dinuc): + __base = source_file.name[:source_file.name.rfind('.csv')] + __read_group = read_group + __dinuc = dinuc + + +if len(sys.argv) < 3: + exit("Usage: logistic_regression ",1) + +source_file = open_source_file(sys.argv[1]) +target_file = open_target_file(sys.argv[2]) + +header = read_header(source_file) + +intermediate_file = None +read_group = None +dinuc = None + +for data_line in source_file: + data_line = data_line.strip() + if len(data_line) == 0: + continue + data = data_line.split(',') + if read_group != data[0] or dinuc != data[1]: + if intermediate_file: + intermediate_file.close() + process_file(source_file,read_group,dinuc,target_file) + read_group,dinuc = data[0:2] + intermediate_file = create_intermediate_file(source_file,read_group,dinuc) + intermediate_file.write(','.join(data[2:])+'\n') + +if intermediate_file: + intermediate_file.close() + process_file(source_file,read_group,dinuc,target_file) + +source_file.close() +target_file.close()