diff --git a/java/src/org/broadinstitute/sting/playground/gatk/walkers/IntervalCleanerWalker.java b/java/src/org/broadinstitute/sting/playground/gatk/walkers/IntervalCleanerWalker.java index 8c77754bf..f660ef3cd 100755 --- a/java/src/org/broadinstitute/sting/playground/gatk/walkers/IntervalCleanerWalker.java +++ b/java/src/org/broadinstitute/sting/playground/gatk/walkers/IntervalCleanerWalker.java @@ -1,28 +1,27 @@ package org.broadinstitute.sting.playground.gatk.walkers; +import org.broadinstitute.sting.utils.Pair; import org.broadinstitute.sting.gatk.refdata.*; import org.broadinstitute.sting.gatk.walkers.LocusWindowWalker; import org.broadinstitute.sting.gatk.walkers.WalkerName; import org.broadinstitute.sting.gatk.LocusContext; import org.broadinstitute.sting.utils.cmdLine.Argument; import org.broadinstitute.sting.playground.indels.*; -import org.broadinstitute.sting.playground.utils.CountedObject; -import org.broadinstitute.sting.playground.utils.CountedObjectComparatorAdapter; import net.sf.samtools.*; import java.util.ArrayList; import java.util.List; -import java.util.TreeSet; import java.io.File; @WalkerName("IntervalCleaner") -public class IntervalCleanerWalker extends LocusWindowWalker { +public class IntervalCleanerWalker extends LocusWindowWalker { @Argument(fullName="maxReadLength", shortName="maxRead", doc="max read length", required=false, defaultValue="-1") public int maxReadLength; @Argument(fullName="OutputCleaned", shortName="O", required=true, doc="Output file (sam or bam) for improved (realigned) reads") public String OUT; + public static final int MAX_QUAL = 99; private SAMFileWriter writer; @@ -34,12 +33,15 @@ public class IntervalCleanerWalker extends LocusWindowWalker { public Integer map(RefMetaDataTracker tracker, String ref, LocusContext context) { List reads = context.getReads(); ArrayList goodReads = new ArrayList(); - long leftmostIndex = context.getLocation().getStart(); for ( SAMRecord read : reads ) { if ( read.getReadLength() <= maxReadLength ) goodReads.add(read); } + clean(goodReads, ref, context.getLocation().getStart()); + //bruteForceClean(goodReads, ref, context.getLocation().getStart()); + //testCleanWithDeletion(); + //testCleanWithInsertion(); return 1; } @@ -54,22 +56,505 @@ public class IntervalCleanerWalker extends LocusWindowWalker { public void onTraversalDone(Integer result) { out.println("Saw " + result + " intervals"); + writer.close(); + } + + private static int mismatchQualitySumCigar(AlignedRead aRead, String ref, int refIndex) { + String read = aRead.getReadString(); + String quals = aRead.getBaseQualityString(); + Cigar c = aRead.getCigar(); + + int sum = 0; + int readIndex = 0; + for ( int i = 0 ; i < c.numCigarElements() ; i++ ) { + CigarElement ce = c.getCigarElement(i); + switch( ce.getOperator() ) { + case M: + for ( int j = 0 ; j < ce.getLength() ; j++, refIndex++, readIndex++ ) { + if ( Character.toUpperCase(read.charAt(readIndex)) != Character.toUpperCase(ref.charAt(refIndex)) ) + sum += (int)quals.charAt(readIndex) - 33; + } + break; + case I: + readIndex += ce.getLength(); + break; + case D: + refIndex += ce.getLength(); + break; + default: throw new RuntimeException("Unrecognized cigar element"); + } + } + return sum; + } + + private static int mismatchQualitySum(AlignedRead aRead, String ref, int refIndex) { + String read = aRead.getReadString(); + String quals = aRead.getBaseQualityString(); + + int sum = 0; + for ( int readIndex = 0 ; readIndex < read.length() ; readIndex++, refIndex++ ) { + if ( refIndex > ref.length() ) + sum += MAX_QUAL; + else if ( Character.toUpperCase(read.charAt(readIndex)) != Character.toUpperCase(ref.charAt(refIndex)) ) + sum += (int)quals.charAt(readIndex) - 33; + } + return sum; } private void clean(List reads, String reference, long leftmostIndex) { - // total mismatches across all reads - //int totalMismatches = 0; - //TreeSet< CountedObject > all_indels = new TreeSet< CountedObject >( - // new CountedObjectComparatorAdapter(new IntervalComparator())); + ArrayList refReads = new ArrayList(); + ArrayList altReads = new ArrayList(); + ArrayList altAlignmentsToTest = new ArrayList(); + int totalMismatchSum = 0; + + // decide which reads potentially need to be cleaned for ( SAMRecord read : reads ) { - System.out.println(read.getReadString()); - System.out.println(reference.substring(read.getAlignmentStart()-(int)leftmostIndex, read.getAlignmentEnd()-(int)leftmostIndex+1)); - //totalMismatches += AlignmentUtils.numMismatches(read, reference); - //System.out.println(totalMismatches + "\n"); + AlignedRead aRead = new AlignedRead(read); + int mismatchScore = mismatchQualitySum(aRead, reference, read.getAlignmentStart()-(int)leftmostIndex); + + // if this doesn't match perfectly to the reference, let's try to clean it + if ( mismatchScore > 0 ) { + altReads.add(aRead); + altAlignmentsToTest.add(true); + totalMismatchSum += mismatchScore; + aRead.setMismatchScoreToReference(mismatchScore); + } + // otherwise, we can emit it as is + else { + refReads.add(read); + } } + Consensus bestConsensus = null; + // for each alternative consensus to test, align it to the reference and create an alternative consensus + for ( int index = 0; index < altAlignmentsToTest.size(); index++ ) { + if ( altAlignmentsToTest.get(index) ) { + // do a pairwise alignment against the reference + AlignedRead aRead = altReads.get(index); + SWPairwiseAlignment swConsensus = new SWPairwiseAlignment(reference, aRead.getReadString()); + int idx = swConsensus.getAlignmentStart2wrt1(); + + // create the new consensus + StringBuffer sb = new StringBuffer(); + sb.append(reference.substring(0, idx)); + sb.append(aRead.getReadString()); + Cigar c = swConsensus.getCigar(); + + int indelCount = 0; + for ( int i = 0 ; i < c.numCigarElements() ; i++ ) { + CigarElement ce = c.getCigarElement(i); + switch( ce.getOperator() ) { + case D: + indelCount++; + case M: + idx += ce.getLength(); + break; + case I: + indelCount++; + break; + } + } + // make sure that there is at most only a single indel! + if ( indelCount > 1 ) + continue; + + sb.append(reference.substring(idx)); + String altConsensus = sb.toString(); + + // for each imperfect match to the reference, score it against this alternative + Consensus consensus = new Consensus(altConsensus, c, swConsensus.getAlignmentStart2wrt1()); + for ( int j = 0; j < altReads.size(); j++ ) { + if (j == index) { + consensus.readIndexes.add(new Pair(j, swConsensus.getAlignmentStart2wrt1())); + continue; + } + AlignedRead toTest = altReads.get(j); + Pair altAlignment = findBestOffset(altConsensus, toTest); + + // the mismatch score is the min of its alignment vs. the reference and vs. the alternate + int myScore = altAlignment.getSecond(); + if ( myScore >= toTest.getMismatchScoreToReference() ) + myScore = toTest.getMismatchScoreToReference(); + // keep track of reads that align better to the alternate consensus + else + consensus.readIndexes.add(new Pair(j, altAlignment.getFirst())); + + logger.info(aRead.getReadString() + " vs. " + toTest.getReadString() + " => " + myScore + " - " + altAlignment.getFirst()); + consensus.mismatchSum += myScore; + if ( myScore == 0 ) + // we already know that this is its consensus, so don't bother testing it later + altAlignmentsToTest.set(j, false); + } + logger.info(aRead.getReadString() + " " + consensus.mismatchSum); + if ( bestConsensus == null || bestConsensus.mismatchSum > consensus.mismatchSum) { + bestConsensus = consensus; + logger.info(aRead.getReadString() + " " + consensus.mismatchSum); + } + } + } + + // if the best alternate consensus has a smaller sum of quality score mismatches, then clean! + if ( bestConsensus.mismatchSum < totalMismatchSum ) { + logger.info("CLEAN: " + bestConsensus.str); + + // clean the appropriate reads + for ( Pair indexPair : bestConsensus.readIndexes ) + updateRead(bestConsensus.cigar, bestConsensus.positionOnReference, indexPair.getSecond(), altReads.get(indexPair.getFirst()), (int)leftmostIndex); + + // write them out + for ( SAMRecord rec : refReads ) + writer.addAlignment(rec); + for ( AlignedRead aRec : altReads ) + writer.addAlignment(aRec.getRead()); + } + } + + private Pair findBestOffset(String ref, AlignedRead read) { + int attempts = ref.length() - read.getReadLength() + 1; + int bestScore = mismatchQualitySum(read, ref, 0); + int bestIndex = 0; + for ( int i = 1; i < attempts; i++ ) { + // we can't get better than 0! + if ( bestScore == 0 ) + return new Pair(bestIndex, 0); + int score = mismatchQualitySum(read, ref, i); + if ( score < bestScore ) { + bestScore = score; + bestIndex = i; + } + } + return new Pair(bestIndex, bestScore); + } + + private void updateRead(Cigar altCigar, int altPosOnRef, int myPosOnAlt, AlignedRead aRead, int leftmostIndex) { + Cigar readCigar = new Cigar(); + + // special case: there is no indel + if ( altCigar.getCigarElements().size() == 1 ) { + aRead.getRead().setAlignmentStart(leftmostIndex + myPosOnAlt); + readCigar.add(new CigarElement(aRead.getReadLength(), CigarOperator.M)); + aRead.getRead().setCigar(readCigar); + return; + } + + CigarElement altCE1 = altCigar.getCigarElement(0); + CigarElement altCE2 = altCigar.getCigarElement(1); + + // the easiest thing to do is to take each case separately + int endOfFirstBlock = altPosOnRef + altCE1.getLength(); + boolean sawAlignmentStart = false; + + // for reads starting before the indel + if ( myPosOnAlt < endOfFirstBlock) { + aRead.getRead().setAlignmentStart(leftmostIndex + myPosOnAlt); + sawAlignmentStart = true; + + // for reads ending before the indel + if ( myPosOnAlt + aRead.getReadLength() <= endOfFirstBlock) { + readCigar.add(new CigarElement(aRead.getReadLength(), CigarOperator.M)); + aRead.getRead().setCigar(readCigar); + return; + } + readCigar.add(new CigarElement(endOfFirstBlock - myPosOnAlt, CigarOperator.M)); + } + + int indelOffsetOnRef = 0, indelOffsetOnRead = 0; + // forward along the indel + if ( altCE2.getOperator() == CigarOperator.I ) { + // for reads that end in an insertion + if ( myPosOnAlt + aRead.getReadLength() < endOfFirstBlock + altCE2.getLength() ) { + readCigar.add(new CigarElement(myPosOnAlt + aRead.getReadLength() - endOfFirstBlock, CigarOperator.I)); + aRead.getRead().setCigar(readCigar); + return; + } + + // for reads that start in an insertion + if ( !sawAlignmentStart && myPosOnAlt < endOfFirstBlock + altCE2.getLength() ) { + aRead.getRead().setAlignmentStart(leftmostIndex + endOfFirstBlock); + readCigar.add(new CigarElement(myPosOnAlt - endOfFirstBlock, CigarOperator.I)); + indelOffsetOnRead = myPosOnAlt - endOfFirstBlock; + sawAlignmentStart = true; + } else if ( sawAlignmentStart ) { + readCigar.add(altCE2); + indelOffsetOnRead = altCE2.getLength(); + } + } else if ( altCE2.getOperator() == CigarOperator.D ) { + readCigar.add(altCE2); + indelOffsetOnRef = altCE2.getLength(); + } else { + throw new RuntimeException("Operator of middle block is not I or D: " + altCE2.getOperator()); + } + + // for reads that start after the indel + if ( !sawAlignmentStart ) { + aRead.getRead().setAlignmentStart(leftmostIndex + myPosOnAlt + indelOffsetOnRef - indelOffsetOnRead); + readCigar.add(new CigarElement(aRead.getReadLength(), CigarOperator.M)); + aRead.getRead().setCigar(readCigar); + return; + } + + int readRemaining = aRead.getReadLength(); + for ( CigarElement ce : readCigar.getCigarElements() ) { + if ( ce.getOperator() != CigarOperator.D ) + readRemaining -= ce.getLength(); + } + readCigar.add(new CigarElement(readRemaining, CigarOperator.M)); + aRead.getRead().setCigar(readCigar); + } + + private class AlignedRead { + SAMRecord read; + int mismatchScoreToReference; + + public AlignedRead(SAMRecord read) { + this.read = read; + mismatchScoreToReference = 0; + } + + public SAMRecord getRead() { + return read; + } + + public String getReadString() { + return read.getReadString(); + } + + public int getReadLength() { + return read.getReadLength(); + } + + public Cigar getCigar() { + return read.getCigar(); + } + + public void setCigar(Cigar cigar) { + read.setCigar(cigar); + } + + public String getBaseQualityString() { + return read.getBaseQualityString(); + } + + public void setMismatchScoreToReference(int score) { + mismatchScoreToReference = score; + } + + public int getMismatchScoreToReference() { + return mismatchScoreToReference; + } + } + + private class Consensus { + public String str; + public int mismatchSum; + public int positionOnReference; + public Cigar cigar; + public ArrayList> readIndexes; + + public Consensus(String str, Cigar cigar, int positionOnReference) { + this.str = str; + this.cigar = cigar; + this.positionOnReference = positionOnReference; + mismatchSum = 0; + readIndexes = new ArrayList>(); + } + } + + private void testCleanWithInsertion() { + String reference = "AAAAAACCCCCCAAAAAA"; + // the alternate reference is: "AAAAAACCCTTCCCAAAAAA"; + ArrayList reads = new ArrayList(); + SAMFileHeader header = getToolkit().getSamReader().getFileHeader(); + SAMRecord r1 = new SAMRecord(header); + r1.setReadName("1"); + r1.setReadString("AACCCCCC"); + r1.setAlignmentStart(4); + r1.setBaseQualityString("BBBBBBBB"); + SAMRecord r2 = new SAMRecord(header); + r2.setReadName("2"); + r2.setReadString("AAAACCCT"); + r2.setAlignmentStart(2); + r2.setBaseQualityString("BBBBBBBB"); + SAMRecord r3 = new SAMRecord(header); + r3.setReadName("3"); + r3.setReadString("CTTC"); + r3.setAlignmentStart(10); + r3.setBaseQualityString("BBBB"); + SAMRecord r4 = new SAMRecord(header); + r4.setReadName("4"); + r4.setReadString("TCCCAA"); + r4.setAlignmentStart(8); + r4.setBaseQualityString("BBBBBB"); + SAMRecord r5 = new SAMRecord(header); + r5.setReadName("5"); + r5.setReadString("AAAGAACC"); + r5.setAlignmentStart(0); + r5.setBaseQualityString("BBBBBBBB"); + SAMRecord r6 = new SAMRecord(header); + r6.setReadName("6"); + r6.setReadString("CCAAAGAA"); + r6.setAlignmentStart(10); + r6.setBaseQualityString("BBBBBBBB"); + SAMRecord r7 = new SAMRecord(header); + r7.setReadName("7"); + r7.setReadString("AACCCTTCCC"); + r7.setAlignmentStart(4); + r7.setBaseQualityString("BBBBBBBBBB"); + reads.add(r1); + reads.add(r2); + reads.add(r3); + reads.add(r4); + reads.add(r5); + reads.add(r6); + reads.add(r7); + clean(reads, reference, 0); + } + + private void testCleanWithDeletion() { + String reference = "AAAAAACCCTTCCCAAAAAA"; + // the alternate reference is: "AAAAAACCCCCCAAAAAA"; + ArrayList reads = new ArrayList(); + SAMFileHeader header = getToolkit().getSamReader().getFileHeader(); + SAMRecord r1 = new SAMRecord(header); + r1.setReadName("1"); + r1.setReadString("ACCCTTCC"); + r1.setAlignmentStart(5); + r1.setBaseQualityString("BBBBBBBB"); + SAMRecord r2 = new SAMRecord(header); + r2.setReadName("2"); + r2.setReadString("AAAACCCC"); + r2.setAlignmentStart(2); + r2.setBaseQualityString("BBBBBBBB"); + SAMRecord r3 = new SAMRecord(header); + r3.setReadName("3"); + r3.setReadString("CCCC"); + r3.setAlignmentStart(6); + r3.setBaseQualityString("BBBB"); + SAMRecord r4 = new SAMRecord(header); + r4.setReadName("4"); + r4.setReadString("CCCCAA"); + r4.setAlignmentStart(10); + r4.setBaseQualityString("BBBBBB"); + SAMRecord r5 = new SAMRecord(header); + r5.setReadName("5"); + r5.setReadString("AAAGAACC"); + r5.setAlignmentStart(0); + r5.setBaseQualityString("BBBBBBBB"); + SAMRecord r6 = new SAMRecord(header); + r6.setReadName("6"); + r6.setReadString("CCAAAGAA"); + r6.setAlignmentStart(10); + r6.setBaseQualityString("BBBBBBBB"); + SAMRecord r7 = new SAMRecord(header); + r7.setReadName("7"); + r7.setReadString("AAAACCCG"); + r7.setAlignmentStart(2); + r7.setBaseQualityString("BBBBBBBB"); + SAMRecord r8 = new SAMRecord(header); + r8.setReadName("8"); + r8.setReadString("AACCCCCC"); + r8.setAlignmentStart(4); + r8.setBaseQualityString("BBBBBBBB"); + reads.add(r1); + reads.add(r2); + reads.add(r3); + reads.add(r4); + reads.add(r5); + reads.add(r6); + reads.add(r7); + reads.add(r8); + clean(reads, reference, 0); + } + + private void bruteForceClean(List reads, String reference, long leftmostIndex) { + + ArrayList refReads = new ArrayList(); + ArrayList altReads = new ArrayList(); + int totalMismatchSum = 0; + + // decide which reads potentially need to be cleaned + for ( SAMRecord read : reads ) { + AlignedRead aRead = new AlignedRead(read); + int mismatchScore = mismatchQualitySum(aRead, reference, read.getAlignmentStart()-(int)leftmostIndex); + + // if this doesn't match perfectly to the reference, let's try to clean it + if ( mismatchScore > 0 ) { + altReads.add(aRead); + totalMismatchSum += mismatchScore; + aRead.setMismatchScoreToReference(mismatchScore); + } + // otherwise, we can emit it as is + else { + refReads.add(read); + } + } + + Consensus bestConsensus = null; + + // for each alternative consensus to test, align it to the reference and create an alternative consensus + for ( int indelSize = 1; indelSize <= 5; indelSize++ ) { + for ( int index = 1; index < reference.length(); index++ ) { + for ( int inOrDel = 0; inOrDel < 2; inOrDel++ ) { + + // create the new consensus + Cigar c = new Cigar(); + c.add(new CigarElement(index, CigarOperator.M)); + StringBuffer sb = new StringBuffer(); + sb.append(reference.substring(0, index)); + if ( inOrDel == 0 ) { + c.add(new CigarElement(indelSize, CigarOperator.D)); + c.add(new CigarElement(reference.length()-index-indelSize, CigarOperator.M)); + if ( reference.length() > index+indelSize ) + sb.append(reference.substring(index+indelSize)); + } else { + c.add(new CigarElement(indelSize, CigarOperator.I)); + c.add(new CigarElement(reference.length()-index+indelSize, CigarOperator.M)); + for ( int i = 0; i < indelSize; i++ ) + sb.append("A"); + sb.append(reference.substring(index)); + } + String altConsensus = sb.toString(); + + // for each imperfect match to the reference, score it against this alternative + Consensus consensus = new Consensus(altConsensus, c, 0); + for ( int j = 0; j < altReads.size(); j++ ) { + AlignedRead toTest = altReads.get(j); + Pair altAlignment = findBestOffset(altConsensus, toTest); + + // the mismatch score is the min of its alignment vs. the reference and vs. the alternate + int myScore = altAlignment.getSecond(); + if ( myScore >= toTest.getMismatchScoreToReference() ) + myScore = toTest.getMismatchScoreToReference(); + // keep track of reads that align better to the alternate consensus + else + consensus.readIndexes.add(new Pair(j, altAlignment.getFirst())); + + consensus.mismatchSum += myScore; + } + if ( bestConsensus == null || bestConsensus.mismatchSum > consensus.mismatchSum) { + bestConsensus = consensus; + logger.info(altConsensus + " " + consensus.mismatchSum); + } + } + } + } + + // if the best alternate consensus has a smaller sum of quality score mismatches, then clean! + if ( bestConsensus.mismatchSum < totalMismatchSum ) { + logger.info("CLEAN: " + bestConsensus.str); + + // clean the appropriate reads + for ( Pair indexPair : bestConsensus.readIndexes ) + updateRead(bestConsensus.cigar, bestConsensus.positionOnReference, indexPair.getSecond(), altReads.get(indexPair.getFirst()), (int)leftmostIndex); + + // write them out + for ( SAMRecord rec : refReads ) + writer.addAlignment(rec); + for ( AlignedRead aRec : altReads ) + writer.addAlignment(aRec.getRead()); + } } } \ No newline at end of file