[plugin] add 3d parallel plugin (#4295)

* [amp] add mixed precision optimizer

* [plugin] add 3d parallel plugin

* [booster] support pipeline

* [plugin] 3d parallel plugin support clip grad norm

* [shardformer] fix sharder and add plugin test

* [plugin] rename 3d parallel plugin

* [ci] support testmon core pkg change detection (#4305)

* [hotfix] debug testmon

* [hotfix] fix llama

* [hotfix] fix p2p bugs

* [hotfix] fix requirements
pull/4445/head
Hongxin Liu 2023-07-26 00:53:57 +08:00
parent b3f5d7a3ba
commit 261eab02fb
9 changed files with 621 additions and 24 deletions

View File

@ -0,0 +1,149 @@
from typing import Dict, List
import torch
from torch import Tensor
from torch.nn import Parameter
from torch.optim import Optimizer
from colossalai.interface import OptimizerWrapper
from .mixed_precision_mixin import BF16MixedPrecisionMixin, FP16MixedPrecisionMixin
class NaiveFP16MixedPrecisionMixin(FP16MixedPrecisionMixin):
def __init__(self,
working_params: List[Parameter],
initial_scale: float = 2**16,
min_scale: float = 1,
growth_factor: float = 2,
backoff_factor: float = 0.5,
growth_interval: int = 1000,
hysteresis: int = 2,
max_scale: float = 2**32) -> None:
super().__init__(initial_scale, min_scale, growth_factor, backoff_factor, growth_interval, hysteresis,
max_scale)
self.params = working_params
def check_local_overflow(self) -> bool:
for p in self.params:
if p.grad is not None and not torch.isfinite(p.grad).all():
return True
return False
class MixedPrecisionOptimizer(OptimizerWrapper):
def __init__(self,
optim: Optimizer,
precision: str = 'fp16',
initial_scale: float = 2**16,
min_scale: float = 1,
growth_factor: float = 2,
backoff_factor: float = 0.5,
growth_interval: int = 1000,
hysteresis: int = 2,
max_scale: float = 2**32,
max_norm: float = 0.0):
super().__init__(optim)
if precision == 'fp16':
working_params = []
for group in self.optim.param_groups:
for p in group['params']:
working_params.append(p)
self.mixed_precision = NaiveFP16MixedPrecisionMixin(working_params,
initial_scale=initial_scale,
min_scale=min_scale,
growth_factor=growth_factor,
backoff_factor=backoff_factor,
growth_interval=growth_interval,
hysteresis=hysteresis,
max_scale=max_scale)
elif precision == 'bf16':
self.mixed_precision = BF16MixedPrecisionMixin()
else:
raise ValueError(f'Unsupported precision: {precision}')
if max_norm > 0.0:
raise NotImplementedError('max_norm is not supported yet.')
self.max_norm = max_norm
self.working_to_master_map: Dict[Parameter, Tensor] = {}
self.master_to_working_map: Dict[Tensor, Parameter] = {}
# create master weights
for group in self.optim.param_groups:
master_params = []
for p in group['params']:
if p.requires_grad:
master_p = p
if p.dtype != torch.float:
master_p = p.detach().float()
self.working_to_master_map[p] = master_p
self.master_to_working_map[master_p] = p
master_params.append(master_p)
group['params'] = master_params
def backward(self, loss: Tensor, *args, **kwargs):
loss = self.mixed_precision.pre_backward(loss)
loss.backward(*args, **kwargs)
def backward_by_grad(self, tensor: Tensor, grad: Tensor):
grad = self.mixed_precision.pre_backward_by_grad(tensor, grad)
tensor.backward(grad)
def zero_grad(self, *args, **kwargs):
for p in self.working_to_master_map.keys():
p.grad = None
self.mixed_precision.pre_zero_grad()
return super().zero_grad(*args, **kwargs)
def _unscale_and_clip_grads(self, total_norm: float) -> None:
div_scale = 1.0
if self.mixed_precision is not None:
div_scale = self.mixed_precision.get_grad_div_scale()
if self.max_norm > 0.:
# norm is in fact norm*scale
clip = ((total_norm / div_scale) + 1e-6) / self.max_norm
if clip > 1:
div_scale = clip * div_scale
for group in self.param_groups:
for p in group['params']:
if p.grad is None:
continue
p.grad.data.mul_(1. / div_scale)
def _compute_grad_norm(self) -> float:
if self.max_norm <= 0.:
return 0.
grads = [p.grad for group in self.param_groups for p in group['params'] if p.grad is not None]
if len(grads) == 0:
return 0.
device = grads[0].device
# TODO(ver217): support tp
total_norm = torch.norm(torch.stack([torch.norm(g.detach(), 2).to(device) for g in grads]), 2)
return total_norm.item()
def step(self, *args, **kwargs):
if self.mixed_precision.should_skip_step():
self.zero_grad()
return
# prepare grads
for group in self.optim.param_groups:
for p in group['params']:
working_param = self.master_to_working_map[p]
if p is working_param:
continue
if working_param.grad is None:
p.grad = working_param.grad.data.float()
working_param.grad = None
total_norm = self._compute_grad_norm()
self._unscale_and_clip_grads(total_norm)
self.optim.step(*args, **kwargs)
# update working params
for group in self.optim.param_groups:
for p in group['params']:
working_param = self.master_to_working_map[p]
if p is working_param:
continue
working_param.data.copy_(p.data)

View File

@ -1,6 +1,6 @@
import warnings
from contextlib import contextmanager
from typing import Callable, Iterator, List, Optional, Tuple, Union
from typing import Any, Callable, Iterator, List, Optional, Union
import torch
import torch.nn as nn
@ -14,6 +14,7 @@ from colossalai.interface import ModelWrapper, OptimizerWrapper
from .accelerator import Accelerator
from .mixed_precision import MixedPrecision, mixed_precision_factory
from .plugin import Plugin
from .plugin.pp_plugin_base import PipelinePluginBase
__all__ = ['Booster']
@ -144,14 +145,15 @@ class Booster:
def execute_pipeline(self,
data_iter: Iterator,
model: nn.Module,
criterion: Callable[[torch.Tensor], torch.Tensor],
criterion: Callable[[Any, Any], torch.Tensor],
optimizer: Optimizer,
return_loss: bool = True,
return_outputs: bool = False) -> Tuple[Optional[torch.Tensor], ...]:
# TODO: implement this method
return_outputs: bool = False) -> dict:
# run pipeline forward backward pass
# return loss or outputs if needed
pass
assert isinstance(self.plugin,
PipelinePluginBase), f'The plugin {self.plugin.__class__.__name__} does not support pipeline.'
return self.plugin.execute_pipeline(data_iter, model, criterion, optimizer, return_loss, return_outputs)
def no_sync(self, model: nn.Module = None, optimizer: OptimizerWrapper = None) -> contextmanager:
"""Context manager to disable gradient synchronization across DP process groups.

View File

@ -1,9 +1,10 @@
from .gemini_plugin import GeminiPlugin
from .hybrid_parallel_plugin import HybridParallelPlugin
from .low_level_zero_plugin import LowLevelZeroPlugin
from .plugin_base import Plugin
from .torch_ddp_plugin import TorchDDPPlugin
__all__ = ['Plugin', 'TorchDDPPlugin', 'GeminiPlugin', 'LowLevelZeroPlugin']
__all__ = ['Plugin', 'TorchDDPPlugin', 'GeminiPlugin', 'LowLevelZeroPlugin', 'HybridParallelPlugin']
import torch
from packaging import version

View File

@ -0,0 +1,316 @@
import random
from contextlib import nullcontext
from typing import Any, Callable, Iterator, List, Optional, Tuple, Union
import numpy as np
import torch
import torch.distributed as dist
from torch.distributed import ProcessGroup
from torch.nn import Module
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from colossalai.amp.naive_amp.mixed_precision_optimizer import MixedPrecisionOptimizer
from colossalai.checkpoint_io import CheckpointIO
from colossalai.cluster import ProcessGroupMesh
from colossalai.interface import ModelWrapper, OptimizerWrapper
from colossalai.pipeline.schedule import OneForwardOneBackwardSchedule
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer import ShardConfig, ShardFormer
from colossalai.zero.low_level import LowLevelZeroOptimizer
from .pp_plugin_base import PipelinePluginBase
DP_AXIS, PP_AXIS, TP_AXIS = 0, 1, 2
class HybridParallelModule(ModelWrapper):
def __init__(self, module: Module, precision: str, shard_config: ShardConfig, dp_group: ProcessGroup) -> None:
self.stage_manager = shard_config.pipeline_stage_manager
self.dp_group = dp_group
shardformer = ShardFormer(shard_config)
module, self.shared_params = shardformer.optimize(module)
# TODO(ver217): add input type cast
self.shared_param_process_groups = []
for shared_param in self.shared_params:
if len(shared_param) > 0:
self.stage_manager.init_process_group_by_stages(list(shared_param.keys()))
if precision == 'fp16':
module = module.half().cuda()
elif precision == 'bf16':
module = module.to(dtype=torch.bfloat16).cuda()
# TODO(ver217): support TP+DP
super().__init__(module)
def sync_shared_params(self):
for shared_param, group in zip(self.shared_params, self.shared_param_process_groups):
param = shared_param[self.stage_manager.stage]
dist.all_reduce(param.grad, group=group)
def no_sync(self) -> Iterator[None]:
# no sync grads across data parallel
return nullcontext()
def sync_grads(self):
# sync grad across data parallel
if self.dp_group.size() == 1:
return
for p in self.module.parameters():
if p.grad is not None:
dist.all_reduce(p.grad, group=self.dp_group)
def init_pipeline_optimizer(optim: Optimizer, model: Module):
params = set(model.parameters())
new_param_groups = []
for group in optim.param_groups:
params = [p for p in group['params'] if p in params]
new_param_groups.append({**group, 'params': params})
optim.__setstate__({'param_groups': new_param_groups})
class HybridParallelOptimizer(MixedPrecisionOptimizer):
def __init__(self,
optim: Optimizer,
model: Module,
use_pipeline: bool,
precision: str = 'fp16',
initial_scale: float = 2**16,
min_scale: float = 1,
growth_factor: float = 2,
backoff_factor: float = 0.5,
growth_interval: int = 1000,
hysteresis: int = 2,
max_scale: float = 2**32,
max_norm: float = 0):
if use_pipeline:
init_pipeline_optimizer(optim, model)
super().__init__(optim, precision, initial_scale, min_scale, growth_factor, backoff_factor, growth_interval,
hysteresis, max_scale, max_norm)
class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
def __init__(
self,
optimizer: Optimizer,
model: Module,
use_pipeline: bool,
initial_scale: int = 2**16, # grad scaler config
min_scale: int = 1,
growth_factor: float = 2.,
backoff_factor: float = .5,
growth_interval: int = 2000,
hysteresis: int = 2,
max_scale: int = 2**24,
clip_grad_norm: float = 0.0, # grad clipping
verbose: bool = False,
reduce_bucket_size: int = 1024 * 1024, # communication
communication_dtype: Optional[torch.dtype] = None,
overlap_communication: bool = True,
partition_grad: bool = False, # stage 2 flag
cpu_offload: bool = False, # cpu offload
dp_process_group: Optional[ProcessGroup] = None, # the dp pg for comm
tp_process_group: Optional[ProcessGroup] = None, # if using tp
forced_dtype: Optional[torch.dtype] = None):
if use_pipeline:
init_pipeline_optimizer(optimizer, model)
super().__init__(optimizer, initial_scale, min_scale, growth_factor, backoff_factor, growth_interval,
hysteresis, max_scale, clip_grad_norm, verbose, reduce_bucket_size, communication_dtype,
overlap_communication, partition_grad, cpu_offload, dp_process_group, tp_process_group,
forced_dtype)
class HybridParallelPlugin(PipelinePluginBase):
def __init__(
self,
tp_size: int,
pp_size: int,
precision: str = 'fp16',
zero_stage: int = 0,
cpu_offload: bool = False,
enable_fused_normalization: bool = False,
num_microbatches: Optional[int] = None,
initial_scale: float = 2**16,
min_scale: float = 1,
growth_factor: float = 2,
backoff_factor: float = 0.5,
growth_interval: int = 1000,
hysteresis: int = 2,
max_scale: float = 2**32,
max_norm: float = 0,
) -> None:
super().__init__()
assert dist.get_world_size() % (
tp_size * pp_size
) == 0, f'world size {dist.get_world_size()} is not divisible by tp_size {tp_size} * pp_size {pp_size}'
# TODO(ver217): support zero
assert zero_stage == 0, 'zero is not support yet'
self.tp_size = tp_size
self.pp_size = pp_size
self.dp_size = dist.get_world_size() // (tp_size * pp_size)
self.precision = precision
self.zero_stage = zero_stage
self.cpu_offload = cpu_offload
self.enable_fused_normalization = enable_fused_normalization
self.pg_mesh = ProcessGroupMesh(self.dp_size, self.pp_size, self.tp_size)
self.stage_manager = None
self.schedule = None
assert zero_stage in (0, 1, 2)
if self.pp_size > 1:
assert num_microbatches is not None, 'num_microbatches must be specified when using pipeline parallelism'
assert self.zero_stage <= 1, 'zero stage must be 0 or 1 when using pipeline parallelism'
self.stage_manager = PipelineStageManager(self.pg_mesh, PP_AXIS)
self.schedule = OneForwardOneBackwardSchedule(num_microbatches, self.stage_manager)
self.tp_group = self.pg_mesh.get_group_along_axis(TP_AXIS)
self.dp_group = self.pg_mesh.get_group_along_axis(DP_AXIS)
self.shard_config = ShardConfig(tensor_parallel_process_group=self.tp_group,
pipeline_stage_manager=self.stage_manager,
enable_tensor_parallelism=self.tp_size > 1,
enable_fused_normalization=self.enable_fused_normalization)
self.amp_config = dict(
initial_scale=initial_scale,
growth_factor=growth_factor,
backoff_factor=backoff_factor,
growth_interval=growth_interval,
hysteresis=hysteresis,
min_scale=min_scale,
max_scale=max_scale,
)
self.max_norm = max_norm
@property
def enable_pipeline_parallelism(self) -> bool:
return self.pp_size > 1
def supported_devices(self) -> List[str]:
return ['cuda']
def supported_precisions(self) -> List[str]:
return ['fp16', 'bf16']
def control_device(self) -> bool:
return True
def control_precision(self) -> bool:
return True
def support_no_sync(self) -> bool:
return False
def control_checkpoint_io(self) -> bool:
return True
def configure(
self,
model: Module,
optimizer: Optional[Optimizer] = None,
criterion: Optional[Callable] = None,
dataloader: Optional[DataLoader] = None,
lr_scheduler: Optional[LRScheduler] = None,
) -> Tuple[Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]:
if not isinstance(model, ModelWrapper):
model = HybridParallelModule(model, self.precision, self.shard_config, self.dp_group)
if optimizer is not None and not isinstance(optimizer, OptimizerWrapper):
if self.zero_stage == 0:
optimizer = HybridParallelOptimizer(optimizer,
model,
use_pipeline=self.enable_pipeline_parallelism,
precision=self.precision,
max_norm=self.max_norm,
**self.amp_config)
else:
optimizer = HybridParallelZeroOptimizer(optimizer,
model,
use_pipeline=self.enable_pipeline_parallelism,
partition_grad=(self.zero_stage == 2),
cpu_offload=self.cpu_offload,
dp_process_group=self.dp_group,
tp_process_group=self.tp_group,
verbose=True,
clip_grad_norm=self.max_norm,
**self.amp_config)
return model, optimizer, criterion, dataloader, lr_scheduler
def execute_pipeline(self,
data_iter: Iterator,
model: HybridParallelModule,
criterion: Callable[[Any, Any], torch.Tensor],
optimizer: Union[HybridParallelOptimizer, HybridParallelZeroOptimizer],
return_loss: bool = True,
return_outputs: bool = False) -> dict:
assert self.enable_pipeline_parallelism, 'pipeline parallelism is not enabled'
# return loss or outputs if needed
ctx = optimizer.no_sync() if isinstance(optimizer, HybridParallelZeroOptimizer) else model.no_sync()
with ctx:
outputs = self.schedule.forward_backward_step(model, optimizer, data_iter, criterion, return_loss,
return_outputs)
# model.sync_shared_params()
if isinstance(optimizer, HybridParallelZeroOptimizer):
optimizer.sync_grad()
else:
model.sync_grads()
return outputs
def prepare_dataloader(self,
dataset,
batch_size,
shuffle=False,
seed=1024,
drop_last=False,
pin_memory=False,
num_workers=0,
**kwargs):
r"""
Prepare a dataloader for distributed training. The dataloader will be wrapped by
`torch.utils.data.DataLoader` and `torch.utils.data.DistributedSampler`.
Args:
dataset (`torch.utils.data.Dataset`): The dataset to be loaded.
shuffle (bool, optional): Whether to shuffle the dataset. Defaults to False.
seed (int, optional): Random worker seed for sampling, defaults to 1024.
add_sampler: Whether to add ``DistributedDataParallelSampler`` to the dataset. Defaults to True.
drop_last (bool, optional): Set to True to drop the last incomplete batch, if the dataset size
is not divisible by the batch size. If False and the size of dataset is not divisible by
the batch size, then the last batch will be smaller, defaults to False.
pin_memory (bool, optional): Whether to pin memory address in CPU memory. Defaults to False.
num_workers (int, optional): Number of worker threads for this dataloader. Defaults to 0.
kwargs (dict): optional parameters for ``torch.utils.data.DataLoader``, more details could be found in
`DataLoader <https://pytorch.org/docs/stable/_modules/torch/utils/data/dataloader.html#DataLoader>`_.
Returns:
:class:`torch.utils.data.DataLoader`: A DataLoader used for training or testing.
"""
_kwargs = kwargs.copy()
sampler = DistributedSampler(dataset,
num_replicas=self.pg_mesh.size(DP_AXIS),
rank=self.pg_mesh.coordinate(DP_AXIS),
shuffle=shuffle)
# Deterministic dataloader
def seed_worker(worker_id):
worker_seed = seed
np.random.seed(worker_seed)
torch.manual_seed(worker_seed)
random.seed(worker_seed)
return DataLoader(dataset,
batch_size=batch_size,
sampler=sampler,
worker_init_fn=seed_worker,
drop_last=drop_last,
pin_memory=pin_memory,
num_workers=num_workers,
**_kwargs)
def get_checkpoint_io(self) -> CheckpointIO:
return None
def no_sync(self, model: Module) -> Iterator[None]:
raise NotImplementedError

View File

@ -0,0 +1,21 @@
from abc import abstractmethod
from typing import Any, Callable, Iterator
import torch
from colossalai.interface import ModelWrapper, OptimizerWrapper
from .plugin_base import Plugin
class PipelinePluginBase(Plugin):
@abstractmethod
def execute_pipeline(self,
data_iter: Iterator,
model: ModelWrapper,
criterion: Callable[[Any, Any], torch.Tensor],
optimizer: OptimizerWrapper,
return_loss: bool = True,
return_outputs: bool = False) -> dict:
pass

View File

@ -7,9 +7,9 @@ from typing import Any, List, Optional, Union
import torch
import torch.distributed as dist
from packaging.version import Version
from torch.distributed import ProcessGroup
from torch.distributed import distributed_c10d as c10d
from version_parser.version import Version
from .stage_manager import PipelineStageManager

View File

@ -223,9 +223,6 @@ class LlamaPipelineForwards:
if output_hidden_states:
logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.')
output_hidden_states = False
if return_dict:
logger.warning_once('return_dict is not supported for pipeline models at the moment')
return_dict = False
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = LlamaPipelineForwards.llama_model_forward(
@ -311,9 +308,6 @@ class LlamaPipelineForwards:
if output_hidden_states:
logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.')
output_hidden_states = False
if return_dict:
logger.warning_once('return_dict is not supported for pipeline models at the moment')
return_dict = False
transformer_outputs = LlamaPipelineForwards.llama_model_forward(
self.model,

View File

@ -1,5 +1,5 @@
from types import MethodType
from typing import Any, Callable, Dict, List, Union
from typing import Any, Callable, Dict, List, Optional, Set, Union
import torch.nn as nn
from torch import Tensor
@ -39,8 +39,8 @@ class ModelSharder(object):
self._preprocess()
# get shared params before release unheld layers, this avoid misjudgement of shared params (None is None)
shared_params = self.policy.get_shared_params()
self._release_unheld_layers()
self._replace_module()
held_layers = self._release_unheld_layers()
self._replace_module(include=held_layers)
self._materialize()
self._postprocess()
return shared_params
@ -51,7 +51,7 @@ class ModelSharder(object):
def _postprocess(self) -> None:
self.model = self.policy.postprocess()
def _replace_module(self,) -> None:
def _replace_module(self, include: Optional[Set[nn.Module]] = None) -> None:
r"""
Replace the module according to the policy, and replace the module one by one
@ -64,8 +64,13 @@ class ModelSharder(object):
param_replacement = module_description.param_replacement
sub_module_replacement = module_description.sub_module_replacement
method_replacement = module_description.method_replacement
self._recursive_replace_layer(self.model, layer_cls, attr_replacement, param_replacement,
method_replacement, sub_module_replacement)
self._recursive_replace_layer(self.model,
layer_cls,
attr_replacement,
param_replacement,
method_replacement,
sub_module_replacement,
include=include)
def _recursive_replace_layer(
self,
@ -75,6 +80,7 @@ class ModelSharder(object):
param_replacement: List[Callable],
method_replacement: Dict[str, Callable],
sub_module_replacement: List[SubModuleReplacementDescription],
include: Optional[Set[nn.Module]] = None,
) -> None:
r"""
Reverse the replace layer operation
@ -87,23 +93,30 @@ class ModelSharder(object):
method_replacement (Dict[str, Callable]): Key is the method name, value is the method for replacement
sub_module_replacement ((List[SubModuleReplacementDescription]): The function list to get sub module shard information in policy
"""
# released layers are not shardable
can_replace_param_or_layer = include is None or module in include
if (isinstance(origin_cls, str) and origin_cls == module.__class__.__name__) or \
(module.__class__ == origin_cls):
if attr_replacement is not None:
self._replace_attr(module, attr_replacement)
if param_replacement is not None:
if param_replacement is not None and can_replace_param_or_layer:
self._replace_param(module, param_replacement)
if method_replacement is not None:
self._replace_method(module, method_replacement)
if sub_module_replacement is not None:
if sub_module_replacement is not None and can_replace_param_or_layer:
self._replace_sub_module(module, sub_module_replacement)
for name, child in module.named_children():
self._recursive_replace_layer(child, origin_cls, attr_replacement, param_replacement, method_replacement,
sub_module_replacement)
self._recursive_replace_layer(child,
origin_cls,
attr_replacement,
param_replacement,
method_replacement,
sub_module_replacement,
include=include)
def _replace_attr(
self,
@ -185,13 +198,15 @@ class ModelSharder(object):
setattr_(org_layer, suffix, replace_layer)
def _release_unheld_layers(self) -> None:
def _release_unheld_layers(self) -> Optional[Set[nn.Module]]:
r"""
Release the unheld layers in the model
"""
if self.shard_config and self.shard_config.pipeline_stage_manager:
held_layers = self.policy.get_held_layers()
set_tensors_to_none(self.model, exclude=set(held_layers))
return set(held_layers)
return None
def _materialize(self) -> None:
r"""

View File

@ -0,0 +1,99 @@
from contextlib import nullcontext
from typing import Optional
import torch
import torch.distributed as dist
import colossalai
from colossalai.booster import Booster
from colossalai.booster.plugin import HybridParallelPlugin
from colossalai.fx import is_compatible_with_meta
from colossalai.lazy.lazy_init import LazyInitContext
from colossalai.nn.optimizer import HybridAdam
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
from tests.kit.model_zoo import model_zoo
def run_fn(init_method, model_fn, data_gen_fn, output_transform_fn) -> Optional[str]:
try:
if init_method == 'lazy':
ctx = LazyInitContext()
else:
ctx = nullcontext()
plugin = HybridParallelPlugin(tp_size=2, pp_size=2, num_microbatches=4, precision='bf16')
booster = Booster(plugin=plugin)
with ctx:
model = model_fn()
optimizer = HybridAdam(model.parameters(), lr=1e-3)
criterion = lambda x: x.mean()
data = data_gen_fn()
data = {
k: v.to('cuda').repeat(4, 1) if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__ else v
for k, v in data.items()
}
model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion)
data_iter = iter([data])
def _criterion(outputs, inputs):
outputs = output_transform_fn(outputs)
output_key = list(outputs.keys())[0]
loss = criterion(outputs[output_key])
return loss
booster.execute_pipeline(data_iter, model, _criterion, optimizer, return_loss=True, return_outputs=False)
optimizer.step()
except Exception as e:
return repr(e)
@parameterize('init_method', ['none', 'lazy'])
def check_3d_plugin(init_method: str = 'none', early_stop: bool = True):
"""check gemini plugin over model zoo
Args:
early_stop (bool, optional): Whether to stop when getting the first error. Defaults to True.
"""
is_support_meta = is_compatible_with_meta()
if not is_support_meta and init_method == 'lazy':
return
passed_models = []
failed_info = {} # (model_name, error) pair
# TODO(ver217): add more models
for name, (model_fn, data_gen_fn, output_transform_fn, _,
_) in model_zoo.get_sub_registry('transformers_llama_for_casual_lm').items():
err = run_fn(init_method, model_fn, data_gen_fn, output_transform_fn)
torch.cuda.empty_cache()
if err is None:
passed_models.append(name)
else:
failed_info[name] = err
if early_stop:
break
if dist.get_rank() == 0:
print(f'Init method: {init_method}')
print(f'Passed models({len(passed_models)}): {passed_models}\n\n')
print(f'Failed models({len(failed_info)}): {list(failed_info.keys())}\n\n')
assert len(failed_info) == 0, '\n'.join([f'{k}: {v}' for k, v in failed_info.items()])
def run_dist(rank, world_size, port, early_stop: bool = True):
# init dist env
colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost')
check_3d_plugin(early_stop=early_stop)
@rerun_if_address_is_in_use()
def test_gemini_plugin(early_stop: bool = True):
spawn(run_dist, 4, early_stop=early_stop)
if __name__ == '__main__':
test_gemini_plugin(early_stop=False)