深入理解CyclicBarrier

2023-12-14 22:59:04

1. 概念

CyclicBarrier 字面意思回环栅栏(循环屏障),通过它可以实现让一组线程等待至某个状态(屏障点)之后再全部同时执行。叫做回环是因为当所有等待线程都被释放以后,CyclicBarrier可以被重用。CyclicBarrier 作用是让一组线程相互等待,当达到一个共同点时,所有之前等待的线程再继续执行,且 CyclicBarrier 功能可重复使用。
在这里插入图片描述

2. CylicBarier使用简单案例

public class Main {
    public static void main(String[] args) throws  InterruptedException{
      CyclicBarrier cyclicBarrier=new CyclicBarrier(3);
        for (int i = 0; i < 5; i++) {
            new Thread(()->{
                try{
                    System.out.println(Thread.currentThread().getName()+"开始等待其它线程");
                    //阻塞直到指定方法的数量调用这个方法就会停止阻塞
                    cyclicBarrier.await();
                    System.out.println(Thread.currentThread().getName()+"开始执行");
                    Thread.sleep(5000);
                    System.out.println(Thread.currentThread().getName()+"执行完毕");

                } catch (Exception e) {
                    e.printStackTrace();
                }
            }).start();
        }
    }
}

在这里插入图片描述

可以发现只有3个线程继续执行,剩余两个线程被阻塞

3. 源码

  • 构造方法
//这个构造方法有两个参数,分别是parties和一个任务,parties代表着屏障拦截的线程数量,每个线程调用 await 方法告诉 CyclicBarrier 我已经到达了屏障,然后当前线程被阻塞。当阻塞的线程达到parties的数量时,就会执行barrieAction这个任务
public CyclicBarrier(int parties, Runnable barrierAction) {
        if (parties <= 0) throw new IllegalArgumentException();
        this.parties = parties;
        //使用两个变量存储parties,这也是parties可以复用的根本原因
        this.count = parties;
        this.barrierCommand = barrierAction;
    }

 public CyclicBarrier(int parties) {
        this(parties, null);
    }
  • 重要方法
 public int await() throws InterruptedException, BrokenBarrierException {
        try {
            return dowait(false, 0L);
        } catch (TimeoutException toe) {
            throw new Error(toe); // cannot happen
        }
    }

    public int await(long timeout, TimeUnit unit)
        throws InterruptedException,
               BrokenBarrierException,
               TimeoutException {
        return dowait(true, unit.toNanos(timeout));
    }

源码分析要点

1. 一组现场在触发屏障之前互相等待,最后一个线程到达屏障后唤醒逻辑是如何实现的
2. 栅栏循环是如何实现的
3. 条件队列到同步队列的转换实现逻辑

await()方法

   public int await() throws InterruptedException, BrokenBarrierException {
        try {
            return dowait(false, 0L);
        } catch (TimeoutException toe) {
            throw new Error(toe); // cannot happen
        }
    }

发现里面实际逻辑调用的是dowait(false, 0L)方法

private int dowait(boolean timed, long nanos)
        throws InterruptedException, BrokenBarrierException,
               TimeoutException {
               //定义了一个ReentrantLock
        final ReentrantLock lock = this.lock;
        lock.lock();
        try {
            final Generation g = generation;

            if (g.broken)
                throw new BrokenBarrierException();

            if (Thread.interrupted()) {
                breakBarrier();
                throw new InterruptedException();
            }
            //更新count方法
            int index = --count;
            if (index == 0) {  // tripped
                boolean ranAction = false;
                try {
                    final Runnable command = barrierCommand;
                    if (command != null)
                        command.run();
                    ranAction = true;
                    nextGeneration();
                    return 0;
                } finally {
                    if (!ranAction)
                        breakBarrier();
                }
            }

            // loop until tripped, broken, interrupted, or timed out
            for (;;) {
                try {
                    if (!timed)
                    //进入条件队列trip进行阻塞
                        trip.await();
                    else if (nanos > 0L)
                        nanos = trip.awaitNanos(nanos);
                } catch (InterruptedException ie) {
                    if (g == generation && ! g.broken) {
                        breakBarrier();
                        throw ie;
                    } else {
                     Thread.currentThread().interrupt();
                    }
                }

                if (g.broken)
                    throw new BrokenBarrierException();

                if (g != generation)
                    return index;

                if (timed && nanos <= 0L) {
                    breakBarrier();
                    throw new TimeoutException();
                }
            }
        } finally {
            lock.unlock();
        }
    }

上面方法最核心的就是更新count,然后判断count是否为0,如果为0就开始执行唤醒逻辑(这里先不考虑),如果不为0就会进入trip这个条件队列进行阻塞,下面分析线程是如何进行条件队列阻塞的。

//这是AQS类的一个方法
  public final void await() throws InterruptedException {
            if (Thread.interrupted())
                throw new InterruptedException();
               
            Node node = addConditionWaiter();
            int savedState = fullyRelease(node);
            int interruptMode = 0;
            //判断当亲线程是不是同步队列,不是直接调用park进行阻塞
            while (!isOnSyncQueue(node)) {
                LockSupport.park(this);
                if ((interruptMode = checkInterruptWhileWaiting(node)) != 0)
                    break;
            }
            if (acquireQueued(node, savedState) && interruptMode != THROW_IE)
                interruptMode = REINTERRUPT;
            if (node.nextWaiter != null) // clean up if cancelled
                unlinkCancelledWaiters();
            if (interruptMode != 0)
                reportInterruptAfterWait(interruptMode);
        }

