嘘~ 正在从服务器偷取页面 . . .

【项目实战】大模型微调-情绪对话模型


一、项目介绍

背景

在看抖音时经常刷到一个小智聊天机器人,发现它在回答时具有非常强烈和个性化的情绪表达,使用就准备以此为目标,自己尝试做一个这样的情绪对话大模型。
加上在人工智能技术快速发展的今天,对话系统已广泛应用于客服、心理辅导、社交娱乐等领域。例如,在心理咨询场景中,模型需识别用户的焦虑或抑郁情绪并给出共情回应;在电商客服中,需对用户的不满情绪进行安抚。
因此,构建具备情绪感知能力的对话模型也是一个非常有价值的事情。

项目目标

本项目旨在通过实战微调一个情绪对话模型,覆盖数据收集、模型训练、评估优化到部署的全流程,重点解决以下问题:

  • 如何从多来源数据中构建高质量的情绪对话数据集?

  • 如何选择合适的预训练模型并针对性优化其情感生成能力?

  • 如何平衡生成内容的流畅性与情绪准确性?

  • 如何将模型高效部署到实际应用场景?

二、数据收集与处理

2.1 数据来源

本次项目数据来源主要有两点:

  1. 人工制定
  2. 基于现有开源数据,让AI实现情绪数据集制作。

注意:如果让AI来帮助处理数据,尽可能选择效果较好的API接口,不要使用本地的大模型来处理。

2.2 数据标注工具与方法

本项目数据主要是根据公开数据集然后使用Python代码+AI大模型来进行生成,没有用到其他标注工具和方法。

2.3 数据清洗与预处理

通过Python代码指定数据生成模板,然后基于收集的问题来让AI进行自动生成,并设置规则,让AI生成符合我们需求的高质量的数据。

数据清洗和预处理流程如下:

  • 设置数据风格模板(主要修正消息格式)

    • 让AI大模型按照我们想要的数据模板进行生成
  • 生成函数(主要修正消息的结构)

    • 调用大模型API来准备生成数据并返回
  • 设置数据质量过滤规则(添加空值检查)

    • 规则1:回复长度检查
    • 规则2:风格关键词检查
    • 规则3:语义相似度检查
  • 执行生成(添加容错)

2.4 数据划分

训练集、验证集、测试集划分

三、数据收集与处理(实现篇)

3.1 数据收集整理

使用的是公开的CDial-GPT数据集
还有魔塔社区的 LCCC 数据集

数据集下载后,在里面挑选 1000~3000 条数据即可

3.2 配置文件创建

主要用来存放我们的 API 密匙、模型本地路径等敏感配置以及风格模板配置

打开 Pycharm 新建一个emotion_dialogue_tuner项目,在项目下新建 config 目录,然后新建config/settings.yamlconfig/style_config.json文件,分别用来存放API密钥等敏感配置风格模板配置

settings.yaml文件内容如下:

API:
  ZHIPU_API_KEY: "your_api_key_here"  # 敏感信息隔离
  MODEL_NAME: "glm-3-turbo"

PATHS:
  EMBEDDING_MODEL: "embedding_model/thomas/text2vec-base-chinese"  # 你本地embedding模型路径

我embedding模型选择的是thomas/text2vec-base-chinese,可以在魔塔社区下载

style_config.json文件内容如下:

{  
    "温柔":{  
        "system_prompt":"你是一个温柔体贴的聊天助手,说话时总是充满关怀,使用以下特征:\n1. 包含'呢、呀、啦'等语气词\n2. 使用🌸💖😊等温暖表情\n3. 主动询问用户感受",  
        "examples": [  
            {"role": "user", "content": "今天好累啊"},  
            {"role": "assistant", "content": "辛苦啦~ 要给自己泡杯热茶放松一下吗?🌸"},  
            {"role": "user", "content": "考试没考好..."},  
            {"role": "assistant", "content": "没关系的呀~ 下次一定会更好!需要我陪你聊聊吗?😊"}  
        ],     "temperature": 0.3  
    },  
    "毒舌":{  
        "system_prompt":"你是一个喜欢用犀利吐槽表达关心的朋友,需满足:\n1. 使用网络流行语(如'栓Q''退退退')\n2. 包含夸张比喻('你这速度堪比树懒')\n3. 结尾隐藏关心",  
        "examples": [  
            {"role": "user", "content": "又胖了5斤!"},  
            {"role": "assistant", "content": "好家伙!你这是要把体重秤压成分子料理?🏋️"},  
            {"role": "user", "content": "游戏又输了"},  
            {"role": "assistant", "content": "菜就多练练!需要给你推荐《从零开始的电竞之路》吗?🎮"}  
        ],     "temperature": 0.7  
    }  
}

我这里风格类型主要配置了两种:温柔和毒舌两种相反风格,然后为它们设置了相应的系统提示词和示例,让大模型明确自己的角色以及学习我们需要的数据模板类型,以供后续生成。

3.3 配置文件加载模块

新建emotion_dialogue_tuner/src/utils/config_loader.py文件

这个文件主要负责统一管理API密钥、模型路径等敏感信息和风格配置

# config_loader.py  
"""配置文件加载模块,负责统一管理API密钥、模型路径等敏感信息和风格配置"""  
  
import yaml  
import json  
from pathlib import Path  
  
  
class ConfigLoader:  
    """配置加载器,封装配置文件的读取操作"""  
  
    def __init__(self):  
        """初始化时自动定位项目根目录"""  
        self.root_path = Path(__file__).resolve().parent.parent.parent  # 根据实际层级调整  
  
    def load_settings(self) -> dict:  
        """加载YAML格式的全局设置  
        Returns:            dict: 包含API密钥、模型路径等配置的字典  
        """        with open(self.root_path / "config/settings.yaml", "r", encoding="utf-8") as f:  
            return yaml.safe_load(f)  
  
    def load_style_config(self) -> dict:  
        """加载JSON格式的风格配置  
        Returns:            dict: 包含不同对话风格的模板配置  
        """        with open(self.root_path / "config/style_config.json", "r", encoding="utf-8") as f:  
            return json.load(f)

3.4 数据生成核心模块

新建emotion_dialogue_tuner/src/data_generator.py文件

这个文件主要负责调用API生成指定风格的对话数据

# data_generator.py  
"""数据生成核心模块,负责调用API生成指定风格的对话数据"""  
  
from zhipuai import ZhipuAI  
import random  
import time  
  
class StyleDataGenerator:  
    """对话数据生成器,根据配置生成特定风格的对话数据"""  
  
    def __init__(self, api_key: str, style_config: dict):  
        """  
        Args:            api_key (str): 智普API访问密钥  
            style_config (dict): 风格配置字典  
        """        self.client = ZhipuAI(api_key=api_key)  
        self.style_config = style_config  
  
    def _build_messages(self, style_name: str) -> list:  
        """构建符合API要求的消息格式  
        Args:            style_name (str): 目标风格名称(如'温柔')  
        Returns:            list: 包含系统提示和示例对话的消息列表  
        """        config = self.style_config[style_name]  
        return [  
            {"role": "system", "content": config["system_prompt"]},  
            *config["examples"]  # 展开示例对话  
        ]  
  
    def generate_style_data(self, style_name: str, num_samples: int = 50) -> list:  
        """生成指定风格的对话数据  
        Args:            style_name (str): 目标风格名称  
            num_samples (int): 需要生成的样本数量  
        Returns:            list: 生成的对话数据列表,每个元素包含用户输入、助手回复和风格标签  
        """        data = []  
        messages = self._build_messages(style_name)  
  
        # 从本地文件加载用户输入  
        user_inputs = []  
        with open("data/cleaned_output.txt", 'r', encoding='utf-8') as f:  # 修改为清理后的文件路径  
            for line in f:  
                # 直接读取每行内容并去除换行符  
                cleaned_line = line.rstrip('\n')  # 或使用 line.strip()                if cleaned_line:  # 空行过滤(冗余保护)  
                    user_inputs.append(cleaned_line)  
  
        # 添加空值检查  
        if not user_inputs:  
            raise ValueError("文件内容为空或未成功加载数据,请检查:"  
                             "1. 文件路径是否正确 2. 文件是否包含有效内容")  
  
        # 初始化顺序索引  
        current_index = 0  # 添加索引计数器  
        for _ in range(num_samples):  
            try:  
  
                # 按顺序选择用户输入(修改核心部分)  
                user_msg = user_inputs[current_index]  
                current_index = (current_index + 1) % len(user_inputs)  # 循环计数  
  
                # 添加当前用户消息  
                current_messages = messages + [{"role": "user", "content": user_msg}]  
  
                # 调用大模型API生成回复  
                response = self.client.chat.completions.create(  
                    model="glm-3-turbo",  
                    messages=current_messages,  
                    temperature=self.style_config[style_name]["temperature"],  
                    max_tokens=100  
                )  
                reply = response.choices[0].message.content  
  
                # 保存通过质量检查的数据  
                if self._validate_reply(style_name, user_msg, reply):  
                    data.append({  
                        "user": user_msg,  
                        "assistant": reply,  
                        "style": style_name  
                    })  
  
                time.sleep(1.5)  # API调用频率限制保护  
  
            except Exception as e:  
                print(f"生成失败: {str(e)}")  
  
        return data  
  
    def _validate_reply(self, style: str, user_msg: str, reply: str) -> bool:  
        """内部方法:验证回复质量(实际实现应调用Validator类)"""  
        # 简化的验证逻辑,实际应使用独立的Validator类  
        return bool(reply)  # 示例代码

3.5 生成数据质量验证模块

新建emotion_dialogue_tuner/src/utils/validator.py文件

这个文件主要负责回复质量验证,确保生成数据符合质量标准

# validator.py  
"""回复质量验证模块,确保生成数据符合质量标准"""  
  
