背景

因为我自己有一个tonic的微服务项目,想要在这个项目中加入分布式追踪功能。但又不想引入更多的第三方系统,所以就做了一个简单的分布式追踪系统。

设计目标

只采集出现错误的请求,不采集正常的请求,也就是后采样。

上代码

依赖

[dependencies]
rdkafka = { version = "0.36", features = ["tokio"] }
tokio = { version = "1", features = ["full"] }
serde = { version = "1", features = ["derive"] }
serde_json = "1"
dashmap = "5"

数据结构

use serde::{Deserialize, Serialize};

#[derive(Serialize, Deserialize, Clone)]
struct LogEvent {
    trace_id: String,
    timestamp: i64,
    level: String,
    message: String,
    end: bool,
}

TraceBuffer

struct TraceBuffer {
    logs: Vec<LogEvent>,
    last_update: i64,
    has_error: bool,
}

impl TraceBuffer {
    fn new() -> Self {
        Self {
            logs: Vec::new(),
            last_update: now(),
            has_error: false,
        }
    }

    fn append(&mut self, log: LogEvent) {
        if log.level == "ERROR" || log.message.contains("Exception") {
            self.has_error = true;
        }
        self.logs.push(log);
        self.last_update = now();
    }

    fn should_flush(&self, now: i64, timeout_ms: i64) -> bool {
        now - self.last_update > timeout_ms
    }
}

fn now() -> i64 {
    chrono::Utc::now().timestamp_millis()
}

Consumer(核心)

use rdkafka::consumer::{Consumer, StreamConsumer};
use rdkafka::message::Message;
use dashmap::DashMap;
use std::sync::Arc;
use tokio::time::{sleep, Duration};

type TraceMap = Arc<DashMap<String, TraceBuffer>>;

async fn run_consumer() {
    let consumer: StreamConsumer = rdkafka::config::ClientConfig::new()
        .set("group.id", "log-consumer")
        .set("bootstrap.servers", "localhost:9092")
        .create()
        .unwrap();

    consumer.subscribe(&["logs-topic"]).unwrap();

    let trace_map: TraceMap = Arc::new(DashMap::new());
    start_cleanup_task(trace_map.clone()).await;
    while let Some(msg) = consumer.recv().await.ok() {
        if let Some(payload) = msg.payload() {
            let log: LogEvent = serde_json::from_slice(payload).unwrap();
            process(log, &trace_map);
        }
    }
}
async fn start_cleanup_task(trace_map: TraceMap) {
    tokio::spawn(async move {
        loop {
            sleep(Duration::from_millis(CLEAN_INTERVAL_MS)).await;
            let now = now();
            let keys: Vec<String> = trace_map
                .iter()
                .filter(|entry| entry.value().should_flush(now, TRACE_TIMEOUT_MS))
                .map(|entry| entry.key().clone())
                .collect();
            for trace_id in keys {
                flush(&trace_id, &trace_map);
            }
        }
    });
}

处理逻辑

fn process(log: LogEvent, trace_map: &TraceMap) {
    let mut entry = trace_map
        .entry(log.trace_id.clone())
        .or_insert_with(TraceBuffer::new);

    entry.append(log.clone());

    if log.end {
        flush(&log.trace_id, trace_map);
    }
}

flush + 条件落盘

fn flush(trace_id: &str, trace_map: &TraceMap) {
    if let Some((_, buffer)) = trace_map.remove(trace_id) {
        if buffer.has_error {
            persist(buffer.logs);
        }
    }
}

fn persist(logs: Vec<LogEvent>) {
    println!("persist {} logs", logs.len());
}

设计思路和缺陷

思路

  1. 通过Kafka消费日志数据,使用DashMap存储每个trace_id对应的日志缓冲区(使用trace_id作为键,保证每个trace_id落在相同的消费者身上)。
  2. 每当有新的日志事件到来时,更新对应的TraceBuffer,并检查是否有错误日志。
  3. 定期检查TraceBuffer,如果超过一定时间没有更新或者日志事件标记为结束,并根据是否包含错误日志决定是否持久化。

