mirror of https://github.com/hpcaitech/ColossalAI
62 lines
1.7 KiB
Markdown
62 lines
1.7 KiB
Markdown
|
# Model Checkpoint
|
||
|
|
||
|
Author : Guangyang Lu
|
||
|
|
||
|
**Prerequisite:**
|
||
|
- [Launch Colossal-AI](./launch_colossalai.md)
|
||
|
- [Initialize Colossal-AI](./initialize_features.md)
|
||
|
|
||
|
**Example Code:**
|
||
|
- [ColossalAI-Examples Model Checkpoint](https://github.com/hpcaitech/ColossalAI-Examples/tree/main/utils/checkpoint)
|
||
|
|
||
|
**This function is experiential.**
|
||
|
|
||
|
## Introduction
|
||
|
|
||
|
In this tutorial, you will learn how to save and load model checkpoints.
|
||
|
|
||
|
To leverage the power of parallel strategies in Colossal-AI, modifications to models and tensors are needed, for which you cannot directly use `torch.save` or `torch.load` to save or load model checkpoints. Therefore, we have provided you with the API to achieve the same thing.
|
||
|
|
||
|
Moreover, when loading, you are not demanded to use the same parallel strategy as saving.
|
||
|
|
||
|
## How to use
|
||
|
|
||
|
### Save
|
||
|
|
||
|
There are two ways to train a model in Colossal-AI, by engine or by trainer.
|
||
|
**Be aware that we only save the `state_dict`.** Therefore, when loading the checkpoints, you need to define the model first.
|
||
|
|
||
|
#### Save when using engine
|
||
|
|
||
|
```python
|
||
|
from colossalai.utils import save_checkpoint
|
||
|
model = ...
|
||
|
engine, _, _, _ = colossalai.initialize(model=model, ...)
|
||
|
for epoch in range(num_epochs):
|
||
|
... # do some training
|
||
|
save_checkpoint('xxx.pt', epoch, model)
|
||
|
```
|
||
|
|
||
|
#### Save when using trainer
|
||
|
```python
|
||
|
from colossalai.trainer import Trainer, hooks
|
||
|
model = ...
|
||
|
engine, _, _, _ = colossalai.initialize(model=model, ...)
|
||
|
trainer = Trainer(engine, ...)
|
||
|
hook_list = [
|
||
|
hooks.SaveCheckpointHook(1, 'xxx.pt', model)
|
||
|
...]
|
||
|
|
||
|
trainer.fit(...
|
||
|
hook=hook_list)
|
||
|
```
|
||
|
|
||
|
### Load
|
||
|
|
||
|
```python
|
||
|
from colossalai.utils import load_checkpoint
|
||
|
model = ...
|
||
|
load_checkpoint('xxx.pt', model)
|
||
|
... # train or test
|
||
|
```
|