Reciprocal Rank Fusion (RRF)简单实现
文章目录
什么是 RRF?
倒数排名融合 (RRF) 是一种简单而有效的方法,用于合并多个排序列表。它广泛应用于信息检索、搜索引擎和推荐系统,用来融合来自不同排序算法的结果。
RRF 通过基于文档在每个排序列表中的位置分配分数,然后组合这些分数来创建最终排名。公式为:
RRF_score(document) = ∑ 1/(k + rank_i)
其中 k 是一个常数(通常为 60),用于减轻单个列表中高排名的影响。
java实现
import java.util.Arrays;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
/**
* 当你有来自多个检索器/召回器的有序结果列表时(例如 BM25、向量召回、规则检索等),RRF 用一个非常简单但稳定的公式把它们融合成一个最终排序
*
* @author p_x_c
*/
public class RRF {
/**
* 进行 RRF 融合:输入多个有序列表(每个列表第一名、第二名……)
*
* @param rankedLists 每个列表是按相关度从高到低的文档ID(或任意唯一键)
* @param k 平滑参数,常用 60
* @param <T> 文档ID类型(String/Long/自定义ID)
* @return 融合后按分数从高到低排序的条目(文档ID -> RRF分数)
*/
public static <T> LinkedHashMap<T, Double> fuse(List<List<T>> rankedLists, int k) {
Map<T, Double> score = new HashMap<>();
for (List<T> list : rankedLists) {
int rank = 1;
for (T doc : list) {
// 公式:1 / (k + rank)
score.merge(doc, 1.0 / (k + rank), Double::sum);
rank++;
}
}
// 按分数降序输出稳定有序的 LinkedHashMap
return score.entrySet()
.stream()
.sorted((a, b) -> {
int c = Double.compare(b.getValue(), a.getValue());
if (c != 0) {
return c;
}
// 可选:分数相同按 key 的字典序稳定一下(或不做)
return a.getKey()
.toString()
.compareTo(b.getKey()
.toString());
})
.collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue, (x, y) -> x, LinkedHashMap::new));
}
/**
* 带列表权重的 RRF:给更可信的列表更高权重
*
* @param rankedLists 多个列表
* @param weights 与 rankedLists 一一对应的权重,长度相同
* @param k 平滑参数
*/
public static <T> LinkedHashMap<T, Double> fuseWeighted(List<List<T>> rankedLists, List<Double> weights, int k) {
if (rankedLists.size() != weights.size()) {
throw new IllegalArgumentException("rankedLists 和 weights 长度必须一致");
}
Map<T, Double> score = new HashMap<>();
for (int i = 0; i < rankedLists.size(); i++) {
List<T> list = rankedLists.get(i);
double w = weights.get(i);
int rank = 1;
for (T doc : list) {
score.merge(doc, w * (1.0 / (k + rank)), Double::sum);
rank++;
}
}
return score.entrySet()
.stream()
.sorted((a, b) -> {
int c = Double.compare(b.getValue(), a.getValue());
if (c != 0) {
return c;
}
return a.getKey()
.toString()
.compareTo(b.getKey()
.toString());
})
.collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue, (x, y) -> x, LinkedHashMap::new));
}
// 一个简单的演示
public static void main(String[] args) {
List<String> bm25 = Arrays.asList("D3", "D1", "D2", "D5");
List<String> vector = Arrays.asList("D2", "D4", "D1");
List<String> rules = Arrays.asList("D5", "D2", "D6");
LinkedHashMap<String, Double> fused = fuse(Arrays.asList(bm25, vector, rules), 60);
fused.forEach((id, s) -> System.out.printf("%s -> %.6f%n", id, s));
// 带权版本:更信任向量检索(2.0),BM25(1.0),规则(0.5)
LinkedHashMap<String, Double> fusedW = fuseWeighted(Arrays.asList(bm25, vector, rules), Arrays.asList(1.0, 2.0, 0.5), 60);
System.out.println("---- weighted ----");
fusedW.forEach((id, s) -> System.out.printf("%s -> %.6f%n", id, s));
}
}
rust版本
相似文章
文章作者 pengxiaochao
上次更新 2025-09-08
许可协议 不允许任何形式转载。