缺陷

  1. 内存占用:如果有大量的trace_id且没有及时结束,可能会导致内存占用过高。
  2. 数据丢失:如果系统崩溃或者重启,未持久化的日志数据将会丢失。
  3. 内存爆炸:大量MAP可能会导致内存爆炸,尤其是在高并发场景下。

解决方案(AI给的,但我没用,因为我的数据量很小)

const MAX_TRACE: usize = 100_000;

if trace_map.len() > MAX_TRACE {
    let keys: Vec<String> = trace_map.iter()
        .take(1000)
        .map(|e| e.key().clone())
        .collect();

    for k in keys {
        flush(&k, &trace_map);
    }
}
const MAX_LOGS_PER_TRACE: usize = 1000;

if entry.logs.len() > MAX_LOGS_PER_TRACE {
    flush(&log.trace_id, trace_map);
}

对应的java版本

@Data
public class LogEvent {
    private String traceId;
    private long timestamp;
    private String level; // INFO/WARN/ERROR
    private String message;
    private boolean end; // 是否结束标志
}
public class LogProducer {

    private final KafkaProducer<String, LogEvent> producer;

    public LogProducer(Properties props) {
        this.producer = new KafkaProducer<>(props);
    }

    public void send(LogEvent log) {
        ProducerRecord<String, LogEvent> record =
                new ProducerRecord<>("logs-topic", log.getTraceId(), log);
        producer.send(record);
    }
}
public class TraceBuffer {

    private final List<LogEvent> logs = new ArrayList<>();
    private long lastUpdateTime;
    private boolean hasError = false;

    public void append(LogEvent log) {
        logs.add(log);
        lastUpdateTime = System.currentTimeMillis();

        if ("ERROR".equals(log.getLevel()) || log.getMessage().contains("Exception")) {
            hasError = true;
        }
    }

    public boolean shouldFlush(long now, long timeoutMs) {
        return now - lastUpdateTime > timeoutMs;
    }

    public boolean shouldPersist() {
        return hasError;
    }

    public List<LogEvent> getLogs() {
        return logs;
    }
}
public class LogConsumer {

    private final KafkaConsumer<String, LogEvent> consumer;
    private final Map<String, TraceBuffer> traceMap = new ConcurrentHashMap<>();

    private final long TRACE_TIMEOUT = 5000; // 5秒
    private final int MAX_TRACE = 10000;

    public LogConsumer(Properties props) {
        this.consumer = new KafkaConsumer<>(props);
        consumer.subscribe(Collections.singletonList("logs-topic"));
    }

    public void run() {
        while (true) {
            ConsumerRecords<String, LogEvent> records = consumer.poll(Duration.ofMillis(100));

            for (ConsumerRecord<String, LogEvent> record : records) {
                process(record.value());
            }

            cleanup();
        }
    }

    private void process(LogEvent log) {
        TraceBuffer buffer = traceMap.computeIfAbsent(log.getTraceId(), k -> new TraceBuffer());

        buffer.append(log);

        if (log.isEnd()) {
            flush(log.getTraceId());
        }
    }

    private void flush(String traceId) {
        TraceBuffer buffer = traceMap.remove(traceId);
        if (buffer == null) return;

        if (buffer.shouldPersist()) {
            persist(buffer.getLogs());
        }
    }

    private void cleanup() {
        long now = System.currentTimeMillis();

        for (Iterator<Map.Entry<String, TraceBuffer>> it = traceMap.entrySet().iterator(); it.hasNext();) {
            Map.Entry<String, TraceBuffer> entry = it.next();

            if (entry.getValue().shouldFlush(now, TRACE_TIMEOUT)) {
                flush(entry.getKey());
                it.remove();
            }
        }

        // 防止内存爆炸
        if (traceMap.size() > MAX_TRACE) {
            traceMap.keySet().stream().limit(1000).forEach(this::flush);
        }
    }

    private void persist(List<LogEvent> logs) {
        // 可以写 ES / ClickHouse / DB
        System.out.println("Persist logs: " + logs.size());
    }
}