Clean, documented implementation of ThreadFactory that monitors running / blocking / waiting time of threads it creates

-- Expanded unit tests
-- Support for clean logging of results to logger
-- Refactored MyTime into AutoFormattingTime in Utils, out of TraversalEngine, for cleanliness and reuse
-- Added docs and contracts to StateMonitoringThreadFactory
This commit is contained in:
Mark DePristo 2012-08-14 16:27:30 -04:00
parent be3230a1fd
commit 9459e6203a
5 changed files with 265 additions and 101 deletions

View File

@ -62,54 +62,6 @@ public abstract class TraversalEngine<M,T,WalkerType extends Walker<M,T>,Provide
}
/**
* Simple utility class that makes it convenient to print unit adjusted times
*/
private static class MyTime {
double t; // in Seconds
int precision; // for format
public MyTime(double t, int precision) {
this.t = t;
this.precision = precision;
}
public MyTime(double t) {
this(t, 1);
}
/**
* Instead of 10000 s, returns 2.8 hours
* @return
*/
public String toString() {
double unitTime = t;
String unit = "s";
if ( t > 120 ) {
unitTime = t / 60; // minutes
unit = "m";
if ( unitTime > 120 ) {
unitTime /= 60; // hours
unit = "h";
if ( unitTime > 100 ) {
unitTime /= 24; // days
unit = "d";
if ( unitTime > 20 ) {
unitTime /= 7; // days
unit = "w";
}
}
}
}
return String.format("%6."+precision+"f %s", unitTime, unit);
}
}
/** lock object to sure updates to history are consistent across threads */
private static final Object lock = new Object();
LinkedList<ProcessingHistory> history = new LinkedList<ProcessingHistory>();
@ -280,20 +232,20 @@ public abstract class TraversalEngine<M,T,WalkerType extends Walker<M,T>,Provide
ProcessingHistory last = updateHistory(loc,cumulativeMetrics);
final MyTime elapsed = new MyTime(last.elapsedSeconds);
final MyTime bpRate = new MyTime(secondsPerMillionBP(last));
final MyTime unitRate = new MyTime(secondsPerMillionElements(last));
final AutoFormattingTime elapsed = new AutoFormattingTime(last.elapsedSeconds);
final AutoFormattingTime bpRate = new AutoFormattingTime(secondsPerMillionBP(last));
final AutoFormattingTime unitRate = new AutoFormattingTime(secondsPerMillionElements(last));
final double fractionGenomeTargetCompleted = calculateFractionGenomeTargetCompleted(last);
final MyTime estTotalRuntime = new MyTime(elapsed.t / fractionGenomeTargetCompleted);
final MyTime timeToCompletion = new MyTime(estTotalRuntime.t - elapsed.t);
final AutoFormattingTime estTotalRuntime = new AutoFormattingTime(elapsed.getTimeInSeconds() / fractionGenomeTargetCompleted);
final AutoFormattingTime timeToCompletion = new AutoFormattingTime(estTotalRuntime.getTimeInSeconds() - elapsed.getTimeInSeconds());
if ( printProgress ) {
lastProgressPrintTime = curTime;
// dynamically change the update rate so that short running jobs receive frequent updates while longer jobs receive fewer updates
if ( estTotalRuntime.t > TWELVE_HOURS_IN_SECONDS )
if ( estTotalRuntime.getTimeInSeconds() > TWELVE_HOURS_IN_SECONDS )
PROGRESS_PRINT_FREQUENCY = 60 * 1000; // in milliseconds
else if ( estTotalRuntime.t > TWO_HOURS_IN_SECONDS )
else if ( estTotalRuntime.getTimeInSeconds() > TWO_HOURS_IN_SECONDS )
PROGRESS_PRINT_FREQUENCY = 30 * 1000; // in milliseconds
else
PROGRESS_PRINT_FREQUENCY = 10 * 1000; // in milliseconds
@ -308,8 +260,9 @@ public abstract class TraversalEngine<M,T,WalkerType extends Walker<M,T>,Provide
lastPerformanceLogPrintTime = curTime;
synchronized(performanceLogLock) {
performanceLog.printf("%.2f\t%d\t%.2e\t%d\t%.2e\t%.2e\t%.2f\t%.2f%n",
elapsed.t, nRecords, unitRate.t, last.bpProcessed, bpRate.t,
fractionGenomeTargetCompleted, estTotalRuntime.t, timeToCompletion.t);
elapsed.getTimeInSeconds(), nRecords, unitRate.getTimeInSeconds(), last.bpProcessed,
bpRate.getTimeInSeconds(), fractionGenomeTargetCompleted, estTotalRuntime.getTimeInSeconds(),
timeToCompletion.getTimeInSeconds());
}
}
}

