ColossalAI/docs/source/zh-Hans/basics/model_checkpoint.md

1.8 KiB
Raw Blame History

模型Checkpoint

作者 : Guangyang Lu

⚠️ 此页面上的信息已经过时并将被废弃。请在Booster Checkpoint页面查阅更新。

预备知识:

示例代码:

函数是经验函数.

简介

本教程将介绍如何保存和加载模型Checkpoint。

为了充分利用Colossal-AI的强大并行策略我们需要修改模型和张量可以直接使用 torch.save 或者 torch.load 保存或加载模型Checkpoint。在Colossal-AI中我们提供了应用程序接口实现上述同样的效果。

但是,在加载时,你不需要使用与存储相同的保存策略。

使用方法

保存

有两种方法可以使用Colossal-AI训练模型即使用engine或使用trainer。 注意我们只保存 state_dict. 因此在加载Checkpoint时需要首先定义模型。

同 engine 保存

from colossalai.utils import save_checkpoint
model = ...
engine, _, _, _ = colossalai.initialize(model=model, ...)
for epoch in range(num_epochs):
    ... # do some training
    save_checkpoint('xxx.pt', epoch, model)

用 trainer 保存

from colossalai.trainer import Trainer, hooks
model = ...
engine, _, _, _ = colossalai.initialize(model=model, ...)
trainer = Trainer(engine, ...)
hook_list = [
            hooks.SaveCheckpointHook(1, 'xxx.pt', model)
            ...]

trainer.fit(...
            hook=hook_list)

加载

from colossalai.utils import load_checkpoint
model = ...
load_checkpoint('xxx.pt', model)
... # train or test