欢迎review代码,指出错误
import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicLong; import java.util.concurrent.locks.AbstractQueuedSynchronizer; /** * 可重用的CountDownLatch * 增加reset方法:count值减少到0后,可以通过reset方法重置,可重复使用 * 增加版本号:可以通过自主控制版本号来实现带有固定周期数的等待和唤醒 */ public class ReusableCountDownLatch { private final Sync sync; /** * 等待线程的版本号 */ private ThreadLocal<Long> threadVersion = new ThreadLocal<>(); /** * 当前对象的最新版本号 */ private AtomicLong latchVersion = new AtomicLong(0); private final class Sync extends AbstractQueuedSynchronizer { private static final long serialVersionUID = 4982264981922014374L; /** * 记录count值,用于重置时使用 */ private int count; /** * 是否自动重置 */ private boolean autoReset; Sync(int count) { this.count = count; this.autoReset = false; setState(count); } Sync(int count,boolean autoReset) { this.count = count; this.autoReset = autoReset; setState(count); } protected void reset() { latchVersion.getAndIncrement(); setState(count); } protected void reset(long version) { latchVersion.set(version); setState(count); } int getCount() { return getState(); } /** * 尝试获取共享锁,AQS框架保证了获取锁和释放锁的过程不会出现并发问题 * @param acquires the acquire argument. This value is always the one * passed to an acquire method, or is the value saved on entry * to a condition wait. The value is otherwise uninterpreted * and can represent anything you like. * @return */ protected int tryAcquireShared(int acquires) { Long tVersion = threadVersion.get(); long lVersion = latchVersion.get(); if(tVersion != null && lVersion > tVersion) { threadVersion.set(null); return 1; } else if(tVersion != null && lVersion < tVersion) { return -1; } boolean res = getState() == 0; if(!res) { threadVersion.set(lVersion); return -1; } return 1; } /** * 尝试释放共享锁 * @param releases the release argument. This value is always the one * passed to a release method, or the current state value upon * entry to a condition wait. The value is otherwise * uninterpreted and can represent anything you like. * @return */ protected boolean tryReleaseShared(int releases) { // Decrement count; signal when transition to zero for (; ; ) { int c = getState(); if (c == 0) return false; int nextc = c - 1; if (compareAndSetState(c, nextc)) { boolean res = nextc == 0; if(res && autoReset) { // 自动reset之后才会唤醒等待线程 reset(); // System.out.println("rest"); } return res; } } } } public ReusableCountDownLatch(int count) { if (count < 0) throw new IllegalArgumentException("count < 0"); this.sync = new Sync(count); } public ReusableCountDownLatch(int count, boolean autoReset) { if (count < 0) throw new IllegalArgumentException("count < 0"); this.sync = new Sync(count,autoReset); } public void await() throws InterruptedException { sync.acquireSharedInterruptibly(1); } public void await(long version) throws InterruptedException { threadVersion.set(version); sync.acquireSharedInterruptibly(1); } public boolean await(long timeout, TimeUnit unit) throws InterruptedException { return sync.tryAcquireSharedNanos(1, unit.toNanos(timeout)); } public boolean await(long timeout, TimeUnit unit,long version) throws InterruptedException { threadVersion.set(version); return sync.tryAcquireSharedNanos(1, unit.toNanos(timeout)); } public void countDown() { sync.releaseShared(1); } public long getCount() { return sync.getCount(); } public void reset() { sync.reset(); } public void reset(long version) { sync.reset(version); } public String toString() { return super.toString() + "[Count = " + sync.getCount() + "]"; } // chatgpt帮忙写的测试用例 public static void main(String[] args) throws InterruptedException { System.out.println("start"); //带自动重置 ReusableCountDownLatch latch = new ReusableCountDownLatch(3, true); for (int i = 0; i < 3; i++) { new Thread(() -> { try { Thread.sleep(1000); } catch (InterruptedException e) { Thread.currentThread().interrupt(); } latch.countDown(); System.out.println("Thread finished"); }).start(); } System.out.println("All threads await"); latch.await(); System.out.println("All threads finished"); // 如果是不自动重置的需要手动重置 //latch.reset(); for (int i = 0; i < 3; i++) { new Thread(() -> { try { Thread.sleep(1000); } catch (InterruptedException e) { Thread.currentThread().interrupt(); } latch.countDown(); System.out.println("Thread finished"); }).start(); } latch.await(); System.out.println("All threads finished again"); } }
标签:count,reset,return,int,重用,sync,CountDownLatch,工具,public From: https://www.cnblogs.com/wsss/p/17385407.html