Separately track time spent doing user and system CPU work

-- Allows us to ID (by proxy) time spent doing IO
-- Refactor StateMonitoryingThreadFactory to use it's own enum, not Thread.State
-- Reliable unit tests across mac and unix
This commit is contained in:
Mark DePristo 2012-08-22 10:28:27 -04:00
parent 18060f237b
commit f876c51277
2 changed files with 99 additions and 66 deletions

View File

@ -25,6 +25,7 @@ package org.broadinstitute.sting.utils.threading;
import com.google.java.contract.Ensures; import com.google.java.contract.Ensures;
import com.google.java.contract.Invariant; import com.google.java.contract.Invariant;
import com.google.java.contract.Requires;
import org.apache.log4j.Logger; import org.apache.log4j.Logger;
import org.apache.log4j.Priority; import org.apache.log4j.Priority;
import org.broadinstitute.sting.utils.AutoFormattingTime; import org.broadinstitute.sting.utils.AutoFormattingTime;
@ -33,11 +34,11 @@ import java.lang.management.ManagementFactory;
import java.lang.management.ThreadInfo; import java.lang.management.ThreadInfo;
import java.lang.management.ThreadMXBean; import java.lang.management.ThreadMXBean;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays;
import java.util.EnumMap; import java.util.EnumMap;
import java.util.List; import java.util.List;
import java.util.concurrent.CountDownLatch; import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ThreadFactory; import java.util.concurrent.ThreadFactory;
import java.util.concurrent.TimeUnit;
/** /**
* Create activeThreads, collecting statistics about their running state over time * Create activeThreads, collecting statistics about their running state over time
@ -51,20 +52,36 @@ import java.util.concurrent.ThreadFactory;
@Invariant({ @Invariant({
"activeThreads.size() <= nThreadsToCreate", "activeThreads.size() <= nThreadsToCreate",
"countDownLatch.getCount() <= nThreadsToCreate", "countDownLatch.getCount() <= nThreadsToCreate",
"nThreadsToCreated <= nThreadsToCreate" "nThreadsCreated <= nThreadsToCreate"
}) })
public class StateMonitoringThreadFactory implements ThreadFactory { public class StateMonitoringThreadFactory implements ThreadFactory {
protected static final boolean DEBUG = false; protected static final boolean DEBUG = true;
private static Logger logger = Logger.getLogger(StateMonitoringThreadFactory.class); private static Logger logger = Logger.getLogger(StateMonitoringThreadFactory.class);
public static final List<Thread.State> TRACKED_STATES = Arrays.asList(Thread.State.BLOCKED, Thread.State.RUNNABLE, Thread.State.WAITING);
public enum State {
BLOCKING("blocking on synchronized data structure"),
WAITING("waiting on some other thread"),
USER_CPU("doing productive CPU work"),
WAITING_FOR_IO("waiting for I/O");
private final String userFriendlyName;
private State(String userFriendlyName) {
this.userFriendlyName = userFriendlyName;
}
public String getUserFriendlyName() {
return userFriendlyName;
}
}
// todo -- it would be nice to not have to specify upfront the number of threads. // todo -- it would be nice to not have to specify upfront the number of threads.
// todo -- can we dynamically increment countDownLatch? It seems not... // todo -- can we dynamically increment countDownLatch? It seems not...
final int nThreadsToCreate; final int nThreadsToCreate;
final List<Thread> activeThreads; final List<Thread> activeThreads;
final EnumMap<Thread.State, Long> times = new EnumMap<Thread.State, Long>(Thread.State.class); final EnumMap<State, Long> times = new EnumMap<State, Long>(State.class);
int nThreadsToCreated = 0; int nThreadsCreated = 0;
/** /**
* The bean used to get the thread info about blocked and waiting times * The bean used to get the thread info about blocked and waiting times
@ -78,16 +95,6 @@ public class StateMonitoringThreadFactory implements ThreadFactory {
*/ */
final CountDownLatch countDownLatch; final CountDownLatch countDownLatch;
/**
* Instead of RUNNABLE we want to print running. This map goes from Thread.State names to human readable ones
*/
final static EnumMap<Thread.State, String> PRETTY_NAMES = new EnumMap<Thread.State, String>(Thread.State.class);
static {
PRETTY_NAMES.put(Thread.State.RUNNABLE, "running");
PRETTY_NAMES.put(Thread.State.BLOCKED, "blocked");
PRETTY_NAMES.put(Thread.State.WAITING, "waiting");
}
/** /**
* Create a new factory generating threads whose runtime and contention * Create a new factory generating threads whose runtime and contention
* behavior is tracked in this factory. * behavior is tracked in this factory.
@ -102,7 +109,7 @@ public class StateMonitoringThreadFactory implements ThreadFactory {
activeThreads = new ArrayList<Thread>(nThreadsToCreate); activeThreads = new ArrayList<Thread>(nThreadsToCreate);
// initialize times to 0 // initialize times to 0
for ( final Thread.State state : Thread.State.values() ) for ( final State state : State.values() )
times.put(state, 0l); times.put(state, 0l);
// get the bean, and start tracking // get the bean, and start tracking
@ -113,17 +120,22 @@ public class StateMonitoringThreadFactory implements ThreadFactory {
logger.warn("Thread contention monitoring not supported, we cannot track GATK multi-threaded efficiency"); logger.warn("Thread contention monitoring not supported, we cannot track GATK multi-threaded efficiency");
//bean.setThreadCpuTimeEnabled(true); //bean.setThreadCpuTimeEnabled(true);
if ( bean.isThreadCpuTimeSupported() )
bean.setThreadCpuTimeEnabled(true);
else
logger.warn("Thread CPU monitoring not supported, we cannot track GATK multi-threaded efficiency");
countDownLatch = new CountDownLatch(nThreadsToCreate); countDownLatch = new CountDownLatch(nThreadsToCreate);
} }
/** /**
* Get the time spent in state across all threads created by this factory * Get the time spent in state across all threads created by this factory
* *
* @param state on of the TRACKED_STATES * @param state to get information about
* @return the time in milliseconds * @return the time in milliseconds
*/ */
@Ensures({"result >= 0", "TRACKED_STATES.contains(state)"}) @Ensures({"result >= 0"})
public synchronized long getStateTime(final Thread.State state) { public synchronized long getStateTime(final State state) {
return times.get(state); return times.get(state);
} }
@ -145,8 +157,8 @@ public class StateMonitoringThreadFactory implements ThreadFactory {
* *
* @return the fraction (0.0-1.0) of time spent in state over all state times of all threads * @return the fraction (0.0-1.0) of time spent in state over all state times of all threads
*/ */
@Ensures({"result >= 0.0", "result <= 1.0", "TRACKED_STATES.contains(state)"}) @Ensures({"result >= 0.0", "result <= 1.0"})
public synchronized double getStateFraction(final Thread.State state) { public synchronized double getStateFraction(final State state) {
return getStateTime(state) / (1.0 * Math.max(getTotalTime(), 1)); return getStateTime(state) / (1.0 * Math.max(getTotalTime(), 1));
} }
@ -156,10 +168,15 @@ public class StateMonitoringThreadFactory implements ThreadFactory {
*/ */
@Ensures("result >= 0") @Ensures("result >= 0")
public int getNThreadsCreated() { public int getNThreadsCreated() {
return nThreadsToCreated; return nThreadsCreated;
} }
public void waitForAllThreadsToComplete() throws InterruptedException { /**
* Only useful for testing, so that we can wait for all of the threads in the factory to complete running
*
* @throws InterruptedException
*/
protected void waitForAllThreadsToComplete() throws InterruptedException {
countDownLatch.await(); countDownLatch.await();
} }
@ -168,7 +185,7 @@ public class StateMonitoringThreadFactory implements ThreadFactory {
final StringBuilder b = new StringBuilder(); final StringBuilder b = new StringBuilder();
b.append("total ").append(getTotalTime()).append(" "); b.append("total ").append(getTotalTime()).append(" ");
for ( final Thread.State state : TRACKED_STATES ) { for ( final State state : State.values() ) {
b.append(state).append(" ").append(getStateTime(state)).append(" "); b.append(state).append(" ").append(getStateTime(state)).append(" ");
} }
@ -193,17 +210,17 @@ public class StateMonitoringThreadFactory implements ThreadFactory {
*/ */
public synchronized void printUsageInformation(final Logger logger, final Priority priority) { public synchronized void printUsageInformation(final Logger logger, final Priority priority) {
logger.log(priority, "Number of activeThreads used: " + getNThreadsCreated()); logger.log(priority, "Number of activeThreads used: " + getNThreadsCreated());
logger.log(priority, "Total runtime " + new AutoFormattingTime(getTotalTime() / 1000.0)); logger.log(priority, "Total runtime " + new AutoFormattingTime(TimeUnit.MILLISECONDS.toSeconds(getTotalTime())));
for ( final Thread.State state : TRACKED_STATES ) { for ( final State state : State.values() ) {
logger.log(priority, String.format(" Fraction of time spent %s is %.2f (%s)", logger.log(priority, String.format(" Fraction of time spent %s is %.2f (%s)",
prettyName(state), getStateFraction(state), new AutoFormattingTime(getStateTime(state) / 1000.0))); state.getUserFriendlyName(),
getStateFraction(state),
new AutoFormattingTime(getStateTime(state) / 1000.0)));
} }
logger.log(priority, String.format("Efficiency of multi-threading: %.2f%% of time spent doing productive work", logger.log(priority, String.format("CPU efficiency : %.2f%% of time spent doing productive work",
getStateFraction(Thread.State.RUNNABLE) * 100)); getStateFraction(State.USER_CPU) * 100));
} logger.log(priority, String.format("I/O inefficiency: %.2f%% of time spent waiting on I/O",
getStateFraction(State.WAITING_FOR_IO) * 100));
private String prettyName(final Thread.State state) {
return PRETTY_NAMES.get(state);
} }
/** /**
@ -216,13 +233,13 @@ public class StateMonitoringThreadFactory implements ThreadFactory {
@Ensures({ @Ensures({
"activeThreads.size() > old(activeThreads.size())", "activeThreads.size() > old(activeThreads.size())",
"activeThreads.contains(result)", "activeThreads.contains(result)",
"nThreadsToCreated == old(nThreadsToCreated) + 1" "nThreadsCreated == old(nThreadsCreated) + 1"
}) })
public synchronized Thread newThread(final Runnable runnable) { public synchronized Thread newThread(final Runnable runnable) {
if ( activeThreads.size() >= nThreadsToCreate) if ( activeThreads.size() >= nThreadsToCreate)
throw new IllegalStateException("Attempting to create more activeThreads than allowed by constructor argument nThreadsToCreate " + nThreadsToCreate); throw new IllegalStateException("Attempting to create more activeThreads than allowed by constructor argument nThreadsToCreate " + nThreadsToCreate);
nThreadsToCreated++; nThreadsCreated++;
final Thread myThread = new TrackingThread(runnable); final Thread myThread = new TrackingThread(runnable);
activeThreads.add(myThread); activeThreads.add(myThread);
return myThread; return myThread;
@ -234,8 +251,7 @@ public class StateMonitoringThreadFactory implements ThreadFactory {
* This method updates all of the key timing and tracking information in the factory so that * This method updates all of the key timing and tracking information in the factory so that
* thread can be retired. After this call the factory shouldn't have a pointer to the thread any longer * thread can be retired. After this call the factory shouldn't have a pointer to the thread any longer
* *
* @param thread * @param thread the thread whose information we are updating
* @param runtimeInMilliseconds
*/ */
@Ensures({ @Ensures({
"activeThreads.size() < old(activeThreads.size())", "activeThreads.size() < old(activeThreads.size())",
@ -243,16 +259,24 @@ public class StateMonitoringThreadFactory implements ThreadFactory {
"getTotalTime() >= old(getTotalTime())", "getTotalTime() >= old(getTotalTime())",
"countDownLatch.getCount() < old(countDownLatch.getCount())" "countDownLatch.getCount() < old(countDownLatch.getCount())"
}) })
private synchronized void threadIsDone(final Thread thread, final long runtimeInMilliseconds) { private synchronized void threadIsDone(final Thread thread) {
if ( DEBUG ) logger.warn(" Countdown " + countDownLatch.getCount() + " in thread " + Thread.currentThread().getName()); if ( DEBUG ) logger.warn(" Countdown " + countDownLatch.getCount() + " in thread " + Thread.currentThread().getName());
if ( DEBUG ) logger.warn("UpdateThreadInfo called"); if ( DEBUG ) logger.warn("UpdateThreadInfo called");
final long threadID = thread.getId();
final ThreadInfo info = bean.getThreadInfo(thread.getId()); final ThreadInfo info = bean.getThreadInfo(thread.getId());
final long totalTimeNano = bean.getThreadCpuTime(threadID);
final long userTimeNano = bean.getThreadUserTime(threadID);
final long systemTimeNano = totalTimeNano - userTimeNano;
final long userTimeInMilliseconds = nanoToMilli(userTimeNano);
final long systemTimeInMilliseconds = nanoToMilli(systemTimeNano);
if ( info != null ) { if ( info != null ) {
if ( DEBUG ) logger.warn("Updating thread total runtime " + runtimeInMilliseconds + " of which blocked " + info.getBlockedTime() + " and waiting " + info.getWaitedTime()); if ( DEBUG ) logger.warn("Updating thread with user runtime " + userTimeInMilliseconds + " and system runtime " + systemTimeInMilliseconds + " of which blocked " + info.getBlockedTime() + " and waiting " + info.getWaitedTime());
incTimes(Thread.State.BLOCKED, info.getBlockedTime()); incTimes(State.BLOCKING, info.getBlockedTime());
incTimes(Thread.State.WAITING, info.getWaitedTime()); incTimes(State.WAITING, info.getWaitedTime());
incTimes(Thread.State.RUNNABLE, runtimeInMilliseconds - info.getWaitedTime() - info.getBlockedTime()); incTimes(State.USER_CPU, userTimeInMilliseconds);
incTimes(State.WAITING_FOR_IO, systemTimeInMilliseconds);
} }
// remove the thread from the list of active activeThreads // remove the thread from the list of active activeThreads
@ -270,10 +294,16 @@ public class StateMonitoringThreadFactory implements ThreadFactory {
* @param state * @param state
* @param by * @param by
*/ */
private synchronized void incTimes(final Thread.State state, final long by) { @Requires({"state != null", "by >= 0"})
@Ensures("getTotalTime() == old(getTotalTime()) + by")
private synchronized void incTimes(final State state, final long by) {
times.put(state, times.get(state) + by); times.put(state, times.get(state) + by);
} }
private static long nanoToMilli(final long timeInNano) {
return TimeUnit.NANOSECONDS.toMillis(timeInNano);
}
/** /**
* A wrapper around Thread that tracks the runtime of the thread and calls threadIsDone() when complete * A wrapper around Thread that tracks the runtime of the thread and calls threadIsDone() when complete
*/ */
@ -284,10 +314,8 @@ public class StateMonitoringThreadFactory implements ThreadFactory {
@Override @Override
public void run() { public void run() {
final long startTime = System.currentTimeMillis();
super.run(); super.run();
final long endTime = System.currentTimeMillis(); threadIsDone(this);
threadIsDone(this, endTime - startTime);
} }
} }
} }

View File

@ -41,30 +41,30 @@ import java.util.concurrent.*;
*/ */
public class StateMonitoringThreadFactoryUnitTest extends BaseTest { public class StateMonitoringThreadFactoryUnitTest extends BaseTest {
// the duration of the tests -- 100 ms is tolerable given the number of tests we are doing // the duration of the tests -- 100 ms is tolerable given the number of tests we are doing
private final static long THREAD_TARGET_DURATION_IN_MILLISECOND = 100; private final static long THREAD_TARGET_DURATION_IN_MILLISECOND = 1000;
final static Object GLOBAL_LOCK = new Object(); final static Object GLOBAL_LOCK = new Object();
private class StateTest extends TestDataProvider { private class StateTest extends TestDataProvider {
private final double TOLERANCE = 0.1; // willing to tolerate a 10% error private final double TOLERANCE = 0.1; // willing to tolerate a 10% error
final List<Thread.State> statesForThreads; final List<StateMonitoringThreadFactory.State> statesForThreads;
public StateTest(final List<Thread.State> statesForThreads) { public StateTest(final List<StateMonitoringThreadFactory.State> statesForThreads) {
super(StateTest.class); super(StateTest.class);
this.statesForThreads = statesForThreads; this.statesForThreads = statesForThreads;
setName("StateTest " + Utils.join(",", statesForThreads)); setName("StateTest " + Utils.join(",", statesForThreads));
} }
public List<Thread.State> getStatesForThreads() { public List<StateMonitoringThreadFactory.State> getStatesForThreads() {
return statesForThreads; return statesForThreads;
} }
public int getNStates() { return statesForThreads.size(); } public int getNStates() { return statesForThreads.size(); }
public double maxStateFraction(final Thread.State state) { return fraction(state) + TOLERANCE; } public double maxStateFraction(final StateMonitoringThreadFactory.State state) { return fraction(state) + TOLERANCE; }
public double minStateFraction(final Thread.State state) { return fraction(state) - TOLERANCE; } public double minStateFraction(final StateMonitoringThreadFactory.State state) { return fraction(state) - TOLERANCE; }
private double fraction(final Thread.State state) { private double fraction(final StateMonitoringThreadFactory.State state) {
return Collections.frequency(statesForThreads, state) / (1.0 * statesForThreads.size()); return Collections.frequency(statesForThreads, state) / (1.0 * statesForThreads.size());
} }
} }
@ -74,18 +74,16 @@ public class StateMonitoringThreadFactoryUnitTest extends BaseTest {
* requested for input argument * requested for input argument
*/ */
private static class StateTestThread implements Callable<Double> { private static class StateTestThread implements Callable<Double> {
private final Thread.State stateToImplement; private final StateMonitoringThreadFactory.State stateToImplement;
private StateTestThread(final Thread.State stateToImplement) { private StateTestThread(final StateMonitoringThreadFactory.State stateToImplement) {
if ( ! StateMonitoringThreadFactory.TRACKED_STATES.contains(stateToImplement) )
throw new IllegalArgumentException("Unexpected state " + stateToImplement);
this.stateToImplement = stateToImplement; this.stateToImplement = stateToImplement;
} }
@Override @Override
public Double call() throws Exception { public Double call() throws Exception {
switch ( stateToImplement ) { switch ( stateToImplement ) {
case RUNNABLE: case USER_CPU:
// do some work until we get to THREAD_TARGET_DURATION_IN_MILLISECOND // do some work until we get to THREAD_TARGET_DURATION_IN_MILLISECOND
double sum = 0.0; double sum = 0.0;
final long startTime = System.currentTimeMillis(); final long startTime = System.currentTimeMillis();
@ -96,13 +94,17 @@ public class StateMonitoringThreadFactoryUnitTest extends BaseTest {
case WAITING: case WAITING:
Thread.currentThread().sleep(THREAD_TARGET_DURATION_IN_MILLISECOND); Thread.currentThread().sleep(THREAD_TARGET_DURATION_IN_MILLISECOND);
return 0.0; return 0.0;
case BLOCKED: case BLOCKING:
if ( StateMonitoringThreadFactory.DEBUG ) logger.warn("Blocking..."); if ( StateMonitoringThreadFactory.DEBUG ) logger.warn("Blocking...");
synchronized (GLOBAL_LOCK) { synchronized (GLOBAL_LOCK) {
// the GLOBAL_LOCK must be held by the unit test itself for this to properly block // the GLOBAL_LOCK must be held by the unit test itself for this to properly block
if ( StateMonitoringThreadFactory.DEBUG ) logger.warn(" ... done blocking"); if ( StateMonitoringThreadFactory.DEBUG ) logger.warn(" ... done blocking");
} }
return 0.0; return 0.0;
case WAITING_FOR_IO:
// TODO -- implement me
// shouldn't ever get here, throw an exception
throw new ReviewedStingException("WAITING_FOR_IO testing currently not implemented, until we figure out how to force a system call block");
default: default:
throw new ReviewedStingException("Unexpected thread test state " + stateToImplement); throw new ReviewedStingException("Unexpected thread test state " + stateToImplement);
} }
@ -111,8 +113,11 @@ public class StateMonitoringThreadFactoryUnitTest extends BaseTest {
@DataProvider(name = "StateTest") @DataProvider(name = "StateTest")
public Object[][] createStateTest() { public Object[][] createStateTest() {
for ( final int nThreads : Arrays.asList(1, 2, 3, 4) ) { for ( final int nThreads : Arrays.asList(3) ) {
for (final List<Thread.State> states : Utils.makePermutations(StateMonitoringThreadFactory.TRACKED_STATES, nThreads, true) ) { //final List<StateMonitoringThreadFactory.State> allStates = Arrays.asList(StateMonitoringThreadFactory.State.WAITING_FOR_IO);
final List<StateMonitoringThreadFactory.State> allStates = Arrays.asList(StateMonitoringThreadFactory.State.USER_CPU, StateMonitoringThreadFactory.State.WAITING, StateMonitoringThreadFactory.State.BLOCKING);
//final List<StateMonitoringThreadFactory.State> allStates = Arrays.asList(StateMonitoringThreadFactory.State.values());
for (final List<StateMonitoringThreadFactory.State> states : Utils.makePermutations(allStates, nThreads, true) ) {
//if ( Collections.frequency(states, Thread.State.BLOCKED) > 0) //if ( Collections.frequency(states, Thread.State.BLOCKED) > 0)
new StateTest(states); new StateTest(states);
} }
@ -121,7 +126,7 @@ public class StateMonitoringThreadFactoryUnitTest extends BaseTest {
return StateTest.getTests(StateTest.class); return StateTest.getTests(StateTest.class);
} }
@Test(enabled = false, dataProvider = "StateTest") @Test(enabled = true, dataProvider = "StateTest")
public void testStateTest(final StateTest test) throws InterruptedException { public void testStateTest(final StateTest test) throws InterruptedException {
// allows us to test blocking // allows us to test blocking
final StateMonitoringThreadFactory factory = new StateMonitoringThreadFactory(test.getNStates()); final StateMonitoringThreadFactory factory = new StateMonitoringThreadFactory(test.getNStates());
@ -130,7 +135,7 @@ public class StateMonitoringThreadFactoryUnitTest extends BaseTest {
logger.warn("Running " + test); logger.warn("Running " + test);
synchronized (GLOBAL_LOCK) { synchronized (GLOBAL_LOCK) {
//logger.warn(" Have lock"); //logger.warn(" Have lock");
for ( final Thread.State threadToRunState : test.getStatesForThreads() ) for ( final StateMonitoringThreadFactory.State threadToRunState : test.getStatesForThreads() )
threadPool.submit(new StateTestThread(threadToRunState)); threadPool.submit(new StateTestThread(threadToRunState));
// lock has to be here for the whole running of the activeThreads but end before the sleep so the blocked activeThreads // lock has to be here for the whole running of the activeThreads but end before the sleep so the blocked activeThreads
@ -153,7 +158,7 @@ public class StateMonitoringThreadFactoryUnitTest extends BaseTest {
Assert.assertTrue(totalTime >= minTime, "Factory results not properly accumulated: totalTime = " + totalTime + " < minTime = " + minTime); Assert.assertTrue(totalTime >= minTime, "Factory results not properly accumulated: totalTime = " + totalTime + " < minTime = " + minTime);
Assert.assertTrue(totalTime <= maxTime, "Factory results not properly accumulated: totalTime = " + totalTime + " > maxTime = " + maxTime); Assert.assertTrue(totalTime <= maxTime, "Factory results not properly accumulated: totalTime = " + totalTime + " > maxTime = " + maxTime);
for (final Thread.State state : StateMonitoringThreadFactory.TRACKED_STATES ) { for (final StateMonitoringThreadFactory.State state : StateMonitoringThreadFactory.State.values() ) {
final double min = test.minStateFraction(state); final double min = test.minStateFraction(state);
final double max = test.maxStateFraction(state); final double max = test.maxStateFraction(state);
final double obs = factory.getStateFraction(state); final double obs = factory.getStateFraction(state);
@ -170,6 +175,6 @@ public class StateMonitoringThreadFactoryUnitTest extends BaseTest {
Assert.assertEquals(factory.getNThreadsCreated(), test.getNStates()); Assert.assertEquals(factory.getNThreadsCreated(), test.getNStates());
// should be called to ensure we don't format / NPE on output // should be called to ensure we don't format / NPE on output
factory.printUsageInformation(logger, Priority.INFO); factory.printUsageInformation(logger, Priority.WARN);
} }
} }