import numpy as np  
from sentence_transformers import SentenceTransformer  
  
  
class ReplyValidator:  
    """回复验证器,执行多维度质量检查"""  
  
    def __init__(self, model_path: str):  
        """  
        Args:            model_path (str): 本地嵌入模型文件路径  
        """        self.style_model = SentenceTransformer(model_path)  
  
    def validate(self, style: str, user_msg: str, reply: str, ref_text: str) -> bool:  
        """执行完整的质量验证流程  
        Args:            style (str): 目标风格名称  
            user_msg (str): 用户输入文本  
            reply (str): 待验证的回复文本  
            ref_text (str): 参考文本(用于相似度计算)  
        Returns:            bool: 是否通过所有验证规则  
        """        # 基础格式检查  
        if not self._basic_checks(reply):  
            return False  
  
        # 风格关键词匹配检查  
        if not self._style_keyword_check(style, reply):  
            return False  
  
        # 语义相似度验证  
        return self._semantic_similarity_check(ref_text, reply)  
  
    def _basic_checks(self, reply: str) -> bool:  
        """执行基础格式检查  
        1. 非空检查  
        2. 长度限制检查  
        """        return bool(reply) and (5 <= len(reply) <= 150)  
  
    def _style_keyword_check(self, style: str, reply: str) -> bool:  
        """检查是否包含风格特征关键词"""  
        keyword_map = {  
            "温柔": ["呢", "呀", "😊", "🌸"],  
            "毒舌": ["好家伙", "栓Q", "!", "🏋️"]  
        }        return any(kw in reply for kw in keyword_map.get(style, []))  
  
    def _semantic_similarity_check(self, ref_text: str, reply: str) -> bool:  
        """计算与参考文本的语义相似度  
        使用余弦相似度判断,阈值设为0.65  
        """        ref_vec = self.style_model.encode(ref_text)  
        reply_vec = self.style_model.encode(reply)  
        similarity = np.dot(ref_vec, reply_vec)  
        return similarity > 0.65

3.6 主函数 main.py 实现

新建emotion_dialogue_tuner/main.py 文件

# main.py  
"""主执行入口,协调各模块完成数据生成任务"""  
  
from src.utils.config_loader import ConfigLoader  
from src.utils.validator import ReplyValidator  
from src.data_generator import StyleDataGenerator  
import json  
import os  
  
  
def main():  
    # 初始化配置加载器  
    config_loader = ConfigLoader()  
  
    # 加载配置信息  
    try:  
        settings = config_loader.load_settings()  
        style_config = config_loader.load_style_config()  
    except FileNotFoundError as e:  
        print(f"配置文件缺失:{str(e)}")  
        return  
  
    # 初始化核心组件  
    generator = StyleDataGenerator(  
        api_key=settings["API"]["ZHIPU_API_KEY"],  
        style_config=style_config  
    )  
    validator = ReplyValidator(  
        model_path=settings["PATHS"]["EMBEDDING_MODEL"]  
    )  
    # 执行数据生成流程  
    all_data = []  
    try:  
        print("正在生成温柔风格数据...")  
        gentle_data = generator.generate_style_data("温柔", 20)  
        all_data.extend(gentle_data)  
  
        print("正在生成毒舌风格数据...")  
        sarcastic_data = generator.generate_style_data("毒舌", 20)  
        all_data.extend(sarcastic_data)  
  
    except KeyboardInterrupt:  
        print("\n用户中断操作,正在保存已生成数据...")  
    finally:  
        # 确保输出目录存在  
        output_dir = "outputs"  
        os.makedirs(output_dir, exist_ok=True)  
  
        # 持久化保存数据  
        output_path = os.path.join(output_dir, "style_chat_data.json")  
        with open(output_path, "w", encoding="utf-8") as f:  
            json.dump(all_data, f, ensure_ascii=False, indent=2)  
        print(f"数据保存完成,有效样本数:{len(all_data)}")  
  
  
if __name__ == "__main__":  
    main()

四、模型选型与设计

1.模型选型
根据当前的任务特点,选择合适的评测数据以及预选的候选模型

一般来讲,做什么样的任务就选什么样的模型,我们要做的是一个情感对话模型,所以我们在选择模型时应选择对中文文本理解能力较强的模型。

2.模型的大小选择

  • 服务器配置
  • 任务复杂度

可以根据任务选择对应的评测数据,对期望模型客观评测

当前任务为日常聊天对话模型,主要要求模型的中文理解能力,所以我们这里可以用 CLUE(中文理解)数据进行评测:

我们需要准备 opencompass 环境,参考[[【AI大模型应用学习笔记】OpenCompass模型评估框架的使用教程]]

进入OpenCompass根目录下,执行命令

#输出数据集清单
python tools/list_configs.py clue

执行后输出结果如下:

