Skip to content

模型训练的梳理

约 4418 字大约 15 分钟

机器学习模型训练AI

2025-07-11

一、模型训练的完整过程

机器学习模型的训练通常遵循固定流程,每个环节的质量都会直接影响最终模型的性能

以下是训练的核心步骤:

1. 数据准备与预处理

  • 数据收集:获取与任务相关的原始数据(如文本、图像、数值等),需保证数据的相关性和代表性。
  • 数据清洗:处理缺失值(填充、删除)、异常值(修正、剔除)、重复值(去重),避免噪声影响模型。
  • 特征工程
    • 特征选择:筛选与目标相关的特征(如使用方差分析、互信息),减少冗余。
    • 特征转换:标准化(如Z-score)、归一化(如Min-Max)、离散化(如将连续值分箱),使模型更易学习。
    • 特征构建:通过组合或转换现有特征生成新特征(如多项式特征、文本的词向量)。
  • 数据划分:将数据集分为训练集(70%-80%,用于模型学习)、验证集(10%-15%,用于调优超参数)、测试集(10%-15%,用于评估最终性能)。

2. 模型选择与初始化

  • 模型选择:根据任务类型(分类、回归、聚类等)和数据特点选择合适模型。例如:
    • 线性关系数据: 线性回归/逻辑回归;
    • 非线性复杂关系: 决策树、神经网络;
    • 高维稀疏数据: 支持向量机(SVM)、随机森林。
  • 参数初始化:为模型的可学习参数(如权重、偏置)赋予初始值。常见策略包括:
    • 随机初始化:如正态分布、均匀分布(适用于神经网络);
    • 常数初始化:如全零初始化(需避免对称权重问题);
    • 预训练初始化:使用预训练模型的参数(如迁移学习中的BERT、ResNet)。

3. 训练循环(核心过程)

模型训练通过迭代优化参数,使预测结果逐渐接近真实标签,核心流程如下:

  1. 前向传播:将训练数据输入模型,计算预测值(如神经网络的输出、决策树的分类结果)。
  2. 计算损失:通过损失函数对比预测值与真实标签,得到损失值(衡量预测误差的指标)。
  3. 反向传播:基于损失值,使用链式法则计算参数的梯度(反映参数对损失的影响程度)。
  4. 参数更新:通过优化器根据梯度调整模型参数,降低损失(如随机梯度下降法)。
  5. 迭代终止:当达到最大迭代次数、损失收敛(变化小于阈值)或验证集性能下降时,停止训练。

4. 模型评估与调优

  • 评估指标:根据任务类型选择指标(如分类任务用准确率、F1分数;回归任务用MSE、R²)。
  • 调优方向:若模型欠拟合(训练/验证误差均高),需增加模型复杂度;若过拟合(训练误差低但验证误差高),需加入正则化或简化模型。
  • 模型保存:保存训练好的模型参数(如.pth.h5文件),用于后续预测。

二、常见超参数及调优策略

超参数是训练前手动设置的参数(非模型学习所得),直接影响模型性能

以下是不同模型的核心超参数及调优方法:

1. 通用超参数

超参数作用常见范围调优策略
学习率(Learning Rate)控制参数更新幅度1e-5 ~ 1e-1初期用较大值快速收敛,后期减小(如学习率衰减)
迭代次数(Epoch)训练数据的遍历次数10 ~ 1000基于验证集性能,避免过拟合
批次大小(Batch Size)每次迭代输入的样本数16 ~ 256小批次收敛慢但泛化好,大批量需匹配显存

2. 模型专属超参数

  • 决策树
    • 最大深度(max_depth):限制树的复杂度,防止过拟合(范围:3~30)。
    • 叶子节点最小样本数(min_samples_leaf):避免生成过小叶子(范围:1~10)。
  • 随机森林
    • 树的数量(n_estimators):增加树数量可提升性能,但计算成本增加(范围:100~1000)。
    • 特征采样比例(max_features):控制每棵树使用的特征数(分类用平方根比例,回归用一半比例)。
  • 支持向量机(SVM)
    • 正则化参数(C):平衡分类间隔与错误率(小C代表强正则化,范围:0.1~100)。
    • 核函数参数(gamma):控制核函数的影响范围(线性核无需设置,RBF核范围:1e-3~10)。
  • 神经网络
    • 隐藏层数量/神经元数:增加层数可提升拟合能力(如2-5层,每层16-512个神经元)。
    • dropout率:随机丢弃神经元比例(防止过拟合,范围:0.2~0.5)。

