精度问题主要来源及现象分析

训练一个大型,受数据集、模型结构、并行策略、超参等影响,模型训练过程中可能出现NaN、溢出、Loss发散等情况,这时需要对模型进行精度调试。在昇腾处理器上训练模型时,一般而言Loss总体呈现下降收敛趋势,即使出现偶尔的尖刺现象也可以通过跳过数据集、断点续训等方式规避。最终,当使用训练后得到的权重,采用常规数据集评估模型分数是否符合社区实践评分预期,即可视为精度调试已完成。

本文总结了模型训练过程中常见的问题及其调试方法,力求帮助用户将问题消除在训练开始之前,以及缩短模型精度问题定位的时间。

数据集

数据是增强模型能力的决定性因素之一。模型对数据集的质量敏感,无效数据、低质量数据会让模型学习到错误的模式,甚至影响模型收敛。重复数据会加剧语言模型生成内容重复的问题,并且导致模型过拟合。从指标上看,数据集问题可能会导致Grad Norm产生尖刺,意味着Grad Norm异常增大,导致训练不稳定。

  • 数据异常触发的训练不稳定:引起训练异常的语料包括字节码、unicode乱码、纯数字文本等。这些语料会无报警地被主流的subword tokenizer分割成token,并转成id序列正常进入训练迭代步骤,但由于这些语料本身没有任何语法和语义规律,根据它们来更新参数可能会影响模型的正常学习过程。当异常语料数量显著时,会导致模型无法正常收敛。
  • 语料混合比例与类型的重要性:不同语言不同语料类型(代码、文本)的混合比例和语料规模对模型训练精度也有重要影响。需要调节不同来源的语料的混合比例,不能直接使用原始语料合并训练。均衡的语料比例有助于提高模型的泛化能力,特定类型的语料可以提升模型特定方面的能力。
  • 数据规模与模型参数规模的匹配:参数规模并非越大越好,Hoffmann等人在Chinchilla系列模型中的研究表明,在给定算力的情况下,当语料的token数与模型的参数数量相当时,模型的表现相对更好。因此,不能盲目追求大语料,也应重视控制规模、提高质量和充分训练。

超参数配置

超参数对模型训练收敛有关键性影响,这里说的超参数主要包括:优化器选择、学习率设定、梯度裁剪阈值、Loss scale等。

  • 优化器与学习率:优化器的选择直接影响到模型训练过程中的更新策略,而学习率作为核心超参数之一,若设置不当,可能导致训练过程中出现诸如全局梯度范数(GNorm)突变、Loss曲线无法收敛、上升甚至剧烈波动等问题。因此,在训练大型模型时,需细致地搜索合适的学习率以保证收敛性能。

  • 梯度裁剪阈值:梯度裁剪对模型训练稳定性也至关重要,使用过高的梯度裁剪阈值可能会导致GNorm如图1中蓝线那样持续保持在高位。梯度裁剪在反向传播完成后能够降低某些张量的梯度值,这样可以阻止梯度爆炸,并使训练过程更加稳定。裁剪范围值可以通过试验配置,也可以使用文献中的常用值,也可以通过实验观察范围然后选择一个合理值。网络中的所有层通常都使用相同的梯度裁剪配置,一般来说输出层相比隐藏层允许更大范围的梯度。

    图 1 过高的梯度裁剪阈值

  • Loss Scale动态调整:在训练期间,Loss Scale作为一个动态调节的超参数,其大小变化同样反映着训练的稳定性状态。若Loss Scale长期低于1,这意味着有一些梯度值大到很容易上溢,同时存在的那些小梯度值会下溢为0,这些情况都使得训练将持续处于不稳定状态。

  • Batch尺寸的影响:Batch尺寸是另一个显著影响训练效果的超参数,它在满足内存需求与提升训练效率间寻求平衡。在模型的分布式训练场景中,用户倾向于选择较大的Batch尺寸以缩短训练时长,然而过大的Batch尺寸也可能导致Loss曲线呈现上升趋势。因此,在调整Batch尺寸时,需综合考虑其对训练效率和收敛性的影响,实现最优配置。

浮点数计算

浮点数计算的精度问题根源于计算机中使用特定格式的二进制数表示实数,而二进制只能表示有限数目的离散实数,对于某些十进制小数,只能近似表示。浮点数计算的精度问题只能消减而不能根除,所以理解不同浮点类型的差异对模型训练至关重要。

在CPU计算场景中,经常使用双精度浮点类型double,也称为FP64或者float64。但是在机器学习和AI加速器计算场景下,由于参数规模巨大,很难也没有必要使用双精度数来存储和计算。单精度浮点数和半精度浮点数为模型训练中常用的数据类型。

单精度浮点数