往条件等待队列中添加节点就是下面这句代码

 Node node = addConditionWaiter();
 private Node addConditionWaiter() {
           //获得条件队列的最后一个结点
            Node t = lastWaiter;
            if (t != null && t.waitStatus != Node.CONDITION) {
                unlinkCancelledWaiters();
                t = lastWaiter;
            }
            //如果为空就新创建一个节点
            Node node = new Node(Thread.currentThread(), Node.CONDITION);
            if (t == null)
            //如果当前单向队列为空,直接让新创建的节点成为头节点
                firstWaiter = node;
            else
            //否则就放到尾节点的后面
                t.nextWaiter = node;
             //让尾指针指向当前节点
            lastWaiter = node;
            //返回当前节点
            return node;
        }

addConditionWaiter实际是AQS的内部类ConditionObject中实现的

public class ConditionObject implements Condition, java.io.Serializable {
        private static final long serialVersionUID = 1173984872572414699L;
        //条件队列的第一个节点
        private transient Node firstWaiter;
        //条件队列的最后一个节点
        private transient Node lastWaiter;
        public ConditionObject() { }
        private Node addConditionWaiter() {
            Node t = lastWaiter;
            if (t != null && t.waitStatus != Node.CONDITION) {
                unlinkCancelledWaiters();
                t = lastWaiter;
            }
            //如果条件队列为空,创建一个新的节点
            Node node = new Node(Thread.currentThread(), Node.CONDITION);
            if (t == null)
            //让新创建的节点成为头节点和尾节点
                firstWaiter = node;
            else
                t.nextWaiter = node;
            lastWaiter = node;
            return node;
        }
         private void doSignal(Node first) {
            do {
                if ( (firstWaiter = first.nextWaiter) == null)
                    lastWaiter = null;
                first.nextWaiter = null;
            } while (!transferForSignal(first) &&
                     (first = firstWaiter) != null);
        }

        /**
         * Removes and transfers all nodes.
         * @param first (non-null) the first node on condition queue
         */
        private void doSignalAll(Node first) {
            lastWaiter = firstWaiter = null;
            do {
                Node next = first.nextWaiter;
                first.nextWaiter = null;
                transferForSignal(first);
                first = next;
            } while (first != null);
        }
        private void unlinkCancelledWaiters() {
            Node t = firstWaiter;
            Node trail = null;
            while (t != null) {
                Node next = t.nextWaiter;
                if (t.waitStatus != Node.CONDITION) {
                    t.nextWaiter = null;
                    if (trail == null)
                        firstWaiter = next;
                    else
                        trail.nextWaiter = next;
                    if (next == null)
                        lastWaiter = trail;
                }
                else
                    trail = t;
                t = next;
            }
        }

节点入队后就继续执行 public final void await() throws InterruptedException方法,当调用await()方法,我们需要释放持有的锁,也就是执行下面这句代码:

int savedState = fullyRelease(node);
 final int fullyRelease(Node node) {
        boolean failed = true;
        try {
        //获取state标记(独占锁如果state从0-1表示释放锁,从1-0表示占用锁
            int savedState = getState();
            if (release(savedState)) {
                failed = false;
                return savedState;
            } else {
                throw new IllegalMonitorStateException();
            }
        } finally {
            if (failed)
                node.waitStatus = Node.CANCELLED;
        }
    }
  public final boolean release(int arg) {
        if (tryRelease(arg)) {
            Node h = head;
            if (h != null && h.waitStatus != 0)
                unparkSuccessor(h);
            return true;
        }
        return false;
    }

释放锁后回到await()方法,调用下面代码进行实际阻塞

 while (!isOnSyncQueue(node)) {
                LockSupport.park(this);
                if ((interruptMode = checkInterruptWhileWaiting(node)) != 0)
                    break;
            }

上面就队线程阻塞以及入队的原理分析,下面分析count减到0,后是如何执行线程唤醒的,核心代码是:

private int dowait(boolean timed, long nanos)
        throws InterruptedException, BrokenBarrierException,
               TimeoutException {
    if (index == 0) {  // tripped
        boolean ranAction = false;
        try {
               final Runnable command = barrierCommand;
               if (command != null)
                    command.run();
                    ranAction = true;
                    //开始下一轮屏障
                    nextGeneration();
                    return 0;
                } finally {
                    if (!ranAction)
                        breakBarrier();
                }
            }

nextGeneration的代码如下:

    private void nextGeneration() {
        //唤醒条件队列的所有节点
        trip.signalAll();
        // 恢复count值
        count = parties;
        generation = new Generation();
    }

signalAll()唤醒条件队列中所有的节点

public class ConditionObject implements Condition, java.io.Serializable {
......
	private void doSignalAll(Node first) {
	//首尾节点置为null
            lastWaiter = firstWaiter = null;
            do {
            //获取首节点的下一个节点
                Node next = first.nextWaiter;
                //然后将first的nextWaiter指针置为空
                first.nextWaiter = null;
                //实现头部出队的节点怎么进入同步队列
                transferForSignal(first);
                //然后开始迭代处理下一个节点
                first = next;
            } while (first != null);
        }
......
}

下面分析头部出队的节点进入同步队列的逻辑

final boolean transferForSignal(Node node) {
		//使用CAS操作修改节点的状态
        if (!compareAndSetWaitStatus(node, Node.CONDITION, 0))
            return false;
        //节点入同步队列
        Node p = enq(node);
        int ws = p.waitStatus;
        if (ws > 0 || !compareAndSetWaitStatus(p, ws, Node.SIGNAL))
        //p节点的前驱节点置换为-1,这样就可以唤醒node节点,然后调用park进行阻塞
            LockSupport.unpark(node.thread);
        return true;
    }

文章来源:https://blog.csdn.net/qq_43456605/article/details/134868953
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。