mirror of https://github.com/hpcaitech/ColossalAI
[plugin] add 3d parallel plugin (#4295)
* [amp] add mixed precision optimizer * [plugin] add 3d parallel plugin * [booster] support pipeline * [plugin] 3d parallel plugin support clip grad norm * [shardformer] fix sharder and add plugin test * [plugin] rename 3d parallel plugin * [ci] support testmon core pkg change detection (#4305) * [hotfix] debug testmon * [hotfix] fix llama * [hotfix] fix p2p bugs * [hotfix] fix requirementspull/4445/head
parent
b3f5d7a3ba
commit
261eab02fb
|
@ -0,0 +1,149 @@
|
|||
from typing import Dict, List
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torch.nn import Parameter
|
||||
from torch.optim import Optimizer
|
||||
|
||||
from colossalai.interface import OptimizerWrapper
|
||||
|
||||
from .mixed_precision_mixin import BF16MixedPrecisionMixin, FP16MixedPrecisionMixin
|
||||
|
||||
|
||||
class NaiveFP16MixedPrecisionMixin(FP16MixedPrecisionMixin):
|
||||
|
||||
def __init__(self,
|
||||
working_params: List[Parameter],
|
||||
initial_scale: float = 2**16,
|
||||
min_scale: float = 1,
|
||||
growth_factor: float = 2,
|
||||
backoff_factor: float = 0.5,
|
||||
growth_interval: int = 1000,
|
||||
hysteresis: int = 2,
|
||||
max_scale: float = 2**32) -> None:
|
||||
super().__init__(initial_scale, min_scale, growth_factor, backoff_factor, growth_interval, hysteresis,
|
||||
max_scale)
|
||||
self.params = working_params
|
||||
|
||||
def check_local_overflow(self) -> bool:
|
||||
for p in self.params:
|
||||
if p.grad is not None and not torch.isfinite(p.grad).all():
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
class MixedPrecisionOptimizer(OptimizerWrapper):
|
||||
|
||||
def __init__(self,
|
||||
optim: Optimizer,
|
||||
precision: str = 'fp16',
|
||||
initial_scale: float = 2**16,
|
||||
min_scale: float = 1,
|
||||
growth_factor: float = 2,
|
||||
backoff_factor: float = 0.5,
|
||||
growth_interval: int = 1000,
|
||||
hysteresis: int = 2,
|
||||
max_scale: float = 2**32,
|
||||
max_norm: float = 0.0):
|
||||
super().__init__(optim)
|
||||
if precision == 'fp16':
|
||||
working_params = []
|
||||
for group in self.optim.param_groups:
|
||||
for p in group['params']:
|
||||
working_params.append(p)
|
||||
self.mixed_precision = NaiveFP16MixedPrecisionMixin(working_params,
|
||||
initial_scale=initial_scale,
|
||||
min_scale=min_scale,
|
||||
growth_factor=growth_factor,
|
||||
backoff_factor=backoff_factor,
|
||||
growth_interval=growth_interval,
|
||||
hysteresis=hysteresis,
|
||||
max_scale=max_scale)
|
||||
elif precision == 'bf16':
|
||||
self.mixed_precision = BF16MixedPrecisionMixin()
|
||||
else:
|
||||
raise ValueError(f'Unsupported precision: {precision}')
|
||||
if max_norm > 0.0:
|
||||
raise NotImplementedError('max_norm is not supported yet.')
|
||||
self.max_norm = max_norm
|
||||
self.working_to_master_map: Dict[Parameter, Tensor] = {}
|
||||
self.master_to_working_map: Dict[Tensor, Parameter] = {}
|
||||
|
||||
# create master weights
|
||||
for group in self.optim.param_groups:
|
||||
master_params = []
|
||||
for p in group['params']:
|
||||
if p.requires_grad:
|
||||
master_p = p
|
||||
if p.dtype != torch.float:
|
||||
master_p = p.detach().float()
|
||||
self.working_to_master_map[p] = master_p
|
||||
self.master_to_working_map[master_p] = p
|
||||
master_params.append(master_p)
|
||||
group['params'] = master_params
|
||||
|
||||
def backward(self, loss: Tensor, *args, **kwargs):
|
||||
loss = self.mixed_precision.pre_backward(loss)
|
||||
loss.backward(*args, **kwargs)
|
||||
|
||||
def backward_by_grad(self, tensor: Tensor, grad: Tensor):
|
||||
grad = self.mixed_precision.pre_backward_by_grad(tensor, grad)
|
||||
tensor.backward(grad)
|
||||
|
||||
def zero_grad(self, *args, **kwargs):
|
||||
for p in self.working_to_master_map.keys():
|
||||
p.grad = None
|
||||
self.mixed_precision.pre_zero_grad()
|
||||
return super().zero_grad(*args, **kwargs)
|
||||
|
||||
def _unscale_and_clip_grads(self, total_norm: float) -> None:
|
||||
div_scale = 1.0
|
||||
if self.mixed_precision is not None:
|
||||
div_scale = self.mixed_precision.get_grad_div_scale()
|
||||
|
||||
if self.max_norm > 0.:
|
||||
# norm is in fact norm*scale
|
||||
clip = ((total_norm / div_scale) + 1e-6) / self.max_norm
|
||||
if clip > 1:
|
||||
div_scale = clip * div_scale
|
||||
|
||||
for group in self.param_groups:
|
||||
for p in group['params']:
|
||||
if p.grad is None:
|
||||
continue
|
||||
p.grad.data.mul_(1. / div_scale)
|
||||
|
||||
def _compute_grad_norm(self) -> float:
|
||||
if self.max_norm <= 0.:
|
||||
return 0.
|
||||
grads = [p.grad for group in self.param_groups for p in group['params'] if p.grad is not None]
|
||||
if len(grads) == 0:
|
||||
return 0.
|
||||
device = grads[0].device
|
||||
# TODO(ver217): support tp
|
||||
total_norm = torch.norm(torch.stack([torch.norm(g.detach(), 2).to(device) for g in grads]), 2)
|
||||
return total_norm.item()
|
||||
|
||||
def step(self, *args, **kwargs):
|
||||
if self.mixed_precision.should_skip_step():
|
||||
self.zero_grad()
|
||||
return
|
||||
# prepare grads
|
||||
for group in self.optim.param_groups:
|
||||
for p in group['params']:
|
||||
working_param = self.master_to_working_map[p]
|
||||
if p is working_param:
|
||||
continue
|
||||
if working_param.grad is None:
|
||||
p.grad = working_param.grad.data.float()
|
||||
working_param.grad = None
|
||||
total_norm = self._compute_grad_norm()
|
||||
self._unscale_and_clip_grads(total_norm)
|
||||
self.optim.step(*args, **kwargs)
|
||||
# update working params
|
||||
for group in self.optim.param_groups:
|
||||
for p in group['params']:
|
||||
working_param = self.master_to_working_map[p]
|
||||
if p is working_param:
|
||||
continue
|
||||
working_param.data.copy_(p.data)
|
|
@ -1,6 +1,6 @@
|
|||
import warnings
|
||||
from contextlib import contextmanager
|
||||
from typing import Callable, Iterator, List, Optional, Tuple, Union
|
||||
from typing import Any, Callable, Iterator, List, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
@ -14,6 +14,7 @@ from colossalai.interface import ModelWrapper, OptimizerWrapper
|
|||
from .accelerator import Accelerator
|
||||
from .mixed_precision import MixedPrecision, mixed_precision_factory
|
||||
from .plugin import Plugin
|
||||
from .plugin.pp_plugin_base import PipelinePluginBase
|
||||
|
||||
__all__ = ['Booster']
|
||||
|
||||
|
@ -144,14 +145,15 @@ class Booster:
|
|||
def execute_pipeline(self,
|
||||
data_iter: Iterator,
|
||||
model: nn.Module,
|
||||
criterion: Callable[[torch.Tensor], torch.Tensor],
|
||||
criterion: Callable[[Any, Any], torch.Tensor],
|
||||
optimizer: Optimizer,
|
||||
return_loss: bool = True,
|
||||
return_outputs: bool = False) -> Tuple[Optional[torch.Tensor], ...]:
|
||||
# TODO: implement this method
|
||||
return_outputs: bool = False) -> dict:
|
||||
# run pipeline forward backward pass
|
||||
# return loss or outputs if needed
|
||||
pass
|
||||
assert isinstance(self.plugin,
|
||||
PipelinePluginBase), f'The plugin {self.plugin.__class__.__name__} does not support pipeline.'
|
||||
return self.plugin.execute_pipeline(data_iter, model, criterion, optimizer, return_loss, return_outputs)
|
||||
|
||||
def no_sync(self, model: nn.Module = None, optimizer: OptimizerWrapper = None) -> contextmanager:
|
||||
"""Context manager to disable gradient synchronization across DP process groups.
|
||||
|
|
|
@ -1,9 +1,10 @@
|
|||
from .gemini_plugin import GeminiPlugin
|
||||
from .hybrid_parallel_plugin import HybridParallelPlugin
|
||||
from .low_level_zero_plugin import LowLevelZeroPlugin
|
||||
from .plugin_base import Plugin
|
||||
from .torch_ddp_plugin import TorchDDPPlugin
|
||||
|
||||
__all__ = ['Plugin', 'TorchDDPPlugin', 'GeminiPlugin', 'LowLevelZeroPlugin']
|
||||
__all__ = ['Plugin', 'TorchDDPPlugin', 'GeminiPlugin', 'LowLevelZeroPlugin', 'HybridParallelPlugin']
|
||||
|
||||
import torch
|
||||
from packaging import version
|
||||
|
|
|
@ -0,0 +1,316 @@
|
|||
import random
|
||||
from contextlib import nullcontext
|
||||
from typing import Any, Callable, Iterator, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.distributed import ProcessGroup
|
||||
from torch.nn import Module
|
||||
from torch.optim import Optimizer
|
||||
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
|
||||
from colossalai.amp.naive_amp.mixed_precision_optimizer import MixedPrecisionOptimizer
|
||||
from colossalai.checkpoint_io import CheckpointIO
|
||||
from colossalai.cluster import ProcessGroupMesh
|
||||
from colossalai.interface import ModelWrapper, OptimizerWrapper
|
||||
from colossalai.pipeline.schedule import OneForwardOneBackwardSchedule
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
from colossalai.shardformer import ShardConfig, ShardFormer
|
||||
from colossalai.zero.low_level import LowLevelZeroOptimizer
|
||||
|
||||
from .pp_plugin_base import PipelinePluginBase
|
||||
|
||||
DP_AXIS, PP_AXIS, TP_AXIS = 0, 1, 2
|
||||
|
||||
|
||||
class HybridParallelModule(ModelWrapper):
|
||||
|
||||
def __init__(self, module: Module, precision: str, shard_config: ShardConfig, dp_group: ProcessGroup) -> None:
|
||||
self.stage_manager = shard_config.pipeline_stage_manager
|
||||
self.dp_group = dp_group
|
||||
shardformer = ShardFormer(shard_config)
|
||||
module, self.shared_params = shardformer.optimize(module)
|
||||
# TODO(ver217): add input type cast
|
||||
self.shared_param_process_groups = []
|
||||
for shared_param in self.shared_params:
|
||||
if len(shared_param) > 0:
|
||||
self.stage_manager.init_process_group_by_stages(list(shared_param.keys()))
|
||||
if precision == 'fp16':
|
||||
module = module.half().cuda()
|
||||
elif precision == 'bf16':
|
||||
module = module.to(dtype=torch.bfloat16).cuda()
|
||||
# TODO(ver217): support TP+DP
|
||||
super().__init__(module)
|
||||
|
||||
def sync_shared_params(self):
|
||||
for shared_param, group in zip(self.shared_params, self.shared_param_process_groups):
|
||||
param = shared_param[self.stage_manager.stage]
|
||||
dist.all_reduce(param.grad, group=group)
|
||||
|
||||
def no_sync(self) -> Iterator[None]:
|
||||
# no sync grads across data parallel
|
||||
return nullcontext()
|
||||
|
||||
def sync_grads(self):
|
||||
# sync grad across data parallel
|
||||
if self.dp_group.size() == 1:
|
||||
return
|
||||
for p in self.module.parameters():
|
||||
if p.grad is not None:
|
||||
dist.all_reduce(p.grad, group=self.dp_group)
|
||||
|
||||
|
||||
def init_pipeline_optimizer(optim: Optimizer, model: Module):
|
||||
params = set(model.parameters())
|
||||
new_param_groups = []
|
||||
for group in optim.param_groups:
|
||||
params = [p for p in group['params'] if p in params]
|
||||
new_param_groups.append({**group, 'params': params})
|
||||
optim.__setstate__({'param_groups': new_param_groups})
|
||||
|
||||
|
||||
class HybridParallelOptimizer(MixedPrecisionOptimizer):
|
||||
|
||||
def __init__(self,
|
||||
optim: Optimizer,
|
||||
model: Module,
|
||||
use_pipeline: bool,
|
||||
precision: str = 'fp16',
|
||||
initial_scale: float = 2**16,
|
||||
min_scale: float = 1,
|
||||
growth_factor: float = 2,
|
||||
backoff_factor: float = 0.5,
|
||||
growth_interval: int = 1000,
|
||||
hysteresis: int = 2,
|
||||
max_scale: float = 2**32,
|
||||
max_norm: float = 0):
|
||||
if use_pipeline:
|
||||
init_pipeline_optimizer(optim, model)
|
||||
super().__init__(optim, precision, initial_scale, min_scale, growth_factor, backoff_factor, growth_interval,
|
||||
hysteresis, max_scale, max_norm)
|
||||
|
||||
|
||||
class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
optimizer: Optimizer,
|
||||
model: Module,
|
||||
use_pipeline: bool,
|
||||
initial_scale: int = 2**16, # grad scaler config
|
||||
min_scale: int = 1,
|
||||
growth_factor: float = 2.,
|
||||
backoff_factor: float = .5,
|
||||
growth_interval: int = 2000,
|
||||
hysteresis: int = 2,
|
||||
max_scale: int = 2**24,
|
||||
clip_grad_norm: float = 0.0, # grad clipping
|
||||
verbose: bool = False,
|
||||
reduce_bucket_size: int = 1024 * 1024, # communication
|
||||
communication_dtype: Optional[torch.dtype] = None,
|
||||
overlap_communication: bool = True,
|
||||
partition_grad: bool = False, # stage 2 flag
|
||||
cpu_offload: bool = False, # cpu offload
|
||||
dp_process_group: Optional[ProcessGroup] = None, # the dp pg for comm
|
||||
tp_process_group: Optional[ProcessGroup] = None, # if using tp
|
||||
forced_dtype: Optional[torch.dtype] = None):
|
||||
if use_pipeline:
|
||||
init_pipeline_optimizer(optimizer, model)
|
||||
super().__init__(optimizer, initial_scale, min_scale, growth_factor, backoff_factor, growth_interval,
|
||||
hysteresis, max_scale, clip_grad_norm, verbose, reduce_bucket_size, communication_dtype,
|
||||
overlap_communication, partition_grad, cpu_offload, dp_process_group, tp_process_group,
|
||||
forced_dtype)
|
||||
|
||||
|
||||
class HybridParallelPlugin(PipelinePluginBase):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tp_size: int,
|
||||
pp_size: int,
|
||||
precision: str = 'fp16',
|
||||
zero_stage: int = 0,
|
||||
cpu_offload: bool = False,
|
||||
enable_fused_normalization: bool = False,
|
||||
num_microbatches: Optional[int] = None,
|
||||
initial_scale: float = 2**16,
|
||||
min_scale: float = 1,
|
||||
growth_factor: float = 2,
|
||||
backoff_factor: float = 0.5,
|
||||
growth_interval: int = 1000,
|
||||
hysteresis: int = 2,
|
||||
max_scale: float = 2**32,
|
||||
max_norm: float = 0,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
assert dist.get_world_size() % (
|
||||
tp_size * pp_size
|
||||
) == 0, f'world size {dist.get_world_size()} is not divisible by tp_size {tp_size} * pp_size {pp_size}'
|
||||
# TODO(ver217): support zero
|
||||
assert zero_stage == 0, 'zero is not support yet'
|
||||
self.tp_size = tp_size
|
||||
self.pp_size = pp_size
|
||||
self.dp_size = dist.get_world_size() // (tp_size * pp_size)
|
||||
self.precision = precision
|
||||
self.zero_stage = zero_stage
|
||||
self.cpu_offload = cpu_offload
|
||||
self.enable_fused_normalization = enable_fused_normalization
|
||||
self.pg_mesh = ProcessGroupMesh(self.dp_size, self.pp_size, self.tp_size)
|
||||
self.stage_manager = None
|
||||
self.schedule = None
|
||||
assert zero_stage in (0, 1, 2)
|
||||
if self.pp_size > 1:
|
||||
assert num_microbatches is not None, 'num_microbatches must be specified when using pipeline parallelism'
|
||||
assert self.zero_stage <= 1, 'zero stage must be 0 or 1 when using pipeline parallelism'
|
||||
self.stage_manager = PipelineStageManager(self.pg_mesh, PP_AXIS)
|
||||
self.schedule = OneForwardOneBackwardSchedule(num_microbatches, self.stage_manager)
|
||||
self.tp_group = self.pg_mesh.get_group_along_axis(TP_AXIS)
|
||||
self.dp_group = self.pg_mesh.get_group_along_axis(DP_AXIS)
|
||||
self.shard_config = ShardConfig(tensor_parallel_process_group=self.tp_group,
|
||||
pipeline_stage_manager=self.stage_manager,
|
||||
enable_tensor_parallelism=self.tp_size > 1,
|
||||
enable_fused_normalization=self.enable_fused_normalization)
|
||||
self.amp_config = dict(
|
||||
initial_scale=initial_scale,
|
||||
growth_factor=growth_factor,
|
||||
backoff_factor=backoff_factor,
|
||||
growth_interval=growth_interval,
|
||||
hysteresis=hysteresis,
|
||||
min_scale=min_scale,
|
||||
max_scale=max_scale,
|
||||
)
|
||||
self.max_norm = max_norm
|
||||
|
||||
@property
|
||||
def enable_pipeline_parallelism(self) -> bool:
|
||||
return self.pp_size > 1
|
||||
|
||||
def supported_devices(self) -> List[str]:
|
||||
return ['cuda']
|
||||
|
||||
def supported_precisions(self) -> List[str]:
|
||||
return ['fp16', 'bf16']
|
||||
|
||||
def control_device(self) -> bool:
|
||||
return True
|
||||
|
||||
def control_precision(self) -> bool:
|
||||
return True
|
||||
|
||||
def support_no_sync(self) -> bool:
|
||||
return False
|
||||
|
||||
def control_checkpoint_io(self) -> bool:
|
||||
return True
|
||||
|
||||
def configure(
|
||||
self,
|
||||
model: Module,
|
||||
optimizer: Optional[Optimizer] = None,
|
||||
criterion: Optional[Callable] = None,
|
||||
dataloader: Optional[DataLoader] = None,
|
||||
lr_scheduler: Optional[LRScheduler] = None,
|
||||
) -> Tuple[Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]:
|
||||
if not isinstance(model, ModelWrapper):
|
||||
model = HybridParallelModule(model, self.precision, self.shard_config, self.dp_group)
|
||||
if optimizer is not None and not isinstance(optimizer, OptimizerWrapper):
|
||||
if self.zero_stage == 0:
|
||||
optimizer = HybridParallelOptimizer(optimizer,
|
||||
model,
|
||||
use_pipeline=self.enable_pipeline_parallelism,
|
||||
precision=self.precision,
|
||||
max_norm=self.max_norm,
|
||||
**self.amp_config)
|
||||
else:
|
||||
optimizer = HybridParallelZeroOptimizer(optimizer,
|
||||
model,
|
||||
use_pipeline=self.enable_pipeline_parallelism,
|
||||
partition_grad=(self.zero_stage == 2),
|
||||
cpu_offload=self.cpu_offload,
|
||||
dp_process_group=self.dp_group,
|
||||
tp_process_group=self.tp_group,
|
||||
verbose=True,
|
||||
clip_grad_norm=self.max_norm,
|
||||
**self.amp_config)
|
||||
return model, optimizer, criterion, dataloader, lr_scheduler
|
||||
|
||||
def execute_pipeline(self,
|
||||
data_iter: Iterator,
|
||||
model: HybridParallelModule,
|
||||
criterion: Callable[[Any, Any], torch.Tensor],
|
||||
optimizer: Union[HybridParallelOptimizer, HybridParallelZeroOptimizer],
|
||||
return_loss: bool = True,
|
||||
return_outputs: bool = False) -> dict:
|
||||
assert self.enable_pipeline_parallelism, 'pipeline parallelism is not enabled'
|
||||
# return loss or outputs if needed
|
||||
ctx = optimizer.no_sync() if isinstance(optimizer, HybridParallelZeroOptimizer) else model.no_sync()
|
||||
with ctx:
|
||||
outputs = self.schedule.forward_backward_step(model, optimizer, data_iter, criterion, return_loss,
|
||||
return_outputs)
|
||||
# model.sync_shared_params()
|
||||
if isinstance(optimizer, HybridParallelZeroOptimizer):
|
||||
optimizer.sync_grad()
|
||||
else:
|
||||
model.sync_grads()
|
||||
return outputs
|
||||
|
||||
def prepare_dataloader(self,
|
||||
dataset,
|
||||
batch_size,
|
||||
shuffle=False,
|
||||
seed=1024,
|
||||
drop_last=False,
|
||||
pin_memory=False,
|
||||
num_workers=0,
|
||||
**kwargs):
|
||||
r"""
|
||||
Prepare a dataloader for distributed training. The dataloader will be wrapped by
|
||||
`torch.utils.data.DataLoader` and `torch.utils.data.DistributedSampler`.
|
||||
|
||||
|
||||
Args:
|
||||
dataset (`torch.utils.data.Dataset`): The dataset to be loaded.
|
||||
shuffle (bool, optional): Whether to shuffle the dataset. Defaults to False.
|
||||
seed (int, optional): Random worker seed for sampling, defaults to 1024.
|
||||
add_sampler: Whether to add ``DistributedDataParallelSampler`` to the dataset. Defaults to True.
|
||||
drop_last (bool, optional): Set to True to drop the last incomplete batch, if the dataset size
|
||||
is not divisible by the batch size. If False and the size of dataset is not divisible by
|
||||
the batch size, then the last batch will be smaller, defaults to False.
|
||||
pin_memory (bool, optional): Whether to pin memory address in CPU memory. Defaults to False.
|
||||
num_workers (int, optional): Number of worker threads for this dataloader. Defaults to 0.
|
||||
kwargs (dict): optional parameters for ``torch.utils.data.DataLoader``, more details could be found in
|
||||
`DataLoader <https://pytorch.org/docs/stable/_modules/torch/utils/data/dataloader.html#DataLoader>`_.
|
||||
|
||||
Returns:
|
||||
:class:`torch.utils.data.DataLoader`: A DataLoader used for training or testing.
|
||||
"""
|
||||
_kwargs = kwargs.copy()
|
||||
sampler = DistributedSampler(dataset,
|
||||
num_replicas=self.pg_mesh.size(DP_AXIS),
|
||||
rank=self.pg_mesh.coordinate(DP_AXIS),
|
||||
shuffle=shuffle)
|
||||
|
||||
# Deterministic dataloader
|
||||
def seed_worker(worker_id):
|
||||
worker_seed = seed
|
||||
np.random.seed(worker_seed)
|
||||
torch.manual_seed(worker_seed)
|
||||
random.seed(worker_seed)
|
||||
|
||||
return DataLoader(dataset,
|
||||
batch_size=batch_size,
|
||||
sampler=sampler,
|
||||
worker_init_fn=seed_worker,
|
||||
drop_last=drop_last,
|
||||
pin_memory=pin_memory,
|
||||
num_workers=num_workers,
|
||||
**_kwargs)
|
||||
|
||||
def get_checkpoint_io(self) -> CheckpointIO:
|
||||
return None
|
||||
|
||||
def no_sync(self, model: Module) -> Iterator[None]:
|
||||
raise NotImplementedError
|
|
@ -0,0 +1,21 @@
|
|||
from abc import abstractmethod
|
||||
from typing import Any, Callable, Iterator
|
||||
|
||||
import torch
|
||||
|
||||
from colossalai.interface import ModelWrapper, OptimizerWrapper
|
||||
|
||||
from .plugin_base import Plugin
|
||||
|
||||
|
||||
class PipelinePluginBase(Plugin):
|
||||
|
||||
@abstractmethod
|
||||
def execute_pipeline(self,
|
||||
data_iter: Iterator,
|
||||
model: ModelWrapper,
|
||||
criterion: Callable[[Any, Any], torch.Tensor],
|
||||
optimizer: OptimizerWrapper,
|
||||
return_loss: bool = True,
|
||||
return_outputs: bool = False) -> dict:
|
||||
pass
|
|
@ -7,9 +7,9 @@ from typing import Any, List, Optional, Union
|
|||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from packaging.version import Version
|
||||
from torch.distributed import ProcessGroup
|
||||
from torch.distributed import distributed_c10d as c10d
|
||||
from version_parser.version import Version
|
||||
|
||||
from .stage_manager import PipelineStageManager
|
||||
|
||||
|
|
|
@ -223,9 +223,6 @@ class LlamaPipelineForwards:
|
|||
if output_hidden_states:
|
||||
logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.')
|
||||
output_hidden_states = False
|
||||
if return_dict:
|
||||
logger.warning_once('return_dict is not supported for pipeline models at the moment')
|
||||
return_dict = False
|
||||
|
||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||
outputs = LlamaPipelineForwards.llama_model_forward(
|
||||
|
@ -311,9 +308,6 @@ class LlamaPipelineForwards:
|
|||
if output_hidden_states:
|
||||
logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.')
|
||||
output_hidden_states = False
|
||||
if return_dict:
|
||||
logger.warning_once('return_dict is not supported for pipeline models at the moment')
|
||||
return_dict = False
|
||||
|
||||
transformer_outputs = LlamaPipelineForwards.llama_model_forward(
|
||||
self.model,
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
from types import MethodType
|
||||
from typing import Any, Callable, Dict, List, Union
|
||||
from typing import Any, Callable, Dict, List, Optional, Set, Union
|
||||
|
||||
import torch.nn as nn
|
||||
from torch import Tensor
|
||||
|
@ -39,8 +39,8 @@ class ModelSharder(object):
|
|||
self._preprocess()
|
||||
# get shared params before release unheld layers, this avoid misjudgement of shared params (None is None)
|
||||
shared_params = self.policy.get_shared_params()
|
||||
self._release_unheld_layers()
|
||||
self._replace_module()
|
||||
held_layers = self._release_unheld_layers()
|
||||
self._replace_module(include=held_layers)
|
||||
self._materialize()
|
||||
self._postprocess()
|
||||
return shared_params
|
||||
|
@ -51,7 +51,7 @@ class ModelSharder(object):
|
|||
def _postprocess(self) -> None:
|
||||
self.model = self.policy.postprocess()
|
||||
|
||||
def _replace_module(self,) -> None:
|
||||
def _replace_module(self, include: Optional[Set[nn.Module]] = None) -> None:
|
||||
r"""
|
||||
Replace the module according to the policy, and replace the module one by one
|
||||
|
||||
|
@ -64,8 +64,13 @@ class ModelSharder(object):
|
|||
param_replacement = module_description.param_replacement
|
||||
sub_module_replacement = module_description.sub_module_replacement
|
||||
method_replacement = module_description.method_replacement
|
||||
self._recursive_replace_layer(self.model, layer_cls, attr_replacement, param_replacement,
|
||||
method_replacement, sub_module_replacement)
|
||||
self._recursive_replace_layer(self.model,
|
||||
layer_cls,
|
||||
attr_replacement,
|
||||
param_replacement,
|
||||
method_replacement,
|
||||
sub_module_replacement,
|
||||
include=include)
|
||||
|
||||
def _recursive_replace_layer(
|
||||
self,
|
||||
|
@ -75,6 +80,7 @@ class ModelSharder(object):
|
|||
param_replacement: List[Callable],
|
||||
method_replacement: Dict[str, Callable],
|
||||
sub_module_replacement: List[SubModuleReplacementDescription],
|
||||
include: Optional[Set[nn.Module]] = None,
|
||||
) -> None:
|
||||
r"""
|
||||
Reverse the replace layer operation
|
||||
|
@ -87,23 +93,30 @@ class ModelSharder(object):
|
|||
method_replacement (Dict[str, Callable]): Key is the method name, value is the method for replacement
|
||||
sub_module_replacement ((List[SubModuleReplacementDescription]): The function list to get sub module shard information in policy
|
||||
"""
|
||||
# released layers are not shardable
|
||||
can_replace_param_or_layer = include is None or module in include
|
||||
if (isinstance(origin_cls, str) and origin_cls == module.__class__.__name__) or \
|
||||
(module.__class__ == origin_cls):
|
||||
if attr_replacement is not None:
|
||||
self._replace_attr(module, attr_replacement)
|
||||
|
||||
if param_replacement is not None:
|
||||
if param_replacement is not None and can_replace_param_or_layer:
|
||||
self._replace_param(module, param_replacement)
|
||||
|
||||
if method_replacement is not None:
|
||||
self._replace_method(module, method_replacement)
|
||||
|
||||
if sub_module_replacement is not None:
|
||||
if sub_module_replacement is not None and can_replace_param_or_layer:
|
||||
self._replace_sub_module(module, sub_module_replacement)
|
||||
|
||||
for name, child in module.named_children():
|
||||
self._recursive_replace_layer(child, origin_cls, attr_replacement, param_replacement, method_replacement,
|
||||
sub_module_replacement)
|
||||
self._recursive_replace_layer(child,
|
||||
origin_cls,
|
||||
attr_replacement,
|
||||
param_replacement,
|
||||
method_replacement,
|
||||
sub_module_replacement,
|
||||
include=include)
|
||||
|
||||
def _replace_attr(
|
||||
self,
|
||||
|
@ -185,13 +198,15 @@ class ModelSharder(object):
|
|||
|
||||
setattr_(org_layer, suffix, replace_layer)
|
||||
|
||||
def _release_unheld_layers(self) -> None:
|
||||
def _release_unheld_layers(self) -> Optional[Set[nn.Module]]:
|
||||
r"""
|
||||
Release the unheld layers in the model
|
||||
"""
|
||||
if self.shard_config and self.shard_config.pipeline_stage_manager:
|
||||
held_layers = self.policy.get_held_layers()
|
||||
set_tensors_to_none(self.model, exclude=set(held_layers))
|
||||
return set(held_layers)
|
||||
return None
|
||||
|
||||
def _materialize(self) -> None:
|
||||
r"""
|
||||
|
|
|
@ -0,0 +1,99 @@
|
|||
from contextlib import nullcontext
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
import colossalai
|
||||
from colossalai.booster import Booster
|
||||
from colossalai.booster.plugin import HybridParallelPlugin
|
||||
from colossalai.fx import is_compatible_with_meta
|
||||
from colossalai.lazy.lazy_init import LazyInitContext
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
||||
from tests.kit.model_zoo import model_zoo
|
||||
|
||||
|
||||
def run_fn(init_method, model_fn, data_gen_fn, output_transform_fn) -> Optional[str]:
|
||||
try:
|
||||
if init_method == 'lazy':
|
||||
ctx = LazyInitContext()
|
||||
else:
|
||||
ctx = nullcontext()
|
||||
plugin = HybridParallelPlugin(tp_size=2, pp_size=2, num_microbatches=4, precision='bf16')
|
||||
booster = Booster(plugin=plugin)
|
||||
with ctx:
|
||||
model = model_fn()
|
||||
optimizer = HybridAdam(model.parameters(), lr=1e-3)
|
||||
criterion = lambda x: x.mean()
|
||||
data = data_gen_fn()
|
||||
|
||||
data = {
|
||||
k: v.to('cuda').repeat(4, 1) if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__ else v
|
||||
for k, v in data.items()
|
||||
}
|
||||
|
||||
model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion)
|
||||
|
||||
data_iter = iter([data])
|
||||
|
||||
def _criterion(outputs, inputs):
|
||||
outputs = output_transform_fn(outputs)
|
||||
output_key = list(outputs.keys())[0]
|
||||
loss = criterion(outputs[output_key])
|
||||
return loss
|
||||
|
||||
booster.execute_pipeline(data_iter, model, _criterion, optimizer, return_loss=True, return_outputs=False)
|
||||
optimizer.step()
|
||||
|
||||
except Exception as e:
|
||||
return repr(e)
|
||||
|
||||
|
||||
@parameterize('init_method', ['none', 'lazy'])
|
||||
def check_3d_plugin(init_method: str = 'none', early_stop: bool = True):
|
||||
"""check gemini plugin over model zoo
|
||||
|
||||
Args:
|
||||
early_stop (bool, optional): Whether to stop when getting the first error. Defaults to True.
|
||||
"""
|
||||
is_support_meta = is_compatible_with_meta()
|
||||
if not is_support_meta and init_method == 'lazy':
|
||||
return
|
||||
|
||||
passed_models = []
|
||||
failed_info = {} # (model_name, error) pair
|
||||
|
||||
# TODO(ver217): add more models
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, _,
|
||||
_) in model_zoo.get_sub_registry('transformers_llama_for_casual_lm').items():
|
||||
err = run_fn(init_method, model_fn, data_gen_fn, output_transform_fn)
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
if err is None:
|
||||
passed_models.append(name)
|
||||
else:
|
||||
failed_info[name] = err
|
||||
if early_stop:
|
||||
break
|
||||
|
||||
if dist.get_rank() == 0:
|
||||
print(f'Init method: {init_method}')
|
||||
print(f'Passed models({len(passed_models)}): {passed_models}\n\n')
|
||||
print(f'Failed models({len(failed_info)}): {list(failed_info.keys())}\n\n')
|
||||
assert len(failed_info) == 0, '\n'.join([f'{k}: {v}' for k, v in failed_info.items()])
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port, early_stop: bool = True):
|
||||
# init dist env
|
||||
colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost')
|
||||
check_3d_plugin(early_stop=early_stop)
|
||||
|
||||
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_gemini_plugin(early_stop: bool = True):
|
||||
spawn(run_dist, 4, early_stop=early_stop)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_gemini_plugin(early_stop=False)
|
Loading…
Reference in New Issue