3. 超参数调优策略

  • 网格搜索(Grid Search):穷举预设参数组合(如学习率取[0.01, 0.1],批次大小取[32, 64]),适合参数范围小的场景。
  • 随机搜索(Random Search):在参数范围内随机采样组合,效率高于网格搜索,适合大范围参数。
  • 贝叶斯优化:基于历史参数性能,概率性选择下次参数,适合高维参数空间(如Optuna工具)。
  • 经验调优:先粗调范围(如学习率1e-5~1e-1),再聚焦最优区间微调。

三、常用优化器(Optimizers)

优化器是用于更新模型参数的算法,核心目标是快速找到损失函数的最小值

以下是主流优化器的特点及适用场景:

1. 梯度下降类优化器

  • 随机梯度下降(SGD)

    • 参数更新方式:每次用单个样本的梯度调整参数
    • 特点:每次用单个样本更新,收敛过程波动大但计算速度快。
    • 适用场景:数据量大、简单模型(如线性回归)。
    • 改进:加入动量(Momentum),模拟物理惯性,加速收敛并抑制震荡
    • 动量更新逻辑:保留上一次的更新方向,结合当前梯度调整参数(动量系数通常取0.9)
  • 批量梯度下降(BGD)

    • 特点:用全部样本计算梯度后更新参数,收敛稳定但计算成本高,适用于小数据集。
  • 小批量梯度下降(Mini-batch GD)

    • 特点:平衡SGD和BGD,每次用一批样本(如32/64个)计算梯度并更新,是目前最常用的基础优化器。

2. 自适应学习率优化器

  • Adam(Adaptive Moment Estimation)

    • 核心逻辑:结合动量机制和自适应学习率,同时跟踪梯度的一阶矩(均值)和二阶矩(方差)
    • 特点:收敛快且稳定,适用场景广泛(如神经网络、深度学习)。
  • RMSprop

    • 特点:仅通过梯度的二阶矩自适应调整学习率,适合处理非平稳目标(如递归神经网络RNN)。
  • Adagrad

    • 特点:学习率随参数更新次数增加而自动减小,适合稀疏数据(如自然语言处理),但可能过早停止更新。

3. 优化器选择建议

  • 新手首选Adam,兼顾效率与稳定性;
  • 若模型收敛慢,尝试SGD+Momentum(需精细调整学习率);
  • 稀疏数据场景(如文本)可尝试AdagradRMSprop

四、常用损失函数(Loss Functions)

损失函数(Loss Function)量化模型预测与真实标签的差异,是参数更新的“指挥棒”。不同任务需匹配不同损失函数:

1. 分类任务损失函数

  • 交叉熵损失(Cross-Entropy Loss)

    • 二分类:通过预测概率与真实标签的对数差异计算损失
    • 多分类:对每个类别的预测概率与真实标签的对数差异求和
    • 特点:对错误预测的惩罚更显著,适用于逻辑回归、神经网络分类任务。
  • 铰链损失(Hinge Loss)

    • 核心逻辑:专注于分类边界,对正确分类且置信度高的样本惩罚小
    • 特点:适用于支持向量机(SVM),强调最大化分类间隔。

2. 回归任务损失函数

  • 均方误差(MSE)

    • 计算方式:预测值与真实值之差的平方的平均值
    • 特点:对异常值敏感(平方会放大误差),适用于误差呈正态分布的场景(如房价预测)。
  • 平均绝对误差(MAE)

    • 计算方式:预测值与真实值之差的绝对值的平均值
    • 特点:对异常值更稳健,适用于误差分布不对称的场景(如风速预测)。
  • Huber损失

    • 核心逻辑:误差较小时用平方误差(类似MSE),误差较大时用线性误差(类似MAE)
    • 特点:结合MSE和MAE的优势,适用于含少量异常值的回归任务。

