CountDownLatch的用法
CountDownLatch 类位于java.util.concurrent包下,利用它可以实现一个对多线程任务的计数功能。如有Thread1、Thread2、Thread3,我们需要统计三个线程执行完成的总耗时或在这三个线程执行完成之后在执行Thread4,此时就可以基于CountDownLatch来实现这种功能。
首先,我们来看下它里面提供的函数:
可以看到它只提供了一个构造参数,该构造函数传入一个计量总数,该函数功能是针对计量总数来递减计数1
2
3
4
5
6
7
8
9
10
11/**
* Constructs a {@code CountDownLatch} initialized with the given count.
*
* @param count the number of times {@link #countDown} must be invoked
* before threads can pass through {@link #await}
* @throws IllegalArgumentException if {@code count} is negative
*/
public CountDownLatch(int count) {
if (count < 0) throw new IllegalArgumentException("count < 0");
this.sync = new Sync(count);
}
然后下面这3个方法是CountDownLatch类中最重要的方法1
2
3
4
5
6//调用await()方法的线程会被挂起,它会等待直到count值为0才继续执行
public void await() throws InterruptedException { };
//和await()类似,只不过等待一定的时间后count值还没变为0的话就会继续执行
public boolean await(long timeout, TimeUnit unit) throws InterruptedException { };
//将count值减1
public void countDown() { };
接下来我例举我项目中的一段代码:1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29public static void httpUtils(List<String> list) throws Exception {
//计数总量初始化
CountDownLatch countDownLatch = new CountDownLatch(list.size());
//开始时间
long start = System.currentTimeMillis();
for (int i = 0; i < list.size(); i++) {
String symbol = list.get(i);
String url = server_url + order_url + symbol;
Proxy proxy = IPPool.getProxy(i);
//开启线程池去请求
executorService.execute(new Runnable() {
public void run() {
try {
//运行URL请求
HttpUtils.get(url, proxy);
}catch (Exception e){
e.printStackTrace();
}finally {
//运行完成计数减1
countDownLatch.countDown();
}
}
});
}
//等待计数完成
countDownLatch.await();
System.out.printf("耗时完成:"+(System.currentTimeMillis() - start));
}
上述方法的功能是,计算一批URL的请求完成耗时。await 方法会阻塞主线程,直到所有的 countDown 完成count值为0,或者使用awit 的有参重载,在count>0 的时候,timeout超时会结束阻塞。
实现原理
我们是上面看到了在CountDownLatch中有一个自定义的内部类 Sync 继承 AbstractQueuedSynchronizer 以下简称为 AQS,AQS是一个用来构建锁和同步工具的框架,除了我这里说的CountDownLatch外还包括常用的ReentrantLock、、Semaphore等。
AQS没有锁之类的概念,它有个state变量,是个int类型,在不同场合有着不同含义。对于CountDownLatch来说,则表示计数值的大小。
AQS围绕state提供两种基本操作“获取”和“释放”,有条双向队列存放阻塞的等待线程,并提供一系列判断和处理方法,简单说几点:
- state是独占的,还是共享的;
- state被获取后,其他线程需要等待;
- state被释放后,唤醒等待线程;
- 线程等不及时,如何退出等待。
至于线程是否可以获得state,如何释放state,就不是AQS关心的了,要由子类具体实现。我们看下CountDownLatch的实现:1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31/**
* Synchronization control For CountDownLatch.
* Uses AQS state to represent count.
*/
private static final class Sync extends AbstractQueuedSynchronizer {
private static final long serialVersionUID = 4982264981922014374L;
Sync(int count) {
setState(count);
}
int getCount() {
return getState();
}
protected int tryAcquireShared(int acquires) {
return (getState() == 0) ? 1 : -1;
}
protected boolean tryReleaseShared(int releases) {
// Decrement count; signal when transition to zero
for (;;) {
int c = getState();
if (c == 0)
return false;
int nextc = c-1;
if (compareAndSetState(c, nextc))
return nextc == 0;
}
}
}
Sync中重载了tryAcquireShared 和 tryReleaseShared 方法,其中实现的方法是主要针对 state 的值是否为0进行判断。
接着来看await方法,直接调用了AQS的acquireSharedInterruptibly。1
2
3public void await() throws InterruptedException {
sync.acquireSharedInterruptibly(1);
}
继续跟进1
2
3
4
5
6
7public final void acquireSharedInterruptibly(int arg)
throws InterruptedException {
if (Thread.interrupted())
throw new InterruptedException();
if (tryAcquireShared(arg) < 0)
doAcquireSharedInterruptibly(arg);
}
首先尝试获取共享锁,这里主要是由CountDownLatch实现判断逻辑来判断state 是否为<0。1
2
3protected int tryAcquireShared(int acquires) {
return (getState() == 0) ? 1 : -1;
}
返回1代表获取成功,返回-1代表获取失败。如果获取失败,需要调用doAcquireSharedInterruptibly:1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25private void doAcquireSharedInterruptibly(int arg)
throws InterruptedException {
final Node node = addWaiter(Node.SHARED);
boolean failed = true;
try {
for (;;) {
final Node p = node.predecessor();
if (p == head) {
int r = tryAcquireShared(arg);
if (r >= 0) {
setHeadAndPropagate(node, r);
p.next = null; // help GC
failed = false;
return;
}
}
if (shouldParkAfterFailedAcquire(p, node) &&
parkAndCheckInterrupt())
throw new InterruptedException();
}
} finally {
if (failed)
cancelAcquire(node);
}
}
可以看到上面有个循环阻塞线程的过程,至于为什么用for来做死循环,听说是在汇编码上,for比while要快一些
再来看下释放操作1
2
3public void countDown() {
sync.releaseShared(1);
}
countDown操作实际可以理解为就是释放锁的操作,每调用一次,计数值state减少1:1
2
3
4
5
6
7public final boolean releaseShared(int arg) {
if (tryReleaseShared(arg)) {
doReleaseShared();
return true;
}
return false;
}
同样是首先尝试释放锁,具体实现在CountDownLatch中,对state进行-1:1
2
3
4
5
6
7
8
9
10
11protected boolean tryReleaseShared(int releases) {
// Decrement count; signal when transition to zero
for (;;) {
int c = getState();
if (c == 0)
return false;
int nextc = c-1;
if (compareAndSetState(c, nextc))
return nextc == 0;
}
}
同样是死循环加上cas的方式保证state的减1操作,当计数值等于0,代表所有子线程都执行完毕,被await阻塞的线程可以唤醒了,下一步调用doReleaseShared。
总结
CountDownLatch 使用的是减数计数方式,当计算为0的时候释放所有等待的线程,计数为0时无法再进行重置,不可重复利用