从头训练中文错别字纠错模型 (二)
文章目录
代码解释
训练数据处理(dataset.py)
主要流程就是从文本文件(训练数据)中逐行加载数据(每行包含一个错误文本和一个正确文本,使用制表符分隔)
然后通过传入的tokenizer将这些文本转换为模型可以接受的张量格式(对应代码中56和66行)返回。
注意:填充-100的逻辑主要是为了在计算损失时忽略填充位置的标签,因为这些位置不应该对模型的训练产生影响。PyTorch 的交叉熵损失函数会自动忽略值为 -100 的标签,这样就不会对模型的训练产生干扰。
labels[labels == self.tokenizer.pad_token_id] = -100也可以写作:
for i in range(len(labels)):
if labels[i] == pad_token_id:
labels[i] = -100
# 中文拼写纠错数据集类
# 用于加载和预处理训练数据,将文本转换为模型可接受的张量格式
import torch # 导入 PyTorch 框架
from torch.utils.data import Dataset # 导入 Dataset 基类,用于创建自定义数据集
class CSCDataset(Dataset):
"""
中文拼写纠错数据集类
继承自 PyTorch 的 Dataset 基类,用于加载 "错误文本\t正确文本" 格式的训练数据
"""
def __init__(self, file_path, tokenizer, max_len=128):
"""
初始化数据集
参数:
file_path: 训练数据文件路径,每行格式为 "错误文本\t正确文本"
tokenizer: BERT 分词器实例
max_len: 序列最大长度,默认 128
"""
self.samples = [] # 存储所有样本的列表
self.tokenizer = tokenizer # 保存分词器引用
self.max_len = max_len # 保存最大序列长度
# 读取训练数据文件
with open(file_path, encoding="utf-8") as f:
for line in f: # 逐行读取
# 按制表符分割,得到源文本(错误)和目标文本(正确)
src, tgt = line.strip().split("\t")
# 将文本转换为字符列表并添加到样本列表(中文按字符分割)
self.samples.append((list(src), list(tgt)))
def __len__(self):
"""
返回数据集中的样本总数
PyTorch DataLoader 需要此方法来确定数据集大小
"""
return len(self.samples)
def __getitem__(self, idx):
"""
根据索引获取单个样本
参数:
idx: 样本索引
返回:
包含 input_ids、attention_mask 和 labels 的字典
"""
# 获取指定索引的源文本和目标文本
src, tgt = self.samples[idx]
# 对源文本(包含错误的文本)进行编码
encoding = self.tokenizer(
src, # 输入的字符列表
is_split_into_words=True, # 表示输入已按词/字符分割
truncation=True, # 超过最大长度时截断
padding="max_length", # 填充到最大长度
max_length=self.max_len, # 最大序列长度
return_tensors="pt" # 返回 PyTorch 张量
)
# 对目标文本(正确的文本)进行编码,获取标签
labels = self.tokenizer(
tgt, # 目标字符列表
is_split_into_words=True, # 表示输入已按词/字符分割
truncation=True, # 超过最大长度时截断
padding="max_length", # 填充到最大长度
max_length=self.max_len, # 最大序列长度
return_tensors="pt" # 返回 PyTorch 张量
)["input_ids"] # 只需要 input_ids 作为标签
# 将填充位置的标签设为 -100,这样计算损失时会忽略这些位置
# PyTorch 的交叉熵损失函数会自动忽略值为 -100 的标签
labels[labels == self.tokenizer.pad_token_id] = -100
# 返回训练所需的三个张量
return {
"input_ids": encoding["input_ids"].squeeze(0), # 输入 token ID,去除批次维度
"attention_mask": encoding["attention_mask"].squeeze(0), # 注意力掩码,去除批次维度
"labels": labels.squeeze(0) # 目标标签,去除批次维度
}
训练模型(train.py)
15行:tokenizer = BertTokenizerFast.from_pretrained(MODEL_NAME)用来加载预训练模型“bert-base-chinese”的分词器。这个分词器会将输入文本转换为模型可以理解的 token ID。
参数规模:
12 layers
768 hidden
110M 参数
特点:
字粒度 tokenization(非常适合中文纠错)
不需要 word segmentation
16行:dataset = CSCDataset("train.txt", tokenizer) 调用dataset.py实例化你自定义数据集,将文本转换为模型可接受的张量格式。
17行:loader = DataLoader(dataset, batch_size=16, shuffle=True) 创建一个数据加载器,指定批次大小为16,并启用随机打乱数据,防止模型过拟合。
18行:model = BertForMaskedLM.from_pretrained(MODEL_NAME).to(device) 加载预训练的 BERT 模型,并将其移动到指定的设备(CPU 或 GPU)。
19行:optimizer = AdamW(model.parameters(), lr=5e-5) 使用 AdamW 优化器,设置学习率为 5e-5,这是微调预训练模型的常用学习率(常见1e-5 ~ 5e-5)。
21行:model.train() 将模型设置为训练模式,启用 dropout 和其他训练特定的行为。否则就变成eval推理模式了。
22行:for epoch in range(5): 开始训练循环,设置训练轮数为5。
26行:for batch in tqdm(loader, desc=f"Epoch {epoch+1}"): 使用 tqdm 显示训练进度条,迭代数据加载器中的每个批次。
27行:optimizer.zero_grad()清空梯度,避免梯度累加。
32-37行:前向传播计算损失。因为我们提供了标签,BERT 模型会自动计算交叉熵损失。
loss = model(
input_ids=input_ids,
attention_mask=attention_mask,
labels=labels
).loss
39行:loss.backward() 反向传播计算梯度。
40行:optimizer.step() 更新模型参数。
41行:total_loss += loss.item() 统计 loss。
45行:avg_loss = total_loss / num_batches 计算平均 loss。
49行:if avg_loss < 0.5: 根据平均 loss 判断是否提前停止训练(早停),避免过拟合。
53、54行:model.save_pretrained("csc_model") 和 tokenizer.save_pretrained("csc_model") 保存训练好的模型和分词器到指定目录,方便后续加载使用。
代码使用
相似文章
文章作者 pengxiaochao
上次更新 2026-02-13
许可协议 不允许任何形式转载。