3. 其他任务损失函数

  • Dice损失:常用于图像分割,解决类别不平衡问题;
  • 对比损失:用于度量学习,拉近相似样本距离,拉远异类样本距离。

五、模型训练的调节机制

训练过程中需通过调节机制避免过拟合、加速收敛,常见方法如下:

1. 正则化(防止过拟合)

  • L1正则化:在损失中加入参数绝对值之和,使部分参数为0,实现特征选择。
  • L2正则化(权重衰减):加入参数平方和,使参数值整体缩小,抑制过拟合。
  • Dropout:训练时随机让部分神经元失效(如50%概率),强制模型学习冗余特征,适用于神经网络。
  • 早停机制(Early Stopping):当验证集损失连续多轮上升时,停止训练,避免过拟合。

2. 学习率调度(Learning Rate Scheduling)

  • 分段衰减(Step Decay):每训练一定轮数,学习率乘以衰减因子(如乘以0.1)。
  • 指数衰减:学习率随迭代次数按指数规律减小,适合快速收敛场景。
  • 余弦退火:学习率随迭代次数按余弦曲线变化,先降后升,帮助跳出局部最优。

3. 数据增强(Data Augmentation)

通过对训练数据进行随机变换(如图像的旋转、裁剪,文本的同义词替换),增加数据多样性,抑制过拟合。适用于数据量小的场景(如深度学习图像任务)。

4. 批归一化(Batch Normalization)

对每批数据进行标准化(调整为均值0、方差1),稳定网络训练时的输入分布,加速收敛并允许更高学习率。广泛用于卷积神经网络(CNN)。

六、模型评估指标

模型评估是机器学习流程中至关重要的环节,用于衡量模型的性能表现,指导模型优化方向

不同类型的任务(分类、回归、聚类等)有不同的评估指标,以下是详细介绍:

分类任务评估指标

分类任务的目标是将样本划分到预定义的类别中,评估指标主要围绕预测类别与真实类别的匹配程度展开。

  1. 混淆矩阵(Confusion Matrix)

    是分类任务的基础评估工具,通过四个核心指标描述模型表现:

  • 真正例(True Positive, TP):实际为正类,模型预测为正类。
  • 假正例(False Positive, FP):实际为负类,模型预测为正类(误判)。
  • 真负例(True Negative, TN):实际为负类,模型预测为负类。
  • 假负例(False Negative, FN):实际为正类,模型预测为负类(漏判)。
  1. 准确率(Accuracy)
  • 定义:所有预测正确的样本占总样本的比例。
  • 适用场景:适用于样本类别分布均衡的情况,不适用于不平衡数据集(如疾病检测中,少数正例的漏判影响更大)。
  1. 精确率(Precision)
  • 定义:预测为正类的样本中,实际为正类的比例(关注“预测为正的可靠性”)。
  • 适用场景:注重“减少误判”的场景,如垃圾邮件识别(避免将正常邮件误判为垃圾邮件)。
  1. 召回率(Recall/Sensitivity/True Positive Rate, TPR)
  • 定义:实际为正类的样本中,被模型正确预测为正类的比例(关注“正例的覆盖能力”)。
  • 适用场景:注重“减少漏判”的场景,如癌症检测(尽可能找出所有患者,允许少量误判)。
  1. F1分数(F1-Score)
  • 定义:精确率和召回率的调和平均数,综合两者的表现,避免单一指标的片面性。
  • 适用场景:适用于类别不平衡或需要平衡精确率和召回率的场景(如文本分类)。
  1. ROC曲线与AUC(Area Under ROC Curve)
  • ROC曲线:以假正例率(FPR)为横轴,真正例率(TPR,即召回率)为纵轴,描述不同阈值下模型的区分能力。
  • AUC:ROC曲线下的面积,取值范围为[0,1]。AUC越接近1,模型区分正负类的能力越强;AUC=0.5时,模型性能与随机猜测相当。
  • 适用场景:尤其适用于不平衡数据集,对阈值不敏感,常用于二分类模型评估(如信用风险评估)。
  1. 宏平均(Macro-average)与微平均(Micro-average)
  • 用于多分类任务,综合多个类别的评估指标:
    • 宏平均:先计算每个类别的指标(如精确率),再取平均值,平等对待所有类别(适用于类别不平衡且关注小类别)。
    • 微平均:将所有类别的TP、FP、TN、FN汇总后计算指标,受样本量多的类别影响更大(适用于类别分布较均衡)。

