[booster] implemented the torch ddd + resnet example (#3232)

* [booster] implemented the torch ddd + resnet example

* polish code
pull/3239/head
Frank Lee 2023-03-27 10:24:14 +08:00 committed by GitHub
parent 1a229045af
commit 73d3e4d309
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
22 changed files with 608 additions and 128 deletions

View File

@ -1,4 +1,3 @@
from .accelerator import Accelerator
from .booster import Booster
from .environment_table import EnvironmentTable
from .plugin import Plugin

View File

@ -8,6 +8,8 @@ from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
from torch.utils.data import DataLoader
from colossalai.checkpoint_io import GeneralCheckpointIO
from .accelerator import Accelerator
from .mixed_precision import MixedPrecision, mixed_precision_factory
from .plugin import Plugin
@ -61,19 +63,21 @@ class Booster:
self.plugin = plugin
# set accelerator
if self.plugin and self.plugin.control_device:
if self.plugin and self.plugin.control_device():
self.accelerator = None
warnings.warn('The plugin will control the accelerator, so the device argument will be ignored.')
else:
self.accelerator = Accelerator(device)
# set precision
if mixed_precision is None or (self.plugin and self.plugin.control_precision):
self.mixed_precision = None
if self.plugin and self.plugin.control_precision():
warnings.warn('The plugin will control the precision, so the mixed_precision argument will be ignored.')
self.mixed_precision = None
elif mixed_precision is None:
self.mixed_precision = None
else:
# validate and set precision
if isinstance(MixedPrecision, str):
if isinstance(mixed_precision, str):
# the user will take the default arguments for amp training
self.mixed_precision = mixed_precision_factory(mixed_precision)
elif isinstance(mixed_precision, MixedPrecision):
@ -84,6 +88,11 @@ class Booster:
f'Expected the argument mixed_precision to be a string or an instance of Precision, but got {type(mixed_precision)}.'
)
if self.plugin is not None and self.plugin.control_checkpoint_io():
self.checkpoint_io = self.plugin.get_checkpoint_io()
else:
self.checkpoint_io = GeneralCheckpointIO()
def boost(
self,
model: nn.Module,
@ -109,12 +118,13 @@ class Booster:
model, optimizer, criterion, dataloader, lr_scheduler = self.plugin.configure(
model, optimizer, criterion, dataloader, lr_scheduler)
if self.plugin and not self.plugin.control_device:
if self.plugin and not self.plugin.control_device():
# transform model for accelerator
model = self.accelerator.configure(model)
if self.mixed_precision and self.plugin and not self.plugin.control_precision:
if self.mixed_precision and (self.plugin is None or self.plugin and not self.plugin.control_precision()):
# transform model for mixed precision
# when mixed_precision is specified and the plugin is not given or does not control the precision
model, optimizer, criterion = self.mixed_precision.configure(model, optimizer, criterion)
return model, optimizer, criterion, dataloader, lr_scheduler
@ -140,18 +150,25 @@ class Booster:
assert self.plugin.support_no_sync, f'The plugin {self.plugin.__class__.__name__} does not support no_sync.'
return self.plugin.no_sync(model)
def save(self,
obj: Union[nn.Module, Optimizer, LRScheduler],
path_like: str,
plan: str = 'torch',
**kwargs) -> None:
# TODO: implement this method
pass
def load_model(self, model: nn.Module, checkpoint: str, strict: bool = True):
self.checkpoint_io.load_model(model, checkpoint, strict)
def load(self,
obj: Union[nn.Module, Optimizer, LRScheduler],
path_like: str,
plan: str = 'torch',
**kwargs) -> None:
# TODO: implement this method
pass
def save_model(self,
model: nn.Module,
checkpoint: str,
prefix: str = None,
shard: bool = False,
size_per_shard: int = 1024):
self.checkpoint_io.save_model(model, checkpoint, prefix, shard, size_per_shard)
def load_optimizer(self, optimizer: Optimizer, checkpoint: str):
self.checkpoint_io.load_optimizer(optimizer, checkpoint)
def save_optimizer(self, optimizer: Optimizer, checkpoint: str, shard: bool = False, size_per_shard: int = 1024):
self.checkpoint_io.save_optimizer(optimizer, checkpoint, shard, size_per_shard)
def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str):
self.checkpoint_io.save_lr_scheduler(lr_scheduler, checkpoint)
def load_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str):
self.checkpoint_io.load_lr_scheduler(lr_scheduler, checkpoint)

