nn_trick
主要记录一些深度网络升点的技巧。参考资料
升点
AI performance = data(70%) + model(20%) + trick(10%),数据是对AI性能影响最大的。
数据处理
-
数据增强 (Data Augmentation)
(1)NLP:回译,词性替换
(2)CV:
resize、 crop、flip、ratate、blur、HSV变化、affine(仿射)、perspective(透视)、Mixup、cutout、cutmix、Random Erasing(随机擦除)、Mosaic(马赛克)、CopyPaste、GANs domain transfer等 -
pseudo label / meta pseudo label (伪标签,半监督学习,比赛常用)
(1)pseudo label:伪标签是一种半监督学习方法,旨在利用有限的标注数据和大量的未标注数据来提升模型的性能。
其核心思想是通过一个初步训练的模型对未标注数据进行预测,并将这些预测结果作为“伪标签”,然后将这些伪标签与标注数据一起用于进一步训练模型。
(2)meta pseudo label是一种改进的伪标签方法,旨在解决传统伪标签方法中存在的确认偏差问题。它通过引入一个动态更新的教师模型,根据学生模型在标注数据上的表现来调整伪标签的生成过程。
学生模型使用教师模型生成的伪标签数据进行训练,教师模型根据学生模型在标注数据上的表现进行更新,以生成更准确的伪标签。 -
噪声数据删除
(1)最大熵删除法:构建一个模型来使得信息熵最大,从而区分噪声和真实数据。
(2)cleanlab -
错误标注数据修改:交叉验证训练多个模型,取模型预测结果一致且prob比threshold大的数据(或者topN)。多个模型可以采用不同的seed,不同的训练集测试机,或者不同的模型结果(bert与textcnn等),找出覆盖部分模型预测与标柱数据不一致的标注错误数据进行修改。
模型
1… 模型选择
先行业数据预训练,使用行业数据进行预训练是最优选择。
再领域再训练,在已经有一定通用预训练基础的模型上,使用特定领域的数据进行进一步训练。
最后考虑使用公开的模型进行 finetune(微调)
tricks
-
尝试模型初始化方法,不同的分布,分布参数。
当你的模型从0开始训练时,你需要给一个初始的网络权重。此时如何选择初始权重对网络有不小的影响。下面是几种常见的初始化方法,排名越前效果越好。arxiv
(1) LSUV init
(2) Kaming init
(3) pytorch default
(4) random init -
不同的预训练模型
举例:ImageNet 1K 和 ImageNet 21K ,两者的分类数不同,使得模型对图片细粒度信息的关注不同,后者效果更好。 -
warmup cosine lr scheduler方法(学习率调整)
学习率:训练过程中参数更新的步长大小,即朝着最优方向前进的步子大小
这是一种学习率优化算法,大致思路是:先预热(逐渐增加),然后余弦退火(像余弦函数一样周期变化,衰减)。这种方法对大模型效果很好。 -
对抗训练提升鲁棒性
常用方法:对抗权重扰动(AWP) 实现代码 -
随机权重平均(SWA)
通过对训练过程中的模型权重进行Avg融合,提升模型鲁棒性,PyTorch有官方实现。 -
test time augmentation(测试时增强)
在测试时使用数据增强的方法处理输入数据,然后对多个结果进行融合(平均、投票等方法) -
结构重参数化
在训练阶段,复杂的网络结构能够捕捉到更多的特征和信息,从而实现更高的性能。然而,在推理阶段,复杂的结构会导致计算资源的浪费和推理速度的下降。结构重参数化通过将训练阶段的复杂结构转换为更简单的等效结构,从而在保持性能的同时提高推理效率。
(1)实现方法1:卷积层与批归一化层的合并。在训练阶段,卷积层后通常会接一个批归一化(Batch Normalization, BN)层。在推理阶段,可以将这两个层合并为一个等效的卷积层,减少计算复杂度。
(2)实现方法2:多分支结构的合并 -
Gradient Checkpointing(梯度检查点)
通过减少显存占用来训练更大规模的模型,同时允许使用更大的批量大小以加速收敛。
在深度神经网络的训练中,反向传播需要存储每一层的激活值来计算梯度。对于深层网络,这些激活值会占用大量显存。梯度检查点技术通过在前向传播时丢弃部分激活值,并在反向传播时重新计算这些值,从而显著减少显存占用。
训练时间增长为代价,来训练更大的模型。 -
Random Seed
深度学习中有许多随机,设定随机种子来使得可重复。
(1)数据集的随机划分(训练集、验证集、测试集)。
(2)数据增强(如随机裁剪、随机翻转)。
(3)权重的随机初始化。
(4)批量归一化(Batch Normalization)中的随机抽样。
(5)Dropout 中的随机丢弃。 -
其它方法
蒸馏,