mirror of https://github.com/hpcaitech/ColossalAI
[booster] implemented the torch ddd + resnet example (#3232)
* [booster] implemented the torch ddd + resnet example * polish codepull/3239/head
parent
1a229045af
commit
73d3e4d309
|
@ -1,4 +1,3 @@
|
|||
from .accelerator import Accelerator
|
||||
from .booster import Booster
|
||||
from .environment_table import EnvironmentTable
|
||||
from .plugin import Plugin
|
||||
|
|
|
@ -8,6 +8,8 @@ from torch.optim import Optimizer
|
|||
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from colossalai.checkpoint_io import GeneralCheckpointIO
|
||||
|
||||
from .accelerator import Accelerator
|
||||
from .mixed_precision import MixedPrecision, mixed_precision_factory
|
||||
from .plugin import Plugin
|
||||
|
@ -61,19 +63,21 @@ class Booster:
|
|||
self.plugin = plugin
|
||||
|
||||
# set accelerator
|
||||
if self.plugin and self.plugin.control_device:
|
||||
if self.plugin and self.plugin.control_device():
|
||||
self.accelerator = None
|
||||
warnings.warn('The plugin will control the accelerator, so the device argument will be ignored.')
|
||||
else:
|
||||
self.accelerator = Accelerator(device)
|
||||
|
||||
# set precision
|
||||
if mixed_precision is None or (self.plugin and self.plugin.control_precision):
|
||||
self.mixed_precision = None
|
||||
if self.plugin and self.plugin.control_precision():
|
||||
warnings.warn('The plugin will control the precision, so the mixed_precision argument will be ignored.')
|
||||
self.mixed_precision = None
|
||||
elif mixed_precision is None:
|
||||
self.mixed_precision = None
|
||||
else:
|
||||
# validate and set precision
|
||||
if isinstance(MixedPrecision, str):
|
||||
if isinstance(mixed_precision, str):
|
||||
# the user will take the default arguments for amp training
|
||||
self.mixed_precision = mixed_precision_factory(mixed_precision)
|
||||
elif isinstance(mixed_precision, MixedPrecision):
|
||||
|
@ -84,6 +88,11 @@ class Booster:
|
|||
f'Expected the argument mixed_precision to be a string or an instance of Precision, but got {type(mixed_precision)}.'
|
||||
)
|
||||
|
||||
if self.plugin is not None and self.plugin.control_checkpoint_io():
|
||||
self.checkpoint_io = self.plugin.get_checkpoint_io()
|
||||
else:
|
||||
self.checkpoint_io = GeneralCheckpointIO()
|
||||
|
||||
def boost(
|
||||
self,
|
||||
model: nn.Module,
|
||||
|
@ -109,12 +118,13 @@ class Booster:
|
|||
model, optimizer, criterion, dataloader, lr_scheduler = self.plugin.configure(
|
||||
model, optimizer, criterion, dataloader, lr_scheduler)
|
||||
|
||||
if self.plugin and not self.plugin.control_device:
|
||||
if self.plugin and not self.plugin.control_device():
|
||||
# transform model for accelerator
|
||||
model = self.accelerator.configure(model)
|
||||
|
||||
if self.mixed_precision and self.plugin and not self.plugin.control_precision:
|
||||
if self.mixed_precision and (self.plugin is None or self.plugin and not self.plugin.control_precision()):
|
||||
# transform model for mixed precision
|
||||
# when mixed_precision is specified and the plugin is not given or does not control the precision
|
||||
model, optimizer, criterion = self.mixed_precision.configure(model, optimizer, criterion)
|
||||
|
||||
return model, optimizer, criterion, dataloader, lr_scheduler
|
||||
|
@ -140,18 +150,25 @@ class Booster:
|
|||
assert self.plugin.support_no_sync, f'The plugin {self.plugin.__class__.__name__} does not support no_sync.'
|
||||
return self.plugin.no_sync(model)
|
||||
|
||||
def save(self,
|
||||
obj: Union[nn.Module, Optimizer, LRScheduler],
|
||||
path_like: str,
|
||||
plan: str = 'torch',
|
||||
**kwargs) -> None:
|
||||
# TODO: implement this method
|
||||
pass
|
||||
def load_model(self, model: nn.Module, checkpoint: str, strict: bool = True):
|
||||
self.checkpoint_io.load_model(model, checkpoint, strict)
|
||||
|
||||
def load(self,
|
||||
obj: Union[nn.Module, Optimizer, LRScheduler],
|
||||
path_like: str,
|
||||
plan: str = 'torch',
|
||||
**kwargs) -> None:
|
||||
# TODO: implement this method
|
||||
pass
|
||||
def save_model(self,
|
||||
model: nn.Module,
|
||||
checkpoint: str,
|
||||
prefix: str = None,
|
||||
shard: bool = False,
|
||||
size_per_shard: int = 1024):
|
||||
self.checkpoint_io.save_model(model, checkpoint, prefix, shard, size_per_shard)
|
||||
|
||||
def load_optimizer(self, optimizer: Optimizer, checkpoint: str):
|
||||
self.checkpoint_io.load_optimizer(optimizer, checkpoint)
|
||||
|
||||
def save_optimizer(self, optimizer: Optimizer, checkpoint: str, shard: bool = False, size_per_shard: int = 1024):
|
||||
self.checkpoint_io.save_optimizer(optimizer, checkpoint, shard, size_per_shard)
|
||||
|
||||
def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str):
|
||||
self.checkpoint_io.save_lr_scheduler(lr_scheduler, checkpoint)
|
||||
|
||||
def load_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str):
|
||||
self.checkpoint_io.load_lr_scheduler(lr_scheduler, checkpoint)
|
||||
|
|
|
@ -1,18 +0,0 @@
|
|||
from typing import List
|
||||
|
||||
__all__ = ['EnvironmentTable']
|
||||
|
||||
|
||||
class EnvironmentTable:
|
||||
|
||||
def __init__(self, intra_op_world_sizes: List[int]):
|
||||
# TODO: implement this method
|
||||
pass
|
||||
|
||||
@property
|
||||
def is_master(self) -> bool:
|
||||
# TODO: implement this method
|
||||
pass
|
||||
|
||||
# TODO: implement more utility methods as given in
|
||||
# https://github.com/hpcaitech/ColossalAI/issues/3051
|
|
@ -1,3 +0,0 @@
|
|||
from .optimizer import OptimizerWrapper
|
||||
|
||||
__all__ = ['OptimizerWrapper']
|
|
@ -5,7 +5,8 @@ import torch.nn as nn
|
|||
from torch import Tensor
|
||||
from torch.optim import Optimizer
|
||||
|
||||
from ..interface import OptimizerWrapper
|
||||
from colossalai.interface import ModelWrapper, OptimizerWrapper
|
||||
|
||||
from .mixed_precision_base import MixedPrecision
|
||||
|
||||
__all__ = ['FP16_Torch_MixedPrecision', 'TorchAMPOptimizer', 'TorchAMPModule']
|
||||
|
@ -45,7 +46,9 @@ class TorchAMPOptimizer(OptimizerWrapper):
|
|||
scaled_loss.backward(*args, **kwargs)
|
||||
|
||||
def step(self, *args, **kwargs) -> Optional[float]:
|
||||
return self.scaler.step(self.optim, *args, **kwargs)
|
||||
out = self.scaler.step(self.optim, *args, **kwargs)
|
||||
self.scaler.update()
|
||||
return out
|
||||
|
||||
def scale_loss(self, loss: Tensor) -> Tensor:
|
||||
return self.scaler.scale(loss)
|
||||
|
@ -67,7 +70,7 @@ class TorchAMPOptimizer(OptimizerWrapper):
|
|||
super().clip_grad_by_norm(max_norm, norm_type, error_if_nonfinite, *args, **kwargs)
|
||||
|
||||
|
||||
class TorchAMPModule(nn.Module):
|
||||
class TorchAMPModule(ModelWrapper):
|
||||
"""
|
||||
Module wrapper for mixed precision training in FP16 using PyTorch AMP.
|
||||
|
||||
|
@ -76,8 +79,7 @@ class TorchAMPModule(nn.Module):
|
|||
"""
|
||||
|
||||
def __init__(self, module: nn.Module):
|
||||
super().__init__()
|
||||
self.module = module
|
||||
super().__init__(module)
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
with torch.cuda.amp.autocast():
|
||||
|
|
|
@ -4,7 +4,7 @@ from typing import Callable, Tuple
|
|||
import torch.nn as nn
|
||||
from torch.optim import Optimizer
|
||||
|
||||
from ..interface import OptimizerWrapper
|
||||
from colossalai.interface import OptimizerWrapper
|
||||
|
||||
|
||||
class MixedPrecision(ABC):
|
||||
|
|
|
@ -6,34 +6,30 @@ from torch.optim import Optimizer
|
|||
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from colossalai.booster.interface import OptimizerWrapper
|
||||
from colossalai.checkpoint_io import CheckpointIO
|
||||
from colossalai.interface import OptimizerWrapper
|
||||
|
||||
__all__ = ['Plugin']
|
||||
|
||||
|
||||
class Plugin(ABC):
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def supported_devices(self) -> List[str]:
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def supported_precisions(self) -> List[str]:
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def control_precision(self) -> bool:
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def control_device(self) -> bool:
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def support_no_sync(self) -> bool:
|
||||
pass
|
||||
|
@ -49,3 +45,17 @@ class Plugin(ABC):
|
|||
) -> Tuple[Union[nn.Module, OptimizerWrapper, LRScheduler, DataLoader]]:
|
||||
# implement this method
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def control_checkpoint_io(self) -> bool:
|
||||
"""
|
||||
Whether the plugin controls the checkpoint io
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_checkpoint_io(self) -> CheckpointIO:
|
||||
"""
|
||||
Get checkpoint io object for this plugin, only invoked when control_checkpoint_io is True.
|
||||
"""
|
||||
pass
|
||||
|
|
|
@ -11,13 +11,61 @@ from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
|
|||
from torch.utils.data import DataLoader
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
|
||||
from colossalai.booster.interface import OptimizerWrapper
|
||||
from colossalai.checkpoint_io import CheckpointIO, GeneralCheckpointIO
|
||||
from colossalai.cluster import DistCoordinator
|
||||
from colossalai.interface import ModelWrapper, OptimizerWrapper
|
||||
|
||||
from .plugin_base import Plugin
|
||||
|
||||
__all__ = ['TorchDDPPlugin']
|
||||
|
||||
|
||||
class TorchDDPCheckpointIO(GeneralCheckpointIO):
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.coordinator = DistCoordinator()
|
||||
|
||||
def load_unsharded_model(self, model: nn.Module, checkpoint: str, strict: bool = True):
|
||||
"""
|
||||
Load model from checkpoint with automatic unwrapping.
|
||||
"""
|
||||
# the model should be unwrapped in self.load_model via ModelWrapper.unwrap
|
||||
return super().load_unsharded_model(model, checkpoint, strict=strict)
|
||||
|
||||
def save_unsharded_model(self, model: nn.Module, checkpoint: str):
|
||||
"""
|
||||
Save model to checkpoint but only on master process.
|
||||
"""
|
||||
# the model should be unwrapped in self.load_model via ModelWrapper.unwrap
|
||||
if self.coordinator.is_master():
|
||||
super().save_unsharded_model(model, checkpoint)
|
||||
|
||||
def save_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: str):
|
||||
"""
|
||||
Save optimizer to checkpoint but only on master process.
|
||||
"""
|
||||
if self.coordinator.is_master():
|
||||
super().save_unsharded_optimizer(optimizer, checkpoint)
|
||||
|
||||
def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str):
|
||||
"""
|
||||
Save model to checkpoint but only on master process.
|
||||
"""
|
||||
if self.coordinator.is_master():
|
||||
super().save_lr_scheduler(lr_scheduler, checkpoint)
|
||||
|
||||
|
||||
class TorchDDPModel(ModelWrapper):
|
||||
|
||||
def __init__(self, module: nn.Module, *args, **kwargs) -> None:
|
||||
super().__init__(module)
|
||||
self.module = DDP(module, *args, **kwargs)
|
||||
|
||||
def unwrap(self):
|
||||
return self.module.module
|
||||
|
||||
|
||||
class TorchDDPPlugin(Plugin):
|
||||
"""
|
||||
Plugin for PyTorch DDP.
|
||||
|
@ -138,10 +186,19 @@ class TorchDDPPlugin(Plugin):
|
|||
# cast model to cuda
|
||||
model = model.cuda()
|
||||
|
||||
# convert model to sync bn
|
||||
model = nn.SyncBatchNorm.convert_sync_batchnorm(model, None)
|
||||
|
||||
# wrap the model with PyTorch DDP
|
||||
model = DDP(model, **self.ddp_kwargs)
|
||||
model = TorchDDPModel(model, **self.ddp_kwargs)
|
||||
|
||||
if not isinstance(optimizer, OptimizerWrapper):
|
||||
optimizer = OptimizerWrapper(optimizer)
|
||||
|
||||
return model, optimizer, criterion, dataloader, lr_scheduler
|
||||
|
||||
def control_checkpoint_io(self) -> bool:
|
||||
return True
|
||||
|
||||
def get_checkpoint_io(self) -> CheckpointIO:
|
||||
return TorchDDPCheckpointIO()
|
||||
|
|
|
@ -1,13 +1,15 @@
|
|||
import json
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from typing import Any, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.optim import Optimizer
|
||||
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
|
||||
|
||||
from colossalai.interface import ModelWrapper
|
||||
|
||||
__all__ = ['CheckpointIO', 'ShardCheckpointIndexFile']
|
||||
|
||||
|
||||
|
@ -37,15 +39,15 @@ class CheckpointIO(ABC):
|
|||
>>>
|
||||
>>> # save optimizer to checkpoint
|
||||
>>> checkpoint_io.save_optimizer(optimizer, 'optimizer.pt')
|
||||
|
||||
"""
|
||||
|
||||
# ======================================
|
||||
# Abstract methods for implementation
|
||||
# Public methods
|
||||
# ======================================
|
||||
|
||||
@abstractmethod
|
||||
def load_model(self, model: nn.Module, checkpoint: str, strict: bool = True):
|
||||
def load_model(self,
|
||||
model: Union[nn.Module, ModelWrapper],
|
||||
checkpoint: str,
|
||||
strict: bool = True) -> Union[nn.Module, ModelWrapper]:
|
||||
"""
|
||||
Load model from checkpoint.
|
||||
|
||||
|
@ -59,14 +61,26 @@ class CheckpointIO(ABC):
|
|||
strict (bool): whether to strictly enforce that the param name in
|
||||
the checkpoint match the keys returned by this module's.
|
||||
"""
|
||||
pass
|
||||
ckpt_path = Path(checkpoint)
|
||||
is_sharded = self.is_sharded_checkpoint(ckpt_path)
|
||||
|
||||
origin_model = model
|
||||
|
||||
if isinstance(model, ModelWrapper):
|
||||
model = model.unwrap()
|
||||
|
||||
if is_sharded:
|
||||
self.load_sharded_model(model, ckpt_path, strict)
|
||||
else:
|
||||
self.load_unsharded_model(model, ckpt_path, strict)
|
||||
|
||||
return origin_model
|
||||
|
||||
@abstractmethod
|
||||
def save_model(self,
|
||||
model: nn.Module,
|
||||
model: Union[nn.Module, ModelWrapper],
|
||||
checkpoint: str,
|
||||
prefix: str = None,
|
||||
shard: bool = False,
|
||||
prefix: str = None,
|
||||
size_per_shard: int = 1024):
|
||||
"""
|
||||
Save model to checkpoint.
|
||||
|
@ -83,17 +97,24 @@ class CheckpointIO(ABC):
|
|||
|
||||
Args:
|
||||
model (nn.Module): model to be saved.
|
||||
checkpoint: checkpoint path. The checkpoint path can be :
|
||||
checkpoint (str): checkpoint path. The checkpoint path can be :
|
||||
1. a file path, e.g. 'model.pt'
|
||||
2. a directory path to save the sharded checkpoint, e.g. './checkpoints/' when shard = True.
|
||||
shard: whether to shard the checkpoint. Default: False. If set to True, the checkpoint will be sharded into
|
||||
shard (bool): whether to shard the checkpoint. Default: False. If set to True, the checkpoint will be sharded into
|
||||
multiple files. The model shards will be specificed by a `model.index.json` file. When shard = True, please ensure
|
||||
that the checkpoint path is a directory path instead of a file path.
|
||||
size_per_shard (int): size per shard in MB. Default: 1024. This value is only used when shard is set to True.
|
||||
prefix (str): prefix for the model checkpoint file name when shard=True. Default: None.
|
||||
size_per_shard (int): size per shard in MB. Default: 1024. This value is only used when shard = True.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
if isinstance(model, ModelWrapper):
|
||||
model = model.unwrap()
|
||||
|
||||
if shard:
|
||||
self.save_sharded_model(model, checkpoint, prefix, size_per_shard)
|
||||
else:
|
||||
self.save_unsharded_model(model, checkpoint)
|
||||
|
||||
def load_optimizer(self, optimizer: Optimizer, checkpoint: str):
|
||||
"""
|
||||
Load optimizer from checkpoint.
|
||||
|
@ -102,19 +123,139 @@ class CheckpointIO(ABC):
|
|||
optimizer (Optimizer): optimizer to be loaded.
|
||||
checkpoint (str): checkpoint path. This value is made compatiblity with the model checkpoints in the
|
||||
"""
|
||||
pass
|
||||
ckpt_path = Path(checkpoint)
|
||||
is_sharded = self.is_sharded_checkpoint(ckpt_path)
|
||||
|
||||
@abstractmethod
|
||||
def save_optimizer(self, optimizer: Optimizer, checkpoint: str, shard: bool = False, size_per_shard: int = 1024):
|
||||
if is_sharded:
|
||||
self.load_sharded_optimizer(optimizer, ckpt_path)
|
||||
else:
|
||||
self.load_unsharded_optimizer(optimizer, ckpt_path)
|
||||
|
||||
def save_optimizer(self,
|
||||
optimizer: Optimizer,
|
||||
checkpoint: str,
|
||||
shard: bool = False,
|
||||
prefix: str = None,
|
||||
size_per_shard: int = 1024):
|
||||
"""
|
||||
Save optimizer to checkpoint.
|
||||
|
||||
Args:
|
||||
optimizer (Optimizer): optimizer to be saved.
|
||||
checkpoint: checkpoint path. The checkpoint path can be :
|
||||
checkpoint (str): checkpoint path. The checkpoint path can be :
|
||||
1. a file path, e.g. 'model.pt'
|
||||
2. a path to a json file which defines the index to the sharded checkpoint for the optimizer
|
||||
3. a path to a folder containing a unique .index.json file for sharded checkpoint
|
||||
shard (bool): whether to shard the checkpoint. Default: False. If set to True, the checkpoint will be sharded into
|
||||
multiple files. The optimizer shards will be specificed by a `optimizer.index.json` file.
|
||||
prefix (str): prefix for the optimizer checkpoint when shard = True. Default: None.
|
||||
size_per_shard (int): size per shard in MB. Default: 1024. This value is only used when shard is set to True.
|
||||
"""
|
||||
if shard:
|
||||
self.save_sharded_optimizer(optimizer, checkpoint, prefix, size_per_shard)
|
||||
else:
|
||||
self.save_unsharded_optimizer(optimizer, checkpoint)
|
||||
|
||||
# ========================================================
|
||||
# Abstract methods for model loading/saving implementation
|
||||
# ========================================================
|
||||
@abstractmethod
|
||||
def load_sharded_model(self, model: nn.Module, checkpoint: Path, strict: bool):
|
||||
"""
|
||||
Load model from sharded checkpoint.
|
||||
|
||||
Args:
|
||||
model (nn.Module): model to be loaded.
|
||||
checkpoint (str): checkpoint path. It should be path to the .index.json file or a path to a directory which contains a .index.json file.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def load_unsharded_model(self, model: nn.Module, checkpoint: Path, strict: bool):
|
||||
"""
|
||||
Load model from unsharded checkpoint.
|
||||
|
||||
Args:
|
||||
model (nn.Module): model to be loaded.
|
||||
checkpoint (str): checkpoint path. It should be a single file path pointing to a model weight binary.
|
||||
strict (bool): whether to strictly enforce that the param name in
|
||||
the checkpoint match the keys returned by this module's.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def save_sharded_model(self, model: nn.Module, checkpoint: Path, prefix: str, size_per_shard: int):
|
||||
"""
|
||||
Save model to sharded checkpoint.
|
||||
|
||||
Args:
|
||||
model (nn.Module): model to be saved.
|
||||
checkpoint (Path): checkpoint path. It should be a directory path.
|
||||
prefix (str): prefix for the model checkpoint.
|
||||
size_per_shard (int): size per shard in MB.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def save_unsharded_model(self, model: nn.Module, checkpoint: Path):
|
||||
"""
|
||||
Save model to unsharded checkpoint.
|
||||
|
||||
Args:
|
||||
model (nn.Module): model to be saved.
|
||||
checkpoint (Path): checkpoint path. It should be a single file path pointing to a model weight binary.
|
||||
"""
|
||||
pass
|
||||
|
||||
# ========================================================
|
||||
# Abstract methods for optimizer loading/saving implementation
|
||||
# ========================================================
|
||||
|
||||
@abstractmethod
|
||||
def load_sharded_optimizer(self, optimizer: Optimizer, checkpoint: Path, prefix: str, size_per_shard: int):
|
||||
"""
|
||||
Load optimizer from sharded checkpoint.
|
||||
|
||||
Args:
|
||||
optimizer (Optimizer): optimizer to be loaded.
|
||||
checkpoint (str): checkpoint path. It should be path to the .index.json file or a path to a directory which contains a .index.json file.
|
||||
prefix (str): prefix for the optimizer checkpoint.
|
||||
size_per_shard (int): size per shard in MB.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def load_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: Path):
|
||||
"""
|
||||
Load optimizer from unsharded checkpoint.
|
||||
|
||||
Args:
|
||||
optimizer (Optimizer): optimizer to be loaded.
|
||||
checkpoint (str): checkpoint path. It should be a single file path pointing to a model weight binary.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def save_sharded_optimizer(self, optimizer: Optimizer, checkpoint: Path, prefix: str, size_per_shard: int):
|
||||
"""
|
||||
Save optimizer to sharded checkpoint.
|
||||
|
||||
Args:
|
||||
optimizer (Optimizer): optimizer to be saved.
|
||||
checkpoint (Path): checkpoint path. It should be a directory path.
|
||||
prefix (str): prefix for the optimizer checkpoint.
|
||||
size_per_shard (int): size per shard in MB.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def save_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: Path):
|
||||
"""
|
||||
Save optimizer to unsharded checkpoint.
|
||||
|
||||
Args:
|
||||
optimizer (Optimizer): optimizer to be saved.
|
||||
checkpoint (str): checkpoint path. It should be a single file path pointing to a model weight binary.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
|
|
@ -10,57 +10,36 @@ __all__ = ['GeneralCheckpointIO']
|
|||
|
||||
class GeneralCheckpointIO(CheckpointIO):
|
||||
|
||||
def load_model(self, model: nn.Module, checkpoint: str, strict: bool = True):
|
||||
checkpoint = Path(checkpoint)
|
||||
is_sharded = self.is_sharded_checkpoint(checkpoint)
|
||||
def load_sharded_model(self, model: nn.Module, checkpoint: Path, strict: bool):
|
||||
index_file_path = self.get_sharded_checkpoint_index_file(checkpoint)
|
||||
|
||||
if not is_sharded:
|
||||
checkpoint = self.load_state_dict(checkpoint)
|
||||
model.load_state_dict(checkpoint, strict=strict)
|
||||
else:
|
||||
# find the index file
|
||||
checkpoint_path = Path(checkpoint)
|
||||
index_file_path = self.get_sharded_checkpoint_index_file(checkpoint_path)
|
||||
# iterate over the shard checkpoint files
|
||||
# and load each
|
||||
shard_files = self.get_checkpoint_shard_filenames(index_file_path)
|
||||
for shard_file in shard_files:
|
||||
shard_checkpoint = self.load_state_dict(shard_file)
|
||||
model.load_state_dict(shard_checkpoint, strict=strict)
|
||||
|
||||
# iterate over the shard checkpoint files
|
||||
# and load each
|
||||
shard_files = self.get_checkpoint_shard_filenames(index_file_path)
|
||||
for shard_file in shard_files:
|
||||
shard_checkpoint = self.load_state_dict(shard_file)
|
||||
model.load_state_dict(shard_checkpoint, strict=strict)
|
||||
def load_unsharded_model(self, model: nn.Module, checkpoint: Path, strict: bool):
|
||||
checkpoint = self.load_state_dict(str(checkpoint))
|
||||
model.load_state_dict(checkpoint, strict=strict)
|
||||
|
||||
return model
|
||||
def save_sharded_model(self, model: nn.Module, checkpoint: Path, prefix: str, size_per_shard: int):
|
||||
# TODO(FrankLeeeee): implement this method as it can be supported by Huggingface model
|
||||
raise NotImplementedError("Sharded model checkpoint is not supported yet.")
|
||||
|
||||
def save_model(self,
|
||||
model: nn.Module,
|
||||
checkpoint: str,
|
||||
prefix: str = None,
|
||||
shard: bool = False,
|
||||
size_per_shard: int = 1024):
|
||||
checkpoint = Path(checkpoint)
|
||||
if shard:
|
||||
# TODO(FrankLeeeee): implement checkpoint saving to sharded checkpoint
|
||||
raise NotImplementedError("Not implemented yet")
|
||||
else:
|
||||
self.save_checkpoint(model.state_dict(), checkpoint)
|
||||
def save_unsharded_model(self, model: nn.Module, checkpoint: Path):
|
||||
self.save_checkpoint(model.state_dict(), checkpoint)
|
||||
|
||||
def load_optimizer(self, optimizer: Optimizer, checkpoint: str):
|
||||
checkpoint = Path(checkpoint)
|
||||
is_sharded = self.is_sharded_checkpoint(checkpoint)
|
||||
def load_sharded_optimizer(self, optimizer: Optimizer, checkpoint: Path, prefix: str, size_per_shard: int):
|
||||
raise NotImplementedError("Sharded optimizer checkpoint is not supported yet.")
|
||||
|
||||
if not is_sharded:
|
||||
checkpoint = self.load_state_dict(checkpoint)
|
||||
optimizer.load_state_dict(checkpoint)
|
||||
else:
|
||||
# TODO(FrankLeeeee): implement checkpoint loading from sharded checkpoint
|
||||
# This is not an urgent feature, so we can leave it for later
|
||||
# let's implement this when we test large-scale models
|
||||
pass
|
||||
return optimizer
|
||||
def load_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: Path):
|
||||
checkpoint = self.load_state_dict(checkpoint)
|
||||
optimizer.load_state_dict(checkpoint)
|
||||
|
||||
def save_optimizer(self, optimizer: Optimizer, checkpoint: str, shard: bool = False, size_per_shard: int = 1024):
|
||||
if shard:
|
||||
# TODO(FrankLeeeee): implement checkpoint saving to sharded checkpoint
|
||||
pass
|
||||
else:
|
||||
self.save_checkpoint(optimizer.state_dict(), checkpoint)
|
||||
def save_sharded_optimizer(self, optimizer: Optimizer, checkpoint: Path, prefix: str, size_per_shard: int):
|
||||
raise NotImplementedError("Sharded optimizer checkpoint is not supported yet.")
|
||||
|
||||
def save_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: Path):
|
||||
self.save_checkpoint(optimizer.state_dict(), checkpoint)
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
import functools
|
||||
import os
|
||||
from contextlib import contextmanager
|
||||
|
||||
|
@ -141,12 +142,12 @@ class DistCoordinator(metaclass=SingletonMeta):
|
|||
should_block = rank != executor_rank
|
||||
|
||||
if should_block:
|
||||
dist.barrier(group=process_group)
|
||||
self.block_all(process_group)
|
||||
|
||||
yield
|
||||
|
||||
if not should_block:
|
||||
dist.barrier(group=process_group)
|
||||
self.block_all(process_group)
|
||||
|
||||
def destroy(self, process_group: ProcessGroup = None):
|
||||
"""
|
||||
|
@ -156,3 +157,38 @@ class DistCoordinator(metaclass=SingletonMeta):
|
|||
process_group (ProcessGroup, optional): process group to destroy. Defaults to None, which refers to the default process group.
|
||||
"""
|
||||
dist.destroy_process_group(process_group)
|
||||
|
||||
def block_all(self, process_group: ProcessGroup = None):
|
||||
"""
|
||||
Block all processes in the process group.
|
||||
|
||||
Args:
|
||||
process_group (ProcessGroup, optional): process group to block. Defaults to None, which refers to the default process group.
|
||||
"""
|
||||
dist.barrier(group=process_group)
|
||||
|
||||
def on_master_only(self, process_group: ProcessGroup = None):
|
||||
"""
|
||||
A function wrapper that only executes the wrapped function on the master process (rank 0).
|
||||
|
||||
Example:
|
||||
>>> from colossalai.cluster import DistCoordinator
|
||||
>>> dist_coordinator = DistCoordinator()
|
||||
>>>
|
||||
>>> @dist_coordinator.on_master_only()
|
||||
>>> def print_on_master(msg):
|
||||
>>> print(msg)
|
||||
"""
|
||||
is_master = self.is_master(process_group)
|
||||
|
||||
# define an inner functiuon
|
||||
def decorator(func):
|
||||
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
if is_master:
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
|
|
@ -0,0 +1,4 @@
|
|||
from .model import ModelWrapper
|
||||
from .optimizer import OptimizerWrapper
|
||||
|
||||
__all__ = ['OptimizerWrapper', 'ModelWrapper']
|
|
@ -0,0 +1,25 @@
|
|||
import torch.nn as nn
|
||||
|
||||
|
||||
class ModelWrapper(nn.Module):
|
||||
"""
|
||||
A wrapper class to define the common interface used by booster.
|
||||
|
||||
Args:
|
||||
module (nn.Module): The model to be wrapped.
|
||||
"""
|
||||
|
||||
def __init__(self, module: nn.Module) -> None:
|
||||
super().__init__()
|
||||
self.module = module
|
||||
|
||||
def unwrap(self):
|
||||
"""
|
||||
Unwrap the model to return the original model for checkpoint saving/loading.
|
||||
"""
|
||||
if isinstance(self.module, ModelWrapper):
|
||||
return self.module.unwrap()
|
||||
return self.module
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
return self.module(*args, **kwargs)
|
|
@ -0,0 +1,5 @@
|
|||
# New API Features
|
||||
|
||||
**The New API is not officially released yet.**
|
||||
|
||||
This folder contains some of the demonstrations of the new API. The new API is still under intensive development and will be released soon.
|
|
@ -0,0 +1,2 @@
|
|||
#!/usr/bin/env
|
||||
echo "The CI integration will be completed when the API is stable"
|
|
@ -0,0 +1,4 @@
|
|||
data
|
||||
checkpoint
|
||||
ckpt-fp16
|
||||
ckpt-fp32
|
|
@ -0,0 +1,44 @@
|
|||
# Distributed Data Parallel
|
||||
|
||||
## 🚀 Quick Start
|
||||
|
||||
This example provides a training script and and evaluation script. The training script provides a an example of training ResNet on CIFAR10 dataset from scratch.
|
||||
|
||||
- Training Arguments
|
||||
- `-r, `--resume`: resume from checkpoint file path
|
||||
- `-c`, `--checkpoint`: the folder to save checkpoints
|
||||
- `-i`, `--interval`: epoch interval to save checkpoints
|
||||
- `-f`, `--fp16`: use fp16
|
||||
|
||||
- Eval Arguments
|
||||
- `-e`, `--epoch`: select the epoch to evaluate
|
||||
- `-c`, `--checkpoint`: the folder where checkpoints are found
|
||||
|
||||
|
||||
### Train
|
||||
|
||||
```bash
|
||||
# train with torch DDP with fp32
|
||||
colossalai run --nproc_per_node 2 train.py -c ./ckpt-fp32
|
||||
|
||||
# train with torch DDP with mixed precision training
|
||||
colossalai run --nproc_per_node 2 train.py -c ./ckpt-fp16 --fp16
|
||||
```
|
||||
|
||||
### Eval
|
||||
|
||||
```bash
|
||||
# evaluate fp32 training
|
||||
python eval.py -c ./ckpt-fp32 -e 80
|
||||
|
||||
# evaluate fp16 mixed precision training
|
||||
python eval.py -c ./ckpt-fp16 -e 80
|
||||
```
|
||||
|
||||
Expected accuracy performance will be:
|
||||
|
||||
| Model | Single-GPU Baseline FP32 | Booster DDP with FP32 | Booster DDP with FP16 |
|
||||
| --------- | ------------------------ | --------------------- | --------------------- |
|
||||
| ResNet-18 | 85.85% | 85.03% | 85.12% |
|
||||
|
||||
**Note: the baseline is a adapted from the [script](https://pytorch-tutorial.readthedocs.io/en/latest/tutorial/chapter03_intermediate/3_2_2_cnn_resnet_cifar10/) to use `torchvision.models.resnet18`**
|
|
@ -0,0 +1,48 @@
|
|||
import argparse
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torchvision
|
||||
import torchvision.transforms as transforms
|
||||
|
||||
# ==============================
|
||||
# Parse Arguments
|
||||
# ==============================
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-e', '--epoch', type=int, default=80, help="resume from the epoch's checkpoint")
|
||||
parser.add_argument('-c', '--checkpoint', type=str, default='./checkpoint', help="checkpoint directory")
|
||||
args = parser.parse_args()
|
||||
|
||||
# ==============================
|
||||
# Prepare Test Dataset
|
||||
# ==============================
|
||||
# CIFAR-10 dataset
|
||||
test_dataset = torchvision.datasets.CIFAR10(root='./data/', train=False, transform=transforms.ToTensor())
|
||||
|
||||
# Data loader
|
||||
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=128, shuffle=False)
|
||||
|
||||
# ==============================
|
||||
# Load Model
|
||||
# ==============================
|
||||
model = torchvision.models.resnet18(num_classes=10).cuda()
|
||||
state_dict = torch.load(f'{args.checkpoint}/model_{args.epoch}.pth')
|
||||
model.load_state_dict(state_dict)
|
||||
|
||||
# ==============================
|
||||
# Run Evaluation
|
||||
# ==============================
|
||||
model.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
correct = 0
|
||||
total = 0
|
||||
for images, labels in test_loader:
|
||||
images = images.cuda()
|
||||
labels = labels.cuda()
|
||||
outputs = model(images)
|
||||
_, predicted = torch.max(outputs.data, 1)
|
||||
total += labels.size(0)
|
||||
correct += (predicted == labels).sum().item()
|
||||
|
||||
print('Accuracy of the model on the test images: {} %'.format(100 * correct / total))
|
|
@ -0,0 +1,128 @@
|
|||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torchvision
|
||||
import torchvision.transforms as transforms
|
||||
from torch.optim.lr_scheduler import MultiStepLR
|
||||
|
||||
import colossalai
|
||||
from colossalai.booster import Booster
|
||||
from colossalai.booster.plugin import TorchDDPPlugin
|
||||
from colossalai.cluster import DistCoordinator
|
||||
|
||||
# ==============================
|
||||
# Parse Arguments
|
||||
# ==============================
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-r', '--resume', type=int, default=-1, help="resume from the epoch's checkpoint")
|
||||
parser.add_argument('-c', '--checkpoint', type=str, default='./checkpoint', help="checkpoint directory")
|
||||
parser.add_argument('-i', '--interval', type=int, default=5, help="interval of saving checkpoint")
|
||||
parser.add_argument('-f', '--fp16', action='store_true', help="use fp16")
|
||||
args = parser.parse_args()
|
||||
|
||||
# ==============================
|
||||
# Prepare Checkpoint Directory
|
||||
# ==============================
|
||||
Path(args.checkpoint).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# ==============================
|
||||
# Prepare Hyperparameters
|
||||
# ==============================
|
||||
NUM_EPOCHS = 80
|
||||
LEARNING_RATE = 1e-3
|
||||
START_EPOCH = args.resume if args.resume >= 0 else 0
|
||||
|
||||
# ==============================
|
||||
# Launch Distributed Environment
|
||||
# ==============================
|
||||
colossalai.launch_from_torch(config={})
|
||||
coordinator = DistCoordinator()
|
||||
|
||||
# update the learning rate with linear scaling
|
||||
# old_gpu_num / old_lr = new_gpu_num / new_lr
|
||||
LEARNING_RATE *= coordinator.world_size
|
||||
|
||||
# ==============================
|
||||
# Prepare Booster
|
||||
# ==============================
|
||||
plugin = TorchDDPPlugin()
|
||||
if args.fp16:
|
||||
booster = Booster(mixed_precision='fp16', plugin=plugin)
|
||||
else:
|
||||
booster = Booster(plugin=plugin)
|
||||
|
||||
# ==============================
|
||||
# Prepare Train Dataset
|
||||
# ==============================
|
||||
transform = transforms.Compose(
|
||||
[transforms.Pad(4),
|
||||
transforms.RandomHorizontalFlip(),
|
||||
transforms.RandomCrop(32),
|
||||
transforms.ToTensor()])
|
||||
|
||||
# CIFAR-10 dataset
|
||||
with coordinator.priority_execution():
|
||||
train_dataset = torchvision.datasets.CIFAR10(root='./data/', train=True, transform=transform, download=True)
|
||||
|
||||
# ====================================
|
||||
# Prepare model, optimizer, criterion
|
||||
# ====================================
|
||||
# resent50
|
||||
model = torchvision.models.resnet18(num_classes=10).cuda()
|
||||
|
||||
# Loss and optimizer
|
||||
criterion = nn.CrossEntropyLoss()
|
||||
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
|
||||
|
||||
# lr scheduler
|
||||
lr_scheduler = MultiStepLR(optimizer, milestones=[20, 40, 60, 80], gamma=1 / 3)
|
||||
|
||||
# prepare dataloader with torch ddp plugin
|
||||
train_dataloader = plugin.prepare_train_dataloader(train_dataset, batch_size=100, shuffle=True)
|
||||
|
||||
# ==============================
|
||||
# Resume from checkpoint
|
||||
# ==============================
|
||||
if args.resume >= 0:
|
||||
booster.load_model(model, f'{args.checkpoint}/model_{args.resume}.pth')
|
||||
booster.load_optimizer(optimizer, f'{args.checkpoint}/optimizer_{args.resume}.pth')
|
||||
booster.load_lr_scheduler(lr_scheduler, f'{args.checkpoint}/lr_scheduler_{args.resume}.pth')
|
||||
|
||||
# ==============================
|
||||
# Boost with ColossalAI
|
||||
# ==============================
|
||||
model, optimizer, criterion, train_dataloader, lr_scheduler = booster.boost(model, optimizer, criterion,
|
||||
train_dataloader, lr_scheduler)
|
||||
|
||||
# ==============================
|
||||
# Train model
|
||||
# ==============================
|
||||
total_step = len(train_dataloader)
|
||||
|
||||
for epoch in range(START_EPOCH, NUM_EPOCHS):
|
||||
for i, (images, labels) in enumerate(train_dataloader):
|
||||
images = images.cuda()
|
||||
labels = labels.cuda()
|
||||
|
||||
# Forward pass
|
||||
outputs = model(images)
|
||||
loss = criterion(outputs, labels)
|
||||
|
||||
# Backward and optimize
|
||||
optimizer.zero_grad()
|
||||
booster.backward(loss, optimizer)
|
||||
optimizer.step()
|
||||
|
||||
if (i + 1) % 100 == 0:
|
||||
print("Epoch [{}/{}], Step [{}/{}] Loss: {:.4f}".format(epoch + 1, NUM_EPOCHS, i + 1, total_step,
|
||||
loss.item()))
|
||||
|
||||
lr_scheduler.step()
|
||||
|
||||
# save checkpoint every 5 epoch
|
||||
if (epoch + 1) % args.interval == 0:
|
||||
booster.save_model(model, f'{args.checkpoint}/model_{epoch + 1}.pth')
|
||||
booster.save_optimizer(optimizer, f'{args.checkpoint}/optimizer_{epoch + 1}.pth')
|
||||
booster.save_lr_scheduler(lr_scheduler, f'{args.checkpoint}/lr_scheduler_{epoch + 1}.pth')
|
|
@ -8,8 +8,8 @@ from torch.optim import SGD
|
|||
|
||||
import colossalai
|
||||
from colossalai.booster import Booster
|
||||
from colossalai.booster.interface import OptimizerWrapper
|
||||
from colossalai.booster.plugin import TorchDDPPlugin
|
||||
from colossalai.interface import OptimizerWrapper
|
||||
from colossalai.testing import rerun_if_address_is_in_use
|
||||
from colossalai.utils import free_port
|
||||
from tests.kit.model_zoo import model_zoo
|
||||
|
@ -34,7 +34,7 @@ def check_torch_ddp_plugin():
|
|||
|
||||
model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion)
|
||||
|
||||
assert isinstance(model, DDP)
|
||||
assert isinstance(model.module, DDP)
|
||||
assert isinstance(optimizer, OptimizerWrapper)
|
||||
|
||||
output = model(**data)
|
||||
|
|
|
@ -42,8 +42,8 @@ def test_unsharded_checkpoint():
|
|||
new_optimizer = Adam(new_model.parameters(), lr=0.001)
|
||||
|
||||
# load the model and optimizer
|
||||
new_model = ckpt_io.load_model(new_model, model_ckpt_tempfile.name)
|
||||
new_optimizer = ckpt_io.load_optimizer(new_optimizer, optimizer_ckpt_tempfile.name)
|
||||
ckpt_io.load_model(new_model, model_ckpt_tempfile.name)
|
||||
ckpt_io.load_optimizer(new_optimizer, optimizer_ckpt_tempfile.name)
|
||||
|
||||
# do recursive check for the optimizer state dict
|
||||
# if the value is a dict, compare its values
|
||||
|
|
Loading…
Reference in New Issue