代码解释

训练数据处理(dataset.py)

主要流程就是从文本文件(训练数据)中逐行加载数据(每行包含一个错误文本和一个正确文本,使用制表符分隔)

然后通过传入的tokenizer将这些文本转换为模型可以接受的张量格式(对应代码中56和66行)返回。

注意:填充-100的逻辑主要是为了在计算损失时忽略填充位置的标签,因为这些位置不应该对模型的训练产生影响。PyTorch 的交叉熵损失函数会自动忽略值为 -100 的标签,这样就不会对模型的训练产生干扰。

labels[labels == self.tokenizer.pad_token_id] = -100也可以写作:

for i in range(len(labels)):
    if labels[i] == pad_token_id:
        labels[i] = -100
# 中文拼写纠错数据集类
# 用于加载和预处理训练数据,将文本转换为模型可接受的张量格式

import torch  # 导入 PyTorch 框架
from torch.utils.data import Dataset  # 导入 Dataset 基类,用于创建自定义数据集


class CSCDataset(Dataset):
    """
    中文拼写纠错数据集类
    继承自 PyTorch 的 Dataset 基类,用于加载 "错误文本\t正确文本" 格式的训练数据
    """
    
    def __init__(self, file_path, tokenizer, max_len=128):
        """
        初始化数据集
        
        参数:
            file_path: 训练数据文件路径,每行格式为 "错误文本\t正确文本"
            tokenizer: BERT 分词器实例
            max_len: 序列最大长度,默认 128
        """
        self.samples = []  # 存储所有样本的列表
        self.tokenizer = tokenizer  # 保存分词器引用
        self.max_len = max_len  # 保存最大序列长度

        # 读取训练数据文件
        with open(file_path, encoding="utf-8") as f:
            for line in f:  # 逐行读取
                # 按制表符分割,得到源文本(错误)和目标文本(正确)
                src, tgt = line.strip().split("\t")
                # 将文本转换为字符列表并添加到样本列表(中文按字符分割)
                self.samples.append((list(src), list(tgt)))

    def __len__(self):
        """
        返回数据集中的样本总数
        PyTorch DataLoader 需要此方法来确定数据集大小
        """
        return len(self.samples)

    def __getitem__(self, idx):
        """
        根据索引获取单个样本
        
        参数:
            idx: 样本索引
            
        返回:
            包含 input_ids、attention_mask 和 labels 的字典
        """
        # 获取指定索引的源文本和目标文本
        src, tgt = self.samples[idx]

        # 对源文本(包含错误的文本)进行编码
        encoding = self.tokenizer(
            src,  # 输入的字符列表
            is_split_into_words=True,  # 表示输入已按词/字符分割
            truncation=True,  # 超过最大长度时截断
            padding="max_length",  # 填充到最大长度
            max_length=self.max_len,  # 最大序列长度
            return_tensors="pt"  # 返回 PyTorch 张量
        )

        # 对目标文本(正确的文本)进行编码,获取标签
        labels = self.tokenizer(
            tgt,  # 目标字符列表
            is_split_into_words=True,  # 表示输入已按词/字符分割
            truncation=True,  # 超过最大长度时截断
            padding="max_length",  # 填充到最大长度
            max_length=self.max_len,  # 最大序列长度
            return_tensors="pt"  # 返回 PyTorch 张量
        )["input_ids"]  # 只需要 input_ids 作为标签

        # 将填充位置的标签设为 -100,这样计算损失时会忽略这些位置
        # PyTorch 的交叉熵损失函数会自动忽略值为 -100 的标签
        labels[labels == self.tokenizer.pad_token_id] = -100

        # 返回训练所需的三个张量
        return {
            "input_ids": encoding["input_ids"].squeeze(0),  # 输入 token ID,去除批次维度
            "attention_mask": encoding["attention_mask"].squeeze(0),  # 注意力掩码,去除批次维度
            "labels": labels.squeeze(0)  # 目标标签,去除批次维度
        }

训练模型(train.py)

15行:tokenizer = BertTokenizerFast.from_pretrained(MODEL_NAME)用来加载预训练模型“bert-base-chinese”的分词器。这个分词器会将输入文本转换为模型可以理解的 token ID。 参数规模:

12 layers
768 hidden
110M 参数

特点:

字粒度 tokenization(非常适合中文纠错)
不需要 word segmentation

16行:dataset = CSCDataset("train.txt", tokenizer) 调用dataset.py实例化你自定义数据集,将文本转换为模型可接受的张量格式。

17行:loader = DataLoader(dataset, batch_size=16, shuffle=True) 创建一个数据加载器,指定批次大小为16,并启用随机打乱数据,防止模型过拟合。

18行:model = BertForMaskedLM.from_pretrained(MODEL_NAME).to(device) 加载预训练的 BERT 模型,并将其移动到指定的设备(CPU 或 GPU)。

19行:optimizer = AdamW(model.parameters(), lr=5e-5) 使用 AdamW 优化器,设置学习率为 5e-5,这是微调预训练模型的常用学习率(常见1e-5 ~ 5e-5)。

21行:model.train() 将模型设置为训练模式,启用 dropout 和其他训练特定的行为。否则就变成eval推理模式了。

22行:for epoch in range(5): 开始训练循环,设置训练轮数为5。

26行:for batch in tqdm(loader, desc=f"Epoch {epoch+1}"): 使用 tqdm 显示训练进度条,迭代数据加载器中的每个批次。

27行:optimizer.zero_grad()清空梯度,避免梯度累加。

32-37行:前向传播计算损失。因为我们提供了标签,BERT 模型会自动计算交叉熵损失。

loss = model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                labels=labels
            ).loss

39行:loss.backward() 反向传播计算梯度。

40行:optimizer.step() 更新模型参数。

41行:total_loss += loss.item() 统计 loss。

45行:avg_loss = total_loss / num_batches 计算平均 loss。

49行:if avg_loss < 0.5: 根据平均 loss 判断是否提前停止训练(早停),避免过拟合。

53、54行:model.save_pretrained("csc_model")tokenizer.save_pretrained("csc_model") 保存训练好的模型和分词器到指定目录,方便后续加载使用。

代码使用

参考从头训练中文错别字纠错模型 (一)