CountDownLatch及CyclicBarrier源码分析

之前写的一篇博客JUC包下的线程协作计数CountDownLatch及CyclicBarrier只是介绍了一下这两个工具类的用法,并没有深入探究源码,然而实现方法也比较简单,所以合为一篇来写了,可以借鉴一下设计思想。

CountDownLatch 源码分析

CountDownLatch 的目的是阻塞等待其他线程执行完成,可能是为了满足前置需求,不使用CountDownLatch的时候可以使用join方法来完成这项任务,但CountDownLatch提供了一种更优雅的实现方式。

CountDownLatch的实现是非常简单的,内部就一个继承了AQS的内部类:

private static final class Sync extends AbstractQueuedSynchronizer {
    private static final long serialVersionUID = 4982264981922014374L;

    Sync(int count) {
        setState(count);
    }

    int getCount() {
        return getState();
    }
    // 只有state为0的时候才能获取资源
    protected int tryAcquireShared(int acquires) {
        return (getState() == 0) ? 1 : -1;
    }
    // 释放共享资源,仅是将state通过CAS减1
    protected boolean tryReleaseShared(int releases) {
        // Decrement count; signal when transition to zero
        for (;;) {
            int c = getState();
            if (c == 0)
                return false;
            int nextc = c-1;
            if (compareAndSetState(c, nextc))
                return nextc == 0;
        }
    }
}

Sync比较特殊的地方在于,一般使用AQS的类都将state抽象为资源的数量,acquire则减一,release则加一。而这里CountDownLatch中的Sync确是调用tryReleaseShared将state减一,调用tryAcquireShared不修改state值,但是如果是0的话就返回1。

那么我们通常使用构造方法构造一个CountDownLatch对象,作用是:

public CountDownLatch(int count) {
    if (count < 0) throw new IllegalArgumentException("count < 0");
    this.sync = new Sync(count);
}

就是将这个值传入Sync的构造方法,然后通过上面Sync的源码可以知道就是将state设置为了count这个值。

构造完CountDownLatch对象之后,我们在线程中要做的事一般就是调用一次countDown(),然后调用await()进行等待。

countDown方法仅是调用了一下releaseShared,将state减一:

public void countDown() {
    sync.releaseShared(1);
}

Sync类中可以看到,tryReleaseShared这个方法只要不是将资源减为0了,都返回false,因此并不是每次释放资源都会通知共享节点。

而如果减为0了,根据AQS的逻辑就会进行doReleaseShared

public final boolean releaseShared(int arg) {
    // 如果减为0了,返回true
    if (tryReleaseShared(arg)) {
        doReleaseShared();
        return true;
    }
    return false;
}

doReleaseShared方法就是从头节点开始向后传播,唤醒共享节点,所以这里就会把所有等待的线程唤醒。

其实分析到这里就已经看出核心逻辑了,await方法也只是调用了一下acquireSharedInterruptibly

public void await() throws InterruptedException {
    sync.acquireSharedInterruptibly(1);
}

这个方法就是AQS的实现,如果没有获取到资源就将线程挂起,而只要state不为0是获取不到资源的。

不过通过源码分析可以知道的是,CountDownLatch是不能复用的,因为在tryReleaseShared方法中,减到0就不能继续往下减了,而调用await方法时,会尝试获取资源,此时state就是0,能获取到资源,所以线程就不会挂起了。

CyclicBarrier 源码分析

使用 CyclicBarrier 的目的是,线程必须等待参与协作的线程达到某个个数(这里称为barrier值),才一起开始工作(形象地称为冲破屏障)。

先看一下它的内部类和字段:

public class CyclicBarrier {
    //内部类Generation,用来维护屏障是否打破的信息
    private static class Generation {
        boolean broken = false;
    }

    // 用来互斥进入屏障
    private final ReentrantLock lock = new ReentrantLock();
    // 用来挂起线程,直到到达屏障被冲破
    private final Condition trip = lock.newCondition();
    // 屏障值,构造之后就不会变
    private final int parties;
    // 当冲破屏障时需要做的任务
    private final Runnable barrierCommand;
    // 当前的Generation
    private Generation generation = new Generation();

    // 当前值,从屏障值开始递减,减为0则冲破屏障,然后又恢复屏障值,等待复用
    private int count;
    //...
}
  • parties是构造时赋予的屏障值,之后不会改变,只要等待的线程到达这个个数就能冲破屏障。
  • count,我这里称为当前值,一开始就是屏障值,每等待一个线程就会减一,可以理解为还差多少个线程可以冲破屏障。

看一下构造方法验证一下:

public CyclicBarrier(int parties) {
    this(parties, null);
}

public CyclicBarrier(int parties, Runnable barrierAction) {
    if (parties <= 0) throw new IllegalArgumentException();
    this.parties = parties;
    this.count = parties;
    this.barrierCommand = barrierAction;
}

后面这个版本就是还需要传入一个任务,冲破屏障时会执行这个任务。

然后直接看await方法,因为CyclicBarrier除了构造方法基本上就只会用到这个方法:

public int await() throws InterruptedException, BrokenBarrierException {
    try {
        return dowait(false, 0L);
    } catch (TimeoutException toe) {
        throw new Error(toe); // cannot happen
    }
}