+-----------------------------+------------------------------------------------------------------------------+
| Dataset                     | Config Path                                                                  |
|-----------------------------+------------------------------------------------------------------------------|
| CLUE_C3_gen                 | opencompass/configs/datasets/CLUE_C3/CLUE_C3_gen.py                          |
| CLUE_C3_gen_8c358f          | opencompass/configs/datasets/CLUE_C3/CLUE_C3_gen_8c358f.py                   |
| CLUE_C3_ppl                 | opencompass/configs/datasets/CLUE_C3/CLUE_C3_ppl.py                          |
| CLUE_C3_ppl_56b537          | opencompass/configs/datasets/CLUE_C3/CLUE_C3_ppl_56b537.py                   |
| CLUE_C3_ppl_e24a31          | opencompass/configs/datasets/CLUE_C3/CLUE_C3_ppl_e24a31.py                   |
| CLUE_CMRC_gen               | opencompass/configs/datasets/CLUE_CMRC/CLUE_CMRC_gen.py                      |
| CLUE_CMRC_gen_1bd3c8        | opencompass/configs/datasets/CLUE_CMRC/CLUE_CMRC_gen_1bd3c8.py               |
| CLUE_CMRC_gen_3749cd        | opencompass/configs/datasets/CLUE_CMRC/CLUE_CMRC_gen_3749cd.py               |
| CLUE_CMRC_gen_8484b9        | opencompass/configs/datasets/CLUE_CMRC/CLUE_CMRC_gen_8484b9.py               |
| CLUE_CMRC_gen_941108        | opencompass/configs/datasets/CLUE_CMRC/CLUE_CMRC_gen_941108.py               |
| CLUE_DRCD_gen               | opencompass/configs/datasets/CLUE_DRCD/CLUE_DRCD_gen.py                      |
| CLUE_DRCD_gen_1bd3c8        | opencompass/configs/datasets/CLUE_DRCD/CLUE_DRCD_gen_1bd3c8.py               |
| CLUE_DRCD_gen_3749cd        | opencompass/configs/datasets/CLUE_DRCD/CLUE_DRCD_gen_3749cd.py               |
| CLUE_DRCD_gen_8484b9        | opencompass/configs/datasets/CLUE_DRCD/CLUE_DRCD_gen_8484b9.py               |
| CLUE_DRCD_gen_941108        | opencompass/configs/datasets/CLUE_DRCD/CLUE_DRCD_gen_941108.py               |
| CLUE_afqmc_gen              | opencompass/configs/datasets/CLUE_afqmc/CLUE_afqmc_gen.py                    |
| CLUE_afqmc_gen_901306       | opencompass/configs/datasets/CLUE_afqmc/CLUE_afqmc_gen_901306.py             |
| CLUE_afqmc_ppl              | opencompass/configs/datasets/CLUE_afqmc/CLUE_afqmc_ppl.py                    |
| CLUE_afqmc_ppl_378c5b       | opencompass/configs/datasets/CLUE_afqmc/CLUE_afqmc_ppl_378c5b.py             |
| CLUE_afqmc_ppl_6507d7       | opencompass/configs/datasets/CLUE_afqmc/CLUE_afqmc_ppl_6507d7.py             |
| CLUE_afqmc_ppl_7b0c1e       | opencompass/configs/datasets/CLUE_afqmc/CLUE_afqmc_ppl_7b0c1e.py             |
| CLUE_cmnli_gen              | opencompass/configs/datasets/CLUE_cmnli/CLUE_cmnli_gen.py                    |
| CLUE_cmnli_gen_1abf97       | opencompass/configs/datasets/CLUE_cmnli/CLUE_cmnli_gen_1abf97.py             |
| CLUE_cmnli_gen_51e956       | opencompass/configs/datasets/CLUE_cmnli/CLUE_cmnli_gen_51e956.py             |
| CLUE_cmnli_ppl              | opencompass/configs/datasets/CLUE_cmnli/CLUE_cmnli_ppl.py                    |
| CLUE_cmnli_ppl_98dd6e       | opencompass/configs/datasets/CLUE_cmnli/CLUE_cmnli_ppl_98dd6e.py             |
| CLUE_cmnli_ppl_ef69e7       | opencompass/configs/datasets/CLUE_cmnli/CLUE_cmnli_ppl_ef69e7.py             |
| CLUE_cmnli_ppl_fdc6de       | opencompass/configs/datasets/CLUE_cmnli/CLUE_cmnli_ppl_fdc6de.py             |
| CLUE_ocnli_gen              | opencompass/configs/datasets/CLUE_ocnli/CLUE_ocnli_gen.py                    |
| CLUE_ocnli_gen_51e956       | opencompass/configs/datasets/CLUE_ocnli/CLUE_ocnli_gen_51e956.py             |
| CLUE_ocnli_gen_c4cb6c       | opencompass/configs/datasets/CLUE_ocnli/CLUE_ocnli_gen_c4cb6c.py             |
| CLUE_ocnli_ppl              | opencompass/configs/datasets/CLUE_ocnli/CLUE_ocnli_ppl.py                    |
| CLUE_ocnli_ppl_98dd6e       | opencompass/configs/datasets/CLUE_ocnli/CLUE_ocnli_ppl_98dd6e.py             |
| CLUE_ocnli_ppl_ef69e7       | opencompass/configs/datasets/CLUE_ocnli/CLUE_ocnli_ppl_ef69e7.py             |
| CLUE_ocnli_ppl_fdc6de       | opencompass/configs/datasets/CLUE_ocnli/CLUE_ocnli_ppl_fdc6de.py             |
| FewCLUE_bustm_gen           | opencompass/configs/datasets/FewCLUE_bustm/FewCLUE_bustm_gen.py              |
| FewCLUE_bustm_gen_634f41    | opencompass/configs/datasets/FewCLUE_bustm/FewCLUE_bustm_gen_634f41.py       |
| FewCLUE_bustm_ppl           | opencompass/configs/datasets/FewCLUE_bustm/FewCLUE_bustm_ppl.py              |
| FewCLUE_bustm_ppl_4b16c0    | opencompass/configs/datasets/FewCLUE_bustm/FewCLUE_bustm_ppl_4b16c0.py       |
| FewCLUE_bustm_ppl_9ef540    | opencompass/configs/datasets/FewCLUE_bustm/FewCLUE_bustm_ppl_9ef540.py       |
| FewCLUE_bustm_ppl_e53034    | opencompass/configs/datasets/FewCLUE_bustm/FewCLUE_bustm_ppl_e53034.py       |
| FewCLUE_chid_gen            | opencompass/configs/datasets/FewCLUE_chid/FewCLUE_chid_gen.py                |
| FewCLUE_chid_gen_0a29a2     | opencompass/configs/datasets/FewCLUE_chid/FewCLUE_chid_gen_0a29a2.py         |
| FewCLUE_chid_ppl            | opencompass/configs/datasets/FewCLUE_chid/FewCLUE_chid_ppl.py                |
| FewCLUE_chid_ppl_8f2872     | opencompass/configs/datasets/FewCLUE_chid/FewCLUE_chid_ppl_8f2872.py         |
| FewCLUE_chid_ppl_acccb5     | opencompass/configs/datasets/FewCLUE_chid/FewCLUE_chid_ppl_acccb5.py         |
| FewCLUE_cluewsc_gen         | opencompass/configs/datasets/FewCLUE_cluewsc/FewCLUE_cluewsc_gen.py          |
| FewCLUE_cluewsc_gen_c68933  | opencompass/configs/datasets/FewCLUE_cluewsc/FewCLUE_cluewsc_gen_c68933.py   |
| FewCLUE_cluewsc_ppl         | opencompass/configs/datasets/FewCLUE_cluewsc/FewCLUE_cluewsc_ppl.py          |
| FewCLUE_cluewsc_ppl_12e4e0  | opencompass/configs/datasets/FewCLUE_cluewsc/FewCLUE_cluewsc_ppl_12e4e0.py   |
| FewCLUE_cluewsc_ppl_4284a0  | opencompass/configs/datasets/FewCLUE_cluewsc/FewCLUE_cluewsc_ppl_4284a0.py   |
| FewCLUE_cluewsc_ppl_868415  | opencompass/configs/datasets/FewCLUE_cluewsc/FewCLUE_cluewsc_ppl_868415.py   |
| FewCLUE_csl_gen             | opencompass/configs/datasets/FewCLUE_csl/FewCLUE_csl_gen.py                  |
| FewCLUE_csl_gen_28b223      | opencompass/configs/datasets/FewCLUE_csl/FewCLUE_csl_gen_28b223.py           |
| FewCLUE_csl_gen_87f4a8      | opencompass/configs/datasets/FewCLUE_csl/FewCLUE_csl_gen_87f4a8.py           |
| FewCLUE_csl_ppl             | opencompass/configs/datasets/FewCLUE_csl/FewCLUE_csl_ppl.py                  |
| FewCLUE_csl_ppl_769f8d      | opencompass/configs/datasets/FewCLUE_csl/FewCLUE_csl_ppl_769f8d.py           |
| FewCLUE_csl_ppl_841b62      | opencompass/configs/datasets/FewCLUE_csl/FewCLUE_csl_ppl_841b62.py           |
| FewCLUE_eprstmt_gen         | opencompass/configs/datasets/FewCLUE_eprstmt/FewCLUE_eprstmt_gen.py          |
| FewCLUE_eprstmt_gen_740ea0  | opencompass/configs/datasets/FewCLUE_eprstmt/FewCLUE_eprstmt_gen_740ea0.py   |
| FewCLUE_eprstmt_ppl         | opencompass/configs/datasets/FewCLUE_eprstmt/FewCLUE_eprstmt_ppl.py          |
| FewCLUE_eprstmt_ppl_1ce587  | opencompass/configs/datasets/FewCLUE_eprstmt/FewCLUE_eprstmt_ppl_1ce587.py   |
| FewCLUE_eprstmt_ppl_f1e631  | opencompass/configs/datasets/FewCLUE_eprstmt/FewCLUE_eprstmt_ppl_f1e631.py   |
| FewCLUE_ocnli_fc_gen        | opencompass/configs/datasets/FewCLUE_ocnli_fc/FewCLUE_ocnli_fc_gen.py        |
| FewCLUE_ocnli_fc_gen_f97a97 | opencompass/configs/datasets/FewCLUE_ocnli_fc/FewCLUE_ocnli_fc_gen_f97a97.py |
| FewCLUE_ocnli_fc_ppl        | opencompass/configs/datasets/FewCLUE_ocnli_fc/FewCLUE_ocnli_fc_ppl.py        |
| FewCLUE_ocnli_fc_ppl_9e8b3d | opencompass/configs/datasets/FewCLUE_ocnli_fc/FewCLUE_ocnli_fc_ppl_9e8b3d.py |
| FewCLUE_ocnli_fc_ppl_c08300 | opencompass/configs/datasets/FewCLUE_ocnli_fc/FewCLUE_ocnli_fc_ppl_c08300.py |
| FewCLUE_tnews_gen           | opencompass/configs/datasets/FewCLUE_tnews/FewCLUE_tnews_gen.py              |
| FewCLUE_tnews_gen_b90e4a    | opencompass/configs/datasets/FewCLUE_tnews/FewCLUE_tnews_gen_b90e4a.py       |
| FewCLUE_tnews_ppl           | opencompass/configs/datasets/FewCLUE_tnews/FewCLUE_tnews_ppl.py              |
| FewCLUE_tnews_ppl_7d1c07    | opencompass/configs/datasets/FewCLUE_tnews/FewCLUE_tnews_ppl_7d1c07.py       |
| FewCLUE_tnews_ppl_d10e8a    | opencompass/configs/datasets/FewCLUE_tnews/FewCLUE_tnews_ppl_d10e8a.py       |
| FewCLUE_tnews_ppl_fff486    | opencompass/configs/datasets/FewCLUE_tnews/FewCLUE_tnews_ppl_fff486.py       |
+-----------------------------+------------------------------------------------------------------------------+

gen:生成任务
ppl:分类任务

我们本次任务大多是短语对话,可以选择 FewCLUE_bustm_gen(短文本分类)、FewCLUE_ocnli_fc_gen(自然语言推理)对预期模型进行评估。

我们这里选择Qwen1.5 的 0.5b、1.8b两个个模型来进行比较,这里记得修改opencompass 下对应模型文件中的模型路径。

运行下面命令开始进行模型评估比较:

python run.py \
--models hf_qwen1_5_0_5b_chat hf_qwen1_5_1_8b_chat \
--datasets FewCLUE_bustm_gen FewCLUE_ocnli_fc_gen \
--debug

根据评估结果,选择最终模型。

评估结果如下:

tabulate format
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
dataset        version    metric    mode      qwen1.5-0.5b-chat-hf    qwen1.5-1.8b-chat-hf
-------------  ---------  --------  ------  ----------------------  ----------------------
bustm-dev      5cc669     accuracy  gen                      48.75                   48.75
bustm-test     5cc669     accuracy  gen                      50.00                   50.11
ocnli_fc-dev   51e956     accuracy  gen                      35.62                   46.25
ocnli_fc-test  51e956     accuracy  gen                      35.20                   50.63
$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$

其实这里我们选的是同一个模型的不同参数版本,那么必然是参数量大的那个评估效果要好。

五、模型训练与评估

5.1训练环境配置

使用环境

  • 软件环境:Windows11 + WSL2-Linux-Ubuntu22.04子系统
  • 硬件环境:GeForce RTX 4060 Ti 16GB
  • 框架与工具(LLamaFactory/Xtuner)

因为当前任务的结果更偏向于主观评测,xtener就提供了在训练过程中的主观评测,因此选择xtuner

5.2 训练参数设置

安装好xtuner环境

创建微调训练相关的配置文件在左侧的文件列表,xtuner 的文件夹里,打开xtuner/xtuner/configs/internlm/internlm2_chat_1_8b/internlm2_chat_1_8b_qlora_alpaca_e3.py,复制一份到其他目录。
打开这个文件,然后修改预训练模型地址,数据文件地址等。
我配置修改的地方如下:

### 在 PART 1  Settings 中
# 我们预训练模型存放路径
pretrained_model_name_or_path = "/home/moyuai/moyuai/llm/Qwen/Qwen1___5-1___8B-Chat"

# 微调数据存放路径
alpaca_en_path = "/home/moyuai/moyuai/data/output.json"

# 训练中最大的文本长度
max_length = 512

# 每一批训练样本的大小,根据自己的硬件调整
batch_size = 4 

# 最大训练轮数
max_epochs = 3000 

