diff --git a/colossalai/booster/plugin/gemini_plugin.py b/colossalai/booster/plugin/gemini_plugin.py index 3c6e539ba..6693b1f44 100644 --- a/colossalai/booster/plugin/gemini_plugin.py +++ b/colossalai/booster/plugin/gemini_plugin.py @@ -13,6 +13,7 @@ from torch.utils.data import DataLoader from torch.utils.data.distributed import DistributedSampler from colossalai.checkpoint_io import CheckpointIO, GeneralCheckpointIO +from colossalai.checkpoint_io.utils import save_state_dict from colossalai.cluster import DistCoordinator from colossalai.interface import ModelWrapper, OptimizerWrapper from colossalai.tensor.colo_parameter import ColoParameter @@ -83,7 +84,7 @@ class GeminiCheckpointIO(GeneralCheckpointIO): # the model should be unwrapped in self.load_model via ModelWrapper.unwrap return super().load_unsharded_model(model, checkpoint, strict=strict) - def save_unsharded_model(self, model: GeminiDDP, checkpoint: str): + def save_unsharded_model(self, model: GeminiDDP, checkpoint: str, gather_dtensor: bool, use_safetensors: bool): """ Save model to checkpoint but only on master process. """ @@ -91,14 +92,14 @@ class GeminiCheckpointIO(GeneralCheckpointIO): # as there is communication when get state dict, this must be called on all processes state_dict = model.state_dict(only_rank_0=True) if self.coordinator.is_master(): - self.save_checkpoint(state_dict, checkpoint) + save_state_dict(state_dict, checkpoint, use_safetensors) - def save_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: str): + def save_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: str, gather_dtensor: bool): """ Save optimizer to checkpoint but only on master process. """ # TODO(ver217): optimizer state dict is sharded - super().save_unsharded_optimizer(optimizer, checkpoint) + super().save_unsharded_optimizer(optimizer, checkpoint, gather_dtensor) def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str): """ diff --git a/colossalai/booster/plugin/torch_ddp_plugin.py b/colossalai/booster/plugin/torch_ddp_plugin.py index e2abe11ba..c5e310c7e 100644 --- a/colossalai/booster/plugin/torch_ddp_plugin.py +++ b/colossalai/booster/plugin/torch_ddp_plugin.py @@ -33,20 +33,20 @@ class TorchDDPCheckpointIO(GeneralCheckpointIO): # the model should be unwrapped in self.load_model via ModelWrapper.unwrap return super().load_unsharded_model(model, checkpoint, strict=strict) - def save_unsharded_model(self, model: nn.Module, checkpoint: str, gather_dtensor: bool): + def save_unsharded_model(self, model: nn.Module, checkpoint: str, gather_dtensor: bool, use_safetensors: bool): """ Save model to checkpoint but only on master process. """ # the model should be unwrapped in self.load_model via ModelWrapper.unwrap if self.coordinator.is_master(): - super().save_unsharded_model(model, checkpoint) + super().save_unsharded_model(model, checkpoint, gather_dtensor, use_safetensors) def save_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: str, gather_dtensor: bool): """ Save optimizer to checkpoint but only on master process. """ if self.coordinator.is_master(): - super().save_unsharded_optimizer(optimizer, checkpoint) + super().save_unsharded_optimizer(optimizer, checkpoint, gather_dtensor) def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str): """ diff --git a/examples/tutorial/new_api/torch_ddp/README.md b/examples/tutorial/new_api/torch_ddp/README.md index 62d5a083d..e120bacb0 100644 --- a/examples/tutorial/new_api/torch_ddp/README.md +++ b/examples/tutorial/new_api/torch_ddp/README.md @@ -2,10 +2,10 @@ ## 🚀 Quick Start -This example provides a training script and and evaluation script. The training script provides a an example of training ResNet on CIFAR10 dataset from scratch. +This example provides a training script and an evaluation script. The training script provides an example of training ResNet on CIFAR10 dataset from scratch. - Training Arguments - - `-r, `--resume`: resume from checkpoint file path + - `-r`, `--resume`: resume from checkpoint file path - `-c`, `--checkpoint`: the folder to save checkpoints - `-i`, `--interval`: epoch interval to save checkpoints - `-f`, `--fp16`: use fp16 @@ -41,4 +41,4 @@ Expected accuracy performance will be: | --------- | ------------------------ | --------------------- | --------------------- | | ResNet-18 | 85.85% | 85.03% | 85.12% | -**Note: the baseline is a adapted from the [script](https://pytorch-tutorial.readthedocs.io/en/latest/tutorial/chapter03_intermediate/3_2_2_cnn_resnet_cifar10/) to use `torchvision.models.resnet18`** +**Note: the baseline is adapted from the [script](https://pytorch-tutorial.readthedocs.io/en/latest/tutorial/chapter03_intermediate/3_2_2_cnn_resnet_cifar10/) to use `torchvision.models.resnet18`**