Java 并行框架
文章目录
源头
最早了解到这个概念是通过京东开源的CompletableFuture工具库 asyncTool。
但关注了好久依然对于使用asyncTool觉着是一个非常重的操作,所以就想着自己实现一个轻量级的并行框架。
实现
自己看了asyncTool的实现,在实际案例中发现不需要那么多功能,所以就自己实现了一个轻量级的并行框架。
当前并行框架已知缺点是尽量不要Promise中套Promise,避免死锁。
import com.au92.common.util.exception.MustAlertException;
import com.au92.common.util.exception.UnAlertException;
import com.google.common.collect.Lists;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.SynchronousQueue;
import java.util.concurrent.ThreadFactory;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;
import java.util.function.Consumer;
import lombok.Data;
import lombok.Getter;
import lombok.Setter;
import lombok.SneakyThrows;
import lombok.experimental.Accessors;
import lombok.extern.slf4j.Slf4j;
import lombok.val;
/**
* <pre>输入一个上下文
* 然后增加CompletableFuture
* 然后可以增加一个执行命令
* CompletableFuture链执行,并把结果赋值给上下文
* 之后的CompletableFuture可以使用上下文的数据
* 最终返回一个结果(可以从上下文中拿)
* </pre>
*
* @author p_x_c
*/
@Slf4j
@SuppressWarnings({"all"})
public class Promise<T> {
private final static ThreadFactory factory = Thread.ofVirtual()
.factory();
// 最大线程池大小
private final static int MAX_POOL_SIZE = Runtime.getRuntime()
.availableProcessors() * 2;
// 不限流的方式可能会造成连续不断地FGC,甚至OOM退出,当前采用限流方式
private final static ExecutorService virtualThreadExecutor = new ThreadPoolExecutor(MAX_POOL_SIZE, MAX_POOL_SIZE, 10L, TimeUnit.MINUTES, new SynchronousQueue<>(), factory, new ThreadPoolExecutor.CallerRunsPolicy());
private final List<CompletableFuture<?>> futureList = Collections.synchronizedList(Lists.newLinkedList());
/**
* 异常
*/
@Getter
private final List<Exception> exceptions = Lists.newArrayList();
/**
* 上下文
*/
@Setter
@Getter
private T context;
/**
* 是否停止
*/
@Setter
private boolean broken = false;
/**
* 超时时间(秒)
*/
private int timeout = 10;
/**
* 创建一个Promise
*
* @param context
* @return
*/
public static <T> Promise<T> of(T context) {
val promise = new Promise<T>();
promise.setContext(context);
return promise;
}
/**
* 完成
*/
@SneakyThrows
public void finish() {
if (!futureList.isEmpty()) {
futureList.forEach(CompletableFuture::join);
}
futureList.clear();
if (broken && !exceptions.isEmpty()) {
val exception = exceptions.get(0);
val logMsg = "并发执行框架 error:{}";
if (exception instanceof UnAlertException || exception instanceof MustAlertException) {
log.warn(logMsg, context, exception);
} else {
log.error(logMsg, context, exception);
}
throw exception;
}
}
/**
* 等待所有任务完成
*
* @return
*/
public Promise<T> join() {
futureList.forEach(f -> {
if (broken && f != null) {
f.cancel(true);
futureList.clear();
return;
}
try {
if (!f.isDone() && !f.isCancelled()) {
f.get(timeout, TimeUnit.SECONDS);
}
} catch (Exception e) {
exceptions.add(e);
broken = true;
}
});
futureList.clear();
return this;
}
/**
* 执行一个命令(没有先后顺序并发执行)
* <p>
* 如果有先后顺序,需要先执行{@link #join()}再执行{@link #combine(Consumer)}
*
* @param consumer
* @return
*/
public Promise<T> combine(Consumer<T> consumer) {
if (broken) {
log.error("error:{}", exceptions);
return this;
}
T c = getContext();
this.futureList.add(CompletableFuture.runAsync(
new ContextAwareRunnable(() -> {
try {
consumer.accept(c);
} catch (Exception ex) {
broken = true;
exceptions.add(ex);
}
}),
virtualThreadExecutor
));
return this;
}
/**
* 自定义的Runnable包装类
*/
private static class ContextAwareRunnable implements Runnable {
private final Runnable runnable;
/**
* ThreadLocal对象,如果需要用到ParameterThreadLocal获取,需要在这里也赋值一次
*/
private final Long uid;
public ContextAwareRunnable(Runnable runnable) {
this.runnable = runnable;
// 将主线程的ThreadLocal值传递到线程池内部对象
this.uid = ParameterThreadLocal.UID.get();
}
@Override
public void run() {
ParameterThreadLocal.UID.set(uid);
try {
runnable.run();
} finally {
ParameterThreadLocal.removeAll();
}
}
}
@Data
@Accessors(chain = true)
public static class Context<R> {
private String name;
private R result;
}
}
/**
* @author p_x_c
*/
public class ParameterThreadLocal {
/**
* 用户ID
*/
public final static ThreadLocal<Long> UID = new InheritableThreadLocal<>();
public static void removeAll() {
ParameterThreadLocal.UID.remove();
}
}
使用示例
public static void main(String[] args) {
val promise = Promise.of(new Context<String>().setName("ddd")
.setResult("llll"));
var str = new Context<String>();
promise.combine(context -> {
System.out.println(((Context<?>) context).getName());
try {
TimeUnit.SECONDS.sleep(3L);
} catch (InterruptedException e) {
log.error("Interrupted", e);
}
System.out.println("name: " + promise.getContext()
.getName());
str.setName("b");
})
.combine(x -> {
str.setName("c");
System.out.println("ljlsjlf");
log.info("xxxxxx");
throw new RuntimeException();
})
.combine(context -> System.out.println("name2: " + promise.getContext()
.getName()))
.combine(context -> str.setName("d"))
.join()
.finish();
if (promise.getExceptions()
.isEmpty()) {
System.out.println("ok");
}
try {
TimeUnit.SECONDS.sleep(3L);
} catch (InterruptedException e) {
log.error("Interrupted", e);
}
System.out.println(str.getName());
}
相似文章
文章作者 pengxiaochao
上次更新 2025-10-08
许可协议 不允许任何形式转载。