# 验证数据
evaluation_inputs = [
    "男朋友给女主播刷火箭,算精神出轨吗?",
    "喝红酒养生,结果喝到头晕…",
    "闺蜜和我前任互关小红书,取关拉黑三连击!",
    "体检说胆固醇高,要戒炸鸡了吗?",
    "剧本杀遇读本玩家,直接摔门离场!",
    "领导周末发60秒语音矩阵,装没看见行吗?",
    ]

### 在 PART 2  Model & Tokenizer 中
# 将 qlora4 位关闭,开启八位
load_in_4bit=False,
load_in_8bit=True,
### 注释下面这些内容
# bnb_4bit_compute_dtype=torch.float16,
# bnb_4bit_use_double_quant=True,
# bnb_4bit_quant_type="nf4",

r=32,
lora_alpha=64, # 一般是r的两倍

### 在 PART 3  Dataset & Dataloader 中
dataset=dict(type=load_dataset, path="json",data_files=data_files),

dataset_map_fn=None,

参数设置完成后,我们进入我们复制的 qwen1_5_1_8b_chat_qlora_alpaca_e3.py 文件目录下,在当前目录下,输入以下命令启动微调脚本:

# 后台终端运行
nohup xtuner train qwen1_5_1_8b_chat_qlora_alpaca_e3.py > train_05_10_2.log 2>&1 &

#单机单卡
xtuner train qwen1_5_1_8b_chat_qlora_alpaca_e3.py
#单机多卡
NPROC_PER_NODE=${GPU_NUM} xtuner train qwen1_5_1_8b_chat_qlora_alpaca_e3 --deepspeed deepspeed_zero2

我们训练到20000轮后查看日志,发现模型效果还行,不管是生成的答案还有loss值都是不错的,20000轮训练的日志如下:

/home/moyuai/anaconda3/envs/xtuner-env/lib/python3.10/site-packages/bitsandbytes/autograd/_functions.py:315: UserWarning: MatMul8bitLt: inputs will be cast from torch.float32 to float16 during quantization

  warnings.warn(f"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization")

05/10 08:08:26 - mmengine - INFO - Iter(train) [ 19510/480000]  lr: 1.9994e-04  eta: 7 days, 19:52:24  time: 5.9691  data_time: 4.4355  memory: 9843  loss: 0.0042  grad_norm: 0.1220

05/10 08:08:38 - mmengine - INFO - Iter(train) [ 19520/480000]  lr: 1.9994e-04  eta: 7 days, 19:51:14  time: 1.2358  data_time: 0.0058  memory: 9844  loss: 0.0059  grad_norm: 0.1261

05/10 08:08:56 - mmengine - INFO - Iter(train) [ 19530/480000]  lr: 1.9994e-04  eta: 7 days, 19:52:00  time: 1.7250  data_time: 0.2061  memory: 9844  loss: 0.0046  grad_norm: 0.1261

05/10 08:09:09 - mmengine - INFO - Iter(train) [ 19540/480000]  lr: 1.9994e-04  eta: 7 days, 19:51:21  time: 1.3681  data_time: 0.0061  memory: 9844  loss: 0.0027  grad_norm: 0.1290

05/10 08:09:23 - mmengine - INFO - Iter(train) [ 19550/480000]  lr: 1.9994e-04  eta: 7 days, 19:50:37  time: 1.3440  data_time: 0.0062  memory: 9846  loss: 0.0046  grad_norm: 0.1290

05/10 08:09:38 - mmengine - INFO - Iter(train) [ 19560/480000]  lr: 1.9994e-04  eta: 7 days, 19:50:29  time: 1.4959  data_time: 0.0061  memory: 9846  loss: 0.0027  grad_norm: 0.1263

05/10 08:09:51 - mmengine - INFO - Iter(train) [ 19570/480000]  lr: 1.9994e-04  eta: 7 days, 19:49:43  time: 1.3353  data_time: 0.0060  memory: 9848  loss: 0.0030  grad_norm: 0.1242

05/10 08:10:06 - mmengine - INFO - Iter(train) [ 19580/480000]  lr: 1.9994e-04  eta: 7 days, 19:49:40  time: 1.5177  data_time: 0.0060  memory: 9843  loss: 0.0058  grad_norm: 0.1242

05/10 08:10:20 - mmengine - INFO - Iter(train) [ 19590/480000]  lr: 1.9994e-04  eta: 7 days, 19:48:59  time: 1.3554  data_time: 0.0060  memory: 9843  loss: 0.0034  grad_norm: 0.1254

05/10 08:10:33 - mmengine - INFO - Iter(train) [ 19600/480000]  lr: 1.9994e-04  eta: 7 days, 19:48:17  time: 1.3532  data_time: 0.0062  memory: 9842  loss: 0.0044  grad_norm: 0.1286

05/10 08:10:48 - mmengine - INFO - Iter(train) [ 19610/480000]  lr: 1.9994e-04  eta: 7 days, 19:48:07  time: 1.4891  data_time: 0.0060  memory: 9842  loss: 0.0037  grad_norm: 0.1286

05/10 08:11:01 - mmengine - INFO - Iter(train) [ 19620/480000]  lr: 1.9994e-04  eta: 7 days, 19:47:21  time: 1.3355  data_time: 0.0062  memory: 9844  loss: 0.0047  grad_norm: 0.1350

05/10 08:11:17 - mmengine - INFO - Iter(train) [ 19630/480000]  lr: 1.9994e-04  eta: 7 days, 19:47:17  time: 1.5144  data_time: 0.0061  memory: 9844  loss: 0.0040  grad_norm: 0.1350

05/10 08:11:30 - mmengine - INFO - Iter(train) [ 19640/480000]  lr: 1.9994e-04  eta: 7 days, 19:46:30  time: 1.3288  data_time: 0.0060  memory: 9844  loss: 0.0051  grad_norm: 0.1311

05/10 08:11:45 - mmengine - INFO - Iter(train) [ 19650/480000]  lr: 1.9994e-04  eta: 7 days, 19:46:25  time: 1.5094  data_time: 0.0061  memory: 9842  loss: 0.0069  grad_norm: 0.1381

05/10 08:11:58 - mmengine - INFO - Iter(train) [ 19660/480000]  lr: 1.9994e-04  eta: 7 days, 19:45:37  time: 1.3286  data_time: 0.0059  memory: 9845  loss: 0.0057  grad_norm: 0.1381

05/10 08:12:12 - mmengine - INFO - Iter(train) [ 19670/480000]  lr: 1.9994e-04  eta: 7 days, 19:44:57  time: 1.3595  data_time: 0.0059  memory: 9846  loss: 0.0056  grad_norm: 0.1427

05/10 08:12:26 - mmengine - INFO - Iter(train) [ 19680/480000]  lr: 1.9994e-04  eta: 7 days, 19:44:35  time: 1.4349  data_time: 0.0060  memory: 9843  loss: 0.0050  grad_norm: 0.1426

05/10 08:12:41 - mmengine - INFO - Iter(train) [ 19690/480000]  lr: 1.9994e-04  eta: 7 days, 19:44:29  time: 1.5082  data_time: 0.2061  memory: 9838  loss: 0.0034  grad_norm: 0.1426

05/10 08:12:57 - mmengine - INFO - Iter(train) [ 19700/480000]  lr: 1.9994e-04  eta: 7 days, 19:44:32  time: 1.5429  data_time: 0.0065  memory: 9844  loss: 0.0045  grad_norm: 0.1489

05/10 08:13:10 - mmengine - INFO - Iter(train) [ 19710/480000]  lr: 1.9994e-04  eta: 7 days, 19:43:45  time: 1.3281  data_time: 0.0060  memory: 9846  loss: 0.0038  grad_norm: 0.1489

05/10 08:13:23 - mmengine - INFO - Iter(train) [ 19720/480000]  lr: 1.9994e-04  eta: 7 days, 19:43:00  time: 1.3418  data_time: 0.0062  memory: 9844  loss: 0.0042  grad_norm: 0.1497

05/10 08:13:38 - mmengine - INFO - Iter(train) [ 19730/480000]  lr: 1.9994e-04  eta: 7 days, 19:42:53  time: 1.4978  data_time: 0.0059  memory: 9842  loss: 0.0038  grad_norm: 0.1507

05/10 08:13:52 - mmengine - INFO - Iter(train) [ 19740/480000]  lr: 1.9994e-04  eta: 7 days, 19:42:08  time: 1.3417  data_time: 0.0061  memory: 9842  loss: 0.0042  grad_norm: 0.1507

05/10 08:14:07 - mmengine - INFO - Iter(train) [ 19750/480000]  lr: 1.9993e-04  eta: 7 days, 19:42:03  time: 1.5066  data_time: 0.0060  memory: 9844  loss: 0.0052  grad_norm: 0.1487

05/10 08:14:20 - mmengine - INFO - Iter(train) [ 19760/480000]  lr: 1.9993e-04  eta: 7 days, 19:41:21  time: 1.3511  data_time: 0.0059  memory: 9844  loss: 0.0040  grad_norm: 0.1451

05/10 08:14:36 - mmengine - INFO - Iter(train) [ 19770/480000]  lr: 1.9993e-04  eta: 7 days, 19:41:16  time: 1.5123  data_time: 0.0060  memory: 9845  loss: 0.0062  grad_norm: 0.1451

05/10 08:14:49 - mmengine - INFO - Iter(train) [ 19780/480000]  lr: 1.9993e-04  eta: 7 days, 19:40:25  time: 1.3094  data_time: 0.0059  memory: 9845  loss: 0.0051  grad_norm: 0.1468

05/10 08:15:02 - mmengine - INFO - Iter(train) [ 19790/480000]  lr: 1.9993e-04  eta: 7 days, 19:39:45  time: 1.3595  data_time: 0.0061  memory: 9839  loss: 0.0049  grad_norm: 0.1468