回归任务评估指标

回归任务的目标是预测连续数值,评估指标主要衡量预测值与真实值的差异程度。

  1. 均方误差(Mean Squared Error, MSE)
  • 定义:预测值与真实值之差的平方的平均值。
  • 特点:对异常值敏感(平方会放大误差),单位是目标值单位的平方,常用于模型训练的损失函数。
  1. 均方根误差(Root Mean Squared Error, RMSE)
  • 定义:MSE的平方根,将误差转换为与目标值相同的单位。
  • 特点:同样对异常值敏感,更直观反映误差大小(如预测房价时,RMSE单位为“元”)。
  1. 平均绝对误差(Mean Absolute Error, MAE)
  • 定义:预测值与真实值之差的绝对值的平均值。
  • 特点:对异常值不敏感,适用于数据中存在较多极端值的场景(如收入预测)。
  1. 决定系数(Coefficient of Determination, R²)
  • 定义:衡量模型对数据变异的解释能力,取值范围为(-∞, 1]。
  • 特点:R²越接近1,模型拟合效果越好;R²=0表示模型效果等同于均值预测;R²<0表示模型效果差于均值预测。

聚类任务评估指标

聚类任务无真实标签时,评估目标是衡量 “簇内相似度高、簇间差异大” 的程度。以下是几种常用的无监督聚类评估指标,从不同角度量化聚类结果的质量:

  1. 轮廓系数(Silhouette Coefficient)
  • 核心思想:同时衡量样本与自身簇内其他样本的相似度(簇内紧密度),以及与最近邻簇中样本的相似度(簇间分离度),综合两者评估聚类合理性。
  • 取值范围:([-1, 1])。
    • 接近1:样本聚类合理,簇内紧密且簇间分离好。
    • 接近0:样本处于两个簇的边界,聚类模糊。
    • 接近-1:样本可能被分到错误的簇,聚类效果差。
  • 优点:无需真实标签,直观反映聚类的紧密度和分离度。
  • 缺点:对噪声和离群点敏感,在簇形状不规则(如非凸簇)时表现不佳。
  1. Calinski-Harabasz指数(CH指数)
  • 核心思想:通过簇间离散度与簇内离散度的比值评估聚类质量,比值越大说明聚类越好。
  • 取值特点:值越大,说明簇间差异越大、簇内越集中,聚类效果越好。
  • 优点:计算效率高,对凸形簇(如球形簇)效果较好。
  • 缺点:对非凸簇或大小差异较大的簇不敏感,可能偏向于簇数量较多的聚类结果。
  1. Davies-Bouldin指数(DB指数)
  • 核心思想:衡量每个簇与最相似的其他簇之间的“平均相似度”,相似度越低说明聚类越好。
  • 取值特点:值越小,说明簇内越紧密且簇间越分离,聚类效果越好。
  • 优点:对簇的数量不敏感,计算简单。
  • 缺点:依赖于簇中心的定义,对非球形簇的适应性较差。
  1. Dunn指数(Dunn Index)
  • 核心思想:通过“最小簇间距离”与“最大簇内直径”的比值评估聚类质量,比值越大说明聚类越好。
  • 取值特点:值越大,说明簇间距离越大、簇内样本越集中,聚类效果越好。
  • 优点:对紧凑且分离良好的簇敏感,能反映簇的边界清晰度。
  • 缺点:计算复杂度高(需遍历所有样本对),对高维数据和噪声敏感。