From f876c5127742646854511e6b223145e68323aa34 Mon Sep 17 00:00:00 2001 From: Mark DePristo Date: Wed, 22 Aug 2012 10:28:27 -0400 Subject: [PATCH] Separately track time spent doing user and system CPU work -- Allows us to ID (by proxy) time spent doing IO -- Refactor StateMonitoryingThreadFactory to use it's own enum, not Thread.State -- Reliable unit tests across mac and unix --- .../StateMonitoringThreadFactory.java | 122 +++++++++++------- .../StateMonitoringThreadFactoryUnitTest.java | 43 +++--- 2 files changed, 99 insertions(+), 66 deletions(-) diff --git a/public/java/src/org/broadinstitute/sting/utils/threading/StateMonitoringThreadFactory.java b/public/java/src/org/broadinstitute/sting/utils/threading/StateMonitoringThreadFactory.java index 39d5c1497..a62501f08 100644 --- a/public/java/src/org/broadinstitute/sting/utils/threading/StateMonitoringThreadFactory.java +++ b/public/java/src/org/broadinstitute/sting/utils/threading/StateMonitoringThreadFactory.java @@ -25,6 +25,7 @@ package org.broadinstitute.sting.utils.threading; import com.google.java.contract.Ensures; import com.google.java.contract.Invariant; +import com.google.java.contract.Requires; import org.apache.log4j.Logger; import org.apache.log4j.Priority; import org.broadinstitute.sting.utils.AutoFormattingTime; @@ -33,11 +34,11 @@ import java.lang.management.ManagementFactory; import java.lang.management.ThreadInfo; import java.lang.management.ThreadMXBean; import java.util.ArrayList; -import java.util.Arrays; import java.util.EnumMap; import java.util.List; import java.util.concurrent.CountDownLatch; import java.util.concurrent.ThreadFactory; +import java.util.concurrent.TimeUnit; /** * Create activeThreads, collecting statistics about their running state over time @@ -51,20 +52,36 @@ import java.util.concurrent.ThreadFactory; @Invariant({ "activeThreads.size() <= nThreadsToCreate", "countDownLatch.getCount() <= nThreadsToCreate", - "nThreadsToCreated <= nThreadsToCreate" + "nThreadsCreated <= nThreadsToCreate" }) public class StateMonitoringThreadFactory implements ThreadFactory { - protected static final boolean DEBUG = false; + protected static final boolean DEBUG = true; private static Logger logger = Logger.getLogger(StateMonitoringThreadFactory.class); - public static final List TRACKED_STATES = Arrays.asList(Thread.State.BLOCKED, Thread.State.RUNNABLE, Thread.State.WAITING); + + public enum State { + BLOCKING("blocking on synchronized data structure"), + WAITING("waiting on some other thread"), + USER_CPU("doing productive CPU work"), + WAITING_FOR_IO("waiting for I/O"); + + private final String userFriendlyName; + + private State(String userFriendlyName) { + this.userFriendlyName = userFriendlyName; + } + + public String getUserFriendlyName() { + return userFriendlyName; + } + } // todo -- it would be nice to not have to specify upfront the number of threads. // todo -- can we dynamically increment countDownLatch? It seems not... final int nThreadsToCreate; final List activeThreads; - final EnumMap times = new EnumMap(Thread.State.class); + final EnumMap times = new EnumMap(State.class); - int nThreadsToCreated = 0; + int nThreadsCreated = 0; /** * The bean used to get the thread info about blocked and waiting times @@ -78,16 +95,6 @@ public class StateMonitoringThreadFactory implements ThreadFactory { */ final CountDownLatch countDownLatch; - /** - * Instead of RUNNABLE we want to print running. This map goes from Thread.State names to human readable ones - */ - final static EnumMap PRETTY_NAMES = new EnumMap(Thread.State.class); - static { - PRETTY_NAMES.put(Thread.State.RUNNABLE, "running"); - PRETTY_NAMES.put(Thread.State.BLOCKED, "blocked"); - PRETTY_NAMES.put(Thread.State.WAITING, "waiting"); - } - /** * Create a new factory generating threads whose runtime and contention * behavior is tracked in this factory. @@ -102,7 +109,7 @@ public class StateMonitoringThreadFactory implements ThreadFactory { activeThreads = new ArrayList(nThreadsToCreate); // initialize times to 0 - for ( final Thread.State state : Thread.State.values() ) + for ( final State state : State.values() ) times.put(state, 0l); // get the bean, and start tracking @@ -113,17 +120,22 @@ public class StateMonitoringThreadFactory implements ThreadFactory { logger.warn("Thread contention monitoring not supported, we cannot track GATK multi-threaded efficiency"); //bean.setThreadCpuTimeEnabled(true); + if ( bean.isThreadCpuTimeSupported() ) + bean.setThreadCpuTimeEnabled(true); + else + logger.warn("Thread CPU monitoring not supported, we cannot track GATK multi-threaded efficiency"); + countDownLatch = new CountDownLatch(nThreadsToCreate); } /** * Get the time spent in state across all threads created by this factory * - * @param state on of the TRACKED_STATES + * @param state to get information about * @return the time in milliseconds */ - @Ensures({"result >= 0", "TRACKED_STATES.contains(state)"}) - public synchronized long getStateTime(final Thread.State state) { + @Ensures({"result >= 0"}) + public synchronized long getStateTime(final State state) { return times.get(state); } @@ -145,8 +157,8 @@ public class StateMonitoringThreadFactory implements ThreadFactory { * * @return the fraction (0.0-1.0) of time spent in state over all state times of all threads */ - @Ensures({"result >= 0.0", "result <= 1.0", "TRACKED_STATES.contains(state)"}) - public synchronized double getStateFraction(final Thread.State state) { + @Ensures({"result >= 0.0", "result <= 1.0"}) + public synchronized double getStateFraction(final State state) { return getStateTime(state) / (1.0 * Math.max(getTotalTime(), 1)); } @@ -156,10 +168,15 @@ public class StateMonitoringThreadFactory implements ThreadFactory { */ @Ensures("result >= 0") public int getNThreadsCreated() { - return nThreadsToCreated; + return nThreadsCreated; } - public void waitForAllThreadsToComplete() throws InterruptedException { + /** + * Only useful for testing, so that we can wait for all of the threads in the factory to complete running + * + * @throws InterruptedException + */ + protected void waitForAllThreadsToComplete() throws InterruptedException { countDownLatch.await(); } @@ -168,7 +185,7 @@ public class StateMonitoringThreadFactory implements ThreadFactory { final StringBuilder b = new StringBuilder(); b.append("total ").append(getTotalTime()).append(" "); - for ( final Thread.State state : TRACKED_STATES ) { + for ( final State state : State.values() ) { b.append(state).append(" ").append(getStateTime(state)).append(" "); } @@ -193,17 +210,17 @@ public class StateMonitoringThreadFactory implements ThreadFactory { */ public synchronized void printUsageInformation(final Logger logger, final Priority priority) { logger.log(priority, "Number of activeThreads used: " + getNThreadsCreated()); - logger.log(priority, "Total runtime " + new AutoFormattingTime(getTotalTime() / 1000.0)); - for ( final Thread.State state : TRACKED_STATES ) { + logger.log(priority, "Total runtime " + new AutoFormattingTime(TimeUnit.MILLISECONDS.toSeconds(getTotalTime()))); + for ( final State state : State.values() ) { logger.log(priority, String.format(" Fraction of time spent %s is %.2f (%s)", - prettyName(state), getStateFraction(state), new AutoFormattingTime(getStateTime(state) / 1000.0))); + state.getUserFriendlyName(), + getStateFraction(state), + new AutoFormattingTime(getStateTime(state) / 1000.0))); } - logger.log(priority, String.format("Efficiency of multi-threading: %.2f%% of time spent doing productive work", - getStateFraction(Thread.State.RUNNABLE) * 100)); - } - - private String prettyName(final Thread.State state) { - return PRETTY_NAMES.get(state); + logger.log(priority, String.format("CPU efficiency : %.2f%% of time spent doing productive work", + getStateFraction(State.USER_CPU) * 100)); + logger.log(priority, String.format("I/O inefficiency: %.2f%% of time spent waiting on I/O", + getStateFraction(State.WAITING_FOR_IO) * 100)); } /** @@ -216,13 +233,13 @@ public class StateMonitoringThreadFactory implements ThreadFactory { @Ensures({ "activeThreads.size() > old(activeThreads.size())", "activeThreads.contains(result)", - "nThreadsToCreated == old(nThreadsToCreated) + 1" + "nThreadsCreated == old(nThreadsCreated) + 1" }) public synchronized Thread newThread(final Runnable runnable) { if ( activeThreads.size() >= nThreadsToCreate) throw new IllegalStateException("Attempting to create more activeThreads than allowed by constructor argument nThreadsToCreate " + nThreadsToCreate); - nThreadsToCreated++; + nThreadsCreated++; final Thread myThread = new TrackingThread(runnable); activeThreads.add(myThread); return myThread; @@ -234,8 +251,7 @@ public class StateMonitoringThreadFactory implements ThreadFactory { * This method updates all of the key timing and tracking information in the factory so that * thread can be retired. After this call the factory shouldn't have a pointer to the thread any longer * - * @param thread - * @param runtimeInMilliseconds + * @param thread the thread whose information we are updating */ @Ensures({ "activeThreads.size() < old(activeThreads.size())", @@ -243,16 +259,24 @@ public class StateMonitoringThreadFactory implements ThreadFactory { "getTotalTime() >= old(getTotalTime())", "countDownLatch.getCount() < old(countDownLatch.getCount())" }) - private synchronized void threadIsDone(final Thread thread, final long runtimeInMilliseconds) { + private synchronized void threadIsDone(final Thread thread) { if ( DEBUG ) logger.warn(" Countdown " + countDownLatch.getCount() + " in thread " + Thread.currentThread().getName()); if ( DEBUG ) logger.warn("UpdateThreadInfo called"); + final long threadID = thread.getId(); final ThreadInfo info = bean.getThreadInfo(thread.getId()); + final long totalTimeNano = bean.getThreadCpuTime(threadID); + final long userTimeNano = bean.getThreadUserTime(threadID); + final long systemTimeNano = totalTimeNano - userTimeNano; + final long userTimeInMilliseconds = nanoToMilli(userTimeNano); + final long systemTimeInMilliseconds = nanoToMilli(systemTimeNano); + if ( info != null ) { - if ( DEBUG ) logger.warn("Updating thread total runtime " + runtimeInMilliseconds + " of which blocked " + info.getBlockedTime() + " and waiting " + info.getWaitedTime()); - incTimes(Thread.State.BLOCKED, info.getBlockedTime()); - incTimes(Thread.State.WAITING, info.getWaitedTime()); - incTimes(Thread.State.RUNNABLE, runtimeInMilliseconds - info.getWaitedTime() - info.getBlockedTime()); + if ( DEBUG ) logger.warn("Updating thread with user runtime " + userTimeInMilliseconds + " and system runtime " + systemTimeInMilliseconds + " of which blocked " + info.getBlockedTime() + " and waiting " + info.getWaitedTime()); + incTimes(State.BLOCKING, info.getBlockedTime()); + incTimes(State.WAITING, info.getWaitedTime()); + incTimes(State.USER_CPU, userTimeInMilliseconds); + incTimes(State.WAITING_FOR_IO, systemTimeInMilliseconds); } // remove the thread from the list of active activeThreads @@ -270,10 +294,16 @@ public class StateMonitoringThreadFactory implements ThreadFactory { * @param state * @param by */ - private synchronized void incTimes(final Thread.State state, final long by) { + @Requires({"state != null", "by >= 0"}) + @Ensures("getTotalTime() == old(getTotalTime()) + by") + private synchronized void incTimes(final State state, final long by) { times.put(state, times.get(state) + by); } + private static long nanoToMilli(final long timeInNano) { + return TimeUnit.NANOSECONDS.toMillis(timeInNano); + } + /** * A wrapper around Thread that tracks the runtime of the thread and calls threadIsDone() when complete */ @@ -284,10 +314,8 @@ public class StateMonitoringThreadFactory implements ThreadFactory { @Override public void run() { - final long startTime = System.currentTimeMillis(); super.run(); - final long endTime = System.currentTimeMillis(); - threadIsDone(this, endTime - startTime); + threadIsDone(this); } } } diff --git a/public/java/test/org/broadinstitute/sting/utils/threading/StateMonitoringThreadFactoryUnitTest.java b/public/java/test/org/broadinstitute/sting/utils/threading/StateMonitoringThreadFactoryUnitTest.java index 5a606c50e..b41070a14 100755 --- a/public/java/test/org/broadinstitute/sting/utils/threading/StateMonitoringThreadFactoryUnitTest.java +++ b/public/java/test/org/broadinstitute/sting/utils/threading/StateMonitoringThreadFactoryUnitTest.java @@ -41,30 +41,30 @@ import java.util.concurrent.*; */ public class StateMonitoringThreadFactoryUnitTest extends BaseTest { // the duration of the tests -- 100 ms is tolerable given the number of tests we are doing - private final static long THREAD_TARGET_DURATION_IN_MILLISECOND = 100; + private final static long THREAD_TARGET_DURATION_IN_MILLISECOND = 1000; final static Object GLOBAL_LOCK = new Object(); private class StateTest extends TestDataProvider { private final double TOLERANCE = 0.1; // willing to tolerate a 10% error - final List statesForThreads; + final List statesForThreads; - public StateTest(final List statesForThreads) { + public StateTest(final List statesForThreads) { super(StateTest.class); this.statesForThreads = statesForThreads; setName("StateTest " + Utils.join(",", statesForThreads)); } - public List getStatesForThreads() { + public List getStatesForThreads() { return statesForThreads; } public int getNStates() { return statesForThreads.size(); } - public double maxStateFraction(final Thread.State state) { return fraction(state) + TOLERANCE; } - public double minStateFraction(final Thread.State state) { return fraction(state) - TOLERANCE; } + public double maxStateFraction(final StateMonitoringThreadFactory.State state) { return fraction(state) + TOLERANCE; } + public double minStateFraction(final StateMonitoringThreadFactory.State state) { return fraction(state) - TOLERANCE; } - private double fraction(final Thread.State state) { + private double fraction(final StateMonitoringThreadFactory.State state) { return Collections.frequency(statesForThreads, state) / (1.0 * statesForThreads.size()); } } @@ -74,18 +74,16 @@ public class StateMonitoringThreadFactoryUnitTest extends BaseTest { * requested for input argument */ private static class StateTestThread implements Callable { - private final Thread.State stateToImplement; + private final StateMonitoringThreadFactory.State stateToImplement; - private StateTestThread(final Thread.State stateToImplement) { - if ( ! StateMonitoringThreadFactory.TRACKED_STATES.contains(stateToImplement) ) - throw new IllegalArgumentException("Unexpected state " + stateToImplement); + private StateTestThread(final StateMonitoringThreadFactory.State stateToImplement) { this.stateToImplement = stateToImplement; } @Override public Double call() throws Exception { switch ( stateToImplement ) { - case RUNNABLE: + case USER_CPU: // do some work until we get to THREAD_TARGET_DURATION_IN_MILLISECOND double sum = 0.0; final long startTime = System.currentTimeMillis(); @@ -96,13 +94,17 @@ public class StateMonitoringThreadFactoryUnitTest extends BaseTest { case WAITING: Thread.currentThread().sleep(THREAD_TARGET_DURATION_IN_MILLISECOND); return 0.0; - case BLOCKED: + case BLOCKING: if ( StateMonitoringThreadFactory.DEBUG ) logger.warn("Blocking..."); synchronized (GLOBAL_LOCK) { // the GLOBAL_LOCK must be held by the unit test itself for this to properly block if ( StateMonitoringThreadFactory.DEBUG ) logger.warn(" ... done blocking"); } return 0.0; + case WAITING_FOR_IO: + // TODO -- implement me + // shouldn't ever get here, throw an exception + throw new ReviewedStingException("WAITING_FOR_IO testing currently not implemented, until we figure out how to force a system call block"); default: throw new ReviewedStingException("Unexpected thread test state " + stateToImplement); } @@ -111,8 +113,11 @@ public class StateMonitoringThreadFactoryUnitTest extends BaseTest { @DataProvider(name = "StateTest") public Object[][] createStateTest() { - for ( final int nThreads : Arrays.asList(1, 2, 3, 4) ) { - for (final List states : Utils.makePermutations(StateMonitoringThreadFactory.TRACKED_STATES, nThreads, true) ) { + for ( final int nThreads : Arrays.asList(3) ) { + //final List allStates = Arrays.asList(StateMonitoringThreadFactory.State.WAITING_FOR_IO); + final List allStates = Arrays.asList(StateMonitoringThreadFactory.State.USER_CPU, StateMonitoringThreadFactory.State.WAITING, StateMonitoringThreadFactory.State.BLOCKING); + //final List allStates = Arrays.asList(StateMonitoringThreadFactory.State.values()); + for (final List states : Utils.makePermutations(allStates, nThreads, true) ) { //if ( Collections.frequency(states, Thread.State.BLOCKED) > 0) new StateTest(states); } @@ -121,7 +126,7 @@ public class StateMonitoringThreadFactoryUnitTest extends BaseTest { return StateTest.getTests(StateTest.class); } - @Test(enabled = false, dataProvider = "StateTest") + @Test(enabled = true, dataProvider = "StateTest") public void testStateTest(final StateTest test) throws InterruptedException { // allows us to test blocking final StateMonitoringThreadFactory factory = new StateMonitoringThreadFactory(test.getNStates()); @@ -130,7 +135,7 @@ public class StateMonitoringThreadFactoryUnitTest extends BaseTest { logger.warn("Running " + test); synchronized (GLOBAL_LOCK) { //logger.warn(" Have lock"); - for ( final Thread.State threadToRunState : test.getStatesForThreads() ) + for ( final StateMonitoringThreadFactory.State threadToRunState : test.getStatesForThreads() ) threadPool.submit(new StateTestThread(threadToRunState)); // lock has to be here for the whole running of the activeThreads but end before the sleep so the blocked activeThreads @@ -153,7 +158,7 @@ public class StateMonitoringThreadFactoryUnitTest extends BaseTest { Assert.assertTrue(totalTime >= minTime, "Factory results not properly accumulated: totalTime = " + totalTime + " < minTime = " + minTime); Assert.assertTrue(totalTime <= maxTime, "Factory results not properly accumulated: totalTime = " + totalTime + " > maxTime = " + maxTime); - for (final Thread.State state : StateMonitoringThreadFactory.TRACKED_STATES ) { + for (final StateMonitoringThreadFactory.State state : StateMonitoringThreadFactory.State.values() ) { final double min = test.minStateFraction(state); final double max = test.maxStateFraction(state); final double obs = factory.getStateFraction(state); @@ -170,6 +175,6 @@ public class StateMonitoringThreadFactoryUnitTest extends BaseTest { Assert.assertEquals(factory.getNThreadsCreated(), test.getNStates()); // should be called to ensure we don't format / NPE on output - factory.printUsageInformation(logger, Priority.INFO); + factory.printUsageInformation(logger, Priority.WARN); } } \ No newline at end of file