05/10 08:15:18 - mmengine - INFO - Iter(train) [ 19800/480000]  lr: 1.9993e-04  eta: 7 days, 19:39:59  time: 1.5899  data_time: 0.0062  memory: 9843  loss: 0.0040  grad_norm: 0.1471

05/10 08:15:32 - mmengine - INFO - Iter(train) [ 19810/480000]  lr: 1.9993e-04  eta: 7 days, 19:39:20  time: 1.3668  data_time: 0.0060  memory: 9844  loss: 0.0057  grad_norm: 0.1426

05/10 08:15:47 - mmengine - INFO - Iter(train) [ 19820/480000]  lr: 1.9993e-04  eta: 7 days, 19:39:18  time: 1.5212  data_time: 0.0060  memory: 9840  loss: 0.0058  grad_norm: 0.1426

05/10 08:16:01 - mmengine - INFO - Iter(train) [ 19830/480000]  lr: 1.9993e-04  eta: 7 days, 19:38:37  time: 1.3542  data_time: 0.0060  memory: 9843  loss: 0.0073  grad_norm: 0.1395

05/10 08:16:13 - mmengine - INFO - Iter(train) [ 19840/480000]  lr: 1.9993e-04  eta: 7 days, 19:37:37  time: 1.2729  data_time: 0.0059  memory: 9841  loss: 0.0033  grad_norm: 0.1366

05/10 08:16:31 - mmengine - INFO - Iter(train) [ 19850/480000]  lr: 1.9993e-04  eta: 7 days, 19:38:28  time: 1.7519  data_time: 0.4193  memory: 9845  loss: 0.0041  grad_norm: 0.1366

05/10 08:16:44 - mmengine - INFO - Iter(train) [ 19860/480000]  lr: 1.9993e-04  eta: 7 days, 19:37:29  time: 1.2766  data_time: 0.0057  memory: 9845  loss: 0.0034  grad_norm: 0.1269

05/10 08:17:00 - mmengine - INFO - Iter(train) [ 19870/480000]  lr: 1.9993e-04  eta: 7 days, 19:37:56  time: 1.6473  data_time: 0.0062  memory: 9841  loss: 0.0069  grad_norm: 0.1269

05/10 08:17:13 - mmengine - INFO - Iter(train) [ 19880/480000]  lr: 1.9993e-04  eta: 7 days, 19:36:54  time: 1.2646  data_time: 0.0058  memory: 9848  loss: 0.0044  grad_norm: 0.1299

05/10 08:17:28 - mmengine - INFO - Iter(train) [ 19890/480000]  lr: 1.9993e-04  eta: 7 days, 19:36:57  time: 1.5415  data_time: 0.0059  memory: 9847  loss: 0.0063  grad_norm: 0.1367

05/10 08:17:41 - mmengine - INFO - Iter(train) [ 19900/480000]  lr: 1.9993e-04  eta: 7 days, 19:36:03  time: 1.2998  data_time: 0.0059  memory: 9846  loss: 0.0054  grad_norm: 0.1367

05/10 08:17:54 - mmengine - INFO - Iter(train) [ 19910/480000]  lr: 1.9993e-04  eta: 7 days, 19:35:12  time: 1.3098  data_time: 0.0059  memory: 9846  loss: 0.0035  grad_norm: 0.1363

05/10 08:18:11 - mmengine - INFO - Iter(train) [ 19920/480000]  lr: 1.9993e-04  eta: 7 days, 19:35:38  time: 1.6464  data_time: 0.0062  memory: 9845  loss: 0.0068  grad_norm: 0.1385

05/10 08:18:23 - mmengine - INFO - Iter(train) [ 19930/480000]  lr: 1.9993e-04  eta: 7 days, 19:34:39  time: 1.2759  data_time: 0.0056  memory: 9842  loss: 0.0038  grad_norm: 0.1385

05/10 08:18:39 - mmengine - INFO - Iter(train) [ 19940/480000]  lr: 1.9993e-04  eta: 7 days, 19:34:57  time: 1.6086  data_time: 0.0065  memory: 9844  loss: 0.0065  grad_norm: 0.1295

05/10 08:18:52 - mmengine - INFO - Iter(train) [ 19950/480000]  lr: 1.9993e-04  eta: 7 days, 19:34:02  time: 1.2938  data_time: 0.0059  memory: 9844  loss: 0.0047  grad_norm: 0.1295

05/10 08:19:05 - mmengine - INFO - Iter(train) [ 19960/480000]  lr: 1.9993e-04  eta: 7 days, 19:33:09  time: 1.2996  data_time: 0.0059  memory: 9845  loss: 0.0042  grad_norm: 0.1354

05/10 08:19:22 - mmengine - INFO - Iter(train) [ 19970/480000]  lr: 1.9993e-04  eta: 7 days, 19:33:27  time: 1.6098  data_time: 0.0061  memory: 9844  loss: 0.0043  grad_norm: 0.1308

05/10 08:19:34 - mmengine - INFO - Iter(train) [ 19980/480000]  lr: 1.9993e-04  eta: 7 days, 19:32:27  time: 1.2702  data_time: 0.0056  memory: 9845  loss: 0.0047  grad_norm: 0.1308

05/10 08:19:50 - mmengine - INFO - Iter(train) [ 19990/480000]  lr: 1.9993e-04  eta: 7 days, 19:32:37  time: 1.5734  data_time: 0.0059  memory: 9846  loss: 0.0048  grad_norm: 0.1311

05/10 08:20:02 - mmengine - INFO - Exp name: qwen1_5_1_8b_chat_qlora_alpaca_e3_20250510_001010

05/10 08:20:02 - mmengine - INFO - Iter(train) [ 20000/480000]  lr: 1.9993e-04  eta: 7 days, 19:31:29  time: 1.2366  data_time: 0.0058  memory: 9844  loss: 0.0064  grad_norm: 0.1372

05/10 08:20:02 - mmengine - INFO - after_train_iter in EvaluateChatHook.

05/10 08:20:10 - mmengine - INFO - Sample output:
<|im_start|>system
Below is an instruction that describes a task. Write a response that appropriately completes the request.
<|im_end|>
<|im_start|>user
男朋友给女主播刷火箭,算精神出轨吗?<|im_end|>
<|im_start|>assistant
温柔
哎呀,这种事情确实让人有些难过呢。男票给女主播刷火箭可能只是单纯的游戏娱乐,不应该成为精神出轨的定义。你们之间应该及时沟通,表达一下你的感受,了解他为什么会做出这样的行为。同时,也可以谈谈其他的,分散一下注意力,让彼此之间的关系更加稳固。💖
如果需要倾诉或者寻求更多建议,我在这里哦。🌸<|im_end|>

05/10 08:20:14 - mmengine - INFO - Sample output:
<|im_start|>system
Below is an instruction that describes a task. Write a response that appropriately completes the request.
<|im_end|>
<|im_start|>user
洗牙后牙缝漏风,吹蜡烛像鼓风机!<|im_end|>
<|im_start|>assistant
毒舌
栓Q,你这牙缝大得都能塞进一本《新华字典》了。下次洗牙记得提醒牙医手下留情,别让你的牙缝比得上喜马拉雅山的大峡谷啊!🦷🌊<|im_end|>

05/10 08:20:19 - mmengine - INFO - Sample output:
<|im_start|>system
Below is an instruction that describes a task. Write a response that appropriately completes the request.
<|im_end|>
<|im_start|>user
喝红酒养生,结果喝到头晕…<|im_end|>
<|im_start|>assistant
温柔
哎呀,饮酒要适量呢,对身体不好哦。现在感觉好点了吗?要不要休息一下呢?💖<|im_end|>

05/10 08:20:27 - mmengine - INFO - Sample output:
<|im_start|>system
Below is an instruction that describes a task. Write a response that appropriately completes the request.
<|im_end|>
<|im_start|>user
闺蜜和我前任互关小红书,取关拉黑三连击!<|im_end|>
<|im_start|>assistant
栓Q,这操作简直比双十一抢购还刺激,直接给你来个情感和网络界的地震级爆炸音!不过别太在意,友谊的小船不能翻,但至少能让你们的对话多一点“特殊”的元素,这就是生活的一部分嘛。别让这种小事影响了你们之间的闺蜜和前任,毕竟,真正的友情不是靠取关和拉黑来考验的。👩‍💼💔<|im_end|>

05/10 08:20:34 - mmengine - INFO - Sample output:
<|im_start|>system
Below is an instruction that describes a task. Write a response that appropriately completes the request.
<|im_end|>
<|im_start|>user
体检说胆固醇高,要戒炸鸡了吗?<|im_end|>
<|im_start|>assistant
温柔
哎呀,这确实需要注意一下呢。胆固醇高的话,适量吃一些高脂肪的食物是可以的,但是要注意均衡饮食哦。可以尝试吃一些健康的替代品,比如烤鸡或者蒸鸡,这样既能满足口感,又不会摄入太多油脂。💖 而且,戒掉炸鸡也可能会影响食欲,记得多喝水,保持良好的饮食习惯哦。🌸 如果需要更多建议或者想要聊聊其他话题,我都在这里呢。😊<|im_end|>

05/10 08:20:38 - mmengine - INFO - Sample output:
<|im_start|>system
Below is an instruction that describes a task. Write a response that appropriately completes the request.
<|im_end|>
<|im_start|>user
剧本杀遇读本玩家,直接摔门离场!<|im_end|>
<|im_start|>assistant
毒舌
栓Q,这是遇到了“剧本杀界的扫地僧”吗?📖 下次记得提前调查玩家背景,以免再遇到这种“退退退”的剧情杀“退退退”!<|im_end|>

