ColossalAI/docs/source/en/basics/model_checkpoint.md

1.7 KiB

Model Checkpoint

Author : Guangyang Lu

Prerequisite:

Example Code:

This function is experiential.

Introduction

In this tutorial, you will learn how to save and load model checkpoints.

To leverage the power of parallel strategies in Colossal-AI, modifications to models and tensors are needed, for which you cannot directly use torch.save or torch.load to save or load model checkpoints. Therefore, we have provided you with the API to achieve the same thing.

Moreover, when loading, you are not demanded to use the same parallel strategy as saving.

How to use

Save

There are two ways to train a model in Colossal-AI, by engine or by trainer. Be aware that we only save the state_dict. Therefore, when loading the checkpoints, you need to define the model first.

Save when using engine

from colossalai.utils import save_checkpoint
model = ...
engine, _, _, _ = colossalai.initialize(model=model, ...)
for epoch in range(num_epochs):
    ... # do some training
    save_checkpoint('xxx.pt', epoch, model)

Save when using trainer

from colossalai.trainer import Trainer, hooks
model = ...
engine, _, _, _ = colossalai.initialize(model=model, ...)
trainer = Trainer(engine, ...)
hook_list = [
            hooks.SaveCheckpointHook(1, 'xxx.pt', model)
            ...]

trainer.fit(...
            hook=hook_list)

Load

from colossalai.utils import load_checkpoint
model = ...
load_checkpoint('xxx.pt', model)
... # train or test