ColossalAI/docs/zero_zh.md

3.2 KiB
Raw Blame History

ZeRO优化器与offload

ZeRO优化器可以切分三种模型状态优化器状态、梯度、参数并将它们存储在不同的进程中以此来减少数据并行的存储冗余传统的数据并行需要将上述三种状态复制很多份保存在每一个进程中。与传统的做法相比ZeRO优化器可以极大地提高内存的存储效率并保持较好的通信效率。

  1. ZeRO Level 1: 优化器状态(如对于Adam优化器而言32比特的参数以及第一和第二动量的预测值被切分存储在不同的进程中这样每一个进程只需要更新它对应的那一部分参数。
  2. ZeRO Level 2: 用于更新模型参数的32比特的梯度在这一级被切分存储在不同的进程中这里梯度的切分与level 1中模型参数的切分是一一对应的每一个进程上的梯度恰好被用来更新该进程上的保存的模型参数。
  3. ZeRO Level 3: 16比特的模型参数在这一级被切分存储在不同的进程中ZeRO-3可以在前向传播和后向传播期间自动收集或切分这些参数。

使用ZeRO优化器

在Colossal-AI中启用ZeRO优化器只需要您在配置文件中进行配置即可下面是一些使用ZeRO-3的配置文件例子。

使用ZeRO优化器以及offload

这里我们使用Adam作为我们的初始优化器。

  1. 使用ZeRO来切分优化器状态level 1梯度level 2以及模型参数level 3
    optimizer = dict(
        type='Adam',
        lr=0.001,
        weight_decay=0
    )
    
    zero = dict(
        level=3,
        dynamic_loss_scale=True,
        clip_grad=1.0
    )
    
  2. 将优化器状态以及计算分配到CPU上
    zero = dict(
        offload_optimizer_config=dict(
            device='cpu',
            pin_memory=True,
            fast_init=True
        ),
        ...
    )
    
  3. 将模型参数分配到CPU上来节省显存
    zero = dict(
        offload_optimizer_config=dict(
            device='cpu',
            pin_memory=True,
            fast_init=True
        ),
        offload_param_config=dict(
            device='cpu',
            pin_memory=True,
            fast_init=OFFLOAD_PARAM_MAX_IN_CPU
        ),
        ...
    )
    
  4. 将参数分配到NVMe上来节省更多显存如果您的系统上安装了NVMe
    zero = dict(
        offload_optimizer_config=dict(
            device='nvme',
            pin_memory=True,
            fast_init=True,
            nvme_path='/nvme_data'
        ),
        offload_param_config=dict(
            device='nvme',
            pin_memory=True,
            max_in_cpu=OFFLOAD_PARAM_MAX_IN_CPU,
            nvme_path='/nvme_data'
        ),
        ...
    )
    

请注意使用ZeRO时fp16将会被自动激活。

使用ZeRO优化器进行训练

注意当使用ZeRO-3时如果您的模型过大以至于无法放入内存, 您应该使用colossalai.zero.zero3_model_context来构建您的模型:

from colossalai.zero import zero3_model_context

with zero3_model_context():
    model = Model()

如果您完成了上述配置,可以运行colossalai.initialize()来开始您的训练。