05/10 08:20:42 - mmengine - INFO - Sample output:
<|im_start|>system
Below is an instruction that describes a task. Write a response that appropriately completes the request.
<|im_end|>
<|im_start|>user
被同事说'你今天的眼线像苍蝇腿',气到晕厥!<|im_end|>
<|im_start|>assistant
毒舌
栓Q了,你这同事是自带放大镜上班的吗?不过别太往心里去,你的眼线可能是想开个开放日,举办“恐怖节”。下次记得,找美容师,别让那些不懂欣赏的家伙影响了你的造型。💄😱<|im_end|>

05/10 08:20:47 - mmengine - INFO - Sample output:
<|im_start|>system
Below is an instruction that describes a task. Write a response that appropriately completes the request.
<|im_end|>
<|im_start|>user
领导周末发60秒语音矩阵,装没看见行吗?<|im_end|>
<|im_start|>assistant
温柔
哎呀,这种事情确实会让人感到有些失落呢。领导的指示很重要,及时回应和表达自己的立场是很重要的。你可以选择看到并且回复,或者可以想想有没有其他方式表达你的感受。💖你觉得呢?<|im_end|>

05/10 08:20:47 - mmengine - INFO - Saving checkpoint at 20000 iterations

5.3 模型转换合并

模型转换

模型训练后会自动保存成 PTH 模型(例如 iter_2000.pth ,如果使用了 DeepSpeed,则将会是一个文件夹),我们需要利用 xtuner convert pth_to_hf 将其转换为 HuggingFace 模型,以便于后续使用。具体命令为:

xtuner convert pth_to_hf ${FINETUNE_CFG} ${PTH_PATH} ${SAVE_PATH}

xtuner convert pth_to_hf qwen1_5_1_8b_chat_qlora_alpaca_e3.py /home/moyuai/moyuai/xtuner_out/work_dirs/qwen1_5_1_8b_chat_qlora_alpaca_e3/iter_7000.pth /home/moyuai/moyuai/xtuner_out/iter_7000_hf

# 例如:以我本地为例
xtuner convert pth_to_hf qwen1_5_1_8b_chat_qlora_alpaca_e3.py /home/moyuai/moyuai/python/work_dirs/qwen1_5_1_8b_chat_qlora_alpaca_e3/iter_20000.pth /home/moyuai/moyuai/python/work_dirs/qwen1_5_1_8b_chat_qlora_alpaca_e3/iter_20000_hf
  • FINETUNE_CFG:填我们复制的 py 配置文件路径
  • PTH_PATH:填我们输出的LoRA权重路径
  • SAVE_PATH:填我们要保存的路径

模型合并

如果使用了 LoRA / QLoRA 微调,则模型转换后将得到 adapter 参数,而并不包含原 LLM 参数。如果您期望获得合并后的模型权重(例如用于后续评测),那么可以利用 xtuner convert merge

xtuner convert merge ${LLM} ${LLM_ADAPTER} ${SAVE_PATH}

# 例如:以我本地为例
xtuner convert merge /home/moyuai/moyuai/llm/Qwen/Qwen1___5-1___8B-Chat /home/moyuai/moyuai/xtuner_out/iter_7000_hf /home/moyuai/moyuai/xtuner_out/Qwen1-5-1-8B-Chat-xtuner-merged-7000
  • LLM:填写我们基座模型路径
  • LLM_ADAPTER:填写我们转换过后的权重路径
  • SAVE_PATH:填写我们要保存的模型路径

六、模型部署与应用

6.1 选择合适的大模型推理框架

选择合适的大模型推理框架部署模型(这里选择LMDeploy)
但是我们要注意我们微调出来的模型和LMDeploy支持的对话模板不同,因此我们要进行对话模板对齐。

我们配好LMDeploy环境。

6.1.1 LMDeploy支持的对话模板的形式

LMDeploy 支持两种添加对话模板的形式:

  • 一种是利用现有对话模板,直接配置一个如下的 json 文件使用。
{
   "model_name": "your awesome chat template name",
   "system": "<|im_start|>system\n",
   "meta_instruction": "You are a robot developed by LMDeploy.",    "eosys": "<|im_end|>\n",
   "user": "<|im_start|>user\n",
   "eoh": "<|im_end|>\n",
   "assistant": "<|im_start|>assistant\n",
   "eoa": "<|im_end|>",
   "separator": "\n",
   "capability": "chat",
   "stop_words": ["<|im_end|>"]
}

model_name 为必填项,可以是 LMDeploy 内置对话模板名(通过 lmdeploy list 可查阅),也可以是新名字。其他字段可选填。 当 model_name 是内置对话模板名时,json文件中各非 null 字段会覆盖原有对话模板的对应属性。 而当 model_name 是新名字时,它会把将BaseChatTemplate 直接注册成新的对话模板。
其具体定义可以参考BaseChatTemplate。这样一个模板将会以下面的形式进行拼接。

{system}{meta_instruction}{eosys}{user}{user_content}{eoh}{assistant} {assistant_content}{eoa}{separator}{user}...

在使用 CLI 工具时,可以通过 --chat-template 传入自定义对话模板,比如:

lmdeploy serve api_server internlm/internlm2_5-7b-chat --chat-template ${JSON_FILE}

也可以在通过接口函数传入,比如:

from lmdeploy import ChatTemplateConfig, serve
serve('internlm/internlm2_5-7b-chat',
     chat_template_config=ChatTemplateConfig.from_json('${JSON_FILE}'))
  • 另一种是以 LMDeploy 现有对话模板,自定义一个python对话模板类,注册成功后直接用即可。优点是自定义程度高,可控性强。 下面是一个注册 LMDeploy 对话模板的例子:
from lmdeploy.model import MODELS, BaseChatTemplate
@MODELS.register_module(name='customized_model') class CustomizedModel(BaseChatTemplate):
   """A customized chat template."""
   def __init__(self,
                system='<|im_start|>system\n',
                meta_instruction='You are a robot developed by LMDeploy.',                 user='<|im_start|>user\n',
                assistant='<|im_start|>assistant\n',
                eosys='<|im_end|>\n',
                eoh='<|im_end|>\n',
                eoa='<|im_end|>',
                separator='\n',
                stop_words=['<|im_end|>', '<|action_end|>']):
       super().__init__(system=system,
                        meta_instruction=meta_instruction,
                        eosys=eosys,
                        user=user,
                        eoh=eoh,
                        assistant=assistant,
                        eoa=eoa,
                        separator=separator,
                        stop_words=stop_words)

from lmdeploy import ChatTemplateConfig, pipeline
messages = [{'role': 'user', 'content': 'who are you?'}]
pipe = pipeline('internlm/internlm2_5-7b-chat',
               chat_template_config=ChatTemplateConfig('customized_model')) for response in pipe.stream_infer(messages):
   print(response.text, end='')

这里我们选用CLI 工具推理,可以通过 –chat-template 传入自定义对话模板:

lmdeploy serve api_server internlm/internlm2_5-7b-chat --chat-template ${JSON_FILE}

6.1.2 本项目使用模型的对话模板转换

我们需要先找到xtuner目录下的xtuner/xtuner/utils/templates.py文件,然后搜索字段
qwen_chat 如下图:

qwen_chat字段类型

我们将字典内容拿过来:

    qwen_chat=dict(
        SYSTEM=("<|im_start|>system\n{system}<|im_end|>\n"),
        INSTRUCTION=("<|im_start|>user\n{input}<|im_end|>\n" "<|im_start|>assistant\n"),
        SUFFIX="<|im_end|>",
        SUFFIX_AS_EOS=True,
        SEP="\n",
        STOP_WORDS=["<|im_end|>", "<|endoftext|>"],
    )

字典中的内容就是我们训练时的对话模板,现在我们需要将上面对话模板格式转换为LMDeploy支持的对话模板格式,这一步可以交给大模型帮我们完成。

下面是让大模型帮我们写的对话模板转换脚本:

import re
import json
from typing import Dict, Any


def universal_converter(original_template: Dict[str, Any]) -> Dict[str, Any]:
    """将多种风格的原始模板转换为lmdeploy官方格式"""

    # 字段映射关系(核心逻辑)
    field_mapping = {
        # 基础字段映射
        "SYSTEM": "system",
        "INSTRUCTION": ("user", "assistant"),  # 需要拆分处理
        "SUFFIX": "eoa",
        "SEP": "separator",
        "STOP_WORDS": "stop_words",

        # 特殊处理字段
        "SUFFIX_AS_EOS": None,  # 该字段在官方模板中不需要
    }

    # 初始化目标模板(包含必填字段默认值)
    converted = {
        "meta_instruction": "You are a helpful assistant.",  # 必填项
        "capability": "chat",  # 必填项
        "eosys": "<|im_end|>\n",  # 通常固定格式
        "eoh": "<|im_end|>\n",  # 通常固定格式
    }

    # 自动处理字段映射
    for src_key, dest_key in field_mapping.items():
        if src_key in original_template:
            value = original_template[src_key]

            # 处理需要拆分的字段(如INSTRUCTION)
            if isinstance(dest_key, tuple) and src_key == "INSTRUCTION":
                # 使用正则拆分user和assistant部分
                parts = re.split(r'(<\|im_start\|>assistant\n?)', value)
                converted["user"] = parts[0].strip()
                if len(parts) > 1:
                    converted["assistant"] = parts[1] + parts[2] if len(parts) > 2 else parts[1]

            # 处理直接映射字段
            elif dest_key and not isinstance(dest_key, tuple):
                converted[dest_key] = value

    # 特殊处理system字段的占位符
    if "system" in converted:
        converted["system"] = converted["system"].replace("{system}", "{{ system }}")

    # 处理用户输入占位符
    if "user" in converted:
        converted["user"] = converted["user"].replace("{input}", "{{ input }}")

    # 自动处理停止词(兼容列表和字符串)
    if "stop_words" in converted and isinstance(converted["stop_words"], str):
        converted["stop_words"] = [converted["stop_words"]]

    # 保留原始模板中的额外字段(带警告)
    for key in original_template:
        if key not in field_mapping:
            print(f"Warning: 发现未映射字段 [{key}],已保留原样")
            converted[key] = original_template[key]

    return converted


