Quantcast
Channel: Planet Apache
Viewing all articles
Browse latest Browse all 9364

Gary Gregory: Java multi-threaded unit testing

$
0
0

With Java 5 and JUnit 4, writing multi-threaded unit tests has never been easier.

Let’s take a brain dead ID generator as our domain object example:

    /**
     * Generates sequential unique IDs starting with 1, 2, 3, and so on.
     * <p>
     * This class is NOT thread-safe.
     * </p>
     */
    static class BrokenUniqueIdGenerator {
        private long counter = 0;

        public long nextId() {
            return ++counter;
        }
    }

This is how you test this class with different threads loads of 1, 2, 4, 8, 16, and 32 threads. We use the Java java.util.concurrent API to manage threads to use our domain object concurrently. We use JUnit @Test methods to run the test and verify results.

    @Test
    public void test01() throws InterruptedException, ExecutionException {
        test(1);
    }

    @Test
    public void test02() throws InterruptedException, ExecutionException {
        test(2);
    }

    @Test
    public void test04() throws InterruptedException, ExecutionException {
        test(4);
    }

    @Test
    public void test08() throws InterruptedException, ExecutionException {
        test(8);
    }

    @Test
    public void test16() throws InterruptedException, ExecutionException {
        test(16);
    }

    @Test
    public void test32() throws InterruptedException, ExecutionException {
        test(32);
    }

    private void test(final int threadCount) throws InterruptedException, ExecutionException {
        final BrokenUniqueIdGenerator domainObject = new BrokenUniqueIdGenerator();
        Callable<Long> task = new Callable<Long>() {
            @Override
            public Long call() {
                return domainObject.nextId();
            }
        };
        List<Callable<Long>> tasks = Collections.nCopies(threadCount, task);
        ExecutorService executorService = Executors.newFixedThreadPool(threadCount);
        List<Future<Long>> futures = executorService.invokeAll(tasks);
        List<Long> resultList = new ArrayList<Long>(futures.size());
        // Check for exceptions
        for (Future<Long> future : futures) {
            // Throws an exception if an exception was thrown by the task.
            resultList.add(future.get());
        }
        // Validate the IDs
        Assert.assertEquals(futures.size(), threadCount);
        List<Long> expectedList = new ArrayList<Long>(threadCount);
        for (long i = 1; i <= threadCount; i++) {
            expectedList.add(i);
        }
        Collections.sort(resultList);
        Assert.assertEquals(expectedList, resultList);
    }

Let’s walk through the test(int threadCount) method. We start by creating our domain object:

final BrokenUniqueIdGenerator domainObject = new BrokenUniqueIdGenerator();

This class has one method, nextId, which we wrap into a task, an instance of Callable:

        Callable<Long> task = new Callable<Long>() {
            @Override
            public Long call() {
                return domainObject.nextId();
            }
        };

This is just a generic way to fit our API call in the Java concurrency API.

We then make copies of this task, one for each thread:

List<Callable<Long>> tasks = Collections.nCopies(threadCount, task);

Next, we create a thread pool, sized at least as big as the number of threads we want to test, in this case we use the exact given value threadCount.

ExecutorService executorService = Executors.newFixedThreadPool(threadCount);

And ask Java to run all the tasks concurrently using threads from the pool:

List<Future<Long>> futures = executorService.invokeAll(tasks);

The call to invokeAll blocks until all the threads are done. Each task is run on a thread, which invokes the tasks’ call method, which in turn calls our domain object API, nextId().

When you run this test case, it will sometimes pass and sometimes fail.

That’s multithreaded testing with Java 5 and JUnit 4. Voila!

BTW, the proper implementation is:

    /**
     * Generates sequential unique IDs starting with 1, 2, 3, and so on.
     * <p>
     * This class is thread-safe.
     * </p>
     */
    static class UniqueIdGenerator {
        private final AtomicLong counter = new AtomicLong();

        public long nextId() {
            return counter.incrementAndGet();
        }
    }

The full listing is:

import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.concurrent.atomic.AtomicLong;

import org.junit.Assert;
import org.junit.Test;

public class MultiThreadedTestCase {

    /**
     * Generates sequential unique IDs starting with 1, 2, 3, and so on.
     * <p>
     * This class is NOT thread-safe.
     * </p>
     */
    static class BrokenUniqueIdGenerator {
        private long counter = 0;

        public long nextId() {
            return ++counter;
        }
    }

    /**
     * Generates sequential unique IDs starting with 1, 2, 3, and so on.
     * <p>
     * This class is thread-safe.
     * </p>
     */
    static class UniqueIdGenerator {
        private final AtomicLong counter = new AtomicLong();

        public long nextId() {
            return counter.incrementAndGet();
        }
    }

    private void test(final int threadCount) throws InterruptedException, ExecutionException {
        final UniqueIdGenerator domainObject = new UniqueIdGenerator();
        Callable<Long> task = new Callable<Long>() {
            @Override
            public Long call() {
                return domainObject.nextId();
            }
        };
        List<Callable<Long>> tasks = Collections.nCopies(threadCount, task);
        ExecutorService executorService = Executors.newFixedThreadPool(threadCount);
        List<Future<Long>> futures = executorService.invokeAll(tasks);
        List<Long> resultList = new ArrayList<Long>(futures.size());
        // Check for exceptions
        for (Future<Long> future : futures) {
            // Throws an exception if an exception was thrown by the task.
            resultList.add(future.get());
        }
        // Validate the IDs
        Assert.assertEquals(futures.size(), threadCount);
        List<Long> expectedList = new ArrayList<Long>(threadCount);
        for (long i = 1; i <= threadCount; i++) {
            expectedList.add(i);
        }
        Collections.sort(resultList);
        Assert.assertEquals(expectedList, resultList);
    }

    @Test
    public void test01() throws InterruptedException, ExecutionException {
        test(1);
    }

    @Test
    public void test02() throws InterruptedException, ExecutionException {
        test(2);
    }

    @Test
    public void test04() throws InterruptedException, ExecutionException {
        test(4);
    }

    @Test
    public void test08() throws InterruptedException, ExecutionException {
        test(8);
    }

    @Test
    public void test16() throws InterruptedException, ExecutionException {
        test(16);
    }

    @Test
    public void test32() throws InterruptedException, ExecutionException {
        test(32);
    }
}

Note: I used Oracle Java 1.6.0_24 (64-bit) on Windows 7 (64-bit).



Viewing all articles
Browse latest Browse all 9364

Trending Articles