之前写的一篇博客JUC包下的线程协作计数CountDownLatch及CyclicBarrier只是介绍了一下这两个工具类的用法,并没有深入探究源码,然而实现方法也比较简单,所以合为一篇来写了,可以借鉴一下设计思想。
CountDownLatch 源码分析
CountDownLatch 的目的是阻塞等待其他线程执行完成,可能是为了满足前置需求,不使用CountDownLatch的时候可以使用join方法来完成这项任务,但CountDownLatch提供了一种更优雅的实现方式。
CountDownLatch的实现是非常简单的,内部就一个继承了AQS的内部类:
private static final class Sync extends AbstractQueuedSynchronizer {
private static final long serialVersionUID = 4982264981922014374L;
Sync(int count) {
setState(count);
}
int getCount() {
return getState();
}
// 只有state为0的时候才能获取资源
protected int tryAcquireShared(int acquires) {
return (getState() == 0) ? 1 : -1;
}
// 释放共享资源,仅是将state通过CAS减1
protected boolean tryReleaseShared(int releases) {
// Decrement count; signal when transition to zero
for (;;) {
int c = getState();
if (c == 0)
return false;
int nextc = c-1;
if (compareAndSetState(c, nextc))
return nextc == 0;
}
}
}
Sync比较特殊的地方在于,一般使用AQS的类都将state抽象为资源的数量,acquire则减一,release则加一。而这里CountDownLatch
中的Sync
确是调用tryReleaseShared
将state减一,调用tryAcquireShared
不修改state值,但是如果是0的话就返回1。
那么我们通常使用构造方法构造一个CountDownLatch
对象,作用是:
public CountDownLatch(int count) {
if (count < 0) throw new IllegalArgumentException("count < 0");
this.sync = new Sync(count);
}
就是将这个值传入Sync
的构造方法,然后通过上面Sync
的源码可以知道就是将state设置为了count
这个值。
构造完CountDownLatch
对象之后,我们在线程中要做的事一般就是调用一次countDown()
,然后调用await()
进行等待。
而countDown
方法仅是调用了一下releaseShared
,将state减一:
public void countDown() {
sync.releaseShared(1);
}
在Sync
类中可以看到,tryReleaseShared
这个方法只要不是将资源减为0了,都返回false
,因此并不是每次释放资源都会通知共享节点。
而如果减为0了,根据AQS的逻辑就会进行doReleaseShared
:
public final boolean releaseShared(int arg) {
// 如果减为0了,返回true
if (tryReleaseShared(arg)) {
doReleaseShared();
return true;
}
return false;
}
而doReleaseShared
方法就是从头节点开始向后传播,唤醒共享节点,所以这里就会把所有等待的线程唤醒。
其实分析到这里就已经看出核心逻辑了,await
方法也只是调用了一下acquireSharedInterruptibly
:
public void await() throws InterruptedException {
sync.acquireSharedInterruptibly(1);
}
这个方法就是AQS的实现,如果没有获取到资源就将线程挂起,而只要state不为0是获取不到资源的。
不过通过源码分析可以知道的是,CountDownLatch
是不能复用的,因为在tryReleaseShared
方法中,减到0就不能继续往下减了,而调用await
方法时,会尝试获取资源,此时state就是0,能获取到资源,所以线程就不会挂起了。
CyclicBarrier 源码分析
使用 CyclicBarrier 的目的是,线程必须等待参与协作的线程达到某个个数(这里称为barrier值),才一起开始工作(形象地称为冲破屏障)。
先看一下它的内部类和字段:
public class CyclicBarrier {
//内部类Generation,用来维护屏障是否打破的信息
private static class Generation {
boolean broken = false;
}
// 用来互斥进入屏障
private final ReentrantLock lock = new ReentrantLock();
// 用来挂起线程,直到到达屏障被冲破
private final Condition trip = lock.newCondition();
// 屏障值,构造之后就不会变
private final int parties;
// 当冲破屏障时需要做的任务
private final Runnable barrierCommand;
// 当前的Generation
private Generation generation = new Generation();
// 当前值,从屏障值开始递减,减为0则冲破屏障,然后又恢复屏障值,等待复用
private int count;
//...
}
parties
是构造时赋予的屏障值,之后不会改变,只要等待的线程到达这个个数就能冲破屏障。count
,我这里称为当前值,一开始就是屏障值,每等待一个线程就会减一,可以理解为还差多少个线程可以冲破屏障。
看一下构造方法验证一下:
public CyclicBarrier(int parties) {
this(parties, null);
}
public CyclicBarrier(int parties, Runnable barrierAction) {
if (parties <= 0) throw new IllegalArgumentException();
this.parties = parties;
this.count = parties;
this.barrierCommand = barrierAction;
}
后面这个版本就是还需要传入一个任务,冲破屏障时会执行这个任务。
然后直接看await
方法,因为CyclicBarrier
除了构造方法基本上就只会用到这个方法:
public int await() throws InterruptedException, BrokenBarrierException {
try {
return dowait(false, 0L);
} catch (TimeoutException toe) {
throw new Error(toe); // cannot happen
}
}
await
方法还有一个带超时时间的版本,不过最终都是会调用dowait
:
private int dowait(boolean timed, long nanos)
throws InterruptedException, BrokenBarrierException,
TimeoutException {
final ReentrantLock lock = this.lock;
// 锁上,进入屏障
lock.lock();
try {
// 获取当前的Generation,看看屏障有没有问题
final Generation g = generation;
if (g.broken)
throw new BrokenBarrierException();
if (Thread.interrupted()) {
breakBarrier();
throw new InterruptedException();
}
// 将count值减一
int index = --count;
// 如果count值减为0,可以冲破屏障了
if (index == 0) {
boolean ranAction = false;
try {
// 如果有command,执行这个任务
final Runnable command = barrierCommand;
if (command != null)
command.run();
ranAction = true;
// 产生一个新的屏障,并唤醒所有等待的线程
nextGeneration();
return 0;
} finally {
// 如果任务执行错误,需要打破屏障,禁止使用
if (!ranAction)
breakBarrier();
}
}
// 如果count值没有减为0
for (;;) {
try {
// 如果不带超时
if (!timed)
// 使用条件变量挂起
trip.await();
else if (nanos > 0L)
// 否则带超时的挂起
nanos = trip.awaitNanos(nanos);
} catch (InterruptedException ie) {
if (g == generation && ! g.broken) {
breakBarrier();
throw ie;
} else {
// We're about to finish waiting even if we had not
// been interrupted, so this interrupt is deemed to
// "belong" to subsequent execution.
Thread.currentThread().interrupt();
}
}
if (g.broken)
throw new BrokenBarrierException();
if (g != generation)
return index;
if (timed && nanos <= 0L) {
breakBarrier();
throw new TimeoutException();
}
}
} finally {
lock.unlock();
}
}
如果count值减为0,在nextGeneration
这个方法中线程会调用条件变量的signalAll
唤醒所有等待的线程:
private void nextGeneration() {
// signal completion of last generation
trip.signalAll();
// set up next generation
count = parties;
generation = new Generation();
}
这就是CyclicBarrier
的原理。
并且在dowait
这个方法中可以注意到,Generation
在每次打破屏障后都会产生一个新的实例替换原有的实例,也就是说,通常获取到的Generation
对象中的broken
都是false。
而什么情况下会导致Generation
对象的broken
变为true而不会被替换呢?通过dowait
方法发现,若执行command任务期间发生异常,是不会生成新的Generation
对象的,也就是说屏障被打破且不会更新。
来个例子验证一下:
public static void main(String[] args) throws InterruptedException {
CyclicBarrier cb = new CyclicBarrier(2,()->{int i = 1/0;}); // 个数为2时才会继续执行
Runnable runnable = ()->{
try {
cb.await();
System.out.println("线程"+Thread.currentThread().getName()+"开始执行……");
} catch (InterruptedException | BrokenBarrierException e) {
e.printStackTrace();
}
};
new Thread(runnable).start();
Thread.sleep(3000);
new Thread(runnable).start();
Thread.sleep(3000);
new Thread(runnable).start();
}
运行这个方法,在3秒之后有两个线程启动,完成了CyclicBarrier
的目标,但是由于要求执行的任务会产生一个运行时异常,导致Barrier被打破且不会被还原,产生BrokenBarrierException
异常。
输出:
Exception in thread "Thread-1" java.lang.ArithmeticException: / by zero
at com.rhett.thread.TestBarrier.lambdamain0(TestBarrier.java:9)
at java.util.concurrent.CyclicBarrier.dowait(CyclicBarrier.java:220)
at java.util.concurrent.CyclicBarrier.dowait(CyclicBarrier.java:220)
at java.util.concurrent.CyclicBarrier.await(CyclicBarrier.java:362)
at java.lang.Thread.run(Thread.java:748)
java.util.concurrent.BrokenBarrierException
at java.util.concurrent.CyclicBarrier.dowait(CyclicBarrier.java:250)
at java.util.concurrent.CyclicBarrier.await(CyclicBarrier.java:362)
at com.rhett.thread.TestBarrier.lambdamain1(TestBarrier.java:12)
at java.lang.Thread.run(Thread.java:748)
java.util.concurrent.BrokenBarrierException
at java.util.concurrent.CyclicBarrier.dowait(CyclicBarrier.java:207)
at java.util.concurrent.CyclicBarrier.await(CyclicBarrier.java:362)
at com.rhett.thread.TestBarrier.lambdamain1(TestBarrier.java:12)
at java.lang.Thread.run(Thread.java:748)
原创文章,作者:彭晨涛,如若转载,请注明出处:https://www.codetool.top/article/countdownlatch%e5%8f%8acyclicbarrier%e6%ba%90%e7%a0%81%e5%88%86%e6%9e%90/