diff --git a/public/java/src/org/broadinstitute/sting/gatk/traversals/TraversalEngine.java b/public/java/src/org/broadinstitute/sting/gatk/traversals/TraversalEngine.java index 4ef255524..2593fc72e 100755 --- a/public/java/src/org/broadinstitute/sting/gatk/traversals/TraversalEngine.java +++ b/public/java/src/org/broadinstitute/sting/gatk/traversals/TraversalEngine.java @@ -62,54 +62,6 @@ public abstract class TraversalEngine,Provide } - /** - * Simple utility class that makes it convenient to print unit adjusted times - */ - private static class MyTime { - double t; // in Seconds - int precision; // for format - - public MyTime(double t, int precision) { - this.t = t; - this.precision = precision; - } - - public MyTime(double t) { - this(t, 1); - } - - /** - * Instead of 10000 s, returns 2.8 hours - * @return - */ - public String toString() { - double unitTime = t; - String unit = "s"; - - if ( t > 120 ) { - unitTime = t / 60; // minutes - unit = "m"; - - if ( unitTime > 120 ) { - unitTime /= 60; // hours - unit = "h"; - - if ( unitTime > 100 ) { - unitTime /= 24; // days - unit = "d"; - - if ( unitTime > 20 ) { - unitTime /= 7; // days - unit = "w"; - } - } - } - } - - return String.format("%6."+precision+"f %s", unitTime, unit); - } - } - /** lock object to sure updates to history are consistent across threads */ private static final Object lock = new Object(); LinkedList history = new LinkedList(); @@ -280,20 +232,20 @@ public abstract class TraversalEngine,Provide ProcessingHistory last = updateHistory(loc,cumulativeMetrics); - final MyTime elapsed = new MyTime(last.elapsedSeconds); - final MyTime bpRate = new MyTime(secondsPerMillionBP(last)); - final MyTime unitRate = new MyTime(secondsPerMillionElements(last)); + final AutoFormattingTime elapsed = new AutoFormattingTime(last.elapsedSeconds); + final AutoFormattingTime bpRate = new AutoFormattingTime(secondsPerMillionBP(last)); + final AutoFormattingTime unitRate = new AutoFormattingTime(secondsPerMillionElements(last)); final double fractionGenomeTargetCompleted = calculateFractionGenomeTargetCompleted(last); - final MyTime estTotalRuntime = new MyTime(elapsed.t / fractionGenomeTargetCompleted); - final MyTime timeToCompletion = new MyTime(estTotalRuntime.t - elapsed.t); + final AutoFormattingTime estTotalRuntime = new AutoFormattingTime(elapsed.getTimeInSeconds() / fractionGenomeTargetCompleted); + final AutoFormattingTime timeToCompletion = new AutoFormattingTime(estTotalRuntime.getTimeInSeconds() - elapsed.getTimeInSeconds()); if ( printProgress ) { lastProgressPrintTime = curTime; // dynamically change the update rate so that short running jobs receive frequent updates while longer jobs receive fewer updates - if ( estTotalRuntime.t > TWELVE_HOURS_IN_SECONDS ) + if ( estTotalRuntime.getTimeInSeconds() > TWELVE_HOURS_IN_SECONDS ) PROGRESS_PRINT_FREQUENCY = 60 * 1000; // in milliseconds - else if ( estTotalRuntime.t > TWO_HOURS_IN_SECONDS ) + else if ( estTotalRuntime.getTimeInSeconds() > TWO_HOURS_IN_SECONDS ) PROGRESS_PRINT_FREQUENCY = 30 * 1000; // in milliseconds else PROGRESS_PRINT_FREQUENCY = 10 * 1000; // in milliseconds @@ -308,8 +260,9 @@ public abstract class TraversalEngine,Provide lastPerformanceLogPrintTime = curTime; synchronized(performanceLogLock) { performanceLog.printf("%.2f\t%d\t%.2e\t%d\t%.2e\t%.2e\t%.2f\t%.2f%n", - elapsed.t, nRecords, unitRate.t, last.bpProcessed, bpRate.t, - fractionGenomeTargetCompleted, estTotalRuntime.t, timeToCompletion.t); + elapsed.getTimeInSeconds(), nRecords, unitRate.getTimeInSeconds(), last.bpProcessed, + bpRate.getTimeInSeconds(), fractionGenomeTargetCompleted, estTotalRuntime.getTimeInSeconds(), + timeToCompletion.getTimeInSeconds()); } } } diff --git a/public/java/src/org/broadinstitute/sting/utils/AutoFormattingTime.java b/public/java/src/org/broadinstitute/sting/utils/AutoFormattingTime.java new file mode 100644 index 000000000..8964c16cb --- /dev/null +++ b/public/java/src/org/broadinstitute/sting/utils/AutoFormattingTime.java @@ -0,0 +1,53 @@ +package org.broadinstitute.sting.utils; + +/** + * Simple utility class that makes it convenient to print unit adjusted times + */ +public class AutoFormattingTime { + double timeInSeconds; // in Seconds + int precision; // for format + + public AutoFormattingTime(double timeInSeconds, int precision) { + this.timeInSeconds = timeInSeconds; + this.precision = precision; + } + + public AutoFormattingTime(double timeInSeconds) { + this(timeInSeconds, 1); + } + + public double getTimeInSeconds() { + return timeInSeconds; + } + + /** + * Instead of 10000 s, returns 2.8 hours + * @return + */ + public String toString() { + double unitTime = timeInSeconds; + String unit = "s"; + + if ( timeInSeconds > 120 ) { + unitTime = timeInSeconds / 60; // minutes + unit = "m"; + + if ( unitTime > 120 ) { + unitTime /= 60; // hours + unit = "h"; + + if ( unitTime > 100 ) { + unitTime /= 24; // days + unit = "d"; + + if ( unitTime > 20 ) { + unitTime /= 7; // days + unit = "w"; + } + } + } + } + + return String.format("%6."+precision+"f %s", unitTime, unit); + } +} diff --git a/public/java/src/org/broadinstitute/sting/utils/threading/StateMonitoringThreadFactory.java b/public/java/src/org/broadinstitute/sting/utils/threading/StateMonitoringThreadFactory.java index 1e0988bb7..39d5c1497 100644 --- a/public/java/src/org/broadinstitute/sting/utils/threading/StateMonitoringThreadFactory.java +++ b/public/java/src/org/broadinstitute/sting/utils/threading/StateMonitoringThreadFactory.java @@ -23,7 +23,11 @@ */ package org.broadinstitute.sting.utils.threading; +import com.google.java.contract.Ensures; +import com.google.java.contract.Invariant; import org.apache.log4j.Logger; +import org.apache.log4j.Priority; +import org.broadinstitute.sting.utils.AutoFormattingTime; import java.lang.management.ManagementFactory; import java.lang.management.ThreadInfo; @@ -36,7 +40,7 @@ import java.util.concurrent.CountDownLatch; import java.util.concurrent.ThreadFactory; /** - * Create threads, collecting statistics about their running state over time + * Create activeThreads, collecting statistics about their running state over time * * Uses a ThreadMXBean to capture info via ThreadInfo * @@ -44,34 +48,91 @@ import java.util.concurrent.ThreadFactory; * Date: 8/14/12 * Time: 8:47 AM */ +@Invariant({ + "activeThreads.size() <= nThreadsToCreate", + "countDownLatch.getCount() <= nThreadsToCreate", + "nThreadsToCreated <= nThreadsToCreate" +}) public class StateMonitoringThreadFactory implements ThreadFactory { protected static final boolean DEBUG = false; private static Logger logger = Logger.getLogger(StateMonitoringThreadFactory.class); public static final List TRACKED_STATES = Arrays.asList(Thread.State.BLOCKED, Thread.State.RUNNABLE, Thread.State.WAITING); - final int threadsToCreate; - final List threads; + // todo -- it would be nice to not have to specify upfront the number of threads. + // todo -- can we dynamically increment countDownLatch? It seems not... + final int nThreadsToCreate; + final List activeThreads; final EnumMap times = new EnumMap(Thread.State.class); + + int nThreadsToCreated = 0; + + /** + * The bean used to get the thread info about blocked and waiting times + */ final ThreadMXBean bean; - final CountDownLatch activeThreads; - public StateMonitoringThreadFactory(final int threadsToCreate) { - if ( threadsToCreate <= 0 ) throw new IllegalArgumentException("threadsToCreate <= 0: " + threadsToCreate); + /** + * Counts down the number of active activeThreads whose runtime info hasn't been incorporated into + * times. Counts down from nThreadsToCreate to 0, at which point any code waiting + * on the final times is freed to run. + */ + final CountDownLatch countDownLatch; - this.threadsToCreate = threadsToCreate; - threads = new ArrayList(threadsToCreate); - for ( final Thread.State state : Thread.State.values() ) - times.put(state, 0l); - bean = ManagementFactory.getThreadMXBean(); - bean.setThreadContentionMonitoringEnabled(true); - bean.setThreadCpuTimeEnabled(true); - activeThreads = new CountDownLatch(threadsToCreate); + /** + * Instead of RUNNABLE we want to print running. This map goes from Thread.State names to human readable ones + */ + final static EnumMap PRETTY_NAMES = new EnumMap(Thread.State.class); + static { + PRETTY_NAMES.put(Thread.State.RUNNABLE, "running"); + PRETTY_NAMES.put(Thread.State.BLOCKED, "blocked"); + PRETTY_NAMES.put(Thread.State.WAITING, "waiting"); } + /** + * Create a new factory generating threads whose runtime and contention + * behavior is tracked in this factory. + * + * @param nThreadsToCreate the number of threads we will create in the factory before it's considered complete + * // TODO -- remove argument when we figure out how to implement this capability + */ + public StateMonitoringThreadFactory(final int nThreadsToCreate) { + if ( nThreadsToCreate <= 0 ) throw new IllegalArgumentException("nThreadsToCreate <= 0: " + nThreadsToCreate); + + this.nThreadsToCreate = nThreadsToCreate; + activeThreads = new ArrayList(nThreadsToCreate); + + // initialize times to 0 + for ( final Thread.State state : Thread.State.values() ) + times.put(state, 0l); + + // get the bean, and start tracking + bean = ManagementFactory.getThreadMXBean(); + if ( bean.isThreadContentionMonitoringSupported() ) + bean.setThreadContentionMonitoringEnabled(true); + else + logger.warn("Thread contention monitoring not supported, we cannot track GATK multi-threaded efficiency"); + //bean.setThreadCpuTimeEnabled(true); + + countDownLatch = new CountDownLatch(nThreadsToCreate); + } + + /** + * Get the time spent in state across all threads created by this factory + * + * @param state on of the TRACKED_STATES + * @return the time in milliseconds + */ + @Ensures({"result >= 0", "TRACKED_STATES.contains(state)"}) public synchronized long getStateTime(final Thread.State state) { return times.get(state); } + /** + * Get the total time spent in all states across all threads created by this factory + * + * @return the time in milliseconds + */ + @Ensures({"result >= 0"}) public synchronized long getTotalTime() { long total = 0; for ( final long time : times.values() ) @@ -79,16 +140,27 @@ public class StateMonitoringThreadFactory implements ThreadFactory { return total; } + /** + * Get the fraction of time spent in state across all threads created by this factory + * + * @return the fraction (0.0-1.0) of time spent in state over all state times of all threads + */ + @Ensures({"result >= 0.0", "result <= 1.0", "TRACKED_STATES.contains(state)"}) public synchronized double getStateFraction(final Thread.State state) { - return getStateTime(state) / (1.0 * getTotalTime()); + return getStateTime(state) / (1.0 * Math.max(getTotalTime(), 1)); } - public int getNThreads() { - return threads.size(); + /** + * How many threads have been created by this factory so far? + * @return + */ + @Ensures("result >= 0") + public int getNThreadsCreated() { + return nThreadsToCreated; } public void waitForAllThreadsToComplete() throws InterruptedException { - activeThreads.await(); + countDownLatch.await(); } @Override @@ -103,33 +175,108 @@ public class StateMonitoringThreadFactory implements ThreadFactory { return b.toString(); } - @Override - public synchronized Thread newThread(final Runnable runnable) { - if ( threads.size() >= threadsToCreate ) - throw new IllegalStateException("Attempting to create more threads than allowed by constructor argument threadsToCreate " + threadsToCreate); + /** + * Print usage information about threads from this factory to logger + * with the INFO priority + * + * @param logger + */ + public synchronized void printUsageInformation(final Logger logger) { + printUsageInformation(logger, Priority.INFO); + } + /** + * Print usage information about threads from this factory to logger + * with the provided priority + * + * @param logger + */ + public synchronized void printUsageInformation(final Logger logger, final Priority priority) { + logger.log(priority, "Number of activeThreads used: " + getNThreadsCreated()); + logger.log(priority, "Total runtime " + new AutoFormattingTime(getTotalTime() / 1000.0)); + for ( final Thread.State state : TRACKED_STATES ) { + logger.log(priority, String.format(" Fraction of time spent %s is %.2f (%s)", + prettyName(state), getStateFraction(state), new AutoFormattingTime(getStateTime(state) / 1000.0))); + } + logger.log(priority, String.format("Efficiency of multi-threading: %.2f%% of time spent doing productive work", + getStateFraction(Thread.State.RUNNABLE) * 100)); + } + + private String prettyName(final Thread.State state) { + return PRETTY_NAMES.get(state); + } + + /** + * Create a new thread from this factory + * + * @param runnable + * @return + */ + @Override + @Ensures({ + "activeThreads.size() > old(activeThreads.size())", + "activeThreads.contains(result)", + "nThreadsToCreated == old(nThreadsToCreated) + 1" + }) + public synchronized Thread newThread(final Runnable runnable) { + if ( activeThreads.size() >= nThreadsToCreate) + throw new IllegalStateException("Attempting to create more activeThreads than allowed by constructor argument nThreadsToCreate " + nThreadsToCreate); + + nThreadsToCreated++; final Thread myThread = new TrackingThread(runnable); - threads.add(myThread); + activeThreads.add(myThread); return myThread; } - // TODO -- add polling capability - - private synchronized void updateThreadInfo(final Thread thread, final long runtime) { + /** + * Update the information about completed thread that ran for runtime in milliseconds + * + * This method updates all of the key timing and tracking information in the factory so that + * thread can be retired. After this call the factory shouldn't have a pointer to the thread any longer + * + * @param thread + * @param runtimeInMilliseconds + */ + @Ensures({ + "activeThreads.size() < old(activeThreads.size())", + "! activeThreads.contains(thread)", + "getTotalTime() >= old(getTotalTime())", + "countDownLatch.getCount() < old(countDownLatch.getCount())" + }) + private synchronized void threadIsDone(final Thread thread, final long runtimeInMilliseconds) { + if ( DEBUG ) logger.warn(" Countdown " + countDownLatch.getCount() + " in thread " + Thread.currentThread().getName()); if ( DEBUG ) logger.warn("UpdateThreadInfo called"); + final ThreadInfo info = bean.getThreadInfo(thread.getId()); if ( info != null ) { - if ( DEBUG ) logger.warn("Updating thread total runtime " + runtime + " of which blocked " + info.getBlockedTime() + " and waiting " + info.getWaitedTime()); + if ( DEBUG ) logger.warn("Updating thread total runtime " + runtimeInMilliseconds + " of which blocked " + info.getBlockedTime() + " and waiting " + info.getWaitedTime()); incTimes(Thread.State.BLOCKED, info.getBlockedTime()); incTimes(Thread.State.WAITING, info.getWaitedTime()); - incTimes(Thread.State.RUNNABLE, runtime - info.getWaitedTime() - info.getBlockedTime()); + incTimes(Thread.State.RUNNABLE, runtimeInMilliseconds - info.getWaitedTime() - info.getBlockedTime()); } + + // remove the thread from the list of active activeThreads + if ( ! activeThreads.remove(thread) ) + throw new IllegalStateException("Thread " + thread + " not in list of active activeThreads"); + + // one less thread is live for those blocking on all activeThreads to be complete + countDownLatch.countDown(); + if ( DEBUG ) logger.warn(" -> Countdown " + countDownLatch.getCount() + " in thread " + Thread.currentThread().getName()); } + /** + * Helper function that increments the times counter by by for state + * + * @param state + * @param by + */ private synchronized void incTimes(final Thread.State state, final long by) { times.put(state, times.get(state) + by); } + /** + * A wrapper around Thread that tracks the runtime of the thread and calls threadIsDone() when complete + */ private class TrackingThread extends Thread { private TrackingThread(Runnable runnable) { super(runnable); @@ -140,10 +287,7 @@ public class StateMonitoringThreadFactory implements ThreadFactory { final long startTime = System.currentTimeMillis(); super.run(); final long endTime = System.currentTimeMillis(); - if ( DEBUG ) logger.warn(" Countdown " + activeThreads.getCount() + " in thread " + Thread.currentThread().getName()); - updateThreadInfo(this, endTime - startTime); - activeThreads.countDown(); - if ( DEBUG ) logger.warn(" -> Countdown " + activeThreads.getCount() + " in thread " + Thread.currentThread().getName()); + threadIsDone(this, endTime - startTime); } } } diff --git a/public/java/src/org/broadinstitute/sting/utils/threading/package-info.java b/public/java/src/org/broadinstitute/sting/utils/threading/package-info.java index dc350920e..d72dad471 100644 --- a/public/java/src/org/broadinstitute/sting/utils/threading/package-info.java +++ b/public/java/src/org/broadinstitute/sting/utils/threading/package-info.java @@ -1,4 +1,4 @@ /** - * Provides tools for managing threads, thread pools, and parallelization in general. + * Provides tools for managing activeThreads, thread pools, and parallelization in general. */ package org.broadinstitute.sting.utils.threading; diff --git a/public/java/test/org/broadinstitute/sting/utils/threading/StateMonitoringThreadFactoryUnitTest.java b/public/java/test/org/broadinstitute/sting/utils/threading/StateMonitoringThreadFactoryUnitTest.java index 6fc852bbf..c22b49c23 100755 --- a/public/java/test/org/broadinstitute/sting/utils/threading/StateMonitoringThreadFactoryUnitTest.java +++ b/public/java/test/org/broadinstitute/sting/utils/threading/StateMonitoringThreadFactoryUnitTest.java @@ -23,6 +23,7 @@ */ package org.broadinstitute.sting.utils.threading; +import org.apache.log4j.Priority; import org.broadinstitute.sting.BaseTest; import org.broadinstitute.sting.utils.Utils; import org.broadinstitute.sting.utils.exceptions.ReviewedStingException; @@ -30,7 +31,6 @@ import org.testng.Assert; import org.testng.annotations.DataProvider; import org.testng.annotations.Test; -import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; import java.util.List; @@ -40,6 +40,7 @@ import java.util.concurrent.*; * Tests for the state monitoring thread factory. */ public class StateMonitoringThreadFactoryUnitTest extends BaseTest { + // the duration of the tests -- 100 ms is tolerable given the number of tests we are doing private final static long THREAD_TARGET_DURATION_IN_MILLISECOND = 100; final static Object GLOBAL_LOCK = new Object(); @@ -68,10 +69,16 @@ public class StateMonitoringThreadFactoryUnitTest extends BaseTest { } } + /** + * Test helper threading class that puts the thread into RUNNING, BLOCKED, or WAITING state as + * requested for input argument + */ private static class StateTestThread implements Callable { private final Thread.State stateToImplement; private StateTestThread(final Thread.State stateToImplement) { + if ( ! StateMonitoringThreadFactory.TRACKED_STATES.contains(stateToImplement) ) + throw new IllegalArgumentException("Unexpected state " + stateToImplement); this.stateToImplement = stateToImplement; } @@ -92,6 +99,7 @@ public class StateMonitoringThreadFactoryUnitTest extends BaseTest { case BLOCKED: if ( StateMonitoringThreadFactory.DEBUG ) logger.warn("Blocking..."); synchronized (GLOBAL_LOCK) { + // the GLOBAL_LOCK must be held by the unit test itself for this to properly block if ( StateMonitoringThreadFactory.DEBUG ) logger.warn(" ... done blocking"); } return 0.0; @@ -103,7 +111,7 @@ public class StateMonitoringThreadFactoryUnitTest extends BaseTest { @DataProvider(name = "StateTest") public Object[][] createStateTest() { - for ( final int nThreads : Arrays.asList(1, 2, 3, 4, 5) ) { + for ( final int nThreads : Arrays.asList(1, 2, 3, 4) ) { for (final List states : Utils.makeCombinations(StateMonitoringThreadFactory.TRACKED_STATES, nThreads) ) { //if ( Collections.frequency(states, Thread.State.BLOCKED) > 0) new StateTest(states); @@ -125,7 +133,7 @@ public class StateMonitoringThreadFactoryUnitTest extends BaseTest { for ( final Thread.State threadToRunState : test.getStatesForThreads() ) threadPool.submit(new StateTestThread(threadToRunState)); - // lock has to be here for the whole running of the threads but end before the sleep so the blocked threads + // lock has to be here for the whole running of the activeThreads but end before the sleep so the blocked activeThreads // can block for their allotted time threadPool.shutdown(); Thread.sleep(THREAD_TARGET_DURATION_IN_MILLISECOND); @@ -133,29 +141,35 @@ public class StateMonitoringThreadFactoryUnitTest extends BaseTest { //logger.warn(" Releasing lock"); threadPool.awaitTermination(10, TimeUnit.SECONDS); //logger.warn(" done awaiting termination"); - //logger.warn(" waiting for all threads to complete"); + //logger.warn(" waiting for all activeThreads to complete"); factory.waitForAllThreadsToComplete(); - //logger.warn(" done waiting for threads"); + //logger.warn(" done waiting for activeThreads"); // make sure we counted everything properly final long totalTime = factory.getTotalTime(); final long minTime = (THREAD_TARGET_DURATION_IN_MILLISECOND - 10) * test.getNStates(); + final long maxTime = (THREAD_TARGET_DURATION_IN_MILLISECOND + 10) * test.getNStates(); //logger.warn("Testing total time"); Assert.assertTrue(totalTime >= minTime, "Factory results not properly accumulated: totalTime = " + totalTime + " < minTime = " + minTime); + Assert.assertTrue(totalTime <= maxTime, "Factory results not properly accumulated: totalTime = " + totalTime + " > maxTime = " + maxTime); for (final Thread.State state : StateMonitoringThreadFactory.TRACKED_STATES ) { final double min = test.minStateFraction(state); final double max = test.maxStateFraction(state); final double obs = factory.getStateFraction(state); - logger.warn(" Checking " + state - + " min " + String.format("%.2f", min) - + " max " + String.format("%.2f", max) - + " obs " + String.format("%.2f", obs) - + " factor = " + factory); +// logger.warn(" Checking " + state +// + " min " + String.format("%.2f", min) +// + " max " + String.format("%.2f", max) +// + " obs " + String.format("%.2f", obs) +// + " factor = " + factory); Assert.assertTrue(obs >= min, "Too little time spent in state " + state + " obs " + obs + " min " + min); Assert.assertTrue(obs <= max, "Too much time spent in state " + state + " obs " + obs + " max " + min); } - Assert.assertEquals(factory.getNThreads(), test.getNStates()); + // we actually ran the expected number of activeThreads + Assert.assertEquals(factory.getNThreadsCreated(), test.getNStates()); + + // should be called to ensure we don't format / NPE on output + factory.printUsageInformation(logger, Priority.INFO); } } \ No newline at end of file