什么是 RRF?

倒数排名融合 (RRF) 是一种简单而有效的方法,用于合并多个排序列表。它广泛应用于信息检索、搜索引擎和推荐系统,用来融合来自不同排序算法的结果。

RRF 通过基于文档在每个排序列表中的位置分配分数,然后组合这些分数来创建最终排名。公式为:

RRF_score(document) = ∑ 1/(k + rank_i)

其中 k 是一个常数(通常为 60),用于减轻单个列表中高排名的影响。

java实现

import java.util.Arrays;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;

/**
 * 当你有来自多个检索器/召回器的有序结果列表时(例如 BM25、向量召回、规则检索等),RRF 用一个非常简单但稳定的公式把它们融合成一个最终排序
 *
 * @author p_x_c
 */
public class RRF {

    /**
     * 进行 RRF 融合:输入多个有序列表(每个列表第一名、第二名……)
     *
     * @param rankedLists 每个列表是按相关度从高到低的文档ID(或任意唯一键)
     * @param k           平滑参数,常用 60
     * @param <T>         文档ID类型(String/Long/自定义ID)
     * @return 融合后按分数从高到低排序的条目(文档ID -> RRF分数)
     */
    public static <T> LinkedHashMap<T, Double> fuse(List<List<T>> rankedLists, int k) {
        Map<T, Double> score = new HashMap<>();

        for (List<T> list : rankedLists) {
            int rank = 1;
            for (T doc : list) {
                // 公式:1 / (k + rank)
                score.merge(doc, 1.0 / (k + rank), Double::sum);
                rank++;
            }
        }
        // 按分数降序输出稳定有序的 LinkedHashMap
        return score.entrySet()
                    .stream()
                    .sorted((a, b) -> {
                        int c = Double.compare(b.getValue(), a.getValue());
                        if (c != 0) {
                            return c;
                        }
                        // 可选:分数相同按 key 的字典序稳定一下(或不做)
                        return a.getKey()
                                .toString()
                                .compareTo(b.getKey()
                                            .toString());
                    })
                    .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue, (x, y) -> x, LinkedHashMap::new));
    }

    /**
     * 带列表权重的 RRF:给更可信的列表更高权重
     *
     * @param rankedLists 多个列表
     * @param weights     与 rankedLists 一一对应的权重,长度相同
     * @param k           平滑参数
     */
    public static <T> LinkedHashMap<T, Double> fuseWeighted(List<List<T>> rankedLists, List<Double> weights, int k) {
        if (rankedLists.size() != weights.size()) {
            throw new IllegalArgumentException("rankedLists 和 weights 长度必须一致");
        }
        Map<T, Double> score = new HashMap<>();

        for (int i = 0; i < rankedLists.size(); i++) {
            List<T> list = rankedLists.get(i);
            double w = weights.get(i);
            int rank = 1;
            for (T doc : list) {
                score.merge(doc, w * (1.0 / (k + rank)), Double::sum);
                rank++;
            }
        }
        return score.entrySet()
                    .stream()
                    .sorted((a, b) -> {
                        int c = Double.compare(b.getValue(), a.getValue());
                        if (c != 0) {
                            return c;
                        }
                        return a.getKey()
                                .toString()
                                .compareTo(b.getKey()
                                            .toString());
                    })
                    .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue, (x, y) -> x, LinkedHashMap::new));
    }

    // 一个简单的演示
    public static void main(String[] args) {
        List<String> bm25 = Arrays.asList("D3", "D1", "D2", "D5");
        List<String> vector = Arrays.asList("D2", "D4", "D1");
        List<String> rules = Arrays.asList("D5", "D2", "D6");

        LinkedHashMap<String, Double> fused = fuse(Arrays.asList(bm25, vector, rules), 60);
        fused.forEach((id, s) -> System.out.printf("%s -> %.6f%n", id, s));

        // 带权版本:更信任向量检索(2.0),BM25(1.0),规则(0.5)
        LinkedHashMap<String, Double> fusedW = fuseWeighted(Arrays.asList(bm25, vector, rules), Arrays.asList(1.0, 2.0, 0.5), 60);
        System.out.println("---- weighted ----");
        fusedW.forEach((id, s) -> System.out.printf("%s -> %.6f%n", id, s));
    }
}

rust版本

rrf 0.1.0