FP32或者float32是一种IEEE 754标准定义的浮点格式,它使用32个比特位表示一个浮点数,这些比特位包括一比特正负符号位,8个比特的指数位和23个比特的小数位。比特位结构布局如下:

当指数位不为0时,这32比特表示的浮点数V称为规格数,它由下列公式定义:

当exponent为0时,它表达的浮点数称为非规格数,由下列公式定义:

NPU或主流AI处理器在fast模式下都只能对规格数进行运算,非规格数将被转换为零,并且除法和平方根运算不会被计算到最接近真实值的浮点数值。

当exponent比特位全1时,表达的浮点数有特殊含义。如果此时fraction为0,则V为正负无穷大,如果fraction不为0,则V为NaN。

FP32格式能表示的绝对值最小规范浮点数为2-126,约等于1.18e-38。绝对值最大规范浮点数为2127 * [1 +(223 - 1)/223],约等于3.4e38。在1附近,两个相邻浮点数差距为2-23,约为1.2e-7。所有位置上相邻浮点数的相对差距大约为1.2e-7。可见FP32的浮点精度很高。

半精度浮点数

由于模型参数规模巨大,对AI加速器设备内存要求很大,所以使用16位半精度格式比32位单精度更有优势。另外半精度运算速度更快,精度对机器学习模型往往也能够满足要求。和主流AI处理器一样,最新的昇腾AI加速器也支持两种半精度浮点格式即FP16和BF16。

  • FP16格式

    FP16或者float16是一种IEEE 754标准定义的半精度浮点格式,它使用16个比特位表示一个浮点数,这些比特位包括1个比特的正负符号位(sign),5个比特的指数位(exponent)和10个比特的小数位(fraction)。比特位结构布局如下:

    和FP32一样,如果指数位非0,则能表示一个规格浮点数V,公式如下:

    FP16非规格数这里不再赘述。

    在1附近,相邻两个浮点数的间隔是2-10。任意位置的两个相邻浮点数的相对间隔也是2-10,大约是千分之一。FP16的精度误差带来的严重问题就是加法的大数吃小数。例如:

    x = torch.tensor(1, dtype=torch.float16)
    y = torch.tensor(0.0001, dtype=torch.float16)
    

    x+y仍然是1.0,因为1.0001 FP16没法表示,1后面FP16能表达的最小数为1.0009765625。

    FP16能表达的最大绝对值规范浮点数为215 * (1+1023/1024)=65504,最小绝对值规格浮点数为2-14,约为6.1e-5。这意味着超出65504和小于6.1e-5的数FP16都没法表示,而这些数在模型训练中都很常见。为了避免下溢,FP16需要结合Loss scale才能用于模型训练。

  • BF16格式

    为了解决FP16表达范围偏小的问题,谷歌大脑研究组提出了bfloat16浮点格式,或者叫BF16。BF16的比特位结构布局如下:

    BF16相对于FP16,增大指数位宽到8(与FP32一样),将小数位宽减小到7。这样可以增大浮点表达范围,但同时牺牲了表达精度。当exponent不为0时,BF16浮点值公式如下:

    在1.0附近,BF16可表示的相邻浮点数的间隔为1/128。任意位置的相邻浮点数相对间隔是1/128,可见BF16的精度是很低的,只是FP16的十分之一精度。BF16最大可表示的规范浮点数为2127 * (1+127/128),约为3.4e38, 绝对值最小的规范浮点数为2-126,约为1.176e-38。

    BF16的精度在大部分模型训练场景下仍然是足够的,但要注意某些特定运算,例如涉及累加,需要用更高精度的FP32完成。

综上所述,选择合适的浮点数精度对于优化模型训练过程至关重要,尤其是在处理因浮点精度引起的收敛问题时,需要结合文献研究成果和实践经验,并通过大量的对比实验进行调试和优化。

混合精度计算

尽管低精度计算能够提供非常显著的训练速度提升,但低精度运算相比FP32也会引起数值错误和数值稳定性问题。所以使用低精度运算来加速训练的时候,理解张量运算核心的数值行为也非常重要。

FP16的取值范围为±65504,能表示的绝对值最小的规格数为2-14和-2-14,即为+/-0.000061035,相邻两个浮点数的相对差距约为千分之一。举例来说,比1大的第一个浮点数为1+1/1024,这意味着没法表示1.0001。使用FP16相对FP32会有更大精度误差。下面的代码在主流AI处理器上分别用FP32和FP16计算一个10000维随机张量(中位数0,方差1的正态分布)均值的相对误差。用FP16计算时,平均相对误差在0.1%,最大误差在23%。在模型训练场景下,因为很多张量都很小,且维度高,所以需留心FP16的精度问题。同样的代码如果使用BF16计算,相对误差均值在1%,最大误差接近4800%。因为累加运算有不确定性,程序运行两次会有不同的结果。

