Optimization: from 10k reads/sec - 22k reads/sec..

git-svn-id: file:///humgen/gsa-scr1/gsa-engineering/svn_contents/trunk@1819 348d0f76-0448-11de-a6fe-93d51630548a
This commit is contained in:
hanna 2009-10-13 18:07:15 +00:00
parent 77499e35ac
commit db642fd08b
8 changed files with 158 additions and 139 deletions

View File

@ -47,13 +47,15 @@ public class AlignerTestHarness {
for(SAMRecord read: reader) { for(SAMRecord read: reader) {
count++; count++;
//if( count > 100000 ) break; if( count > 100000 ) break;
//if( count < 366000 ) continue; //if( count < 366000 ) continue;
//if( count != 2 ) continue; //if( count != 2 ) continue;
//if( !read.getReadName().endsWith("SL-XBC:1:82:506:404#0") ) //if( !read.getReadName().endsWith("SL-XBC:1:82:506:404#0") )
// continue; // continue;
//if( !read.getReadName().endsWith("SL-XBC:1:36:30:1926#0") ) //if( !read.getReadName().endsWith("SL-XBC:1:36:30:1926#0") )
// continue; // continue;
//if( !read.getReadName().endsWith("SL-XBC:1:60:1342:1340#0") )
// continue;
SAMRecord alignmentCleaned = null; SAMRecord alignmentCleaned = null;
try { try {

View File

@ -159,7 +159,7 @@ public class BWAAligner implements Aligner {
lowerBounds.get(alignment.position+1).width, lowerBounds.get(alignment.position+1).width,
alignment.loBound, alignment.loBound,
alignment.hiBound); alignment.hiBound);
*/ */
// Temporary -- look ahead to see if the next alignment is bounded. // Temporary -- look ahead to see if the next alignment is bounded.
boolean allowDifferences = mismatches > 0; boolean allowDifferences = mismatches > 0;
@ -240,11 +240,11 @@ public class BWAAligner implements Aligner {
private List<BWAAlignment> createMatchedAlignments( BWT bwt, BWAAlignment alignment, byte[] bases, boolean allowMismatch ) { private List<BWAAlignment> createMatchedAlignments( BWT bwt, BWAAlignment alignment, byte[] bases, boolean allowMismatch ) {
List<BWAAlignment> newAlignments = new ArrayList<BWAAlignment>(); List<BWAAlignment> newAlignments = new ArrayList<BWAAlignment>();
List<Base> baseChoices = new ArrayList<Base>(); List<Byte> baseChoices = new ArrayList<Byte>();
Base thisBase = Base.fromASCII(bases[alignment.position+1]); Byte thisBase = Bases.fromASCII(bases[alignment.position+1]);
if( allowMismatch ) if( allowMismatch )
baseChoices.addAll(EnumSet.allOf(Base.class)); baseChoices.addAll(Bases.allOf());
else else
baseChoices.add(thisBase); baseChoices.add(thisBase);
@ -258,7 +258,7 @@ public class BWAAligner implements Aligner {
} }
} }
for(Base base: baseChoices) { for(byte base: baseChoices) {
BWAAlignment newAlignment = alignment.clone(); BWAAlignment newAlignment = alignment.clone();
newAlignment.loBound = bwt.counts(base) + bwt.occurrences(base,alignment.loBound-1) + 1; newAlignment.loBound = bwt.counts(base) + bwt.occurrences(base,alignment.loBound-1) + 1;
@ -270,7 +270,7 @@ public class BWAAligner implements Aligner {
newAlignment.position++; newAlignment.position++;
newAlignment.addState(AlignmentState.MATCH_MISMATCH); newAlignment.addState(AlignmentState.MATCH_MISMATCH);
if( base.toASCII() != bases[newAlignment.position] ) if( Bases.fromASCII(bases[newAlignment.position]) == null || base != Bases.fromASCII(bases[newAlignment.position]) )
newAlignment.mismatches++; newAlignment.mismatches++;
newAlignments.add(newAlignment); newAlignments.add(newAlignment);
@ -300,7 +300,7 @@ public class BWAAligner implements Aligner {
*/ */
private List<BWAAlignment> createDeletionAlignments( BWT bwt, BWAAlignment alignment) { private List<BWAAlignment> createDeletionAlignments( BWT bwt, BWAAlignment alignment) {
List<BWAAlignment> newAlignments = new ArrayList<BWAAlignment>(); List<BWAAlignment> newAlignments = new ArrayList<BWAAlignment>();
for(Base base: EnumSet.allOf(Base.class)) { for(byte base: Bases.instance) {
BWAAlignment newAlignment = alignment.clone(); BWAAlignment newAlignment = alignment.clone();
newAlignment.loBound = bwt.counts(base) + bwt.occurrences(base,alignment.loBound-1) + 1; newAlignment.loBound = bwt.counts(base) + bwt.occurrences(base,alignment.loBound-1) + 1;
@ -326,7 +326,7 @@ public class BWAAligner implements Aligner {
*/ */
private void exactMatch( BWAAlignment alignment, byte[] bases, BWT bwt ) { private void exactMatch( BWAAlignment alignment, byte[] bases, BWT bwt ) {
while( ++alignment.position < bases.length ) { while( ++alignment.position < bases.length ) {
Base base = Base.fromASCII(bases[alignment.position]); byte base = Bases.fromASCII(bases[alignment.position]);
alignment.loBound = bwt.counts(base) + bwt.occurrences(base,alignment.loBound-1) + 1; alignment.loBound = bwt.counts(base) + bwt.occurrences(base,alignment.loBound-1) + 1;
alignment.hiBound = bwt.counts(base) + bwt.occurrences(base,alignment.hiBound); alignment.hiBound = bwt.counts(base) + bwt.occurrences(base,alignment.hiBound);
if( alignment.loBound > alignment.hiBound ) if( alignment.loBound > alignment.hiBound )

View File

@ -3,7 +3,7 @@ package org.broadinstitute.sting.alignment.bwa;
import java.util.List; import java.util.List;
import java.util.ArrayList; import java.util.ArrayList;
import org.broadinstitute.sting.alignment.bwa.bwt.Base; import org.broadinstitute.sting.alignment.bwa.bwt.Bases;
import org.broadinstitute.sting.alignment.bwa.bwt.BWT; import org.broadinstitute.sting.alignment.bwa.bwt.BWT;
/** /**
@ -53,7 +53,7 @@ public class LowerBound {
int loIndex = 0, hiIndex = bwt.length(), mismatches = 0; int loIndex = 0, hiIndex = bwt.length(), mismatches = 0;
for( int i = bases.length-1; i >= 0; i-- ) { for( int i = bases.length-1; i >= 0; i-- ) {
Base base = Base.fromASCII(bases[i]); Byte base = Bases.fromASCII(bases[i]);
// Ignore non-ACGT bases. // Ignore non-ACGT bases.
if( base != null ) { if( base != null ) {

View File

@ -68,11 +68,8 @@ public class BWT {
* @param base The base. * @param base The base.
* @return Total counts for all bases lexicographically smaller than this base. * @return Total counts for all bases lexicographically smaller than this base.
*/ */
public int counts(Base base) { public int counts(byte base) {
if( base.toPack() - 1 >= 0 ) return counts.getCumulative(base);
return counts.getCumulative(Base.fromPack(base.toPack()-1));
else
return 0;
} }
/** /**
@ -81,7 +78,7 @@ public class BWT {
* @param index The position to search within the BWT. * @param index The position to search within the BWT.
* @return Total counts for all bases lexicographically smaller than this base. * @return Total counts for all bases lexicographically smaller than this base.
*/ */
public int occurrences(Base base,int index) { public int occurrences(byte base,int index) {
// If the index is above the SA-1[0], remap it to the appropriate coordinate space. // If the index is above the SA-1[0], remap it to the appropriate coordinate space.
if( index > inverseSA0 ) index--; if( index > inverseSA0 ) index--;
@ -89,7 +86,7 @@ public class BWT {
int position = index % SEQUENCE_BLOCK_SIZE; int position = index % SEQUENCE_BLOCK_SIZE;
int accumulator = block.occurrences.get(base); int accumulator = block.occurrences.get(base);
for(int i = 0; i <= position; i++) { for(int i = 0; i <= position; i++) {
if(base == Base.fromASCII(block.sequence[i])) if(base == block.sequence[i])
accumulator++; accumulator++;
} }
return accumulator; return accumulator;
@ -124,7 +121,7 @@ public class BWT {
sequenceBlocks[block] = new SequenceBlock(blockStart,blockLength,occurrences.clone(),subsequence); sequenceBlocks[block] = new SequenceBlock(blockStart,blockLength,occurrences.clone(),subsequence);
for( byte base: subsequence ) for( byte base: subsequence )
occurrences.increment(Base.fromASCII(base)); occurrences.increment(base);
} }
return sequenceBlocks; return sequenceBlocks;

View File

@ -1,97 +0,0 @@
package org.broadinstitute.sting.alignment.bwa.bwt;
import java.util.EnumSet;
import java.util.Map;
import java.util.HashMap;
/**
* Enhanced enum representation of a base.
*
* @author mhanna
* @version 0.1
*/
public enum Base
{
A((byte)'A',0),
C((byte)'C',1),
G((byte)'G',2),
T((byte)'T',3);
/**
* The ASCII representation of a given base.
*/
private final byte ascii;
/**
* The 2-bit packed value of the base.
*/
private final int pack;
/**
* Representation of the base broken down by packed value.
*/
private static final Map<Integer,Base> basesByPack = new HashMap<Integer,Base>();
/**
* Representation of the base broken down by ASCII code.
*/
private static final Map<Byte,Base> basesByASCII = new HashMap<Byte,Base>();
static {
for(Base base : EnumSet.allOf(Base.class)) {
basesByPack.put(base.pack,base);
basesByASCII.put(base.ascii,base);
}
}
/**
* Create a new base with the given ascii representation and
* pack value.
* @param ascii ASCII representation of a given base.
* @param pack Packed value of a given base.
*/
private Base( byte ascii, int pack ) {
this.ascii = ascii;
this.pack = pack;
}
/**
* Get the given base from the packed representation.
* @param pack Packed representation.
* @return base.
*/
public static Base fromPack( int pack ) { return basesByPack.get(pack); }
/**
* Convert the given base to its packed value.
* @return Packed value.
*/
public int toPack() { return pack; }
/**
* Convert the given base to its packed value.
* @param ascii ASCII representation of the base.
* @return Packed value.
*/
public static int toPack( byte ascii ) { return basesByASCII.get(ascii).pack; }
/**
* Get the given base from the ASCII representation.
* @param ascii ASCII representation.
* @return base.
*/
public static Base fromASCII( byte ascii ) { return basesByASCII.get(ascii); }
/**
* Convert the given base to its ASCII value.
* @return ASCII value.
*/
public byte toASCII() { return ascii; }
/**
* Convert the given base to its ASCII value.
* @param pack The packed representation of the base.
* @return ASCII value.
*/
public static byte toASCII( int pack ) { return basesByPack.get(pack).ascii; }
}

View File

@ -0,0 +1,108 @@
package org.broadinstitute.sting.alignment.bwa.bwt;
import org.broadinstitute.sting.utils.StingException;
import java.util.*;
/**
* Enhanced enum representation of a base.
*
* @author mhanna
* @version 0.1
*/
public class Bases implements Iterable<Byte>
{
public static byte A = 'A';
public static byte C = 'C';
public static byte G = 'G';
public static byte T = 'T';
public static final Bases instance = new Bases();
private static final List<Byte> allBases;
/**
* Representation of the base broken down by packed value.
*/
private static final Map<Integer,Byte> basesByPack = new HashMap<Integer,Byte>();
static {
List<Byte> bases = new ArrayList<Byte>();
bases.add(A);
bases.add(C);
bases.add(G);
bases.add(T);
allBases = Collections.unmodifiableList(bases);
for(int i = 0; i < allBases.size(); i++)
basesByPack.put(i,allBases.get(i));
}
/**
* Create a new base with the given ascii representation and
* pack value.
*/
private Bases() {
}
/**
* Return all possible bases.
* @return Byte representation of all bases.
*/
public static Collection<Byte> allOf() {
return allBases;
}
/**
* Gets the number of known bases.
* @return The number of known bases.
*/
public static int size() {
return allBases.size();
}
/**
* Gets an iterator over the total number of known base types.
* @return Iterator over all known bases.
*/
public Iterator<Byte> iterator() {
return basesByPack.values().iterator();
}
/**
* Get the given base from the packed representation.
* @param pack Packed representation.
* @return base.
*/
public static byte fromPack( int pack ) { return basesByPack.get(pack); }
/**
* Convert the given base to its packed value.
* @param ascii ASCII representation of the base.
* @return Packed value.
*/
public static int toPack( byte ascii )
{
for( Map.Entry<Integer,Byte> entry: basesByPack.entrySet() ) {
if( entry.getValue().equals(ascii) )
return entry.getKey();
}
throw new StingException(String.format("Base %c is an invalid base to pack", (char)ascii));
}
/**
* Convert the ASCII representation of a base to its 'normalized' representation.
* @param base The base itself.
* @return The byte, if present. Null if unknown.
*/
public static Byte fromASCII( byte base ) {
Byte found = null;
for( Byte normalized: allBases ) {
if( normalized.equals(base) ) {
found = normalized;
break;
}
}
return found;
}
}

View File

@ -2,7 +2,9 @@ package org.broadinstitute.sting.alignment.bwa.bwt;
import org.broadinstitute.sting.utils.StingException; import org.broadinstitute.sting.utils.StingException;
import java.util.EnumSet; import java.util.HashMap;
import java.util.Map;
/** /**
* Counts of how many bases of each type have been seen. * Counts of how many bases of each type have been seen.
* *
@ -13,7 +15,7 @@ public class Counts implements Cloneable {
/** /**
* Internal representation of counts, broken down by pack value. * Internal representation of counts, broken down by pack value.
*/ */
private int[] counts = new int[EnumSet.allOf(Base.class).size()]; private Map<Byte,Integer> counts = new HashMap<Byte,Integer>();
/** /**
* Create an empty Counts object with values A=0,C=0,G=0,T=0. * Create an empty Counts object with values A=0,C=0,G=0,T=0.
@ -26,13 +28,17 @@ public class Counts implements Cloneable {
* @param cumulative Whether the counts are cumulative, (count_G=numA+numC+numG,for example). * @param cumulative Whether the counts are cumulative, (count_G=numA+numC+numG,for example).
*/ */
public Counts( int[] data, boolean cumulative ) { public Counts( int[] data, boolean cumulative ) {
for( Base base: EnumSet.allOf(Base.class)) for( byte base: Bases.instance)
counts[base.toPack()] = data[base.toPack()]; counts.put(base,data[Bases.toPack(base)]);
// De-cumulatize data as necessary. // De-cumulatize data as necessary.
if(cumulative) { if(cumulative) {
for( int i = EnumSet.allOf(Base.class).size()-1; i > 0; i-- ) int previousCount = 0;
counts[i] -= counts[i-1]; for( byte base: Bases.instance ) {
int count = counts.get(base);
counts.put(base,count-previousCount);
previousCount = count;
}
} }
} }
@ -42,9 +48,11 @@ public class Counts implements Cloneable {
* @return Array of count values. * @return Array of count values.
*/ */
public int[] toArray(boolean cumulative) { public int[] toArray(boolean cumulative) {
int[] countArray = counts.clone(); int[] countArray = new int[counts.size()];
for(byte base: Bases.instance)
countArray[Bases.toPack(base)] = counts.get(base);
if(cumulative) { if(cumulative) {
for( int i = 1; i < counts.length; i++ ) for( int i = 1; i < countArray.length; i++ )
countArray[i] += countArray[i-1]; countArray[i] += countArray[i-1];
} }
return countArray; return countArray;
@ -62,8 +70,7 @@ public class Counts implements Cloneable {
catch(CloneNotSupportedException ex) { catch(CloneNotSupportedException ex) {
throw new StingException("Unable to clone counts object", ex); throw new StingException("Unable to clone counts object", ex);
} }
other.counts = new int[counts.length]; other.counts = new HashMap<Byte,Integer>(counts);
System.arraycopy(counts,0,other.counts,0,counts.length);
return other; return other;
} }
@ -71,8 +78,8 @@ public class Counts implements Cloneable {
* Increment the number of bases seen at the given location. * Increment the number of bases seen at the given location.
* @param base Base to increment. * @param base Base to increment.
*/ */
public void increment(Base base) { public void increment(byte base) {
counts[base.toPack()]++; counts.put(base,counts.get(base)+1);
} }
/** /**
@ -82,8 +89,8 @@ public class Counts implements Cloneable {
* @param base Base for which to query counts. * @param base Base for which to query counts.
* @return Number of bases of this type seen. * @return Number of bases of this type seen.
*/ */
public int get(Base base) { public int get(byte base) {
return counts[base.toPack()]; return counts.get(base);
} }
/** /**
@ -93,10 +100,12 @@ public class Counts implements Cloneable {
* @param base Base for which to query counts. * @param base Base for which to query counts.
* @return Number of bases of this type seen. * @return Number of bases of this type seen.
*/ */
public int getCumulative(Base base) { public int getCumulative(byte base) {
int accum = 0; int accum = 0;
for(int i = 0; i <= base.toPack(); i++) for( byte current: Bases.allOf() ) {
accum += counts[i]; if(base == current) break;
accum += counts.get(current);
}
return accum; return accum;
} }
@ -106,7 +115,7 @@ public class Counts implements Cloneable {
*/ */
public int getTotal() { public int getTotal() {
int accumulator = 0; int accumulator = 0;
for( int count : counts ) for( int count : counts.values() )
accumulator += count; accumulator += count;
return accumulator; return accumulator;
} }

View File

@ -61,7 +61,7 @@ public class CreateBWTFromReference {
private Counts countOccurrences( String sequence ) { private Counts countOccurrences( String sequence ) {
Counts occurrences = new Counts(); Counts occurrences = new Counts();
for( char base: sequence.toCharArray() ) for( char base: sequence.toCharArray() )
occurrences.increment(Base.fromASCII((byte)base)); occurrences.increment((byte)base);
return occurrences; return occurrences;
} }
@ -146,10 +146,10 @@ public class CreateBWTFromReference {
// Count the occurences of each given base. // Count the occurences of each given base.
Counts occurrences = creator.countOccurrences(sequence); Counts occurrences = creator.countOccurrences(sequence);
System.out.printf("Occurrences: a=%d, c=%d, g=%d, t=%d%n",occurrences.getCumulative(Base.A), System.out.printf("Occurrences: a=%d, c=%d, g=%d, t=%d%n",occurrences.getCumulative(Bases.A),
occurrences.getCumulative(Base.C), occurrences.getCumulative(Bases.C),
occurrences.getCumulative(Base.G), occurrences.getCumulative(Bases.G),
occurrences.getCumulative(Base.T)); occurrences.getCumulative(Bases.T));
// Generate the suffix array and print diagnostics. // Generate the suffix array and print diagnostics.
int[] suffixArrayData = creator.createSuffixArray(sequence); int[] suffixArrayData = creator.createSuffixArray(sequence);