# 示例用法
original_qwen_chat = dict(
SYSTEM=("<|im_start|>system\n{system}<|im_end|>\n"),  
INSTRUCTION=("<|im_start|>user\n{input}<|im_end|>\n" "<|im_start|>assistant\n"), 
SUFFIX="<|im_end|>",  
SUFFIX_AS_EOS=True,  
SEP="\n",  
STOP_WORDS=["<|im_end|>", "<|endoftext|>"]
)

# 执行转换
converted_template = universal_converter(original_qwen_chat)

# 生成JSON文件
with open('chat_template.json', 'w') as f:
    json.dump(converted_template, f,
              indent=2,
              ensure_ascii=False,
              separators=(',', ': '))

运行代码后生成的内容:

{
  "meta_instruction": "You are a helpful assistant.",
  "capability": "chat",
  "eosys": "<|im_end|>\n",
  "eoh": "<|im_end|>\n",
  "system": "<|im_start|>system\n{{ system }}<|im_end|>\n",
  "user": "<|im_start|>user\n{{ input }}<|im_end|>",
  "assistant": "<|im_start|>assistant\n",
  "eoa": "<|im_end|>",
  "separator": "\n",
  "stop_words": [
    "<|im_end|>",
    "<|endoftext|>"
  ]
}

进入LMDeploy环境,执行命令:

lmdeploy serve api_server internlm/internlm2_5-7b-chat --chat-template ${JSON_FILE}

# 我本地命令
lmdeploy serve api_server /home/moyuai/moyuai/xtuner_out/Qwen1-5-1-8B-Chat-xtuner-merged-7000 --chat-template /home/moyuai/moyuai/xtuner_out/chat_template.json

执行命令后如果出现以下内容,那么可以判断我们的自定义对话模板基本没有问题:

[WARNING] gemm_config.in is not found; using default GEMM algo             
HINT:    Please open http://0.0.0.0:23333 in a browser for detailed api usage!!!
HINT:    Please open http://0.0.0.0:23333 in a browser for detailed api usage!!!
HINT:    Please open http://0.0.0.0:23333 in a browser for detailed api usage!!!
INFO:     Started server process [239994]
INFO:     Waiting for application startup.
INFO:     Application startup complete.
INFO:     Uvicorn running on http://0.0.0.0:23333 (Press CTRL+C to quit)

在使用一个简单的openai对话模板调用一下我们的服务进行进一步测试,代码如下:

#多轮对话
from openai import OpenAI

#定义多轮对话方法
def run_chat_session():
    #初始化客户端
    client = OpenAI(base_url="http://localhost:23333/v1/",api_key="ismoyuai")
    #初始化对话历史
    chat_history = []

    #启动对话循环
    while True:
        #获取用户输入
        user_input = input("用户:")
        if user_input.lower() == "exit":
            print("退出对话。")
            break
        #更新对话历史(添加用户输入)
        chat_history.append({"role":"user","content":user_input})
        #调用模型回答
        try:
            chat_complition = client.chat.completions.create(
                messages=chat_history,
                model="/home/moyuai/moyuai/xtuner_out/Qwen1-5-1-8B-Chat-xtuner-merged-10000"
                )
            #获取最新回答
            model_response = chat_complition.choices[0]
            print("AI:",model_response.message.content)
            #更新对话历史(添加AI模型的回复)
            chat_history.append({"role":"assistant","content":model_response.message.content})
        except Exception as e:
            print("发生错误:",e)
            break
if __name__ == '__main__':
    run_chat_session()

运行后我们输入问题,出现报错,后续发现是自定义对话模板有问题,因为我们不加载对话模板则,能够正常进行对话。

6.2 模型测试评估

下面我们不用自定义对话模板来进行模型效果测试,这里我们用一个叫streamlit的大模型前端框架来测试我们的模型效果。
安装streamlit

pip install streamlit

简单编写一个streamlit前端页面,新建一个chat_app.py文件,文件代码如下:

import streamlit as st
from openai import OpenAI

# 初始化客户端
client = OpenAI(base_url="http://localhost:23333/v1/", api_key="ismoyuai")

# 设置页面标题
st.title("moyuai-Chat")
st.markdown("## 这是一个基于Qwen-1.5-1.8B的情感聊天机器人")

# 初始化session状态(仅用于显示历史)
if "messages" not in st.session_state:
    st.session_state.messages = []

# 显示历史消息
for message in st.session_state.messages:
    with st.chat_message(message["role"]):
        st.markdown(message["content"])

# 获取用户输入
if prompt := st.chat_input("请输入您的问题,或输入exit退出"):
    # 处理退出命令
    if prompt.lower() == "exit":
        st.info("退出对话。")
        st.stop()
    # 添加用户消息到显示历史
    st.session_state.messages.append({"role": "user", "content": prompt})
    with st.chat_message("user"):
        st.markdown(prompt)

    try:
        # 发起API请求(每次只发送当前消息)
        response = client.chat.completions.create(
            messages=[{"role": "user", "content": prompt}],  # 每次只发送当前问题
            model="/home/moyuai/moyuai/xtuner_out/Qwen1-5-1-8B-Chat-xtuner-merged-10000"
        )

        # 获取模型回复
        model_response = response.choices[0].message.content
        # 添加AI回复到显示历史
        st.session_state.messages.append({"role": "assistant", "content": model_response})
        with st.chat_message("assistant"):
            st.markdown(model_response)

    except Exception as e:

        st.error(f"发生错误:{e}")

启动streamlit前端服务,服务启动前一定要保证我们的LMDeploy服务是启动的:

streamlit run chat_app.py

启动后界面如下:

streamlit前端页面

我们进行问题测试:
模型测试主观效果

可以看到,我们模型训练生成的回答是按照我们的风格来的,说明模型训练有效果,但是我用同一个问题问了好几次,模型给我们的回答都是一样的,这说明模型训练的过拟合了,泛用性很差,这次训练也是以失败告结的,其实我在看训练日志时,发现到后面模型的loss值非常低,差不多后面一直在0.0044,这个loss值算比较低的了,就是快要过拟合的状态了;而在训练到差不多10000批次左右,模型就已经差不多是拟合状态了,但是由于我是在晚上挂机进行训练的,而在当初设置xtenur训练参数时,只选择保存最后两个训练权重,导致前面训练的权重不存在;所以只能后续在重新训练,在来看一下模型训练效果。

6.3 模型测试评估(重新训练版)

重新训练后,在训练到60007000轮左右时,这时训练的loss值控制在了0.040.03左右,可以说是比较不错,所以就没有继续训练,这时我们将模型转换合并后,重新部署,下面就是重新测试对话效果的图片,看的出来模型这次对同一个问题没有出现完全一样的回答了,这次训练算是比较成功的了。

模型重新测试评估效果

七、总结

1.项目成果总结
本次项目总体来讲是成功的了,虽然还有不少地方可以完善,但也算是第一次完整的从0-1完成关于大模型微调项目的整个流程,这才是最重要的,特别是在做项目的过程中遇到的问题和解决问题的过程是最宝贵的经验。

2.挑战与解决方案回顾
本次项目完成后在回顾,重要的点在于前期的数据收集部分,一份好的质量高的数据集是非常重要的,我在收集数据方面也是花了一些心思,这样在后续的数据生成过程中才能有一个好的前提。
项目过程中遇到的问题其实其他的都还好,主要花时间多一点的地方就是在于一些框架的环境配置方面,需要特别注意一些依赖包的版本,不太适合太新的版本。
还有一个就是最后模型部署方面,在自定义对话模板这里遇到的问题一直没有解决,虽然最后达到的效果是好的,后续
3. 未来方向(语言播放,集成开发板)
这个项目本来就是根据抖音最近比较火的小智机器人想做的,后续可能就是在对话大模型结合语音方面已经部署到 开发板上面来进行研究,这也算是我可能可以做到的,因为自己大学专业学的就是单片机嵌入式开发等等,有软硬结合的底子在,我个人也是比较有需求的。

八、附录与参考资料

1.代码仓库与工具链

  • 项目GitHub仓库:

2.数据集链接

九、问题记录

1.报settings.yaml配置文件找不到

报错信息:

配置文件缺失:[Errno 2] No such file or directory: 'E:\\xuexiziliao\\AiProject\\emotion_dialogue_tuner\\src\\config\\settings.yaml'

解决方法:
是因为在config_loader.py文件中,路径层级少配置一层

# 修改前
self.root_path = Path(__file__).resolve().parent.parent

# 修改后
self.root_path = Path(__file__).resolve().parent.parent.parent  # 根据实际层级调整

2.文件编码格式加载不正确

报错信息:

UnicodeDecodeError: 'gbk' codec can't decode byte 0xa6 in position 94: illegal multibyte sequence

解决方法:
报错原因是因为在config_loader.py文件中的load_settings方法没有指定yaml文件的打开格式。在修改所有文件读取操作,添加encoding='utf-8'参数即可:

# 修改前代码
def load_settings(self) -> dict:  
    """加载YAML格式的全局设置  
    Returns:        dict: 包含API密钥、模型路径等配置的字典  
    """    with open(self.root_path / "config/settings.yaml", "r") as f:  
        return yaml.safe_load(f)

# 修改后代码
def load_settings(self) -> dict:  
    """加载YAML格式的全局设置  
    Returns:        dict: 包含API密钥、模型路径等配置的字典  
    """    with open(self.root_path / "config/settings.yaml", "r", encoding="utf-8") as f:  
        return yaml.safe_load(f)