View File

@ -1,18 +0,0 @@
from typing import List
__all__ = ['EnvironmentTable']
class EnvironmentTable:
def __init__(self, intra_op_world_sizes: List[int]):
# TODO: implement this method
pass
@property
def is_master(self) -> bool:
# TODO: implement this method
pass
# TODO: implement more utility methods as given in
# https://github.com/hpcaitech/ColossalAI/issues/3051

View File

@ -1,3 +0,0 @@
from .optimizer import OptimizerWrapper
__all__ = ['OptimizerWrapper']

View File

@ -5,7 +5,8 @@ import torch.nn as nn
from torch import Tensor
from torch.optim import Optimizer
from ..interface import OptimizerWrapper
from colossalai.interface import ModelWrapper, OptimizerWrapper
from .mixed_precision_base import MixedPrecision
__all__ = ['FP16_Torch_MixedPrecision', 'TorchAMPOptimizer', 'TorchAMPModule']
@ -45,7 +46,9 @@ class TorchAMPOptimizer(OptimizerWrapper):
scaled_loss.backward(*args, **kwargs)
def step(self, *args, **kwargs) -> Optional[float]:
return self.scaler.step(self.optim, *args, **kwargs)
out = self.scaler.step(self.optim, *args, **kwargs)
self.scaler.update()
return out
def scale_loss(self, loss: Tensor) -> Tensor:
return self.scaler.scale(loss)
@ -67,7 +70,7 @@ class TorchAMPOptimizer(OptimizerWrapper):
super().clip_grad_by_norm(max_norm, norm_type, error_if_nonfinite, *args, **kwargs)
class TorchAMPModule(nn.Module):
class TorchAMPModule(ModelWrapper):
"""
Module wrapper for mixed precision training in FP16 using PyTorch AMP.
@ -76,8 +79,7 @@ class TorchAMPModule(nn.Module):
"""
def __init__(self, module: nn.Module):
super().__init__()
self.module = module
super().__init__(module)
def forward(self, *args, **kwargs):
with torch.cuda.amp.autocast():

View File

@ -4,7 +4,7 @@ from typing import Callable, Tuple
import torch.nn as nn
from torch.optim import Optimizer
from ..interface import OptimizerWrapper
from colossalai.interface import OptimizerWrapper
class MixedPrecision(ABC):

View File

@ -6,34 +6,30 @@ from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
from torch.utils.data import DataLoader
from colossalai.booster.interface import OptimizerWrapper
from colossalai.checkpoint_io import CheckpointIO
from colossalai.interface import OptimizerWrapper
__all__ = ['Plugin']
class Plugin(ABC):
@property
@abstractmethod
def supported_devices(self) -> List[str]:
pass
@property
@abstractmethod
def supported_precisions(self) -> List[str]:
pass
@property
@abstractmethod
def control_precision(self) -> bool:
pass
@property
@abstractmethod
def control_device(self) -> bool:
pass
@property
@abstractmethod
def support_no_sync(self) -> bool:
pass
@ -49,3 +45,17 @@ class Plugin(ABC):
) -> Tuple[Union[nn.Module, OptimizerWrapper, LRScheduler, DataLoader]]:
# implement this method
pass
@abstractmethod
def control_checkpoint_io(self) -> bool:
"""
Whether the plugin controls the checkpoint io
"""
pass
@abstractmethod
def get_checkpoint_io(self) -> CheckpointIO:
"""
Get checkpoint io object for this plugin, only invoked when control_checkpoint_io is True.
"""
pass

View File