View File

@ -0,0 +1,53 @@
package org.broadinstitute.sting.utils;
/**
* Simple utility class that makes it convenient to print unit adjusted times
*/
public class AutoFormattingTime {
double timeInSeconds; // in Seconds
int precision; // for format
public AutoFormattingTime(double timeInSeconds, int precision) {
this.timeInSeconds = timeInSeconds;
this.precision = precision;
}
public AutoFormattingTime(double timeInSeconds) {
this(timeInSeconds, 1);
}
public double getTimeInSeconds() {
return timeInSeconds;
}
/**
* Instead of 10000 s, returns 2.8 hours
* @return
*/
public String toString() {
double unitTime = timeInSeconds;
String unit = "s";
if ( timeInSeconds > 120 ) {
unitTime = timeInSeconds / 60; // minutes
unit = "m";
if ( unitTime > 120 ) {
unitTime /= 60; // hours
unit = "h";
if ( unitTime > 100 ) {
unitTime /= 24; // days
unit = "d";
if ( unitTime > 20 ) {
unitTime /= 7; // days
unit = "w";
}
}
}
}
return String.format("%6."+precision+"f %s", unitTime, unit);
}
}

View File

@ -23,7 +23,11 @@
*/
package org.broadinstitute.sting.utils.threading;
import com.google.java.contract.Ensures;
import com.google.java.contract.Invariant;
import org.apache.log4j.Logger;
import org.apache.log4j.Priority;
import org.broadinstitute.sting.utils.AutoFormattingTime;
import java.lang.management.ManagementFactory;
import java.lang.management.ThreadInfo;
@ -36,7 +40,7 @@ import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ThreadFactory;
/**
* Create threads, collecting statistics about their running state over time
* Create activeThreads, collecting statistics about their running state over time
*
* Uses a ThreadMXBean to capture info via ThreadInfo
*
@ -44,34 +48,91 @@ import java.util.concurrent.ThreadFactory;
* Date: 8/14/12
* Time: 8:47 AM
*/
@Invariant({
"activeThreads.size() <= nThreadsToCreate",
"countDownLatch.getCount() <= nThreadsToCreate",
"nThreadsToCreated <= nThreadsToCreate"
})
public class StateMonitoringThreadFactory implements ThreadFactory {
protected static final boolean DEBUG = false;
private static Logger logger = Logger.getLogger(StateMonitoringThreadFactory.class);
public static final List<Thread.State> TRACKED_STATES = Arrays.asList(Thread.State.BLOCKED, Thread.State.RUNNABLE, Thread.State.WAITING);
final int threadsToCreate;
final List<Thread> threads;
// todo -- it would be nice to not have to specify upfront the number of threads.
// todo -- can we dynamically increment countDownLatch? It seems not...
final int nThreadsToCreate;
final List<Thread> activeThreads;
final EnumMap<Thread.State, Long> times = new EnumMap<Thread.State, Long>(Thread.State.class);
int nThreadsToCreated = 0;
/**
* The bean used to get the thread info about blocked and waiting times
*/
final ThreadMXBean bean;
final CountDownLatch activeThreads;
public StateMonitoringThreadFactory(final int threadsToCreate) {
if ( threadsToCreate <= 0 ) throw new IllegalArgumentException("threadsToCreate <= 0: " + threadsToCreate);
/**
* Counts down the number of active activeThreads whose runtime info hasn't been incorporated into
* times. Counts down from nThreadsToCreate to 0, at which point any code waiting
* on the final times is freed to run.
*/
final CountDownLatch countDownLatch;
this.threadsToCreate = threadsToCreate;
threads = new ArrayList<Thread>(threadsToCreate);
for ( final Thread.State state : Thread.State.values() )
times.put(state, 0l);
bean = ManagementFactory.getThreadMXBean();
bean.setThreadContentionMonitoringEnabled(true);
bean.setThreadCpuTimeEnabled(true);
activeThreads = new CountDownLatch(threadsToCreate);
/**
* Instead of RUNNABLE we want to print running. This map goes from Thread.State names to human readable ones
*/
final static EnumMap<Thread.State, String> PRETTY_NAMES = new EnumMap<Thread.State, String>(Thread.State.class);
static {
PRETTY_NAMES.put(Thread.State.RUNNABLE, "running");
PRETTY_NAMES.put(Thread.State.BLOCKED, "blocked");
PRETTY_NAMES.put(Thread.State.WAITING, "waiting");
}
/**
* Create a new factory generating threads whose runtime and contention
* behavior is tracked in this factory.
*
* @param nThreadsToCreate the number of threads we will create in the factory before it's considered complete
* // TODO -- remove argument when we figure out how to implement this capability
*/
public StateMonitoringThreadFactory(final int nThreadsToCreate) {
if ( nThreadsToCreate <= 0 ) throw new IllegalArgumentException("nThreadsToCreate <= 0: " + nThreadsToCreate);
this.nThreadsToCreate = nThreadsToCreate;
activeThreads = new ArrayList<Thread>(nThreadsToCreate);
// initialize times to 0
for ( final Thread.State state : Thread.State.values() )
times.put(state, 0l);
// get the bean, and start tracking
bean = ManagementFactory.getThreadMXBean();
if ( bean.isThreadContentionMonitoringSupported() )
bean.setThreadContentionMonitoringEnabled(true);
else
logger.warn("Thread contention monitoring not supported, we cannot track GATK multi-threaded efficiency");
//bean.setThreadCpuTimeEnabled(true);
countDownLatch = new CountDownLatch(nThreadsToCreate);
}
/**
* Get the time spent in state across all threads created by this factory
*
* @param state on of the TRACKED_STATES
* @return the time in milliseconds
*/
@Ensures({"result >= 0", "TRACKED_STATES.contains(state)"})
public synchronized long getStateTime(final Thread.State state) {
return times.get(state);
}
/**
* Get the total time spent in all states across all threads created by this factory
*
* @return the time in milliseconds
*/
@Ensures({"result >= 0"})
public synchronized long getTotalTime() {
long total = 0;
for ( final long time : times.values() )
@ -79,16 +140,27 @@ public class StateMonitoringThreadFactory implements ThreadFactory {
return total;
}
/**
* Get the fraction of time spent in state across all threads created by this factory
*
* @return the fraction (0.0-1.0) of time spent in state over all state times of all threads
*/
@Ensures({"result >= 0.0", "result <= 1.0", "TRACKED_STATES.contains(state)"})
public synchronized double getStateFraction(final Thread.State state) {
return getStateTime(state) / (1.0 * getTotalTime());
return getStateTime(state) / (1.0 * Math.max(getTotalTime(), 1));
}
public int getNThreads() {
return threads.size();
/**
* How many threads have been created by this factory so far?
* @return
*/
@Ensures("result >= 0")
public int getNThreadsCreated() {
return nThreadsToCreated;
}
public void waitForAllThreadsToComplete() throws InterruptedException {
activeThreads.await();
countDownLatch.await();
}
@Override
@ -103,33 +175,108 @@ public class StateMonitoringThreadFactory implements ThreadFactory {
return b.toString();
}
@Override
public synchronized Thread newThread(final Runnable runnable) {
if ( threads.size() >= threadsToCreate )
throw new IllegalStateException("Attempting to create more threads than allowed by constructor argument threadsToCreate " + threadsToCreate);
/**
* Print usage information about threads from this factory to logger
* with the INFO priority
*
* @param logger
*/
public synchronized void printUsageInformation(final Logger logger) {
printUsageInformation(logger, Priority.INFO);
}
/**
* Print usage information about threads from this factory to logger
* with the provided priority
*
* @param logger
*/
public synchronized void printUsageInformation(final Logger logger, final Priority priority) {
logger.log(priority, "Number of activeThreads used: " + getNThreadsCreated());
logger.log(priority, "Total runtime " + new AutoFormattingTime(getTotalTime() / 1000.0));
for ( final Thread.State state : TRACKED_STATES ) {
logger.log(priority, String.format(" Fraction of time spent %s is %.2f (%s)",
prettyName(state), getStateFraction(state), new AutoFormattingTime(getStateTime(state) / 1000.0)));
}
logger.log(priority, String.format("Efficiency of multi-threading: %.2f%% of time spent doing productive work",
getStateFraction(Thread.State.RUNNABLE) * 100));
}
private String prettyName(final Thread.State state) {
return PRETTY_NAMES.get(state);
}
/**
* Create a new thread from this factory
*
* @param runnable
* @return
*/
@Override
@Ensures({
"activeThreads.size() > old(activeThreads.size())",
"activeThreads.contains(result)",
"nThreadsToCreated == old(nThreadsToCreated) + 1"
})
public synchronized Thread newThread(final Runnable runnable) {
if ( activeThreads.size() >= nThreadsToCreate)
throw new IllegalStateException("Attempting to create more activeThreads than allowed by constructor argument nThreadsToCreate " + nThreadsToCreate);
nThreadsToCreated++;
final Thread myThread = new TrackingThread(runnable);
threads.add(myThread);
activeThreads.add(myThread);
return myThread;
}
// TODO -- add polling capability
private synchronized void updateThreadInfo(final Thread thread, final long runtime) {
/**
* Update the information about completed thread that ran for runtime in milliseconds
*
* This method updates all of the key timing and tracking information in the factory so that
* thread can be retired. After this call the factory shouldn't have a pointer to the thread any longer
*
* @param thread
* @param runtimeInMilliseconds
*/
@Ensures({
"activeThreads.size() < old(activeThreads.size())",
"! activeThreads.contains(thread)",
"getTotalTime() >= old(getTotalTime())",
"countDownLatch.getCount() < old(countDownLatch.getCount())"
})
private synchronized void threadIsDone(final Thread thread, final long runtimeInMilliseconds) {
if ( DEBUG ) logger.warn(" Countdown " + countDownLatch.getCount() + " in thread " + Thread.currentThread().getName());
if ( DEBUG ) logger.warn("UpdateThreadInfo called");
final ThreadInfo info = bean.getThreadInfo(thread.getId());
if ( info != null ) {
if ( DEBUG ) logger.warn("Updating thread total runtime " + runtime + " of which blocked " + info.getBlockedTime() + " and waiting " + info.getWaitedTime());
if ( DEBUG ) logger.warn("Updating thread total runtime " + runtimeInMilliseconds + " of which blocked " + info.getBlockedTime() + " and waiting " + info.getWaitedTime());
incTimes(Thread.State.BLOCKED, info.getBlockedTime());
incTimes(Thread.State.WAITING, info.getWaitedTime());
incTimes(Thread.State.RUNNABLE, runtime - info.getWaitedTime() - info.getBlockedTime());
incTimes(Thread.State.RUNNABLE, runtimeInMilliseconds - info.getWaitedTime() - info.getBlockedTime());
}
// remove the thread from the list of active activeThreads
if ( ! activeThreads.remove(thread) )
throw new IllegalStateException("Thread " + thread + " not in list of active activeThreads");
// one less thread is live for those blocking on all activeThreads to be complete
countDownLatch.countDown();
if ( DEBUG ) logger.warn(" -> Countdown " + countDownLatch.getCount() + " in thread " + Thread.currentThread().getName());
}
/**
* Helper function that increments the times counter by by for state
*
* @param state
* @param by
*/
private synchronized void incTimes(final Thread.State state, final long by) {
times.put(state, times.get(state) + by);
}
/**
* A wrapper around Thread that tracks the runtime of the thread and calls threadIsDone() when complete
*/
private class TrackingThread extends Thread {
private TrackingThread(Runnable runnable) {
super(runnable);
@ -140,10 +287,7 @@ public class StateMonitoringThreadFactory implements ThreadFactory {
final long startTime = System.currentTimeMillis();
super.run();
final long endTime = System.currentTimeMillis();
if ( DEBUG ) logger.warn(" Countdown " + activeThreads.getCount() + " in thread " + Thread.currentThread().getName());
updateThreadInfo(this, endTime - startTime);
activeThreads.countDown();
if ( DEBUG ) logger.warn(" -> Countdown " + activeThreads.getCount() + " in thread " + Thread.currentThread().getName());
threadIsDone(this, endTime - startTime);
}
}
}

View File

@ -1,4 +1,4 @@
/**
* Provides tools for managing threads, thread pools, and parallelization in general.
* Provides tools for managing activeThreads, thread pools, and parallelization in general.
*/
package org.broadinstitute.sting.utils.threading;

View File

@ -23,6 +23,7 @@
*/
package org.broadinstitute.sting.utils.threading;
import org.apache.log4j.Priority;
import org.broadinstitute.sting.BaseTest;
import org.broadinstitute.sting.utils.Utils;
import org.broadinstitute.sting.utils.exceptions.ReviewedStingException;
@ -30,7 +31,6 @@ import org.testng.Assert;
import org.testng.annotations.DataProvider;
import org.testng.annotations.Test;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
@ -40,6 +40,7 @@ import java.util.concurrent.*;
* Tests for the state monitoring thread factory.
*/
public class StateMonitoringThreadFactoryUnitTest extends BaseTest {
// the duration of the tests -- 100 ms is tolerable given the number of tests we are doing
private final static long THREAD_TARGET_DURATION_IN_MILLISECOND = 100;
final static Object GLOBAL_LOCK = new Object();
@ -68,10 +69,16 @@ public class StateMonitoringThreadFactoryUnitTest extends BaseTest {
}
}
/**
* Test helper threading class that puts the thread into RUNNING, BLOCKED, or WAITING state as
* requested for input argument
*/
private static class StateTestThread implements Callable<Double> {
private final Thread.State stateToImplement;
private StateTestThread(final Thread.State stateToImplement) {
if ( ! StateMonitoringThreadFactory.TRACKED_STATES.contains(stateToImplement) )
throw new IllegalArgumentException("Unexpected state " + stateToImplement);
this.stateToImplement = stateToImplement;
}
@ -92,6 +99,7 @@ public class StateMonitoringThreadFactoryUnitTest extends BaseTest {
case BLOCKED:
if ( StateMonitoringThreadFactory.DEBUG ) logger.warn("Blocking...");
synchronized (GLOBAL_LOCK) {
// the GLOBAL_LOCK must be held by the unit test itself for this to properly block
if ( StateMonitoringThreadFactory.DEBUG ) logger.warn(" ... done blocking");
}
return 0.0;
@ -103,7 +111,7 @@ public class StateMonitoringThreadFactoryUnitTest extends BaseTest {
@DataProvider(name = "StateTest")
public Object[][] createStateTest() {
for ( final int nThreads : Arrays.asList(1, 2, 3, 4, 5) ) {
for ( final int nThreads : Arrays.asList(1, 2, 3, 4) ) {
for (final List<Thread.State> states : Utils.makeCombinations(StateMonitoringThreadFactory.TRACKED_STATES, nThreads) ) {
//if ( Collections.frequency(states, Thread.State.BLOCKED) > 0)
new StateTest(states);
@ -125,7 +133,7 @@ public class StateMonitoringThreadFactoryUnitTest extends BaseTest {
for ( final Thread.State threadToRunState : test.getStatesForThreads() )
threadPool.submit(new StateTestThread(threadToRunState));
// lock has to be here for the whole running of the threads but end before the sleep so the blocked threads
// lock has to be here for the whole running of the activeThreads but end before the sleep so the blocked activeThreads
// can block for their allotted time
threadPool.shutdown();
Thread.sleep(THREAD_TARGET_DURATION_IN_MILLISECOND);
@ -133,29 +141,35 @@ public class StateMonitoringThreadFactoryUnitTest extends BaseTest {
//logger.warn(" Releasing lock");
threadPool.awaitTermination(10, TimeUnit.SECONDS);
//logger.warn(" done awaiting termination");
//logger.warn(" waiting for all threads to complete");
//logger.warn(" waiting for all activeThreads to complete");
factory.waitForAllThreadsToComplete();
//logger.warn(" done waiting for threads");
//logger.warn(" done waiting for activeThreads");
// make sure we counted everything properly
final long totalTime = factory.getTotalTime();
final long minTime = (THREAD_TARGET_DURATION_IN_MILLISECOND - 10) * test.getNStates();
final long maxTime = (THREAD_TARGET_DURATION_IN_MILLISECOND + 10) * test.getNStates();
//logger.warn("Testing total time");
Assert.assertTrue(totalTime >= minTime, "Factory results not properly accumulated: totalTime = " + totalTime + " < minTime = " + minTime);
Assert.assertTrue(totalTime <= maxTime, "Factory results not properly accumulated: totalTime = " + totalTime + " > maxTime = " + maxTime);
for (final Thread.State state : StateMonitoringThreadFactory.TRACKED_STATES ) {
final double min = test.minStateFraction(state);
final double max = test.maxStateFraction(state);
final double obs = factory.getStateFraction(state);
logger.warn(" Checking " + state
+ " min " + String.format("%.2f", min)
+ " max " + String.format("%.2f", max)
+ " obs " + String.format("%.2f", obs)
+ " factor = " + factory);
// logger.warn(" Checking " + state
// + " min " + String.format("%.2f", min)
// + " max " + String.format("%.2f", max)
// + " obs " + String.format("%.2f", obs)
// + " factor = " + factory);
Assert.assertTrue(obs >= min, "Too little time spent in state " + state + " obs " + obs + " min " + min);
Assert.assertTrue(obs <= max, "Too much time spent in state " + state + " obs " + obs + " max " + min);
}
Assert.assertEquals(factory.getNThreads(), test.getNStates());
// we actually ran the expected number of activeThreads
Assert.assertEquals(factory.getNThreadsCreated(), test.getNStates());
// should be called to ensure we don't format / NPE on output
factory.printUsageInformation(logger, Priority.INFO);
}
}