3.生成的数据有些是完全一样的

查看生成数据时,发现有些数据是完全一样的,出现了重复,这就导致了数据质量不高。

解决方法:可以自己测试调整判断阈值,或者重新添加新的规则来判断去重。

4.使用Xtuner进行模型qlora模型微调的时候,报错 No module named ‘triton.ops’

执行命令:

xtuner train qwen1_5_1_8b_chat_qlora_alpaca_e3.py

问题记录:

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/moyuai/moyuai/xtuner/xtuner/tools/train.py", line 392, in <module>
    main()
  File "/home/moyuai/moyuai/xtuner/xtuner/tools/train.py", line 381, in main
    runner = Runner.from_cfg(cfg)
  File "/home/moyuai/anaconda3/envs/xtuner_env/lib/python3.10/site-packages/mmengine/runner/runner.py", line 462, in from_cfg
    runner = cls(
  File "/home/moyuai/anaconda3/envs/xtuner_env/lib/python3.10/site-packages/mmengine/runner/runner.py", line 429, in __init__
    self.model = self.build_model(model)
  File "/home/moyuai/anaconda3/envs/xtuner_env/lib/python3.10/site-packages/mmengine/runner/runner.py", line 836, in build_model
    model = MODELS.build(model)
  File "/home/moyuai/anaconda3/envs/xtuner_env/lib/python3.10/site-packages/mmengine/registry/registry.py", line 570, in build
    return self.build_func(cfg, *args, **kwargs, registry=self)
  File "/home/moyuai/anaconda3/envs/xtuner_env/lib/python3.10/site-packages/mmengine/registry/build_functions.py", line 234, in build_model_from_cfg
    return build_from_cfg(cfg, registry, default_args)
  File "/home/moyuai/anaconda3/envs/xtuner_env/lib/python3.10/site-packages/mmengine/registry/build_functions.py", line 123, in build_from_cfg
    obj = obj_cls(**args)  # type: ignore
  File "/home/moyuai/moyuai/xtuner/xtuner/model/sft.py", line 97, in __init__
    self.llm = self.build_llm_from_cfg(
  File "/home/moyuai/moyuai/xtuner/xtuner/model/sft.py", line 143, in build_llm_from_cfg
    llm = self._build_from_cfg_or_module(llm)
  File "/home/moyuai/moyuai/xtuner/xtuner/model/sft.py", line 296, in _build_from_cfg_or_module
    return BUILDER.build(cfg_or_mod)
  File "/home/moyuai/anaconda3/envs/xtuner_env/lib/python3.10/site-packages/mmengine/registry/registry.py", line 570, in build
    return self.build_func(cfg, *args, **kwargs, registry=self)
  File "/home/moyuai/anaconda3/envs/xtuner_env/lib/python3.10/site-packages/mmengine/registry/build_functions.py", line 123, in build_from_cfg
    obj = obj_cls(**args)  # type: ignore
  File "/home/moyuai/anaconda3/envs/xtuner_env/lib/python3.10/site-packages/transformers/models/auto/auto_factory.py", line 564, in from_pretrained
    return model_class.from_pretrained(
  File "/home/moyuai/anaconda3/envs/xtuner_env/lib/python3.10/site-packages/transformers/modeling_utils.py", line 3620, in from_pretrained
    hf_quantizer.validate_environment(
  File "/home/moyuai/anaconda3/envs/xtuner_env/lib/python3.10/site-packages/transformers/quantizers/quantizer_bnb_8bit.py", line 77, in validate_environment
    from ..integrations import validate_bnb_backend_availability
  File "<frozen importlib._bootstrap>", line 1075, in _handle_fromlist
  File "/home/moyuai/anaconda3/envs/xtuner_env/lib/python3.10/site-packages/transformers/utils/import_utils.py", line 1805, in __getattr__
    module = self._get_module(self._class_to_module[name])
  File "/home/moyuai/anaconda3/envs/xtuner_env/lib/python3.10/site-packages/transformers/utils/import_utils.py", line 1819, in _get_module
    raise RuntimeError(
RuntimeError: Failed to import transformers.integrations.bitsandbytes because of the following error (look up to see its traceback):
No module named 'triton.ops'

问题分析:
这是在尝试使用 BitsAndBytes(8-bit 量化) 加载模型时发生的,我们查看一下 bitsandbytes 库是否正确安装。

可能是版本不对,试着重装一下后还是报错,最后查找资料发现是torch2.6以上版本和bitstandbytes版本的冲突问题。

安装低版本pytorch==2.5.1torchvision==0.20.1 ,我们进入xtuner目录 下的这个 xtuner/requirements/runtime.txt文件里面修改一下torch版本

修改环境

修改过后还是有问题,这次报错信息如下:

Traceback (most recent call last): File "/home/moyuai/anaconda3/envs/xtuner-env/bin/xtuner", line 33, in <module> sys.exit(load_entry_point('xtuner', 'console_scripts', 'xtuner')()) File "/home/moyuai/anaconda3/envs/xtuner-env/bin/xtuner", line 25, in importlib_load_entry_point return next(matches).load() File "/home/moyuai/anaconda3/envs/xtuner-env/lib/python3.10/importlib/metadata/__init__.py", line 171, in load module = import_module(match.group('module')) File "/home/moyuai/anaconda3/envs/xtuner-env/lib/python3.10/importlib/__init__.py", line 126, in import_module return _bootstrap._gcd_import(name[level:], package, level) File "<frozen importlib._bootstrap>", line 1050, in _gcd_import File "<frozen importlib._bootstrap>", line 1027, in _find_and_load File "<frozen importlib._bootstrap>", line 1006, in _find_and_load_unlocked File "<frozen importlib._bootstrap>", line 688, in _load_unlocked File "<frozen importlib._bootstrap_external>", line 883, in exec_module File "<frozen importlib._bootstrap>", line 241, in _call_with_frames_removed File "/home/moyuai/moyuai/xtuner/xtuner/__init__.py", line 4, in <module> from mmengine.utils import digit_version File "/home/moyuai/anaconda3/envs/xtuner-env/lib/python3.10/site-packages/mmengine/__init__.py", line 6, in <module> from .registry import * File "/home/moyuai/anaconda3/envs/xtuner-env/lib/python3.10/site-packages/mmengine/registry/__init__.py", line 2, in <module> from .build_functions import (build_from_cfg, build_model_from_cfg, File "/home/moyuai/anaconda3/envs/xtuner-env/lib/python3.10/site-packages/mmengine/registry/build_functions.py", line 6, in <module> import torch File "/home/moyuai/anaconda3/envs/xtuner-env/lib/python3.10/site-packages/torch/__init__.py", line 367, in <module> from torch._C import * # noqa: F403 ImportError: /home/moyuai/anaconda3/envs/xtuner-env/lib/python3.10/site-packages/torch/lib/../../nvidia/cusparse/lib/libcusparse.so.12: undefined symbol: __nvJitLinkComplete_12_4, version libnvJitLink.so.12

推测还是因为torch版本和CUDA版本导致的,这次重新将torch版本和CUDA版本进行修改,进行如下操作

(1)完全卸载当前PyTorch和CUDA工具链

# 删除conda环境(确保已退出环境)
conda deactivate
conda remove -n xtuner-env --all -y

# 清除残留的CUDA软链接
sudo rm -f /usr/local/cuda

(2)安装匹配的CUDA 12.1工具包

### **安装匹配的CUDA 12.1工具包**
# 从NVIDIA官网下载CUDA 12.1.1
wget https://developer.download.nvidia.com/compute/cuda/12.1.1/local_installers/cuda_12.1.1_530.30.02_linux.run
sudo sh cuda_12.1.1_530.30.02_linux.run --override

# 配置环境变量
echo 'export PATH=/usr/local/cuda-12.1/bin:$PATH' >> ~/.bashrc
echo 'export LD_LIBRARY_PATH=/usr/local/cuda-12.1/lib64:$LD_LIBRARY_PATH' >> ~/.bashrc
source ~/.bashrc

(3)创建全新的虚拟环境并安装PyTorch 2.2.0

这里我们记得到xtuner目录下的 xtuner/requirements/runtime.txt文件里面重新修改一下torch版本,版本和下面要安装的命令里面的版本保持一致。PyTorch 2.2.0 + CUDA 12.1

重新指定torch版本

conda create -n xtuner-env python=3.10 -y
conda activate xtuner-env

# 安装PyTorch 2.2.0 + CUDA 12.1
pip install torch==2.2.0+cu121 torchvision==0.17.0+cu121 --extra-index-url https://download.pytorch.org/whl/cu121

# 重新安装xtuner所需依赖
cd xtuner
pip install -e '.[all]'

最后重新运行训练命令,能够正常进行训练

5. 发生错误: ‘NoneType’ object is not subscriptable

这个问题是在LMDeploy导入自定义对话模板时启动openai对话服务后,进行对话时出现的错误。
后面不进行自定义对话模板导入而是直接用LMDeploy启动我们训练的模型后可以正常回答,推测应该是自定义对话模板有问题

更多实用文章和AI大模型应用开发文章欢迎到我个人博客来观看:墨宇Logic


文章作者: 墨宇Logic
版权声明: 本博客所有文章除特別声明外,均采用 CC BY 4.0 许可协议。转载请注明来源 墨宇Logic !
 本篇
【项目实战】大模型微调-情绪对话模型 【项目实战】大模型微调-情绪对话模型
实战微调一个情绪对话模型项目,包含数据收集处理、模型选型、模型训练评估、以及模型部署全流程
下一篇 
【AI大模型应用学习笔记】基于LlamaIndex实现RAG 【AI大模型应用学习笔记】基于LlamaIndex实现RAG
介绍LlamaIndex核心组件,以及使用LlamaIndex简单实现RAG系统
  目录