提升检索精度的秘密武器:微调重排序模型实战指南

在我之前的文章中,我们探讨了微调嵌入模型的世界——这是改进检索系统的一个关键步骤。
Qwen微调干货!对话、指令、Function Call、思考链数据集构造全流程揭秘!
今天,我们要更进一步,深入探讨微调重排序模型。虽然嵌入可以帮助我们检索相关文档,但重排序器可以进一步优化这些结果,以确保最准确和上下文相关的匹配。
在这篇文章中,我将向您介绍我的方法,重点关注准备数据和微调自定义重排序器。

Cross Encoders  跨编码器是一种主要用于自然语言处理(NLP)的神经网络架构,主要用于理解两段文本之间的关系,例如句子对。它们在语义相似度、问答和自然语言推理等任务中特别有效。

Re-ranking  重新排序,在检索增强生成(RAG)中,重新排序是提高检索到的文档或段落质量的关键步骤,在它们被用于生成最终答案之前。RAG 结合了基于检索的方法(从大型语料库中检索相关文档)和生成模型(根据检索到的内容生成答案)。重新排序有助于确保在生成步骤中优先考虑最相关和高质量的文档。


重排序的需求源于初始检索阶段的局限性。检索器,如 BM25 这样的稀疏检索器或双编码器这样的密集检索器,可能会返回大量候选文档,这些文档按照与查询的相关性排序并不完美。


重排序通过使用更复杂的模型,如交叉编码器,来细化检索文档的顺序,从而更好地评估每个文档与查询的相关性。通过将最相关的文档输入到生成模型中,最终的输出,无论是答案还是摘要,都变得更加准确和符合上下文。

微调重排序模型对于优化特定任务或领域的性能至关重要。虽然像 BERT 或 RoBERTa 这样的预训练模型对语言有一般理解,但它们可能无法适应像根据查询对文档进行相关性排序这样的任务的细微差别。
微调使模型适应特定任务,提高准确性和相关性评分。它帮助模型学习特定于任务的关联、特定于领域的知识以及隐含的上下文联系,这对于准确的重新排序至关重要。
微调过程也使模型与目标数据分布保持一致。这个过程增强了模型对噪声和输入数据(如错别字或改写查询)变化的鲁棒性。
对于特定领域(例如医学或法律)以及有标注数据的情况,微调尤为重要。通过微调,重排序模型可以更好地细化检索到的文档,确保为下游任务(如 RAG 系统中的问答)提供高质量的输入。
在微调重排序模型中,常用的两种数据集格式是连续分数基础和不同类别基础。
持续的基于分数的格式涉及具有连续相关性分数(例如,介于 0 和 1 之间)的句子或文本对。例如,使用 sentence-transformers 库中的 InputExample 类,您可以定义具有标签 0.3 或 0.8 的对 ["sentence1", "sentence2"] ,表示两个文本之间的相关性或相似度。这种格式非常适合需要评分相关性的任务,例如语义文本相似度或文档排序。
train_samples = [      InputExample(texts=["sentence1""sentence2"], label=0.3),      InputExample(texts=["Another""pair"], label=0.8),]
另一方面,基于特定类别的格式使用预定义的类别来表示句子对之间的关系。例如,在自然语言推理(NLI)任务中,对被标记为“矛盾”、“蕴涵”或“中性”,这些映射到整数值(例如,0、1、2)。这种格式对于需要分类分类的任务很有用,例如确定句子之间的逻辑关系。这两种格式在微调重排序模型中都被广泛使用,具体选择取决于特定任务和数据性质。
为了生成用于微调的合成数据集,我们首先需要从文档集合中创建一个问答数据集。 create_qa_dataset 函数处理此过程。它接受一个文档目录,将它们分割成更小的块(例如,256 个标记),并使用语言模型(LLM)为每个块生成问题。每个块的问题数量可自定义,允许您控制数据集的密度。
要开始微调,我们初始化一个预训练的交叉编码器模型。在这种情况下,我们使用 BAAI/bge-reranker-base model 作为基础模型。我们之前已经讨论过这一点。

评估器使用验证数据集( val_dataloader )定期评估模型的性能。

from sentence_transformers.evaluation import SentenceEvaluator
import torch
from torch.utils.data import DataLoader
import logging
from sentence_transformers.util import batch_to_device
import os
import csv
from sentence_transformers import CrossEncoder
from tqdm.autonotebook import tqdm

logger = logging.getLogger(__name__)

class MSEEval(SentenceEvaluator):
    """
    Evaluate a model based on its accuracy on a labeled dataset

    This requires a model with LossFunction.SOFTMAX

    The results are written in a CSV. If a CSV already exists, then values are appended.
    """

    def __init__(self, 
                 dataloader: DataLoader, 
                 name: str = "", 
                 show_progress_bar: bool =True,
                 write_csv: bool =True):
        """
        Constructs an evaluator for the given dataset

        :param dataloader:
            the data for the evaluation
        """
        self.dataloader = dataloader
        self.name = name
        self.show_progress_bar = show_progress_bar

        if name:
            name = "_"+name

        self.write_csv = write_csv
        self.csv_file = "accuracy_evaluation"+name+"_results.csv"
        self.csv_headers = ["epoch", "steps", "accuracy"]

    def __call__(self, model: CrossEncoder, output_path: str =None, epoch: int=-1, steps: int=-1->float:
        model.model.eval()
        total =0
        loss_total =0

        if epoch !=-1:
            if steps ==-1:
                out_txt = " after epoch {}:".format(epoch)
else:
                out_txt = " in epoch {} after {} steps:".format(epoch, steps)
else:
            out_txt = ":"

        loss_fnc = torch.nn.MSELoss()
        activation_fnc = torch.nn.Sigmoid()

        logger.info("Evaluation on the "+self.name+" dataset"+out_txt)
        self.dataloader.collate_fn = model.smart_batching_collate
for features, labels in tqdm(self.dataloader,  desc="Evaluation", smoothing=0.05, disable=not self.show_progress_bar):

with torch.no_grad():
                model_predictions = model.model(**features, return_dict=True)
                logits = activation_fnc(model_predictions.logits)
                if model.config.num_labels ==1:
                    logits = logits.view(-1)
                loss_value = loss_fnc(logits, labels)

            total +=1 # number of batches
            loss_total += loss_value.cpu().item()
        mse = loss_total/total

        logger.info("MSE: {:.4f} ({}/{})\n".format(mse, loss_total, total))

        if output_path isnotNoneand self.write_csv:
            csv_path = os.path.join(output_path, self.csv_file)
            if not os.path.isfile(csv_path):
withopen(csv_path, newline='', mode="w", encoding="utf-8") as f:
                    writer = csv.writer(f)
                    writer.writerow(self.csv_headers)
                    writer.writerow([epoch, steps, mse])
else:
withopen(csv_path, newline='', mode="a", encoding="utf-8") as f:
                    writer = csv.writer(f)
                    writer.writerow([epoch, steps, mse])

return mse

在处理嵌入后,微调 reranking 模型是自然的下一步,并且这是一种提高系统理解和优先处理信息能力的方法。从生成合成数据到训练和评估模型,每一步都带来了自己的挑战和回报。
无论您是在处理现实世界的数据还是创建自己的数据集,关键在于实验、迭代和改进。我希望这篇指南能激发您在自己的项目中探索微调的潜力。
参考:https://blog.gopenai.com/fine-tuning-re-ranking-models-a-beginners-guide-066b4b9c3ecf

(文:AI技术研习社)

发表评论

×

下载每时AI手机APP

 

和大家一起交流AI最新资讯!

立即前往