@ -11,13 +11,61 @@ from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from colossalai.booster.interface import OptimizerWrapper
from colossalai.checkpoint_io import CheckpointIO, GeneralCheckpointIO
from colossalai.cluster import DistCoordinator
from colossalai.interface import ModelWrapper, OptimizerWrapper
from .plugin_base import Plugin
__all__ = ['TorchDDPPlugin']
class TorchDDPCheckpointIO(GeneralCheckpointIO):
def __init__(self) -> None:
super().__init__()
self.coordinator = DistCoordinator()
def load_unsharded_model(self, model: nn.Module, checkpoint: str, strict: bool = True):
"""
Load model from checkpoint with automatic unwrapping.
"""
# the model should be unwrapped in self.load_model via ModelWrapper.unwrap
return super().load_unsharded_model(model, checkpoint, strict=strict)
def save_unsharded_model(self, model: nn.Module, checkpoint: str):
"""
Save model to checkpoint but only on master process.
"""
# the model should be unwrapped in self.load_model via ModelWrapper.unwrap
if self.coordinator.is_master():
super().save_unsharded_model(model, checkpoint)
def save_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: str):
"""
Save optimizer to checkpoint but only on master process.
"""
if self.coordinator.is_master():
super().save_unsharded_optimizer(optimizer, checkpoint)
def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str):
"""
Save model to checkpoint but only on master process.
"""
if self.coordinator.is_master():
super().save_lr_scheduler(lr_scheduler, checkpoint)
class TorchDDPModel(ModelWrapper):
def __init__(self, module: nn.Module, *args, **kwargs) -> None:
super().__init__(module)
self.module = DDP(module, *args, **kwargs)
def unwrap(self):
return self.module.module
class TorchDDPPlugin(Plugin):
"""
Plugin for PyTorch DDP.
@ -138,10 +186,19 @@ class TorchDDPPlugin(Plugin):
# cast model to cuda
model = model.cuda()
# convert model to sync bn
model = nn.SyncBatchNorm.convert_sync_batchnorm(model, None)
# wrap the model with PyTorch DDP
model = DDP(model, **self.ddp_kwargs)
model = TorchDDPModel(model, **self.ddp_kwargs)
if not isinstance(optimizer, OptimizerWrapper):
optimizer = OptimizerWrapper(optimizer)
return model, optimizer, criterion, dataloader, lr_scheduler
def control_checkpoint_io(self) -> bool:
return True
def get_checkpoint_io(self) -> CheckpointIO:
return TorchDDPCheckpointIO()

View File

@ -1,13 +1,15 @@
import json
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Any
from typing import Any, Union
import torch
import torch.nn as nn
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
from colossalai.interface import ModelWrapper
__all__ = ['CheckpointIO', 'ShardCheckpointIndexFile']
@ -37,15 +39,15 @@ class CheckpointIO(ABC):
>>>
>>> # save optimizer to checkpoint
>>> checkpoint_io.save_optimizer(optimizer, 'optimizer.pt')
"""
# ======================================
# Abstract methods for implementation
# Public methods
# ======================================
@abstractmethod
def load_model(self, model: nn.Module, checkpoint: str, strict: bool = True):
def load_model(self,
model: Union[nn.Module, ModelWrapper],
checkpoint: str,
strict: bool = True) -> Union[nn.Module, ModelWrapper]:
"""
Load model from checkpoint.
@ -59,14 +61,26 @@ class CheckpointIO(ABC):
strict (bool): whether to strictly enforce that the param name in
the checkpoint match the keys returned by this module's.
"""
pass
ckpt_path = Path(checkpoint)
is_sharded = self.is_sharded_checkpoint(ckpt_path)
origin_model = model
if isinstance(model, ModelWrapper):
model = model.unwrap()
if is_sharded:
self.load_sharded_model(model, ckpt_path, strict)
else:
self.load_unsharded_model(model, ckpt_path, strict)
return origin_model
@abstractmethod
def save_model(self,
model: nn.Module,
model: Union[nn.Module, ModelWrapper],
checkpoint: str,
prefix: str = None,
shard: bool = False,
prefix: str = None,
size_per_shard: int = 1024):
"""
Save model to checkpoint.
@ -83,17 +97,24 @@ class CheckpointIO(ABC):
Args:
model (nn.Module): model to be saved.
checkpoint: checkpoint path. The checkpoint path can be :
checkpoint (str): checkpoint path. The checkpoint path can be :
1. a file path, e.g. 'model.pt'
2. a directory path to save the sharded checkpoint, e.g. './checkpoints/' when shard = True.
shard: whether to shard the checkpoint. Default: False. If set to True, the checkpoint will be sharded into
shard (bool): whether to shard the checkpoint. Default: False. If set to True, the checkpoint will be sharded into
multiple files. The model shards will be specificed by a `model.index.json` file. When shard = True, please ensure
that the checkpoint path is a directory path instead of a file path.
size_per_shard (int): size per shard in MB. Default: 1024. This value is only used when shard is set to True.
prefix (str): prefix for the model checkpoint file name when shard=True. Default: None.
size_per_shard (int): size per shard in MB. Default: 1024. This value is only used when shard = True.
"""
pass
@abstractmethod
if isinstance(model, ModelWrapper):
model = model.unwrap()
if shard:
self.save_sharded_model(model, checkpoint, prefix, size_per_shard)
else:
self.save_unsharded_model(model, checkpoint)
def load_optimizer(self, optimizer: Optimizer, checkpoint: str):
"""
Load optimizer from checkpoint.
@ -102,19 +123,139 @@ class CheckpointIO(ABC):
optimizer (Optimizer): optimizer to be loaded.
checkpoint (str): checkpoint path. This value is made compatiblity with the model checkpoints in the
"""
pass
ckpt_path = Path(checkpoint)
is_sharded = self.is_sharded_checkpoint(ckpt_path)
@abstractmethod
def save_optimizer(self, optimizer: Optimizer, checkpoint: str, shard: bool = False, size_per_shard: int = 1024):
if is_sharded:
self.load_sharded_optimizer(optimizer, ckpt_path)
else:
self.load_unsharded_optimizer(optimizer, ckpt_path)
def save_optimizer(self,
optimizer: Optimizer,
checkpoint: str,
shard: bool = False,
prefix: str = None,
size_per_shard: int = 1024):
"""
Save optimizer to checkpoint.
Args:
optimizer (Optimizer): optimizer to be saved.
checkpoint: checkpoint path. The checkpoint path can be :
checkpoint (str): checkpoint path. The checkpoint path can be :
1. a file path, e.g. 'model.pt'
2. a path to a json file which defines the index to the sharded checkpoint for the optimizer
3. a path to a folder containing a unique .index.json file for sharded checkpoint
shard (bool): whether to shard the checkpoint. Default: False. If set to True, the checkpoint will be sharded into
multiple files. The optimizer shards will be specificed by a `optimizer.index.json` file.
prefix (str): prefix for the optimizer checkpoint when shard = True. Default: None.
size_per_shard (int): size per shard in MB. Default: 1024. This value is only used when shard is set to True.
"""
if shard:
self.save_sharded_optimizer(optimizer, checkpoint, prefix, size_per_shard)
else:
self.save_unsharded_optimizer(optimizer, checkpoint)
# ========================================================
# Abstract methods for model loading/saving implementation
# ========================================================
@abstractmethod
def load_sharded_model(self, model: nn.Module, checkpoint: Path, strict: bool):
"""
Load model from sharded checkpoint.
Args:
model (nn.Module): model to be loaded.
checkpoint (str): checkpoint path. It should be path to the .index.json file or a path to a directory which contains a .index.json file.
"""
pass
@abstractmethod
def load_unsharded_model(self, model: nn.Module, checkpoint: Path, strict: bool):
"""
Load model from unsharded checkpoint.
Args:
model (nn.Module): model to be loaded.
checkpoint (str): checkpoint path. It should be a single file path pointing to a model weight binary.
strict (bool): whether to strictly enforce that the param name in
the checkpoint match the keys returned by this module's.
"""
pass
@abstractmethod
def save_sharded_model(self, model: nn.Module, checkpoint: Path, prefix: str, size_per_shard: int):
"""
Save model to sharded checkpoint.
Args:
model (nn.Module): model to be saved.
checkpoint (Path): checkpoint path. It should be a directory path.
prefix (str): prefix for the model checkpoint.
size_per_shard (int): size per shard in MB.
"""
pass
@abstractmethod
def save_unsharded_model(self, model: nn.Module, checkpoint: Path):
"""
Save model to unsharded checkpoint.
Args:
model (nn.Module): model to be saved.
checkpoint (Path): checkpoint path. It should be a single file path pointing to a model weight binary.
"""
pass
# ========================================================
# Abstract methods for optimizer loading/saving implementation
# ========================================================
@abstractmethod
def load_sharded_optimizer(self, optimizer: Optimizer, checkpoint: Path, prefix: str, size_per_shard: int):
"""
Load optimizer from sharded checkpoint.
Args:
optimizer (Optimizer): optimizer to be loaded.
checkpoint (str): checkpoint path. It should be path to the .index.json file or a path to a directory which contains a .index.json file.
prefix (str): prefix for the optimizer checkpoint.
size_per_shard (int): size per shard in MB.
"""
pass
@abstractmethod
def load_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: Path):
"""
Load optimizer from unsharded checkpoint.
Args:
optimizer (Optimizer): optimizer to be loaded.
checkpoint (str): checkpoint path. It should be a single file path pointing to a model weight binary.
"""
pass
@abstractmethod
def save_sharded_optimizer(self, optimizer: Optimizer, checkpoint: Path, prefix: str, size_per_shard: int):
"""
Save optimizer to sharded checkpoint.
Args:
optimizer (Optimizer): optimizer to be saved.
checkpoint (Path): checkpoint path. It should be a directory path.
prefix (str): prefix for the optimizer checkpoint.
size_per_shard (int): size per shard in MB.
"""
pass
@abstractmethod
def save_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: Path):
"""
Save optimizer to unsharded checkpoint.
Args:
optimizer (Optimizer): optimizer to be saved.
checkpoint (str): checkpoint path. It should be a single file path pointing to a model weight binary.
"""
pass