python non-distributed.py --dtype fp16 --device npu
python non-distributed.py --dtype bf16 --device npu
import torch
import argparse
parser = argparse.ArgumentParser(
                    prog='test_mean',
                    description='check the effect to final result under different floating point precision')
parser.add_argument("-t", "--dtype", default="fp16",help="precision selection")
parser.add_argument("-d", "--device", default="cuda", help="accelerator selection")
args = parser.parse_args()
# accelerator selection
if args.device == 'cuda':
    device = torch.device('cuda:0')
elif args.device == 'npu':
    import torch_npu
    device = torch.device('npu:0')
dtype = torch.bfloat16 if args.dtype == "bf16" else torch.float16
fn = torch.mean

def test(fn):
    N, batch = 10000, 1000
    # calculations based on FP16 or BF16
    x = torch.randn((batch, N))
    x_16 = x.to(device, dtype=dtype)
    x_mean_16 = fn(x_16, dim=-1, dtype=dtype)
    
    # calculations based on FP32
    x_fp32 = x.to(device)
    x_mean_fp32 = fn(x_fp32, dim=-1, dtype=torch.float32)

    # error analysis 
    rel_diff = abs((x_mean_16.float() - x_mean_fp32) / x_mean_fp32)
    return rel_diff
rel_diff= test(fn=fn)
max_rel_diff = rel_diff.max() # maximum calculation error
avg_rel_diff = rel_diff.mean() # average calculation error
print("average relative difference is: ", avg_rel_diff)
print("max relative difference is: ", max_rel_diff)

因为FP16浮点表示范围很小,所以很容易发生上溢,运算结果产生Inf/NaN。某些硬件会将Inf转成最大可表示数,一定程度缓解上溢,但仍然会带来计算误差。

BF16通过增加指数位宽,降低尾数位宽获得更大的浮点数可表示范围,同时也牺牲了可表示精度。它的取值范围为±9.2E38,能表示的绝对值最小的规格数是±2-126,约为1.18*10-38,相邻两个浮点数的相对差距在百分之一。举例来说,比1大的第一个浮点数为1+1/128,这意味着没法表示1.001。BF16的精度更差,但不易发生上溢下溢,在精度不敏感的运算场景比FP16更方便。现有的大型transformer模型实验证明,使用BF16进行混合精度训练不太影响模型的收敛。

分布式通信算子有累加操作时,也会有低精度数的误差累积问题,如下代码所示,主流AI处理器8卡的FP16 all_reduce均值误差大约在千分之一,最大误差约为2倍。而用BF16时,均值相对误差大约在1.5%,最大相对误差有9倍。因为通信算子的确定性计算问题,同样的输入运行两次都会有不同的结果。

torchrun --nnodes=1 --nproc-per-node=8 distributed.py --dtype fp16 --device npu
torchrun --nnodes=1 --nproc-per-node=8 distributed.py --dtype bf16 --device npu
import os
import torch
import torch.distributed as dist

import argparse
parser = argparse.ArgumentParser(
                    prog='fp precision test',
                    description='check the effect to final result under different floating point precision')
parser.add_argument("-t", "--dtype", default="fp16")
parser.add_argument("-d", "--device", default="cuda", help='accelerator selection')
args = parser.parse_args()

dtype = torch.bfloat16 if args.dtype == "bf16" else torch.float16
# distributed environment
local_rank = int(os.environ['LOCAL_RANK'])
world_size = int(os.environ['WORLD_SIZE'])
# accelerator selection
if args.device == 'cuda':
    device = torch.device(f"cuda:{local_rank}")
    dist.init_process_group(backend='nccl', rank=local_rank, world_size=world_size)
elif args.device == 'npu':
    import torch_npu
    device = torch.device(f"npu:{local_rank}")
    torch.npu.set_device(device)
    dist.init_process_group(backend='hccl', rank=local_rank, world_size=world_size)
if __name__ == "__main__":
    print("start all reduce!")
    hccl_input = torch.randn(10000)
    print(local_rank, hccl_input.shape)

    # collective communication 
    hccl_input16 = hccl_input.to(device=device,dtype=dtype)
    hccl_input32 = hccl_input.to(device=device)
    torch.distributed.all_reduce(hccl_input16)
    torch.distributed.all_reduce(hccl_input32)
    if local_rank == 0:
        rel_diff = abs((hccl_input16.float() - hccl_input32) / hccl_input32)
        diff_mean = torch.mean(rel_diff)
        diff_max = torch.max(rel_diff)
        print("diff_mean=", diff_mean)
        print("max diff=", diff_max)

总结而言,选择何种精度进行混合精度训练需依据具体任务需求权衡精度损失与训练速度提升,同时考虑分布式环境下的误差累积效应。