Initial implementation of ThreadFactory that monitors running / blocking / waiting time of threads it creates

-- Created makeCombinations utility function (very useful!).  Moved template from VariantContextTestProvider
-- UnitTests for basic functionality
This commit is contained in:
Mark DePristo 2012-08-14 15:02:45 -04:00
parent fc1bd82011
commit be3230a1fd
4 changed files with 342 additions and 14 deletions

View File

@ -732,6 +732,36 @@ public class Utils {
}
}
/**
* Make all combinations of N size of objects
*
* if objects = [A, B, C]
* if N = 1 => [[A], [B], [C]]
* if N = 2 => [[A, A], [B, A], [C, A], [A, B], [B, B], [C, B], [A, C], [B, C], [C, C]]
*
* @param objects
* @param n
* @param <T>
* @return
*/
public static <T> List<List<T>> makeCombinations(final List<T> objects, final int n) {
final List<List<T>> combinations = new ArrayList<List<T>>();
if ( n == 1 ) {
for ( final T o : objects )
combinations.add(Collections.singletonList(o));
} else {
final List<List<T>> sub = makeCombinations(objects, n - 1);
for ( List<T> subI : sub ) {
for ( final T a : objects ) {
combinations.add(Utils.cons(a, subI));
}
}
}
return combinations;
}
/**
* Convenience function that formats the novelty rate as a %.2f string
*

View File

@ -0,0 +1,149 @@
/*
* The MIT License
*
* Copyright (c) 2009 The Broad Institute
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
package org.broadinstitute.sting.utils.threading;
import org.apache.log4j.Logger;
import java.lang.management.ManagementFactory;
import java.lang.management.ThreadInfo;
import java.lang.management.ThreadMXBean;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.EnumMap;
import java.util.List;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ThreadFactory;
/**
* Create threads, collecting statistics about their running state over time
*
* Uses a ThreadMXBean to capture info via ThreadInfo
*
* User: depristo
* Date: 8/14/12
* Time: 8:47 AM
*/
public class StateMonitoringThreadFactory implements ThreadFactory {
protected static final boolean DEBUG = false;
private static Logger logger = Logger.getLogger(StateMonitoringThreadFactory.class);
public static final List<Thread.State> TRACKED_STATES = Arrays.asList(Thread.State.BLOCKED, Thread.State.RUNNABLE, Thread.State.WAITING);
final int threadsToCreate;
final List<Thread> threads;
final EnumMap<Thread.State, Long> times = new EnumMap<Thread.State, Long>(Thread.State.class);
final ThreadMXBean bean;
final CountDownLatch activeThreads;
public StateMonitoringThreadFactory(final int threadsToCreate) {
if ( threadsToCreate <= 0 ) throw new IllegalArgumentException("threadsToCreate <= 0: " + threadsToCreate);
this.threadsToCreate = threadsToCreate;
threads = new ArrayList<Thread>(threadsToCreate);
for ( final Thread.State state : Thread.State.values() )
times.put(state, 0l);
bean = ManagementFactory.getThreadMXBean();
bean.setThreadContentionMonitoringEnabled(true);
bean.setThreadCpuTimeEnabled(true);
activeThreads = new CountDownLatch(threadsToCreate);
}
public synchronized long getStateTime(final Thread.State state) {
return times.get(state);
}
public synchronized long getTotalTime() {
long total = 0;
for ( final long time : times.values() )
total += time;
return total;
}
public synchronized double getStateFraction(final Thread.State state) {
return getStateTime(state) / (1.0 * getTotalTime());
}
public int getNThreads() {
return threads.size();
}
public void waitForAllThreadsToComplete() throws InterruptedException {
activeThreads.await();
}
@Override
public synchronized String toString() {
final StringBuilder b = new StringBuilder();
b.append("total ").append(getTotalTime()).append(" ");
for ( final Thread.State state : TRACKED_STATES ) {
b.append(state).append(" ").append(getStateTime(state)).append(" ");
}
return b.toString();
}
@Override
public synchronized Thread newThread(final Runnable runnable) {
if ( threads.size() >= threadsToCreate )
throw new IllegalStateException("Attempting to create more threads than allowed by constructor argument threadsToCreate " + threadsToCreate);
final Thread myThread = new TrackingThread(runnable);
threads.add(myThread);
return myThread;
}
// TODO -- add polling capability
private synchronized void updateThreadInfo(final Thread thread, final long runtime) {
if ( DEBUG ) logger.warn("UpdateThreadInfo called");
final ThreadInfo info = bean.getThreadInfo(thread.getId());
if ( info != null ) {
if ( DEBUG ) logger.warn("Updating thread total runtime " + runtime + " of which blocked " + info.getBlockedTime() + " and waiting " + info.getWaitedTime());
incTimes(Thread.State.BLOCKED, info.getBlockedTime());
incTimes(Thread.State.WAITING, info.getWaitedTime());
incTimes(Thread.State.RUNNABLE, runtime - info.getWaitedTime() - info.getBlockedTime());
}
}
private synchronized void incTimes(final Thread.State state, final long by) {
times.put(state, times.get(state) + by);
}
private class TrackingThread extends Thread {
private TrackingThread(Runnable runnable) {
super(runnable);
}
@Override
public void run() {
final long startTime = System.currentTimeMillis();
super.run();
final long endTime = System.currentTimeMillis();
if ( DEBUG ) logger.warn(" Countdown " + activeThreads.getCount() + " in thread " + Thread.currentThread().getName());
updateThreadInfo(this, endTime - startTime);
activeThreads.countDown();
if ( DEBUG ) logger.warn(" -> Countdown " + activeThreads.getCount() + " in thread " + Thread.currentThread().getName());
}
}
}