View File

@ -10,57 +10,36 @@ __all__ = ['GeneralCheckpointIO']
class GeneralCheckpointIO(CheckpointIO):
def load_model(self, model: nn.Module, checkpoint: str, strict: bool = True):
checkpoint = Path(checkpoint)
is_sharded = self.is_sharded_checkpoint(checkpoint)
def load_sharded_model(self, model: nn.Module, checkpoint: Path, strict: bool):
index_file_path = self.get_sharded_checkpoint_index_file(checkpoint)
if not is_sharded:
checkpoint = self.load_state_dict(checkpoint)
model.load_state_dict(checkpoint, strict=strict)
else:
# find the index file
checkpoint_path = Path(checkpoint)
index_file_path = self.get_sharded_checkpoint_index_file(checkpoint_path)
# iterate over the shard checkpoint files
# and load each
shard_files = self.get_checkpoint_shard_filenames(index_file_path)
for shard_file in shard_files:
shard_checkpoint = self.load_state_dict(shard_file)
model.load_state_dict(shard_checkpoint, strict=strict)
# iterate over the shard checkpoint files
# and load each
shard_files = self.get_checkpoint_shard_filenames(index_file_path)
for shard_file in shard_files:
shard_checkpoint = self.load_state_dict(shard_file)
model.load_state_dict(shard_checkpoint, strict=strict)
def load_unsharded_model(self, model: nn.Module, checkpoint: Path, strict: bool):
checkpoint = self.load_state_dict(str(checkpoint))
model.load_state_dict(checkpoint, strict=strict)
return model
def save_sharded_model(self, model: nn.Module, checkpoint: Path, prefix: str, size_per_shard: int):
# TODO(FrankLeeeee): implement this method as it can be supported by Huggingface model
raise NotImplementedError("Sharded model checkpoint is not supported yet.")
def save_model(self,
model: nn.Module,
checkpoint: str,
prefix: str = None,
shard: bool = False,
size_per_shard: int = 1024):
checkpoint = Path(checkpoint)
if shard:
# TODO(FrankLeeeee): implement checkpoint saving to sharded checkpoint
raise NotImplementedError("Not implemented yet")
else:
self.save_checkpoint(model.state_dict(), checkpoint)
def save_unsharded_model(self, model: nn.Module, checkpoint: Path):
self.save_checkpoint(model.state_dict(), checkpoint)
def load_optimizer(self, optimizer: Optimizer, checkpoint: str):
checkpoint = Path(checkpoint)
is_sharded = self.is_sharded_checkpoint(checkpoint)
def load_sharded_optimizer(self, optimizer: Optimizer, checkpoint: Path, prefix: str, size_per_shard: int):
raise NotImplementedError("Sharded optimizer checkpoint is not supported yet.")
if not is_sharded:
checkpoint = self.load_state_dict(checkpoint)
optimizer.load_state_dict(checkpoint)
else:
# TODO(FrankLeeeee): implement checkpoint loading from sharded checkpoint
# This is not an urgent feature, so we can leave it for later
# let's implement this when we test large-scale models
pass
return optimizer
def load_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: Path):
checkpoint = self.load_state_dict(checkpoint)
optimizer.load_state_dict(checkpoint)
def save_optimizer(self, optimizer: Optimizer, checkpoint: str, shard: bool = False, size_per_shard: int = 1024):
if shard:
# TODO(FrankLeeeee): implement checkpoint saving to sharded checkpoint
pass
else:
self.save_checkpoint(optimizer.state_dict(), checkpoint)
def save_sharded_optimizer(self, optimizer: Optimizer, checkpoint: Path, prefix: str, size_per_shard: int):
raise NotImplementedError("Sharded optimizer checkpoint is not supported yet.")
def save_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: Path):
self.save_checkpoint(optimizer.state_dict(), checkpoint)

