diff --git a/colossalai/booster/booster.py b/colossalai/booster/booster.py index c14e602de..6f2adaf03 100644 --- a/colossalai/booster/booster.py +++ b/colossalai/booster/booster.py @@ -151,6 +151,16 @@ class Booster: return self.plugin.no_sync(model) def load_model(self, model: nn.Module, checkpoint: str, strict: bool = True): + """Load model from checkpoint. + + Args: + model (nn.Module): A model boosted by Booster. + checkpoint (str): Path to the checkpoint. It must be a local path. + It should be a directory path if the checkpoint is sharded. Otherwise, it should be a file path. + strict (bool, optional): whether to strictly enforce that the keys + in :attr:`state_dict` match the keys returned by this module's + :meth:`~torch.nn.Module.state_dict` function. Defaults to True. + """ self.checkpoint_io.load_model(model, checkpoint, strict) def save_model(self, @@ -159,16 +169,58 @@ class Booster: prefix: str = None, shard: bool = False, size_per_shard: int = 1024): + """Save model to checkpoint. + + Args: + model (nn.Module): A model boosted by Booster. + checkpoint (str): Path to the checkpoint. It must be a local path. + It is a file path if ``shard=False``. Otherwise, it is a directory path. + prefix (str, optional): A prefix added to parameter and buffer + names to compose the keys in state_dict. Defaults to None. + shard (bool, optional): Whether to save checkpoint a sharded way. + If true, the checkpoint will be a folder. Otherwise, it will be a single file. Defaults to False. + size_per_shard (int, optional): Maximum size of checkpoint shard file in MB. This is useful only when ``shard=True``. Defaults to 1024. + """ self.checkpoint_io.save_model(model, checkpoint, prefix, shard, size_per_shard) def load_optimizer(self, optimizer: Optimizer, checkpoint: str): + """Load optimizer from checkpoint. + + Args: + optimizer (Optimizer): An optimizer boosted by Booster. + checkpoint (str): Path to the checkpoint. It must be a local path. + It should be a directory path if the checkpoint is sharded. Otherwise, it should be a file path. + """ self.checkpoint_io.load_optimizer(optimizer, checkpoint) def save_optimizer(self, optimizer: Optimizer, checkpoint: str, shard: bool = False, size_per_shard: int = 1024): + """Save optimizer to checkpoint. + Warning: Saving sharded optimizer checkpoint is not supported yet. + + Args: + optimizer (Optimizer): An optimizer boosted by Booster. + checkpoint (str): Path to the checkpoint. It must be a local path. + It is a file path if ``shard=False``. Otherwise, it is a directory path. + shard (bool, optional): Whether to save checkpoint a sharded way. + If true, the checkpoint will be a folder. Otherwise, it will be a single file. Defaults to False. + size_per_shard (int, optional): Maximum size of checkpoint shard file in MB. This is useful only when ``shard=True``. Defaults to 1024. + """ self.checkpoint_io.save_optimizer(optimizer, checkpoint, shard, size_per_shard) def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str): + """Save lr scheduler to checkpoint. + + Args: + lr_scheduler (LRScheduler): A lr scheduler boosted by Booster. + checkpoint (str): Path to the checkpoint. It must be a local file path. + """ self.checkpoint_io.save_lr_scheduler(lr_scheduler, checkpoint) def load_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str): + """Load lr scheduler from checkpoint. + + Args: + lr_scheduler (LRScheduler): A lr scheduler boosted by Booster. + checkpoint (str): Path to the checkpoint. It must be a local file path. + """ self.checkpoint_io.load_lr_scheduler(lr_scheduler, checkpoint) diff --git a/docs/sidebars.json b/docs/sidebars.json index ed0ba5278..94f79dcd3 100644 --- a/docs/sidebars.json +++ b/docs/sidebars.json @@ -29,6 +29,7 @@ "basics/launch_colossalai", "basics/booster_api", "basics/booster_plugins", + "basics/booster_checkpoint", "basics/define_your_config", "basics/initialize_features", "basics/engine_trainer", diff --git a/docs/source/en/basics/booster_checkpoint.md b/docs/source/en/basics/booster_checkpoint.md new file mode 100644 index 000000000..adc0af60b --- /dev/null +++ b/docs/source/en/basics/booster_checkpoint.md @@ -0,0 +1,48 @@ +# Booster Checkpoint + +Author: [Hongxin Liu](https://github.com/ver217) + +**Prerequisite:** +- [Booster API](./booster_api.md) + +## Introduction + +We've introduced the [Booster API](./booster_api.md) in the previous tutorial. In this tutorial, we will introduce how to save and load checkpoints using booster. + +## Model Checkpoint + +{{ autodoc:colossalai.booster.Booster.save_model }} + +Model must be boosted by `colossalai.booster.Booster` before saving. `checkpoint` is the path to saved checkpoint. It can be a file, if `shard=False`. Otherwise, it should be a directory. If `shard=True`, the checkpoint will be saved in a sharded way. This is useful when the checkpoint is too large to be saved in a single file. Our sharded checkpoint format is compatible with [huggingface/transformers](https://github.com/huggingface/transformers). + +{{ autodoc:colossalai.booster.Booster.load_model }} + +Model must be boosted by `colossalai.booster.Booster` before loading. It will detect the checkpoint format automatically, and load in corresponding way. + +## Optimizer Checkpoint + +> ⚠ Saving optimizer checkpoint in a sharded way is not supported yet. + +{{ autodoc:colossalai.booster.Booster.save_optimizer }} + +Optimizer must be boosted by `colossalai.booster.Booster` before saving. + +{{ autodoc:colossalai.booster.Booster.load_optimizer }} + +Optimizer must be boosted by `colossalai.booster.Booster` before loading. + +## LR Scheduler Checkpoint + +{{ autodoc:colossalai.booster.Booster.save_lr_scheduler }} + +LR scheduler must be boosted by `colossalai.booster.Booster` before saving. `checkpoint` is the local path to checkpoint file. + +{{ autodoc:colossalai.booster.Booster.load_lr_scheduler }} + +LR scheduler must be boosted by `colossalai.booster.Booster` before loading. `checkpoint` is the local path to checkpoint file. + +## Checkpoint design + +More details about checkpoint design can be found in our discussion [A Unified Checkpoint System Design](https://github.com/hpcaitech/ColossalAI/discussions/3339). + + diff --git a/docs/source/en/basics/booster_plugins.md b/docs/source/en/basics/booster_plugins.md index c15c30c84..0362f095a 100644 --- a/docs/source/en/basics/booster_plugins.md +++ b/docs/source/en/basics/booster_plugins.md @@ -43,12 +43,16 @@ We've tested compatibility on some famous models, following models may not be su Compatibility problems will be fixed in the future. +> ⚠ This plugin can only load optimizer checkpoint saved by itself with the same number of processes now. This will be fixed in the future. + ### Gemini Plugin This plugin implements Zero-3 with chunk-based and heterogeneous memory management. It can train large models without much loss in speed. It also does not support local gradient accumulation. More details can be found in [Gemini Doc](../features/zero_with_chunk.md). {{ autodoc:colossalai.booster.plugin.GeminiPlugin }} +> ⚠ This plugin can only load optimizer checkpoint saved by itself with the same number of processes now. This will be fixed in the future. + ### Torch DDP Plugin More details can be found in [Pytorch Docs](https://pytorch.org/docs/main/generated/torch.nn.parallel.DistributedDataParallel.html#torch.nn.parallel.DistributedDataParallel). @@ -62,3 +66,5 @@ More details can be found in [Pytorch Docs](https://pytorch.org/docs/main/genera More details can be found in [Pytorch Docs](https://pytorch.org/docs/main/fsdp.html). {{ autodoc:colossalai.booster.plugin.TorchFSDPPlugin }} + + diff --git a/docs/source/zh-Hans/basics/booster_checkpoint.md b/docs/source/zh-Hans/basics/booster_checkpoint.md new file mode 100644 index 000000000..d75f18c90 --- /dev/null +++ b/docs/source/zh-Hans/basics/booster_checkpoint.md @@ -0,0 +1,48 @@ +# Booster Checkpoint + +作者: [Hongxin Liu](https://github.com/ver217) + +**前置教程:** +- [Booster API](./booster_api.md) + +## 引言 + +我们在之前的教程中介绍了 [Booster API](./booster_api.md)。在本教程中,我们将介绍如何使用 booster 保存和加载 checkpoint。 + +## 模型 Checkpoint + +{{ autodoc:colossalai.booster.Booster.save_model }} + +模型在保存前必须被 `colossalai.booster.Booster` 加速。 `checkpoint` 是要保存的 checkpoint 的路径。 如果 `shard=False`,它就是文件。 否则, 它就是文件夹。如果 `shard=True`,checkpoint 将以分片方式保存。当 checkpoint 太大而无法保存在单个文件中时,这很有用。我们的分片 checkpoint 格式与 [huggingface/transformers](https://github.com/huggingface/transformers) 兼容。 + +{{ autodoc:colossalai.booster.Booster.load_model }} + +模型在加载前必须被 `colossalai.booster.Booster` 加速。它会自动检测 checkpoint 格式,并以相应的方式加载。 + +## 优化器 Checkpoint + +> ⚠ 尚不支持以分片方式保存优化器 Checkpoint。 + +{{ autodoc:colossalai.booster.Booster.save_optimizer }} + +优化器在保存前必须被 `colossalai.booster.Booster` 加速。 + +{{ autodoc:colossalai.booster.Booster.load_optimizer }} + +优化器在加载前必须被 `colossalai.booster.Booster` 加速。 + +## 学习率调度器 Checkpoint + +{{ autodoc:colossalai.booster.Booster.save_lr_scheduler }} + +学习率调度器在保存前必须被 `colossalai.booster.Booster` 加速。 `checkpoint` 是 checkpoint 文件的本地路径. + +{{ autodoc:colossalai.booster.Booster.load_lr_scheduler }} + +学习率调度器在加载前必须被 `colossalai.booster.Booster` 加速。 `checkpoint` 是 checkpoint 文件的本地路径. + +## Checkpoint 设计 + +有关 Checkpoint 设计的更多详细信息,请参见我们的讨论 [A Unified Checkpoint System Design](https://github.com/hpcaitech/ColossalAI/discussions/3339). + + diff --git a/docs/source/zh-Hans/basics/booster_plugins.md b/docs/source/zh-Hans/basics/booster_plugins.md index e0258eb37..b15ceb1e3 100644 --- a/docs/source/zh-Hans/basics/booster_plugins.md +++ b/docs/source/zh-Hans/basics/booster_plugins.md @@ -43,12 +43,16 @@ Zero-2 不支持局部梯度累积。如果您坚持使用,虽然可以积累 兼容性问题将在未来修复。 +> ⚠ 该插件现在只能加载自己保存的且具有相同进程数的优化器 Checkpoint。这将在未来得到解决。 + ### Gemini 插件 这个插件实现了基于Chunk内存管理和异构内存管理的 Zero-3。它可以训练大型模型而不会损失太多速度。它也不支持局部梯度累积。更多详细信息,请参阅 [Gemini 文档](../features/zero_with_chunk.md). {{ autodoc:colossalai.booster.plugin.GeminiPlugin }} +> ⚠ 该插件现在只能加载自己保存的且具有相同进程数的优化器 Checkpoint。这将在未来得到解决。 + ### Torch DDP 插件 更多详细信息,请参阅 [Pytorch 文档](https://pytorch.org/docs/main/generated/torch.nn.parallel.DistributedDataParallel.html#torch.nn.parallel.DistributedDataParallel). @@ -62,3 +66,5 @@ Zero-2 不支持局部梯度累积。如果您坚持使用,虽然可以积累 更多详细信息,请参阅 [Pytorch 文档](https://pytorch.org/docs/main/fsdp.html). {{ autodoc:colossalai.booster.plugin.TorchFSDPPlugin }} + +