源头

最早了解到这个概念是通过京东开源的CompletableFuture工具库 asyncTool

但关注了好久依然对于使用asyncTool觉着是一个非常重的操作,所以就想着自己实现一个轻量级的并行框架。

实现

自己看了asyncTool的实现,在实际案例中发现不需要那么多功能,所以就自己实现了一个轻量级的并行框架。

当前并行框架已知缺点是尽量不要Promise中套Promise,避免死锁。

import com.au92.common.util.exception.MustAlertException;
import com.au92.common.util.exception.UnAlertException;
import com.google.common.collect.Lists;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.SynchronousQueue;
import java.util.concurrent.ThreadFactory;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;
import java.util.function.Consumer;
import lombok.Data;
import lombok.Getter;
import lombok.Setter;
import lombok.SneakyThrows;
import lombok.experimental.Accessors;
import lombok.extern.slf4j.Slf4j;
import lombok.val;

/**
 * <pre>输入一个上下文
 * 然后增加CompletableFuture
 * 然后可以增加一个执行命令
 * CompletableFuture链执行,并把结果赋值给上下文
 * 之后的CompletableFuture可以使用上下文的数据
 * 最终返回一个结果(可以从上下文中拿)
 * </pre>
 *
 * @author p_x_c
 */
@Slf4j
@SuppressWarnings({"all"})
public class Promise<T> {

    private final static ThreadFactory factory = Thread.ofVirtual()
                                                       .factory();

    // 最大线程池大小
    private final static int MAX_POOL_SIZE = Runtime.getRuntime()
                                                    .availableProcessors() * 2;

    // 不限流的方式可能会造成连续不断地FGC,甚至OOM退出,当前采用限流方式
    private final static ExecutorService virtualThreadExecutor = new ThreadPoolExecutor(MAX_POOL_SIZE, MAX_POOL_SIZE, 10L, TimeUnit.MINUTES, new SynchronousQueue<>(), factory, new ThreadPoolExecutor.CallerRunsPolicy());

    private final List<CompletableFuture<?>> futureList = Collections.synchronizedList(Lists.newLinkedList());

    /**
     * 异常
     */
    @Getter
    private final List<Exception> exceptions = Lists.newArrayList();

    /**
     * 上下文
     */
    @Setter
    @Getter
    private T context;

    /**
     * 是否停止
     */
    @Setter
    private boolean broken = false;

    /**
     * 超时时间(秒)
     */
    private int timeout = 10;

    /**
     * 创建一个Promise
     *
     * @param context
     * @return
     */
    public static <T> Promise<T> of(T context) {
        val promise = new Promise<T>();
        promise.setContext(context);
        return promise;
    }

    /**
     * 完成
     */
    @SneakyThrows
    public void finish() {
        if (!futureList.isEmpty()) {
            futureList.forEach(CompletableFuture::join);
        }
        futureList.clear();
        if (broken && !exceptions.isEmpty()) {
            val exception = exceptions.get(0);
            val logMsg = "并发执行框架 error:{}";
            if (exception instanceof UnAlertException || exception instanceof MustAlertException) {
                log.warn(logMsg, context, exception);
            } else {
                log.error(logMsg, context, exception);
            }
            throw exception;
        }
    }

    /**
     * 等待所有任务完成
     *
     * @return
     */
    public Promise<T> join() {
        futureList.forEach(f -> {
            if (broken && f != null) {
                f.cancel(true);
                futureList.clear();
                return;
            }
            try {
                if (!f.isDone() && !f.isCancelled()) {
                    f.get(timeout, TimeUnit.SECONDS);
                }
            } catch (Exception e) {
                exceptions.add(e);
                broken = true;
            }
        });
        futureList.clear();
        return this;
    }

    /**
     * 执行一个命令(没有先后顺序并发执行)
     * <p>
     * 如果有先后顺序,需要先执行{@link #join()}再执行{@link #combine(Consumer)}
     *
     * @param consumer
     * @return
     */
    public Promise<T> combine(Consumer<T> consumer) {
        if (broken) {
            log.error("error:{}", exceptions);
            return this;
        }
        T c = getContext();
        this.futureList.add(CompletableFuture.runAsync(
                new ContextAwareRunnable(() -> {
                    try {
                        consumer.accept(c);
                    } catch (Exception ex) {
                        broken = true;
                        exceptions.add(ex);
                    }
                }),
                virtualThreadExecutor
        ));
        return this;
    }

    /**
     * 自定义的Runnable包装类
     */
    private static class ContextAwareRunnable implements Runnable {
        private final Runnable runnable;

        /**
         * ThreadLocal对象,如果需要用到ParameterThreadLocal获取,需要在这里也赋值一次
         */
        private final Long uid;


        public ContextAwareRunnable(Runnable runnable) {
            this.runnable = runnable;
            // 将主线程的ThreadLocal值传递到线程池内部对象
            this.uid = ParameterThreadLocal.UID.get();
      
        }

        @Override
        public void run() {
            ParameterThreadLocal.UID.set(uid);
            try {
                runnable.run();
            } finally {
                ParameterThreadLocal.removeAll();
            }
        }
    }

    @Data
    @Accessors(chain = true)
    public static class Context<R> {
        private String name;
        private R result;
    }
}

/**
 * @author p_x_c
 */
public class ParameterThreadLocal {
    /**
     * 用户ID
     */
    public final static ThreadLocal<Long> UID = new InheritableThreadLocal<>();




    public static void removeAll() {
        ParameterThreadLocal.UID.remove();
    }
}

使用示例

public static void main(String[] args) {
        val promise = Promise.of(new Context<String>().setName("ddd")
                                                      .setResult("llll"));
        var str = new Context<String>();
        promise.combine(context -> {
                   System.out.println(((Context<?>) context).getName());
                   try {
                       TimeUnit.SECONDS.sleep(3L);
                   } catch (InterruptedException e) {
                       log.error("Interrupted", e);
                   }
                   System.out.println("name: " + promise.getContext()
                                                        .getName());
                   str.setName("b");
               })
               .combine(x -> {
                   str.setName("c");
                   System.out.println("ljlsjlf");
                   log.info("xxxxxx");
                   throw new RuntimeException();
               })
               .combine(context -> System.out.println("name2: " + promise.getContext()
                                                                         .getName()))
               .combine(context -> str.setName("d"))
               .join()
               .finish();
        if (promise.getExceptions()
                   .isEmpty()) {
            System.out.println("ok");
        }
        try {
            TimeUnit.SECONDS.sleep(3L);
        } catch (InterruptedException e) {
            log.error("Interrupted", e);
        }
        System.out.println(str.getName());
    }