View File

@ -1,3 +1,4 @@
import functools
import os
from contextlib import contextmanager
@ -141,12 +142,12 @@ class DistCoordinator(metaclass=SingletonMeta):
should_block = rank != executor_rank
if should_block:
dist.barrier(group=process_group)
self.block_all(process_group)
yield
if not should_block:
dist.barrier(group=process_group)
self.block_all(process_group)
def destroy(self, process_group: ProcessGroup = None):
"""
@ -156,3 +157,38 @@ class DistCoordinator(metaclass=SingletonMeta):
process_group (ProcessGroup, optional): process group to destroy. Defaults to None, which refers to the default process group.
"""
dist.destroy_process_group(process_group)
def block_all(self, process_group: ProcessGroup = None):
"""
Block all processes in the process group.
Args:
process_group (ProcessGroup, optional): process group to block. Defaults to None, which refers to the default process group.
"""
dist.barrier(group=process_group)
def on_master_only(self, process_group: ProcessGroup = None):
"""
A function wrapper that only executes the wrapped function on the master process (rank 0).
Example:
>>> from colossalai.cluster import DistCoordinator
>>> dist_coordinator = DistCoordinator()
>>>
>>> @dist_coordinator.on_master_only()
>>> def print_on_master(msg):
>>> print(msg)
"""
is_master = self.is_master(process_group)
# define an inner functiuon
def decorator(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
if is_master:
return func(*args, **kwargs)
return wrapper
return decorator

View File

@ -0,0 +1,4 @@
from .model import ModelWrapper
from .optimizer import OptimizerWrapper
__all__ = ['OptimizerWrapper', 'ModelWrapper']

View File

@ -0,0 +1,25 @@
import torch.nn as nn
class ModelWrapper(nn.Module):
"""
A wrapper class to define the common interface used by booster.
Args:
module (nn.Module): The model to be wrapped.
"""
def __init__(self, module: nn.Module) -> None:
super().__init__()
self.module = module
def unwrap(self):
"""
Unwrap the model to return the original model for checkpoint saving/loading.
"""
if isinstance(self.module, ModelWrapper):
return self.module.unwrap()
return self.module
def forward(self, *args, **kwargs):
return self.module(*args, **kwargs)

View File

@ -0,0 +1,5 @@
# New API Features
**The New API is not officially released yet.**
This folder contains some of the demonstrations of the new API. The new API is still under intensive development and will be released soon.

View File

@ -0,0 +1,2 @@
#!/usr/bin/env
echo "The CI integration will be completed when the API is stable"

View File

@ -0,0 +1,4 @@
data
checkpoint
ckpt-fp16
ckpt-fp32

View File

@ -0,0 +1,44 @@
# Distributed Data Parallel
## 🚀 Quick Start
This example provides a training script and and evaluation script. The training script provides a an example of training ResNet on CIFAR10 dataset from scratch.
- Training Arguments
- `-r, `--resume`: resume from checkpoint file path
- `-c`, `--checkpoint`: the folder to save checkpoints
- `-i`, `--interval`: epoch interval to save checkpoints
- `-f`, `--fp16`: use fp16
- Eval Arguments
- `-e`, `--epoch`: select the epoch to evaluate
- `-c`, `--checkpoint`: the folder where checkpoints are found
### Train
```bash
# train with torch DDP with fp32
colossalai run --nproc_per_node 2 train.py -c ./ckpt-fp32
# train with torch DDP with mixed precision training
colossalai run --nproc_per_node 2 train.py -c ./ckpt-fp16 --fp16
```
### Eval
```bash
# evaluate fp32 training
python eval.py -c ./ckpt-fp32 -e 80
# evaluate fp16 mixed precision training
python eval.py -c ./ckpt-fp16 -e 80
```
Expected accuracy performance will be:
| Model | Single-GPU Baseline FP32 | Booster DDP with FP32 | Booster DDP with FP16 |
| --------- | ------------------------ | --------------------- | --------------------- |
| ResNet-18 | 85.85% | 85.03% | 85.12% |
**Note: the baseline is a adapted from the [script](https://pytorch-tutorial.readthedocs.io/en/latest/tutorial/chapter03_intermediate/3_2_2_cnn_resnet_cifar10/) to use `torchvision.models.resnet18`**

View File

@ -0,0 +1,48 @@
import argparse
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
# ==============================
# Parse Arguments
# ==============================
parser = argparse.ArgumentParser()
parser.add_argument('-e', '--epoch', type=int, default=80, help="resume from the epoch's checkpoint")
parser.add_argument('-c', '--checkpoint', type=str, default='./checkpoint', help="checkpoint directory")
args = parser.parse_args()
# ==============================
# Prepare Test Dataset
# ==============================
# CIFAR-10 dataset
test_dataset = torchvision.datasets.CIFAR10(root='./data/', train=False, transform=transforms.ToTensor())
# Data loader
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=128, shuffle=False)
# ==============================
# Load Model
# ==============================
model = torchvision.models.resnet18(num_classes=10).cuda()
state_dict = torch.load(f'{args.checkpoint}/model_{args.epoch}.pth')
model.load_state_dict(state_dict)
# ==============================
# Run Evaluation
# ==============================
model.eval()
with torch.no_grad():
correct = 0
total = 0
for images, labels in test_loader:
images = images.cuda()
labels = labels.cuda()
outputs = model(images)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
print('Accuracy of the model on the test images: {} %'.format(100 * correct / total))

View File

@ -0,0 +1,128 @@
import argparse
from pathlib import Path
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from torch.optim.lr_scheduler import MultiStepLR
import colossalai
from colossalai.booster import Booster
from colossalai.booster.plugin import TorchDDPPlugin
from colossalai.cluster import DistCoordinator
# ==============================
# Parse Arguments
# ==============================
parser = argparse.ArgumentParser()
parser.add_argument('-r', '--resume', type=int, default=-1, help="resume from the epoch's checkpoint")
parser.add_argument('-c', '--checkpoint', type=str, default='./checkpoint', help="checkpoint directory")
parser.add_argument('-i', '--interval', type=int, default=5, help="interval of saving checkpoint")
parser.add_argument('-f', '--fp16', action='store_true', help="use fp16")
args = parser.parse_args()
# ==============================
# Prepare Checkpoint Directory
# ==============================
Path(args.checkpoint).mkdir(parents=True, exist_ok=True)
# ==============================
# Prepare Hyperparameters
# ==============================
NUM_EPOCHS = 80
LEARNING_RATE = 1e-3
START_EPOCH = args.resume if args.resume >= 0 else 0
# ==============================
# Launch Distributed Environment
# ==============================
colossalai.launch_from_torch(config={})
coordinator = DistCoordinator()
# update the learning rate with linear scaling
# old_gpu_num / old_lr = new_gpu_num / new_lr
LEARNING_RATE *= coordinator.world_size
# ==============================
# Prepare Booster
# ==============================
plugin = TorchDDPPlugin()
if args.fp16:
booster = Booster(mixed_precision='fp16', plugin=plugin)
else:
booster = Booster(plugin=plugin)
# ==============================
# Prepare Train Dataset
# ==============================
transform = transforms.Compose(
[transforms.Pad(4),
transforms.RandomHorizontalFlip(),
transforms.RandomCrop(32),
transforms.ToTensor()])
# CIFAR-10 dataset
with coordinator.priority_execution():
train_dataset = torchvision.datasets.CIFAR10(root='./data/', train=True, transform=transform, download=True)
# ====================================
# Prepare model, optimizer, criterion
# ====================================
# resent50
model = torchvision.models.resnet18(num_classes=10).cuda()
# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
# lr scheduler
lr_scheduler = MultiStepLR(optimizer, milestones=[20, 40, 60, 80], gamma=1 / 3)
# prepare dataloader with torch ddp plugin
train_dataloader = plugin.prepare_train_dataloader(train_dataset, batch_size=100, shuffle=True)
# ==============================
# Resume from checkpoint
# ==============================
if args.resume >= 0:
booster.load_model(model, f'{args.checkpoint}/model_{args.resume}.pth')
booster.load_optimizer(optimizer, f'{args.checkpoint}/optimizer_{args.resume}.pth')
booster.load_lr_scheduler(lr_scheduler, f'{args.checkpoint}/lr_scheduler_{args.resume}.pth')
# ==============================
# Boost with ColossalAI
# ==============================
model, optimizer, criterion, train_dataloader, lr_scheduler = booster.boost(model, optimizer, criterion,
train_dataloader, lr_scheduler)
# ==============================
# Train model
# ==============================
total_step = len(train_dataloader)
for epoch in range(START_EPOCH, NUM_EPOCHS):
for i, (images, labels) in enumerate(train_dataloader):
images = images.cuda()
labels = labels.cuda()
# Forward pass
outputs = model(images)
loss = criterion(outputs, labels)
# Backward and optimize
optimizer.zero_grad()
booster.backward(loss, optimizer)
optimizer.step()
if (i + 1) % 100 == 0:
print("Epoch [{}/{}], Step [{}/{}] Loss: {:.4f}".format(epoch + 1, NUM_EPOCHS, i + 1, total_step,
loss.item()))
lr_scheduler.step()
# save checkpoint every 5 epoch
if (epoch + 1) % args.interval == 0:
booster.save_model(model, f'{args.checkpoint}/model_{epoch + 1}.pth')
booster.save_optimizer(optimizer, f'{args.checkpoint}/optimizer_{epoch + 1}.pth')
booster.save_lr_scheduler(lr_scheduler, f'{args.checkpoint}/lr_scheduler_{epoch + 1}.pth')

View File

@ -8,8 +8,8 @@ from torch.optim import SGD
import colossalai
from colossalai.booster import Booster
from colossalai.booster.interface import OptimizerWrapper
from colossalai.booster.plugin import TorchDDPPlugin
from colossalai.interface import OptimizerWrapper
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils import free_port
from tests.kit.model_zoo import model_zoo
@ -34,7 +34,7 @@ def check_torch_ddp_plugin():
model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion)
assert isinstance(model, DDP)
assert isinstance(model.module, DDP)
assert isinstance(optimizer, OptimizerWrapper)
output = model(**data)

View File

@ -42,8 +42,8 @@ def test_unsharded_checkpoint():
new_optimizer = Adam(new_model.parameters(), lr=0.001)
# load the model and optimizer
new_model = ckpt_io.load_model(new_model, model_ckpt_tempfile.name)
new_optimizer = ckpt_io.load_optimizer(new_optimizer, optimizer_ckpt_tempfile.name)
ckpt_io.load_model(new_model, model_ckpt_tempfile.name)
ckpt_io.load_optimizer(new_optimizer, optimizer_ckpt_tempfile.name)
# do recursive check for the optimizer state dict
# if the value is a dict, compare its values