[booster] fixed the torch ddp plugin with the new checkpoint api (#3442)

pull/3219/head
Frank Lee 2023-04-06 09:43:51 +08:00 committed by GitHub
parent 8f740deb53
commit 7d8d825681
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 11 additions and 10 deletions

View File

@ -13,6 +13,7 @@ from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from colossalai.checkpoint_io import CheckpointIO, GeneralCheckpointIO
from colossalai.checkpoint_io.utils import save_state_dict
from colossalai.cluster import DistCoordinator
from colossalai.interface import ModelWrapper, OptimizerWrapper
from colossalai.tensor.colo_parameter import ColoParameter
@ -83,7 +84,7 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
# the model should be unwrapped in self.load_model via ModelWrapper.unwrap
return super().load_unsharded_model(model, checkpoint, strict=strict)
def save_unsharded_model(self, model: GeminiDDP, checkpoint: str):
def save_unsharded_model(self, model: GeminiDDP, checkpoint: str, gather_dtensor: bool, use_safetensors: bool):
"""
Save model to checkpoint but only on master process.
"""
@ -91,14 +92,14 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
# as there is communication when get state dict, this must be called on all processes
state_dict = model.state_dict(only_rank_0=True)
if self.coordinator.is_master():
self.save_checkpoint(state_dict, checkpoint)
save_state_dict(state_dict, checkpoint, use_safetensors)
def save_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: str):
def save_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: str, gather_dtensor: bool):
"""
Save optimizer to checkpoint but only on master process.
"""
# TODO(ver217): optimizer state dict is sharded
super().save_unsharded_optimizer(optimizer, checkpoint)
super().save_unsharded_optimizer(optimizer, checkpoint, gather_dtensor)
def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str):
"""

View File

@ -33,20 +33,20 @@ class TorchDDPCheckpointIO(GeneralCheckpointIO):
# the model should be unwrapped in self.load_model via ModelWrapper.unwrap
return super().load_unsharded_model(model, checkpoint, strict=strict)
def save_unsharded_model(self, model: nn.Module, checkpoint: str, gather_dtensor: bool):
def save_unsharded_model(self, model: nn.Module, checkpoint: str, gather_dtensor: bool, use_safetensors: bool):
"""
Save model to checkpoint but only on master process.
"""
# the model should be unwrapped in self.load_model via ModelWrapper.unwrap
if self.coordinator.is_master():
super().save_unsharded_model(model, checkpoint)
super().save_unsharded_model(model, checkpoint, gather_dtensor, use_safetensors)
def save_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: str, gather_dtensor: bool):
"""
Save optimizer to checkpoint but only on master process.
"""
if self.coordinator.is_master():
super().save_unsharded_optimizer(optimizer, checkpoint)
super().save_unsharded_optimizer(optimizer, checkpoint, gather_dtensor)
def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str):
"""

View File

@ -2,10 +2,10 @@
## 🚀 Quick Start
This example provides a training script and and evaluation script. The training script provides a an example of training ResNet on CIFAR10 dataset from scratch.
This example provides a training script and an evaluation script. The training script provides an example of training ResNet on CIFAR10 dataset from scratch.
- Training Arguments
- `-r, `--resume`: resume from checkpoint file path
- `-r`, `--resume`: resume from checkpoint file path
- `-c`, `--checkpoint`: the folder to save checkpoints
- `-i`, `--interval`: epoch interval to save checkpoints
- `-f`, `--fp16`: use fp16
@ -41,4 +41,4 @@ Expected accuracy performance will be:
| --------- | ------------------------ | --------------------- | --------------------- |
| ResNet-18 | 85.85% | 85.03% | 85.12% |
**Note: the baseline is a adapted from the [script](https://pytorch-tutorial.readthedocs.io/en/latest/tutorial/chapter03_intermediate/3_2_2_cnn_resnet_cifar10/) to use `torchvision.models.resnet18`**
**Note: the baseline is adapted from the [script](https://pytorch-tutorial.readthedocs.io/en/latest/tutorial/chapter03_intermediate/3_2_2_cnn_resnet_cifar10/) to use `torchvision.models.resnet18`**