await方法还有一个带超时时间的版本,不过最终都是会调用dowait:

private int dowait(boolean timed, long nanos)
    throws InterruptedException, BrokenBarrierException,
            TimeoutException {
    final ReentrantLock lock = this.lock;
    // 锁上,进入屏障
    lock.lock();
    try {
        // 获取当前的Generation,看看屏障有没有问题
        final Generation g = generation;

        if (g.broken)
            throw new BrokenBarrierException();

        if (Thread.interrupted()) {
            breakBarrier();
            throw new InterruptedException();
        }
        // 将count值减一
        int index = --count;
        // 如果count值减为0,可以冲破屏障了
        if (index == 0) {  
            boolean ranAction = false;
            try {
                // 如果有command,执行这个任务
                final Runnable command = barrierCommand;
                if (command != null)
                    command.run();
                ranAction = true;
                // 产生一个新的屏障,并唤醒所有等待的线程
                nextGeneration();
                return 0;
            } finally {
                // 如果任务执行错误,需要打破屏障,禁止使用
                if (!ranAction)
                    breakBarrier();
            }
        }

        // 如果count值没有减为0
        for (;;) {
            try {
                // 如果不带超时
                if (!timed)
                    // 使用条件变量挂起
                    trip.await();
                else if (nanos > 0L)
                    // 否则带超时的挂起
                    nanos = trip.awaitNanos(nanos);
            } catch (InterruptedException ie) {
                if (g == generation && ! g.broken) {
                    breakBarrier();
                    throw ie;
                } else {
                    // We're about to finish waiting even if we had not
                    // been interrupted, so this interrupt is deemed to
                    // "belong" to subsequent execution.
                    Thread.currentThread().interrupt();
                }
            }

            if (g.broken)
                throw new BrokenBarrierException();

            if (g != generation)
                return index;

            if (timed && nanos <= 0L) {
                breakBarrier();
                throw new TimeoutException();
            }
        }
    } finally {
        lock.unlock();
    }
}

如果count值减为0,在nextGeneration这个方法中线程会调用条件变量的signalAll唤醒所有等待的线程:

private void nextGeneration() {
    // signal completion of last generation
    trip.signalAll();
    // set up next generation
    count = parties;
    generation = new Generation();
}

这就是CyclicBarrier的原理。

并且在dowait这个方法中可以注意到,Generation在每次打破屏障后都会产生一个新的实例替换原有的实例,也就是说,通常获取到的Generation对象中的broken都是false。

而什么情况下会导致Generation对象的broken变为true而不会被替换呢?通过dowait方法发现,若执行command任务期间发生异常,是不会生成新的Generation对象的,也就是说屏障被打破且不会更新。

来个例子验证一下:

public static void main(String[] args) throws InterruptedException {
    CyclicBarrier cb = new CyclicBarrier(2,()->{int i = 1/0;}); // 个数为2时才会继续执行
    Runnable runnable = ()->{
        try {
            cb.await();
            System.out.println("线程"+Thread.currentThread().getName()+"开始执行……");
        } catch (InterruptedException | BrokenBarrierException e) {
            e.printStackTrace();
        }
    };
    new Thread(runnable).start();
    Thread.sleep(3000);
    new Thread(runnable).start();
    Thread.sleep(3000);
    new Thread(runnable).start();
}

运行这个方法,在3秒之后有两个线程启动,完成了CyclicBarrier的目标,但是由于要求执行的任务会产生一个运行时异常,导致Barrier被打破且不会被还原,产生BrokenBarrierException异常。

输出:

Exception in thread "Thread-1" java.lang.ArithmeticException: / by zero
    at com.rhett.thread.TestBarrier.lambdamain0(TestBarrier.java:9)
    at java.util.concurrent.CyclicBarrier.dowait(CyclicBarrier.java:220)
    at java.util.concurrent.CyclicBarrier.dowait(CyclicBarrier.java:220)
    at java.util.concurrent.CyclicBarrier.await(CyclicBarrier.java:362)
    at java.lang.Thread.run(Thread.java:748)
java.util.concurrent.BrokenBarrierException
    at java.util.concurrent.CyclicBarrier.dowait(CyclicBarrier.java:250)
    at java.util.concurrent.CyclicBarrier.await(CyclicBarrier.java:362)
    at com.rhett.thread.TestBarrier.lambdamain1(TestBarrier.java:12)
    at java.lang.Thread.run(Thread.java:748)
java.util.concurrent.BrokenBarrierException
    at java.util.concurrent.CyclicBarrier.dowait(CyclicBarrier.java:207)
    at java.util.concurrent.CyclicBarrier.await(CyclicBarrier.java:362)
    at com.rhett.thread.TestBarrier.lambdamain1(TestBarrier.java:12)
    at java.lang.Thread.run(Thread.java:748)

原创文章,作者:彭晨涛,如若转载,请注明出处:https://www.codetool.top/article/countdownlatch%e5%8f%8acyclicbarrier%e6%ba%90%e7%a0%81%e5%88%86%e6%9e%90/

发表评论

电子邮件地址不会被公开。