View File

@ -0,0 +1,161 @@
/*
* The MIT License
*
* Copyright (c) 2009 The Broad Institute
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
package org.broadinstitute.sting.utils.threading;
import org.broadinstitute.sting.BaseTest;
import org.broadinstitute.sting.utils.Utils;
import org.broadinstitute.sting.utils.exceptions.ReviewedStingException;
import org.testng.Assert;
import org.testng.annotations.DataProvider;
import org.testng.annotations.Test;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.*;
/**
* Tests for the state monitoring thread factory.
*/
public class StateMonitoringThreadFactoryUnitTest extends BaseTest {
private final static long THREAD_TARGET_DURATION_IN_MILLISECOND = 100;
final static Object GLOBAL_LOCK = new Object();
private class StateTest extends TestDataProvider {
private final double TOLERANCE = 0.1; // willing to tolerate a 10% error
final List<Thread.State> statesForThreads;
public StateTest(final List<Thread.State> statesForThreads) {
super(StateTest.class);
this.statesForThreads = statesForThreads;
setName("StateTest " + Utils.join(",", statesForThreads));
}
public List<Thread.State> getStatesForThreads() {
return statesForThreads;
}
public int getNStates() { return statesForThreads.size(); }
public double maxStateFraction(final Thread.State state) { return fraction(state) + TOLERANCE; }
public double minStateFraction(final Thread.State state) { return fraction(state) - TOLERANCE; }
private double fraction(final Thread.State state) {
return Collections.frequency(statesForThreads, state) / (1.0 * statesForThreads.size());
}
}
private static class StateTestThread implements Callable<Double> {
private final Thread.State stateToImplement;
private StateTestThread(final Thread.State stateToImplement) {
this.stateToImplement = stateToImplement;
}
@Override
public Double call() throws Exception {
switch ( stateToImplement ) {
case RUNNABLE:
// do some work until we get to THREAD_TARGET_DURATION_IN_MILLISECOND
double sum = 0.0;
final long startTime = System.currentTimeMillis();
for ( int i = 1; System.currentTimeMillis() - startTime < (THREAD_TARGET_DURATION_IN_MILLISECOND - 1); i++ ) {
sum += Math.log10(i);
}
return sum;
case WAITING:
Thread.currentThread().sleep(THREAD_TARGET_DURATION_IN_MILLISECOND);
return 0.0;
case BLOCKED:
if ( StateMonitoringThreadFactory.DEBUG ) logger.warn("Blocking...");
synchronized (GLOBAL_LOCK) {
if ( StateMonitoringThreadFactory.DEBUG ) logger.warn(" ... done blocking");
}
return 0.0;
default:
throw new ReviewedStingException("Unexpected thread test state " + stateToImplement);
}
}
}
@DataProvider(name = "StateTest")
public Object[][] createStateTest() {
for ( final int nThreads : Arrays.asList(1, 2, 3, 4, 5) ) {
for (final List<Thread.State> states : Utils.makeCombinations(StateMonitoringThreadFactory.TRACKED_STATES, nThreads) ) {
//if ( Collections.frequency(states, Thread.State.BLOCKED) > 0)
new StateTest(states);
}
}
return StateTest.getTests(StateTest.class);
}
@Test(enabled = true, dataProvider = "StateTest")
public void testStateTest(final StateTest test) throws InterruptedException {
// allows us to test blocking
final StateMonitoringThreadFactory factory = new StateMonitoringThreadFactory(test.getNStates());
final ExecutorService threadPool = Executors.newFixedThreadPool(test.getNStates(), factory);
logger.warn("Running " + test);
synchronized (GLOBAL_LOCK) {
//logger.warn(" Have lock");
for ( final Thread.State threadToRunState : test.getStatesForThreads() )
threadPool.submit(new StateTestThread(threadToRunState));
// lock has to be here for the whole running of the threads but end before the sleep so the blocked threads
// can block for their allotted time
threadPool.shutdown();
Thread.sleep(THREAD_TARGET_DURATION_IN_MILLISECOND);
}
//logger.warn(" Releasing lock");
threadPool.awaitTermination(10, TimeUnit.SECONDS);
//logger.warn(" done awaiting termination");
//logger.warn(" waiting for all threads to complete");
factory.waitForAllThreadsToComplete();
//logger.warn(" done waiting for threads");
// make sure we counted everything properly
final long totalTime = factory.getTotalTime();
final long minTime = (THREAD_TARGET_DURATION_IN_MILLISECOND - 10) * test.getNStates();
//logger.warn("Testing total time");
Assert.assertTrue(totalTime >= minTime, "Factory results not properly accumulated: totalTime = " + totalTime + " < minTime = " + minTime);
for (final Thread.State state : StateMonitoringThreadFactory.TRACKED_STATES ) {
final double min = test.minStateFraction(state);
final double max = test.maxStateFraction(state);
final double obs = factory.getStateFraction(state);
logger.warn(" Checking " + state
+ " min " + String.format("%.2f", min)
+ " max " + String.format("%.2f", max)
+ " obs " + String.format("%.2f", obs)
+ " factor = " + factory);
Assert.assertTrue(obs >= min, "Too little time spent in state " + state + " obs " + obs + " min " + min);
Assert.assertTrue(obs <= max, "Too much time spent in state " + state + " obs " + obs + " max " + min);
}
Assert.assertEquals(factory.getNThreads(), test.getNStates());
}
}

View File

@ -888,20 +888,8 @@ public class VariantContextTestProvider {
}
}
private static final List<List<Allele>> makeAllGenotypes(final List<Allele> alleles, final int highestPloidy) {
final List<List<Allele>> combinations = new ArrayList<List<Allele>>();
if ( highestPloidy == 1 ) {
for ( final Allele a : alleles )
combinations.add(Collections.singletonList(a));
} else {
final List<List<Allele>> sub = makeAllGenotypes(alleles, highestPloidy - 1);
for ( List<Allele> subI : sub ) {
for ( final Allele a : alleles ) {
combinations.add(Utils.cons(a, subI));
}
}
}
return combinations;
private static List<List<Allele>> makeAllGenotypes(final List<Allele> alleles, final int highestPloidy) {
return Utils.makeCombinations(alleles, highestPloidy);
}
public static void assertEquals(final VCFHeader actual, final VCFHeader expected) {