mirror of https://github.com/hpcaitech/ColossalAI
[gemini] improve compatibility and add static placement policy (#4479)
* [gemini] remove distributed-related part from colotensor (#4379) * [gemini] remove process group dependency * [gemini] remove tp part from colo tensor * [gemini] patch inplace op * [gemini] fix param op hook and update tests * [test] remove useless tests * [test] remove useless tests * [misc] fix requirements * [test] fix model zoo * [test] fix model zoo * [test] fix model zoo * [test] fix model zoo * [test] fix model zoo * [misc] update requirements * [gemini] refactor gemini optimizer and gemini ddp (#4398) * [gemini] update optimizer interface * [gemini] renaming gemini optimizer * [gemini] refactor gemini ddp class * [example] update gemini related example * [example] update gemini related example * [plugin] fix gemini plugin args * [test] update gemini ckpt tests * [gemini] fix checkpoint io * [example] fix opt example requirements * [example] fix opt example * [example] fix opt example * [example] fix opt example * [gemini] add static placement policy (#4443) * [gemini] add static placement policy * [gemini] fix param offload * [test] update gemini tests * [plugin] update gemini plugin * [plugin] update gemini plugin docstr * [misc] fix flash attn requirement * [test] fix gemini checkpoint io test * [example] update resnet example result (#4457) * [example] update bert example result (#4458) * [doc] update gemini doc (#4468) * [example] update gemini related examples (#4473) * [example] update gpt example * [example] update dreambooth example * [example] update vit * [example] update opt * [example] update palm * [example] update vit and opt benchmark * [hotfix] fix bert in model zoo (#4480) * [hotfix] fix bert in model zoo * [test] remove chatglm gemini test * [test] remove sam gemini test * [test] remove vit gemini test * [hotfix] fix opt tutorial example (#4497) * [hotfix] fix opt tutorial example * [hotfix] fix opt tutorial examplepull/4504/head
parent
285fe7ba71
commit
27061426f7
|
@ -1,13 +1,11 @@
|
|||
import gc
|
||||
import logging
|
||||
import os
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from typing import Callable, Iterator, List, Optional, Tuple, Union
|
||||
from typing import Callable, Iterator, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch import Tensor
|
||||
from torch.optim import Optimizer
|
||||
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
|
||||
from torch.utils.data import DataLoader
|
||||
|
@ -16,7 +14,6 @@ from colossalai.checkpoint_io import CheckpointIndexFile, CheckpointIO, GeneralC
|
|||
from colossalai.checkpoint_io.utils import (
|
||||
get_model_base_filenames,
|
||||
get_optimizer_base_filenames,
|
||||
get_shard_filename,
|
||||
load_shard_state_dict,
|
||||
save_state_dict,
|
||||
save_state_dict_shards,
|
||||
|
@ -24,8 +21,7 @@ from colossalai.checkpoint_io.utils import (
|
|||
from colossalai.cluster import DistCoordinator
|
||||
from colossalai.interface import ModelWrapper, OptimizerWrapper
|
||||
from colossalai.utils import get_current_device
|
||||
from colossalai.zero import GeminiDDP, zero_model_wrapper, zero_optim_wrapper
|
||||
from colossalai.zero.gemini import ZeroOptimizer
|
||||
from colossalai.zero import GeminiDDP, GeminiOptimizer
|
||||
from colossalai.zero.gemini.memory_tracer import MemStats
|
||||
|
||||
from .dp_plugin_base import DPPluginBase
|
||||
|
@ -132,11 +128,7 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
|
|||
As there is communication when getting state dict, this must be called on all processes.
|
||||
"""
|
||||
|
||||
# If optimizer is wrapped, unwrap it.
|
||||
if isinstance(optimizer, OptimizerWrapper):
|
||||
optimizer = optimizer.unwrap()
|
||||
|
||||
assert isinstance(optimizer, ZeroOptimizer)
|
||||
assert isinstance(optimizer, GeminiOptimizer)
|
||||
|
||||
if os.path.isfile(checkpoint):
|
||||
logging.error(f"Provided path ({checkpoint}) should be a directory, not a file")
|
||||
|
@ -183,11 +175,7 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
|
|||
if not os.path.isfile(checkpoint_index_file):
|
||||
logging.error(f"Provided path ({checkpoint_index_file}) should be a file")
|
||||
|
||||
# If optimizer is wrapped, unwrap it.
|
||||
if isinstance(optimizer, OptimizerWrapper):
|
||||
optimizer = optimizer.unwrap()
|
||||
|
||||
assert isinstance(optimizer, ZeroOptimizer)
|
||||
assert isinstance(optimizer, GeminiOptimizer)
|
||||
|
||||
# Read checkpoint index file.
|
||||
ckpt_index_file = CheckpointIndexFile.from_file(checkpoint_index_file)
|
||||
|
@ -220,47 +208,6 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
|
|||
super().save_lr_scheduler(lr_scheduler, checkpoint)
|
||||
|
||||
|
||||
class GeminiModel(ModelWrapper):
|
||||
|
||||
def __init__(self, module: nn.Module, gemini_config: dict, verbose: bool = False) -> None:
|
||||
super().__init__(module)
|
||||
self.module = zero_model_wrapper(module, zero_stage=3, gemini_config=gemini_config, verbose=verbose)
|
||||
|
||||
def unwrap(self):
|
||||
# as save/load state dict is coupled with the GeminiDDP, we only return GeminiDDP model
|
||||
return self.module
|
||||
|
||||
|
||||
class GeminiOptimizer(OptimizerWrapper):
|
||||
|
||||
def __init__(self,
|
||||
module: GeminiDDP,
|
||||
optimizer: Optimizer,
|
||||
zero_optim_config: dict,
|
||||
optim_kwargs: dict,
|
||||
verbose: bool = False) -> None:
|
||||
optimizer = zero_optim_wrapper(module,
|
||||
optimizer,
|
||||
optim_config=zero_optim_config,
|
||||
**optim_kwargs,
|
||||
verbose=verbose)
|
||||
super().__init__(optimizer)
|
||||
|
||||
def backward(self, loss: Tensor, *args, **kwargs):
|
||||
self.optim.backward(loss)
|
||||
|
||||
def clip_grad_by_norm(self,
|
||||
max_norm: Union[float, int],
|
||||
norm_type: Union[float, int] = 2,
|
||||
error_if_nonfinite: bool = False,
|
||||
*args,
|
||||
**kwargs) -> Tensor:
|
||||
warnings.warn(f'Gemini controls grad clipping by itself, so you should not use clip_grad_by_norm')
|
||||
|
||||
def clip_grad_by_value(self, clip_value: float, *args, **kwargs) -> None:
|
||||
raise NotImplementedError('Gemini does not support clip_grad_by_value')
|
||||
|
||||
|
||||
class GeminiPlugin(DPPluginBase):
|
||||
"""
|
||||
Plugin for Gemini.
|
||||
|
@ -277,8 +224,20 @@ class GeminiPlugin(DPPluginBase):
|
|||
>>> model, optimizer, train_dataloader, criterion = booster.boost(model, optimizer, train_dataloader, criterion)
|
||||
|
||||
Args:
|
||||
device (torch.device): device to place the model.
|
||||
placement_policy (str, optional): "cpu", "cuda", "auto". Defaults to "cpu".
|
||||
chunk_config_dict (dict, optional): chunk configuration dictionary.
|
||||
chunk_init_device (torch.device, optional): device to initialize the chunk.
|
||||
placement_policy (str, optional): "static" and "auto". Defaults to "static".
|
||||
shard_param_frac (float, optional): fraction of parameters to be sharded. Only for "static" placement.
|
||||
If `shard_param_frac` is 1.0, it's equal to zero-3. If `shard_param_frac` is 0.0, it's equal to zero-2. Defaults to 1.0.
|
||||
offload_optim_frac (float, optional): fraction of optimizer states to be offloaded. Only for "static" placement.
|
||||
If `shard_param_frac` is 1.0 and `offload_optim_frac` is 0.0, it's equal to old "cuda" placement. Defaults to 0.0.
|
||||
offload_param_frac (float, optional): fraction of parameters to be offloaded. Only for "static" placement.
|
||||
For efficiency, this argument is useful only when `shard_param_frac` is 1.0 and `offload_optim_frac` is 1.0.
|
||||
If `shard_param_frac` is 1.0, `offload_optim_frac` is 1.0 and `offload_param_frac` is 1.0, it's equal to old "cpu" placement.
|
||||
When using static placement, we recommend users to tune `shard_param_frac` first and then `offload_optim_frac`.
|
||||
Defaults to 0.0.
|
||||
warmup_non_model_data_ratio (float, optional): ratio of expected non-model data memory during warmup. Only for "auto" placement. Defaults to 0.8.
|
||||
steady_cuda_cap_ratio (float, optional): ratio of allowed cuda capacity for model data during steady state. Only for "auto" placement. Defaults to 0.9.
|
||||
precision (str, optional): precision. Support 'fp16' and 'bf16'. Defaults to 'fp16'.
|
||||
pin_memory (bool, optional): use pin memory on CPU. Defaults to False.
|
||||
force_outputs_fp32 (bool, optional): force outputs are fp32. Defaults to False.
|
||||
|
@ -310,8 +269,14 @@ class GeminiPlugin(DPPluginBase):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
device: Optional[torch.device] = None,
|
||||
placement_policy: str = "cpu",
|
||||
chunk_config_dict: Optional[dict] = None,
|
||||
chunk_init_device: Optional[torch.device] = None,
|
||||
placement_policy: str = "static",
|
||||
shard_param_frac: float = 1.0, # only for static placement
|
||||
offload_optim_frac: float = 0.0, # only for static placement
|
||||
offload_param_frac: float = 0.0, # only for static placement
|
||||
warmup_non_model_data_ratio: float = 0.8, # only for auto placement
|
||||
steady_cuda_cap_ratio: float = 0.9, # only for auto placement
|
||||
precision: str = "fp16",
|
||||
pin_memory: bool = False,
|
||||
force_outputs_fp32: bool = False,
|
||||
|
@ -335,8 +300,14 @@ class GeminiPlugin(DPPluginBase):
|
|||
super().__init__()
|
||||
assert precision in SUPPORTED_PRECISION, f'precision {precision} is not supported'
|
||||
self.gemini_config = dict(
|
||||
device=(device or get_current_device()),
|
||||
chunk_config_dict=chunk_config_dict,
|
||||
chunk_init_device=(chunk_init_device or get_current_device()),
|
||||
placement_policy=placement_policy,
|
||||
shard_param_frac=shard_param_frac,
|
||||
offload_optim_frac=offload_optim_frac,
|
||||
offload_param_frac=offload_param_frac,
|
||||
warmup_non_model_data_ratio=warmup_non_model_data_ratio,
|
||||
steady_cuda_cap_ratio=steady_cuda_cap_ratio,
|
||||
pin_memory=pin_memory,
|
||||
force_outputs_fp32=force_outputs_fp32,
|
||||
strict_ddp_mode=strict_ddp_mode,
|
||||
|
@ -393,12 +364,15 @@ class GeminiPlugin(DPPluginBase):
|
|||
# model = nn.SyncBatchNorm.convert_sync_batchnorm(model, None)
|
||||
|
||||
# wrap the model with Gemini
|
||||
model = GeminiModel(model, self.gemini_config, self.verbose)
|
||||
model = GeminiDDP(model, **self.gemini_config, verbose=self.verbose)
|
||||
|
||||
if optimizer is not None and \
|
||||
not isinstance(optimizer, OptimizerWrapper):
|
||||
optimizer = GeminiOptimizer(model.unwrap(), optimizer, self.zero_optim_config, self.optim_kwargs,
|
||||
self.verbose)
|
||||
optimizer = GeminiOptimizer(optimizer,
|
||||
model.unwrap(),
|
||||
**self.zero_optim_config,
|
||||
**self.optim_kwargs,
|
||||
verbose=self.verbose)
|
||||
|
||||
return model, optimizer, criterion, dataloader, lr_scheduler
|
||||
|
||||
|
|
|
@ -3,9 +3,15 @@ from typing import Optional
|
|||
import torch
|
||||
|
||||
from colossalai.tensor.colo_tensor import ColoTensor
|
||||
from colossalai.tensor.const import TensorType
|
||||
from colossalai.tensor.param_op_hook import ColoParamOpHookManager
|
||||
from colossalai.tensor.tensor_spec import ColoTensorSpec
|
||||
|
||||
from .colo_tensor import _convert_output
|
||||
|
||||
WHITE_LIST_FUNCS = {torch.Tensor.__getitem__}
|
||||
|
||||
|
||||
def is_no_hook_op(func) -> bool:
|
||||
return func.__name__.startswith('__') and func not in WHITE_LIST_FUNCS
|
||||
|
||||
|
||||
def filter_colo_parameters(*args, **kwargs):
|
||||
|
@ -41,53 +47,25 @@ class ColoParameter(ColoTensor, torch.nn.Parameter):
|
|||
|
||||
"""
|
||||
|
||||
def __new__(cls,
|
||||
data: Optional[torch.Tensor] = None,
|
||||
requires_grad: bool = True,
|
||||
spec: ColoTensorSpec = None) -> 'ColoParameter':
|
||||
def __new__(cls, data: Optional[torch.Tensor] = None, requires_grad: bool = True) -> 'ColoParameter':
|
||||
if data is None:
|
||||
data = torch.empty(0)
|
||||
return torch.Tensor._make_subclass(cls, data, requires_grad)
|
||||
|
||||
def __init__(self,
|
||||
data: Optional[torch.Tensor] = None,
|
||||
requires_grad: bool = True,
|
||||
spec: ColoTensorSpec = None) -> None:
|
||||
ColoTensor.__init__(self, data, spec)
|
||||
self._type = TensorType.MODEL
|
||||
# a list contains modules sharing this ColoParameter with others.
|
||||
self._shared_param_modules = []
|
||||
|
||||
@property
|
||||
def shared_param_modules(self):
|
||||
return self._shared_param_modules
|
||||
|
||||
@staticmethod
|
||||
def from_torch_tensor(tensor: torch.Tensor,
|
||||
requires_grad: bool = True,
|
||||
spec: ColoTensorSpec = None) -> 'ColoParameter':
|
||||
tensor = tensor.as_subclass(ColoParameter)
|
||||
tensor.__init__(tensor, requires_grad=requires_grad, spec=spec)
|
||||
return tensor
|
||||
|
||||
def __repr__(self):
|
||||
return super(ColoParameter, self).__repr__()
|
||||
|
||||
@classmethod
|
||||
def __torch_function__(cls, func, types, args=..., kwargs=None):
|
||||
if ColoParamOpHookManager.has_hook():
|
||||
if not func.__name__.startswith('__'):
|
||||
if kwargs is None:
|
||||
kwargs = {}
|
||||
params = filter_colo_parameters(*args, **kwargs)
|
||||
if len(params) > 0:
|
||||
with torch._C.DisableTorchFunction():
|
||||
new_args = ColoParamOpHookManager.pre_op(params, *args, *kwargs.values())
|
||||
args, kwargs = replace_args(args, kwargs, new_args)
|
||||
ret = super().__torch_function__(func, types, args, kwargs)
|
||||
with torch._C.DisableTorchFunction():
|
||||
ret = ColoParamOpHookManager.post_op(params, ret)
|
||||
return ret
|
||||
if kwargs is None:
|
||||
kwargs = {}
|
||||
if ColoParamOpHookManager.has_hook() and not is_no_hook_op(func):
|
||||
params = filter_colo_parameters(*args, **kwargs)
|
||||
if len(params) > 0:
|
||||
with torch._C.DisableTorchFunction():
|
||||
new_args = ColoParamOpHookManager.pre_op(params, *args, *kwargs.values())
|
||||
args, kwargs = replace_args(args, kwargs, new_args)
|
||||
ret = super().__torch_function__(func, types, args, kwargs)
|
||||
with torch._C.DisableTorchFunction():
|
||||
ret = ColoParamOpHookManager.post_op(params, ret)
|
||||
return _convert_output(ret, func)
|
||||
return super().__torch_function__(func, types, args, kwargs)
|
||||
|
||||
def __deepcopy__(self, memo):
|
||||
|
@ -96,9 +74,7 @@ class ColoParameter(ColoTensor, torch.nn.Parameter):
|
|||
else:
|
||||
with torch._C.DisableTorchFunction():
|
||||
data = self.data.clone()
|
||||
tensor = ColoParameter(data,
|
||||
self.requires_grad,
|
||||
spec=ColoTensorSpec(self.get_process_group(), self.dist_spec, self.compute_spec))
|
||||
tensor = ColoParameter(data, self.requires_grad)
|
||||
memo[id(self)] = tensor
|
||||
return tensor
|
||||
|
||||
|
|
|
@ -1,17 +1,14 @@
|
|||
import operator
|
||||
from copy import copy
|
||||
from functools import lru_cache, reduce
|
||||
from typing import Callable, Optional, Set
|
||||
from functools import lru_cache
|
||||
from typing import Callable, Set
|
||||
|
||||
import torch
|
||||
|
||||
from colossalai.tensor.dist_spec_mgr import DistSpecManager
|
||||
from colossalai.tensor.distspec import DistPlacementPattern, ReplicaSpec, _DistSpec
|
||||
from colossalai.tensor.process_group import ProcessGroup
|
||||
from colossalai.tensor.tensor_spec import ColoTensorSpec
|
||||
|
||||
from .const import TensorType
|
||||
from .op_wrapper import _COLOSSAL_OPS
|
||||
INPALCE_MAPPING = {
|
||||
torch.Tensor.add_: torch.Tensor.add,
|
||||
torch.Tensor.sub_: torch.Tensor.sub,
|
||||
torch.Tensor.mul_: torch.Tensor.mul,
|
||||
torch.Tensor.div_: torch.Tensor.div
|
||||
}
|
||||
|
||||
|
||||
@lru_cache(None)
|
||||
|
@ -25,61 +22,37 @@ def _get_my_nowrap_functions() -> Set[Callable]:
|
|||
}
|
||||
|
||||
|
||||
def _convert_output(output, colo_spec: ColoTensorSpec):
|
||||
if type(output) == torch.Tensor:
|
||||
return ColoTensor.from_torch_tensor(output, colo_spec)
|
||||
def _convert(output):
|
||||
if isinstance(output, torch.Tensor) and not isinstance(output, ColoTensor):
|
||||
output.__class__ = ColoTensor
|
||||
elif isinstance(output, (list, tuple)):
|
||||
return type(output)(_convert_output(o, colo_spec) for o in output)
|
||||
else:
|
||||
output = type(output)(_convert(o) for o in output)
|
||||
return output
|
||||
|
||||
|
||||
def _convert_output(output, func):
|
||||
if func in _get_my_nowrap_functions():
|
||||
return output
|
||||
|
||||
|
||||
def _get_spec_from_args(args, kwargs) -> ColoTensorSpec:
|
||||
for elem in args:
|
||||
if isinstance(elem, ColoTensor):
|
||||
pg = elem.get_process_group()
|
||||
dp = elem.dist_spec
|
||||
return ColoTensorSpec(pg, dp)
|
||||
elif isinstance(elem, (list, tuple)):
|
||||
spec = _get_spec_from_args(elem, {})
|
||||
if spec is not None:
|
||||
return spec
|
||||
for k, v in kwargs.items():
|
||||
if isinstance(v, ColoTensor):
|
||||
pg = v.get_process_group()
|
||||
dp = v.dist_spec
|
||||
return ColoTensorSpec(pg, dp)
|
||||
return None
|
||||
return _convert(output)
|
||||
|
||||
|
||||
class ColoTensor(torch.Tensor):
|
||||
""" Data Structure for Tensor in Colossal-AI. It is a subclass of torch.Tensor.
|
||||
|
||||
The Colotensor can be initialized with a PyTorch tensor in the following ways.
|
||||
|
||||
>>> pg = ProcessGroup()
|
||||
>>> colo_t1 = ColoTensor(torch.randn(2,3), spec = ColoTensorSpec(pg, ReplicaSpec()))
|
||||
>>> # The tensor passed in is a tensor after sharding but not a global tensor.
|
||||
>>> shard_spec = ShardSpec(process_group=ProcessGroup(tp=world_size),
|
||||
>>> dims=[0],
|
||||
>>> num_partitions=[world_size])
|
||||
>>> tensor_spec = ColoTensorSpec(pg, shard_spec)
|
||||
>>> colo_t2 = ColoTensor.from_torch_tensor(t_ref.clone(), tensor_spec)
|
||||
It is only used to trigger the torch function hook.
|
||||
|
||||
Args:
|
||||
data (torch.Tensor): a torch tensor used as the payload the colotensor.
|
||||
spec (ColoTensorSpec, optional): the tensor spec of initialization. Defaults to ColoTensorSpec(ReplicaSpec()).
|
||||
"""
|
||||
torch_major = int(torch.__version__.split('.')[0])
|
||||
torch_minor = int(torch.__version__.split('.')[1])
|
||||
|
||||
def __new__(cls, data: torch.Tensor, spec: ColoTensorSpec) -> 'ColoTensor':
|
||||
def __new__(cls, data: torch.Tensor) -> 'ColoTensor':
|
||||
"""
|
||||
The signature of the __new__ has to be consistent with the torch.Tensor.
|
||||
|
||||
Args:
|
||||
data (torch.Tensor): a torch tensor used as the payload the colotensor.
|
||||
spec (TensorSpec, optional): the tensor spec of initialization.
|
||||
|
||||
Returns:
|
||||
ColoTensor: a ColoTensor wrappers the data.
|
||||
|
@ -88,86 +61,6 @@ class ColoTensor(torch.Tensor):
|
|||
data = torch.empty(0)
|
||||
return torch.Tensor._make_subclass(cls, data, data.requires_grad)
|
||||
|
||||
def __init__(self, data: torch.Tensor, spec: Optional[ColoTensorSpec] = None) -> None:
|
||||
# If not set spec, use a DP process group and replicate dist spec
|
||||
if spec is None:
|
||||
self.has_initialized = False
|
||||
self.dist_spec = ReplicaSpec()
|
||||
self.compute_spec = None
|
||||
self.process_group = ProcessGroup()
|
||||
else:
|
||||
self.has_initialized = True
|
||||
self.dist_spec = spec.dist_attr
|
||||
self.compute_spec = spec.compute_attr
|
||||
if spec.pg is None:
|
||||
self.process_group = ProcessGroup()
|
||||
else:
|
||||
self.process_group = spec.pg
|
||||
|
||||
self._type = TensorType.NONMODEL
|
||||
|
||||
def has_compute_spec(self) -> bool:
|
||||
return self.compute_spec is not None
|
||||
|
||||
def is_model_data(self) -> bool:
|
||||
return self._type == TensorType.MODEL
|
||||
|
||||
def get_process_group(self) -> 'ProcessGroup':
|
||||
return self.process_group
|
||||
|
||||
def set_process_group(self, pg: ProcessGroup):
|
||||
"""set_process_group
|
||||
change the pg of the ColoTensor. Note that the valid use cases is limited.
|
||||
It works for the target pg is DP and TP only and current dist spec of the Tensor is Replica.
|
||||
|
||||
Args:
|
||||
pg (ProcessGroup): target pg
|
||||
|
||||
"""
|
||||
assert isinstance(pg, ProcessGroup), f"pg as type {type(pg)} is invalid"
|
||||
# if the new pg is the same as the old pg, just returns
|
||||
if self.process_group == pg:
|
||||
return
|
||||
assert self.process_group.tp_world_size() == 1 or self.process_group.dp_world_size() == 1, \
|
||||
"Can not set_process_group on a ColoTensor whose process_group is both tp > 1 and world group > 1"
|
||||
assert self.dist_spec.placement.value == 'r', \
|
||||
"Can not set_process_group on a ColoTensor whose dist spec is not Replica"
|
||||
|
||||
self.process_group = pg
|
||||
|
||||
def get_tp_world_size(self) -> int:
|
||||
return self.process_group.tp_world_size()
|
||||
|
||||
def get_dp_world_size(self) -> int:
|
||||
"""get_dp_world_size
|
||||
get the dp world size of the tensor.
|
||||
|
||||
Returns:
|
||||
int: dp world size
|
||||
"""
|
||||
return self.process_group.dp_world_size()
|
||||
|
||||
def set_dist_spec(self, dist_spec: _DistSpec):
|
||||
"""set_dist_spec
|
||||
set dist spec and change the payloads.
|
||||
|
||||
Args:
|
||||
dist_spec (_DistSpec): target dist spec.
|
||||
"""
|
||||
assert isinstance(dist_spec, _DistSpec)
|
||||
assert self.process_group is not None
|
||||
self._redistribute(dist_spec)
|
||||
|
||||
def set_tensor_spec(self, dist_spec, compute_spec):
|
||||
if dist_spec is not None:
|
||||
assert isinstance(dist_spec, _DistSpec), f"{type(dist_spec)}"
|
||||
self.set_dist_spec(dist_spec)
|
||||
if compute_spec is not None:
|
||||
self.compute_spec = compute_spec
|
||||
|
||||
def has_compute_pattern(self, compute_pattern):
|
||||
return self.compute_spec.compute_pattern == compute_pattern
|
||||
|
||||
@classmethod
|
||||
def __torch_function__(cls, func, types, args=(), kwargs=None):
|
||||
if kwargs is None:
|
||||
|
@ -175,9 +68,6 @@ class ColoTensor(torch.Tensor):
|
|||
|
||||
if not all(issubclass(cls, t) for t in types):
|
||||
return NotImplemented
|
||||
global _COLOSSAL_OPS
|
||||
if func in _COLOSSAL_OPS:
|
||||
func = _COLOSSAL_OPS[func]
|
||||
|
||||
if cls.torch_major > 1 or (cls.torch_major == 1 and cls.torch_minor >= 12):
|
||||
# in order to trigger pre-op hook in the forward of checkpoint module
|
||||
|
@ -189,94 +79,16 @@ class ColoTensor(torch.Tensor):
|
|||
tensor_kwargs = {k: torch.Tensor(v) if torch.is_tensor(v) else v for k, v in kwargs.items()}
|
||||
return backward_tensor.backward(**tensor_kwargs)
|
||||
|
||||
# replace the in-place function
|
||||
if func in INPALCE_MAPPING:
|
||||
func = INPALCE_MAPPING[func]
|
||||
# set the 'inplace' kwargs to False
|
||||
if 'inplace' in kwargs:
|
||||
kwargs['inplace'] = False
|
||||
|
||||
with torch._C.DisableTorchFunction():
|
||||
ret = func(*args, **kwargs)
|
||||
if func in _get_my_nowrap_functions():
|
||||
return ret
|
||||
else:
|
||||
colo_spec = _get_spec_from_args(args, kwargs)
|
||||
return _convert_output(ret, colo_spec)
|
||||
|
||||
def __repr__(self):
|
||||
output_list = [super(ColoTensor, self).__repr__()]
|
||||
output_list.append(str(self.process_group))
|
||||
output_list.append(str(self.dist_spec))
|
||||
if self.compute_spec is not None:
|
||||
output_list.append(str(self.compute_spec))
|
||||
return "\n".join(output_list)
|
||||
|
||||
def _redistribute(self, dist_spec: _DistSpec) -> None:
|
||||
"""_redistribute
|
||||
Note the function will not handle the logic of backward propagation!
|
||||
It is used during model tensor initializations as an internal function.
|
||||
|
||||
Args:
|
||||
dist_spec (_DistSpec): the target dist. spec.
|
||||
"""
|
||||
assert self.grad_fn is None, "Current tensor has grad_fn and it can't get converted"
|
||||
with DistSpecManager.no_grad():
|
||||
self.data = DistSpecManager.handle_trans_spec(self.data, self.dist_spec, dist_spec, self.process_group)
|
||||
self.dist_spec = dist_spec
|
||||
|
||||
def redistribute(self, dist_spec: _DistSpec, pg: Optional[ProcessGroup] = None) -> 'ColoTensor':
|
||||
"""redistribute
|
||||
Redistribute the tensor among processes. The rule is like this:
|
||||
|
||||
1. If the pg is None, then redistribute the tensor payload among the TP process group. Keep the
|
||||
DP process group not changed.
|
||||
|
||||
2. If the pg is not not None and not equal to the current process group.
|
||||
First, convert the tensor as replicated among the TP process group.
|
||||
Second, reset the process group to the new pg.
|
||||
Third, convert the tensor (new replicated both among the tp process group) to the new dist_spec.
|
||||
|
||||
Args:
|
||||
dist_spec (_DistSpec): the new dist spec.
|
||||
pg (Optional[ProcessGroup], optional): the new process group . Defaults to None.
|
||||
|
||||
Returns:
|
||||
ColoTensor: a redistributed colotensor
|
||||
"""
|
||||
if pg is not None and pg != self.get_process_group():
|
||||
# if the pg is not equal, convert the current tensor to replicated
|
||||
handled = self.redistribute(ReplicaSpec())
|
||||
else:
|
||||
handled = self
|
||||
pg = self.process_group
|
||||
|
||||
ret = DistSpecManager.handle_trans_spec(handled, handled.dist_spec, dist_spec, pg)
|
||||
return ColoTensor.from_torch_tensor(ret, ColoTensorSpec(pg=pg, dist_attr=dist_spec))
|
||||
|
||||
def to_replicate_(self):
|
||||
"""to_replicate_
|
||||
|
||||
an inline member function, converting dist spec of the tensor to REPLICATE
|
||||
"""
|
||||
self._redistribute(dist_spec=ReplicaSpec())
|
||||
|
||||
def to_replicate(self) -> 'ColoTensor':
|
||||
"""to_replicate
|
||||
|
||||
converting dist spec of the tensor to ReplicaSpec()
|
||||
"""
|
||||
return self.redistribute(ReplicaSpec())
|
||||
|
||||
@staticmethod
|
||||
def from_torch_tensor(tensor: torch.Tensor, spec: Optional[ColoTensorSpec] = None) -> 'ColoTensor':
|
||||
"""from_torch_tensor
|
||||
|
||||
A static method builds a `ColoTensor` from a PyTorch Tensor.
|
||||
|
||||
Args:
|
||||
tensor (torch.Tensor): the pytorch tensor, which is a local tensor for this rank not a global tensor.
|
||||
spec (Optional[ColoTensorSpec], optional): tensor spec. Defaults to None.
|
||||
|
||||
Returns:
|
||||
ColoTensor: a ColoTensor
|
||||
"""
|
||||
tensor = tensor.as_subclass(ColoTensor)
|
||||
tensor.__init__(tensor, spec=spec)
|
||||
return tensor
|
||||
return _convert_output(ret, func)
|
||||
|
||||
def __deepcopy__(self, memo):
|
||||
if id(self) in memo:
|
||||
|
@ -284,60 +96,6 @@ class ColoTensor(torch.Tensor):
|
|||
else:
|
||||
with torch._C.DisableTorchFunction():
|
||||
data = self.data.clone()
|
||||
tensor = ColoTensor(data, spec=copy(ColoTensorSpec(self.process_group, self.dist_spec, self.compute_spec)))
|
||||
tensor = ColoTensor(data)
|
||||
memo[id(self)] = tensor
|
||||
return tensor
|
||||
|
||||
# override builtin functions which must use tensor in replicate placement #
|
||||
|
||||
def size_local(self, *args) -> torch.Size:
|
||||
with torch._C.DisableTorchFunction():
|
||||
return super().size(*args)
|
||||
|
||||
def size_global(self, *args) -> torch.Size:
|
||||
"""size_global
|
||||
|
||||
override the torch building size()
|
||||
the shape passed in must be in a replicate placement.
|
||||
|
||||
Returns:
|
||||
torch.Size: the global tensor shape
|
||||
"""
|
||||
if self.is_replicate():
|
||||
return self.size_local(*args)
|
||||
spec = self.dist_spec
|
||||
dims = spec.dims
|
||||
num_partitions = spec.num_partitions
|
||||
# import inspect
|
||||
# print(*['{:40}| {}:{}\n'.format(x.function, x.filename, x.lineno) for x in inspect.stack()])
|
||||
size_list = list(self.size_local())
|
||||
for dim, num_partition in zip(dims, num_partitions):
|
||||
size_list[dim] *= num_partition
|
||||
if args == ():
|
||||
return torch.Size(size_list)
|
||||
else:
|
||||
return size_list[args[0]]
|
||||
|
||||
def numel_global(self):
|
||||
"""Returns the number of elements in the tensor when it's replicated.
|
||||
"""
|
||||
return reduce(operator.mul, self.size_global(), 1)
|
||||
|
||||
# Some API for dist spec check
|
||||
|
||||
def is_replicate(self):
|
||||
return self.dist_spec.placement == DistPlacementPattern.REPLICATE \
|
||||
or (len(self.dist_spec.num_partitions) == 1
|
||||
and self.dist_spec.num_partitions[0] == 1) \
|
||||
or (self.process_group.tp_world_size() == 1)
|
||||
|
||||
def is_shard_1dcol(self):
|
||||
return self.dist_spec.placement == DistPlacementPattern.SHARD \
|
||||
and len(self.dist_spec.dims) == 1 and self.dist_spec.dims[0] == -1
|
||||
|
||||
def is_shard_1drow(self):
|
||||
return self.dist_spec.placement == DistPlacementPattern.SHARD \
|
||||
and len(self.dist_spec.dims) == 1 and self.dist_spec.dims[0] == 0
|
||||
|
||||
def is_sharded(self):
|
||||
return self.dist_spec.placement == DistPlacementPattern.SHARD
|
||||
|
|
|
@ -3,9 +3,7 @@ from contextlib import contextmanager
|
|||
from typing import Any, List, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from colossalai.tensor.colo_tensor import ColoTensor
|
||||
from colossalai.tensor.tensor_spec import ColoTensorSpec
|
||||
from torch.utils._pytree import TreeSpec, tree_flatten, tree_unflatten
|
||||
|
||||
|
||||
class ColoParamOpHook(ABC):
|
||||
|
@ -82,26 +80,18 @@ class ColoParamOpHookManager:
|
|||
@staticmethod
|
||||
def pre_op(params: List[torch.Tensor], *args: Any) -> list:
|
||||
ColoParamOpHookManager._trigger_pre_forward(params)
|
||||
grad_args, rear_args = _get_grad_args(*args)
|
||||
colo_info = _get_colo_tensors_info(*grad_args)
|
||||
rets = PreFwdPostBwd.apply(params, *grad_args)
|
||||
update_args = _update_colo_tensors(colo_info, *rets)
|
||||
if rear_args is None:
|
||||
return update_args
|
||||
else:
|
||||
arg_zero = (tuple(update_args),)
|
||||
return arg_zero + rear_args
|
||||
# auto grad function can only recognize torch.Tensor, thus we have to flatten the input
|
||||
# if one of the input requires grad, all the output will be treated as requires grad
|
||||
# and will have grad fn even the corresponding input does not require grad
|
||||
# we have to extract tensors requiring grad into flat list and then merge them back
|
||||
grad_args, other_args, grad_flags, spec = _flatten_grad_args(args)
|
||||
new_grad_args = PreFwdPostBwd.apply(params, *grad_args)
|
||||
return _merge_args(new_grad_args, other_args, grad_flags, spec)
|
||||
|
||||
@staticmethod
|
||||
def post_op(params: List[torch.Tensor], arg: Any) -> Any:
|
||||
ColoParamOpHookManager._trigger_post_forward(params)
|
||||
colo_info = _get_colo_tensors_info(arg)
|
||||
ret = PostFwdPreBwd.apply(params, arg)
|
||||
res = _update_colo_tensors(colo_info, ret)
|
||||
if len(res) == 1:
|
||||
return res[0]
|
||||
else:
|
||||
return res
|
||||
return PostFwdPreBwd.apply(params, arg)
|
||||
|
||||
@staticmethod
|
||||
def has_hook() -> bool:
|
||||
|
@ -141,57 +131,24 @@ def _is_grad_tensor(obj) -> bool:
|
|||
return False
|
||||
|
||||
|
||||
def _has_grad_tensor(obj) -> bool:
|
||||
if isinstance(obj, tuple) or isinstance(obj, list):
|
||||
for x in obj:
|
||||
if _has_grad_tensor(x):
|
||||
return True
|
||||
return False
|
||||
elif isinstance(obj, dict):
|
||||
for x in obj.values():
|
||||
if _has_grad_tensor(x):
|
||||
return True
|
||||
return False
|
||||
else:
|
||||
return _is_grad_tensor(obj)
|
||||
|
||||
|
||||
def _get_grad_args(*args):
|
||||
# if there is no grad tensors, do nothing
|
||||
if not _has_grad_tensor(args):
|
||||
return args, None
|
||||
# returns the identical args if there is a grad tensor
|
||||
for obj in args:
|
||||
if _is_grad_tensor(obj):
|
||||
return args, None
|
||||
# otherwise, the first argument should be a tuple of grad tensors
|
||||
# if there is no grad tensor, the backward of PreFwdPostBwd can't be triggered
|
||||
arg_zero = args[0]
|
||||
if not isinstance(arg_zero, tuple):
|
||||
raise NotImplementedError("Some torch function is incompatible because of its complicated inputs.")
|
||||
check_grad_flag = False
|
||||
for obj in arg_zero:
|
||||
check_grad_flag |= _is_grad_tensor(obj)
|
||||
if not check_grad_flag:
|
||||
raise NotImplementedError("Some torch function is incompatible because of its complicated inputs.")
|
||||
return arg_zero, args[1:]
|
||||
|
||||
|
||||
def _get_colo_tensors_info(*args) -> list:
|
||||
info = []
|
||||
for arg in args:
|
||||
if isinstance(arg, ColoTensor):
|
||||
info.append((arg.__class__, ColoTensorSpec(arg.get_process_group(), arg.dist_spec, arg.compute_spec)))
|
||||
def _flatten_grad_args(args) -> Tuple[list, list, List[bool], TreeSpec]:
|
||||
flat_args, spec = tree_flatten(args)
|
||||
grad_args = []
|
||||
other_args = []
|
||||
grad_flags = []
|
||||
for arg in flat_args:
|
||||
flag = _is_grad_tensor(arg)
|
||||
grad_flags.append(flag)
|
||||
if flag:
|
||||
grad_args.append(arg)
|
||||
else:
|
||||
info.append(None)
|
||||
return info
|
||||
other_args.append(arg)
|
||||
assert len(grad_args) > 0
|
||||
return grad_args, other_args, grad_flags, spec
|
||||
|
||||
|
||||
def _update_colo_tensors(info, *args) -> list:
|
||||
ret = []
|
||||
for t_info, arg in zip(info, args):
|
||||
if t_info is not None:
|
||||
t_cls, spec = t_info
|
||||
arg = t_cls.from_torch_tensor(arg, spec=spec)
|
||||
ret.append(arg)
|
||||
return ret
|
||||
def _merge_args(grad_args, other_args, grad_flags, spec):
|
||||
grad_iter = iter(grad_args)
|
||||
other_iter = iter(other_args)
|
||||
flat_args = [next(grad_iter) if flag else next(other_iter) for flag in grad_flags]
|
||||
return tree_unflatten(flat_args, spec)
|
||||
|
|
|
@ -2,8 +2,7 @@ from .gemini import (
|
|||
ColoInitContext,
|
||||
GeminiAdamOptimizer,
|
||||
GeminiDDP,
|
||||
ZeroDDP,
|
||||
ZeroOptimizer,
|
||||
GeminiOptimizer,
|
||||
get_static_torch_model,
|
||||
post_process_colo_init_ctx,
|
||||
)
|
||||
|
@ -11,6 +10,6 @@ from .low_level import LowLevelZeroOptimizer
|
|||
from .wrapper import zero_model_wrapper, zero_optim_wrapper
|
||||
|
||||
__all__ = [
|
||||
'ZeroDDP', 'GeminiDDP', 'ZeroOptimizer', 'GeminiAdamOptimizer', 'zero_model_wrapper', 'zero_optim_wrapper',
|
||||
'GeminiDDP', 'GeminiOptimizer', 'GeminiAdamOptimizer', 'zero_model_wrapper', 'zero_optim_wrapper',
|
||||
'LowLevelZeroOptimizer', 'ColoInitContext', 'post_process_colo_init_ctx', 'get_static_torch_model'
|
||||
]
|
||||
|
|
|
@ -1,11 +1,11 @@
|
|||
from .chunk import ChunkManager, TensorInfo, TensorState, search_chunk_configuration
|
||||
from .colo_init_context import ColoInitContext, post_process_colo_init_ctx
|
||||
from .gemini_ddp import GeminiDDP, ZeroDDP
|
||||
from .gemini_ddp import GeminiDDP
|
||||
from .gemini_mgr import GeminiManager
|
||||
from .gemini_optimizer import GeminiAdamOptimizer, ZeroOptimizer
|
||||
from .gemini_optimizer import GeminiAdamOptimizer, GeminiOptimizer
|
||||
from .utils import get_static_torch_model
|
||||
|
||||
__all__ = [
|
||||
'GeminiManager', 'TensorInfo', 'TensorState', 'ChunkManager', 'search_chunk_configuration', 'ZeroDDP', 'GeminiDDP',
|
||||
'get_static_torch_model', 'GeminiAdamOptimizer', 'ZeroOptimizer', 'ColoInitContext', 'post_process_colo_init_ctx'
|
||||
'GeminiManager', 'TensorInfo', 'TensorState', 'ChunkManager', 'search_chunk_configuration', 'GeminiDDP',
|
||||
'get_static_torch_model', 'GeminiAdamOptimizer', 'GeminiOptimizer', 'ColoInitContext', 'post_process_colo_init_ctx'
|
||||
]
|
||||
|
|
|
@ -4,8 +4,8 @@ from typing import Dict, List, Optional
|
|||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
from colossalai.tensor import ProcessGroup as ColoProcessGroup
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
|
||||
|
@ -55,7 +55,7 @@ class Chunk:
|
|||
|
||||
def __init__(self,
|
||||
chunk_size: int,
|
||||
process_group: ColoProcessGroup,
|
||||
process_group: ProcessGroup,
|
||||
dtype: torch.dtype,
|
||||
init_device: Optional[torch.device] = None,
|
||||
cpu_shard_init: bool = False,
|
||||
|
@ -69,7 +69,7 @@ class Chunk:
|
|||
|
||||
Args:
|
||||
chunk_size (int): the number of elements in the chunk
|
||||
process_group (ColoProcessGroup): the process group of this chunk
|
||||
process_group (ProcessGroup): the process group of this chunk
|
||||
dtype (torch.dtype): the data type of the chunk
|
||||
init_device (torch.device): optional, During the chunk construction process, where the tensor is stored.
|
||||
The default value is None, which is the current GPU
|
||||
|
@ -83,7 +83,7 @@ class Chunk:
|
|||
self.chunk_size = chunk_size
|
||||
self.utilized_size = 0
|
||||
|
||||
self.torch_pg = process_group.dp_process_group()
|
||||
self.torch_pg = process_group
|
||||
self.pg_size = dist.get_world_size(self.torch_pg)
|
||||
self.pg_rank = dist.get_rank(self.torch_pg)
|
||||
|
||||
|
@ -218,7 +218,7 @@ class Chunk:
|
|||
return False
|
||||
else:
|
||||
return self.tensor_state_cnter[TensorState.HOLD] + \
|
||||
self.tensor_state_cnter[TensorState.HOLD_AFTER_BWD] == self.num_tensors
|
||||
self.tensor_state_cnter[TensorState.HOLD_AFTER_BWD] == self.num_tensors
|
||||
|
||||
@property
|
||||
def can_reduce(self):
|
||||
|
|
|
@ -2,8 +2,9 @@ from collections import deque
|
|||
from typing import Deque, Dict, Iterable, List, Optional, Set, Tuple
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
from colossalai.tensor import ColoTensor
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
from .chunk import Chunk, ChunkFullError, TensorState
|
||||
|
@ -27,16 +28,17 @@ class ChunkManager:
|
|||
self.dp_degree_chunk_size_dict[k] = v.pop('chunk_size')
|
||||
v['init_device'] = self.device
|
||||
|
||||
self.chunk_groups: Dict[str, Deque] = dict()
|
||||
self.chunk_groups: Dict[str, Deque[Chunk]] = dict()
|
||||
self.tensor_chunk_map: Dict[torch.Tensor, Chunk] = dict()
|
||||
self.accessed_chunks: Set[Chunk] = set()
|
||||
self.accessed_mem: int = 0
|
||||
self.total_mem: Dict[str, int] = {'cpu': 0, 'cuda': 0}
|
||||
|
||||
def register_tensor(self,
|
||||
tensor: ColoTensor,
|
||||
tensor: torch.Tensor,
|
||||
group_type: str,
|
||||
config_key: int,
|
||||
process_group: ProcessGroup,
|
||||
cpu_offload: bool = False,
|
||||
pin_memory: bool = False) -> None:
|
||||
"""
|
||||
|
@ -51,7 +53,7 @@ class ChunkManager:
|
|||
pin_memory: whether the chunk is pinned in the cpu memory
|
||||
"""
|
||||
assert tensor not in self.tensor_chunk_map
|
||||
assert isinstance(tensor, ColoTensor), "Please feed ColoTensor to this ChunkManager"
|
||||
assert isinstance(tensor, torch.Tensor), "Please feed Tensor to this ChunkManager"
|
||||
assert config_key in self.dp_degree_chunk_size_dict
|
||||
|
||||
chunk_size = self.dp_degree_chunk_size_dict[config_key]
|
||||
|
@ -73,12 +75,12 @@ class ChunkManager:
|
|||
|
||||
if tensor.numel() > chunk_size:
|
||||
chunk_size = tensor.numel()
|
||||
dp_size = tensor.get_dp_world_size()
|
||||
dp_size = dist.get_world_size(process_group)
|
||||
chunk_size = chunk_size + (-chunk_size % dp_size)
|
||||
|
||||
chunk = Chunk(
|
||||
chunk_size=chunk_size,
|
||||
process_group=tensor.process_group,
|
||||
process_group=process_group,
|
||||
dtype=tensor.dtype,
|
||||
cpu_shard_init=cpu_offload,
|
||||
pin_memory=pin_memory,
|
||||
|
@ -220,7 +222,7 @@ class ChunkManager:
|
|||
msg.append(f'[{i}] {chunk}\n')
|
||||
return ''.join(msg)
|
||||
|
||||
def __get_chunk_group(self, group_name: str) -> Deque:
|
||||
def __get_chunk_group(self, group_name: str) -> Deque[Chunk]:
|
||||
"""Register a chunk group.
|
||||
"""
|
||||
if group_name not in self.chunk_groups:
|
||||
|
|
|
@ -4,6 +4,7 @@ from typing import Dict, List, Optional, Tuple
|
|||
import numpy as np
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
from colossalai.tensor import ColoParameter
|
||||
from colossalai.utils import is_ddp_ignored
|
||||
|
@ -59,7 +60,7 @@ def _get_unused_byte(size_list: List[int], chunk_size: int) -> int:
|
|||
return left + acc
|
||||
|
||||
|
||||
def _tensor_numel(local_param: ColoParameter, strict_ddp_flag: bool) -> int:
|
||||
def _tensor_numel(local_param: ColoParameter) -> int:
|
||||
"""_tensor_numel
|
||||
|
||||
Get the number of elements of a tensor.
|
||||
|
@ -71,15 +72,12 @@ def _tensor_numel(local_param: ColoParameter, strict_ddp_flag: bool) -> int:
|
|||
Returns:
|
||||
int: the number of elements.
|
||||
"""
|
||||
if strict_ddp_flag and type(local_param) is ColoParameter:
|
||||
return local_param.numel_global()
|
||||
else:
|
||||
# if local_param is not ColoParameter, we assume it's replicated
|
||||
return local_param.numel()
|
||||
# TODO(ver217): support dtensor here
|
||||
return local_param.numel()
|
||||
|
||||
|
||||
def classify_params_by_dp_degree(param_order: OrderedParamGenerator,
|
||||
strict_ddp_flag: bool = False) -> Dict[int, List[ColoParameter]]:
|
||||
process_group: ProcessGroup) -> Dict[int, List[ColoParameter]]:
|
||||
"""classify_params_by_dp_degree
|
||||
|
||||
Classify the parameters by their dp degree
|
||||
|
@ -97,13 +95,7 @@ def classify_params_by_dp_degree(param_order: OrderedParamGenerator,
|
|||
# assert isinstance(param, ColoParameter), "please init model in the ColoInitContext"
|
||||
if is_ddp_ignored(param):
|
||||
continue
|
||||
|
||||
if strict_ddp_flag or type(param) is not ColoParameter:
|
||||
# if model is not initialized with ColoInitContext, we assume it's replicated
|
||||
# TODO(ver217): integrate DTensor
|
||||
param_key = dist.get_world_size()
|
||||
else:
|
||||
param_key = param.process_group.dp_world_size()
|
||||
param_key = dist.get_world_size(process_group)
|
||||
|
||||
if param_key not in params_dict:
|
||||
params_dict[param_key] = []
|
||||
|
@ -119,6 +111,7 @@ def search_chunk_configuration(
|
|||
min_chunk_size_m: float = 32,
|
||||
filter_exlarge_params: bool = True,
|
||||
strict_ddp_flag: bool = False,
|
||||
process_group: Optional[ProcessGroup] = None,
|
||||
memstas: Optional[MemStats] = None) -> Tuple[Dict, int, int]:
|
||||
"""search_chunk_configuration
|
||||
|
||||
|
@ -149,7 +142,7 @@ def search_chunk_configuration(
|
|||
min_chunk_size = round(min_chunk_size_m * 1024**2)
|
||||
assert search_range >= 0
|
||||
|
||||
params_dict = classify_params_by_dp_degree(param_order, strict_ddp_flag)
|
||||
params_dict = classify_params_by_dp_degree(param_order, process_group)
|
||||
size_lcm = np.lcm.reduce(list(params_dict.keys()))
|
||||
config_dict: Dict[int, Dict] = dict()
|
||||
total_param_size = 0
|
||||
|
@ -157,7 +150,7 @@ def search_chunk_configuration(
|
|||
size_dict: Dict[int, List[int]] = dict()
|
||||
for dp_degree in params_dict:
|
||||
params_list = params_dict[dp_degree]
|
||||
size_list = [_tensor_numel(p, strict_ddp_flag) for p in params_list]
|
||||
size_list = [_tensor_numel(p) for p in params_list]
|
||||
group_acc_size = sum(size_list)
|
||||
total_param_size += group_acc_size
|
||||
|
||||
|
|
|
@ -2,19 +2,20 @@ import itertools
|
|||
from collections import OrderedDict
|
||||
from contextlib import nullcontext
|
||||
from functools import partial
|
||||
from typing import Dict, Iterator, List, Optional, Set, Tuple, Union
|
||||
from typing import Dict, Iterable, Iterator, List, Optional, Set, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from torch.distributed import ProcessGroup
|
||||
from torch.distributed.distributed_c10d import _get_default_group
|
||||
|
||||
from colossalai.checkpoint_io.utils import calculate_tensor_size
|
||||
from colossalai.interface import ModelWrapper
|
||||
from colossalai.lazy import LazyTensor
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.nn.parallel.data_parallel import ColoDDP, _cast_float, free_storage
|
||||
from colossalai.tensor import ProcessGroup as ColoProcessGroup
|
||||
from colossalai.tensor import ReplicaSpec
|
||||
from colossalai.tensor.colo_parameter import ColoParameter, ColoTensor, ColoTensorSpec
|
||||
from colossalai.nn.parallel.data_parallel import _cast_float, free_storage
|
||||
from colossalai.tensor.colo_parameter import ColoParameter
|
||||
from colossalai.tensor.param_op_hook import ColoParamOpHookManager
|
||||
from colossalai.utils import get_current_device, is_ddp_ignored
|
||||
|
||||
|
@ -30,14 +31,13 @@ except ImportError:
|
|||
_EXTRA_STATE_KEY_SUFFIX = '_extra_state'
|
||||
|
||||
__all__ = [
|
||||
'ZeroDDP',
|
||||
'GeminiDDP',
|
||||
]
|
||||
|
||||
|
||||
class ZeroDDP(ColoDDP):
|
||||
"""ZeRO DDP for ColoTensor.
|
||||
Warning: Nested ZeroDDP is not supported now.
|
||||
class GeminiDDP(ModelWrapper):
|
||||
"""ZeRO DDP.
|
||||
Warning: Nested GeminiDDP is not supported now.
|
||||
It is designed to be used with ChunkManager and GeminiManager.
|
||||
For more details, see the API reference of ``ChunkManager`` and ``GeminiManager``.
|
||||
|
||||
|
@ -54,20 +54,54 @@ class ZeroDDP(ColoDDP):
|
|||
mixed_precision (torch.dtype): If set to torch.float16, the model will be trained in fp16. Otherwise, the model will be trained in bf16. Defaults to torch.float16.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
module: torch.nn.Module,
|
||||
gemini_manager: GeminiManager,
|
||||
pin_memory: bool = False,
|
||||
force_outputs_fp32: bool = False,
|
||||
strict_ddp_mode: bool = False,
|
||||
scatter_after_inference: bool = True,
|
||||
mixed_precision: torch.dtype = torch.float16) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
module: torch.nn.Module,
|
||||
chunk_config_dict: Optional[dict] = None,
|
||||
chunk_init_device: torch.device = torch.device('cpu'),
|
||||
placement_policy: str = "static",
|
||||
shard_param_frac: float = 1.0, # only for static placement
|
||||
offload_optim_frac: float = 0.0, # only for static placement
|
||||
offload_param_frac: float = 0.0, # only for static placement
|
||||
warmup_non_model_data_ratio: float = 0.8, # only for auto placement
|
||||
steady_cuda_cap_ratio: float = 0.9, # only for auto placement
|
||||
search_range_m: int = 32, # chunk search options
|
||||
hidden_dim: Optional[int] = None, # chunk search options
|
||||
min_chunk_size_m: float = 32, # chunk search options
|
||||
pin_memory: bool = False,
|
||||
force_outputs_fp32: bool = False,
|
||||
strict_ddp_mode: bool = False,
|
||||
scatter_after_inference: bool = True,
|
||||
mixed_precision: torch.dtype = torch.float16,
|
||||
process_group: Optional[ProcessGroup] = None,
|
||||
memstats: Optional[MemStats] = None, # genimi memory stats
|
||||
verbose: bool = False) -> None:
|
||||
assert mixed_precision in (torch.float16, torch.bfloat16)
|
||||
self.gemini_manager = gemini_manager
|
||||
self.chunk_manager: ChunkManager = gemini_manager.chunk_manager
|
||||
if chunk_config_dict is not None:
|
||||
self.chunk_manager = ChunkManager(chunk_config_dict, chunk_init_device)
|
||||
else:
|
||||
# some ugly hotfix for the compatibility with Lightning
|
||||
if search_range_m is None:
|
||||
search_range_m = 32
|
||||
self.chunk_manager = init_chunk_manager(model=module,
|
||||
init_device=chunk_init_device,
|
||||
hidden_dim=hidden_dim,
|
||||
search_range_m=search_range_m,
|
||||
min_chunk_size_m=min_chunk_size_m,
|
||||
strict_ddp_flag=strict_ddp_mode,
|
||||
process_group=process_group,
|
||||
verbose=verbose)
|
||||
self.gemini_manager = GeminiManager(placement_policy,
|
||||
self.chunk_manager,
|
||||
memstats,
|
||||
shard_param_frac=shard_param_frac,
|
||||
offload_optim_frac=offload_optim_frac,
|
||||
offload_param_frac=offload_param_frac,
|
||||
warmup_non_model_data_ratio=warmup_non_model_data_ratio,
|
||||
steady_cuda_cap_ratio=steady_cuda_cap_ratio)
|
||||
self.force_outputs_fp32 = force_outputs_fp32
|
||||
self.param_op_hook = GeminiZeROHook(gemini_manager)
|
||||
self.fp32_params: List[ColoTensor] = list()
|
||||
self.param_op_hook = GeminiZeROHook(self.gemini_manager)
|
||||
self.fp32_params: List[torch.Tensor] = list()
|
||||
self.fp16_params: List[ColoParameter] = list()
|
||||
self.overflow_counter = 0
|
||||
self.grads_device: Dict[torch.Tensor, torch.device] = dict()
|
||||
|
@ -75,6 +109,7 @@ class ZeroDDP(ColoDDP):
|
|||
self.name2param: Dict[str, nn.Parameter] = dict()
|
||||
self.scatter_after_inference = scatter_after_inference
|
||||
self.mixed_precision = mixed_precision
|
||||
self.dp_process_group = process_group or _get_default_group()
|
||||
|
||||
self._logger = get_dist_logger()
|
||||
|
||||
|
@ -88,20 +123,67 @@ class ZeroDDP(ColoDDP):
|
|||
for p in module.parameters():
|
||||
param_order.append(p)
|
||||
|
||||
self._init_chunks(param_order=param_order,
|
||||
strict_ddp_mode=strict_ddp_mode,
|
||||
cpu_offload=self.gemini_manager.policy_name != 'cuda',
|
||||
pin_memory=pin_memory)
|
||||
|
||||
for name, param in module.named_parameters():
|
||||
self.param2name[param] = name
|
||||
for m_name, m_var in module.named_modules():
|
||||
for p_name, p_var in m_var.named_parameters(recurse=False):
|
||||
param_name = m_name + '.' + p_name if m_name else p_name
|
||||
self.name2param[param_name] = p_var
|
||||
super().__init__(module, process_group=ColoProcessGroup())
|
||||
|
||||
self._init_chunks(param_order=param_order,
|
||||
strict_ddp_mode=strict_ddp_mode,
|
||||
cpu_offload=self.gemini_manager.policy_name != 'cuda',
|
||||
pin_memory=pin_memory)
|
||||
super().__init__(module)
|
||||
self._non_persistent_buffers_set = self._get_non_persistent_buffers_set(module)
|
||||
self._cast_buffers()
|
||||
# register grad hook
|
||||
for p in module.parameters():
|
||||
if is_ddp_ignored(p):
|
||||
continue
|
||||
if p.requires_grad:
|
||||
p.register_hook(partial(self.grad_handle, p))
|
||||
|
||||
def parameters(self, recurse: bool = True):
|
||||
return self.module.parameters(recurse)
|
||||
|
||||
def named_parameters(self, prefix: str = '', recurse: bool = True):
|
||||
return self.module.named_parameters(prefix, recurse)
|
||||
|
||||
def named_buffers(self, prefix: str = '', recurse: bool = True):
|
||||
return self.module.named_buffers(prefix, recurse)
|
||||
|
||||
def named_children(self):
|
||||
return self.module.named_children()
|
||||
|
||||
def named_modules(self,
|
||||
memo: Optional[Set[torch.nn.Module]] = None,
|
||||
prefix: str = '',
|
||||
remove_duplicate: bool = True):
|
||||
return self.module.named_modules(memo, prefix, remove_duplicate)
|
||||
|
||||
@staticmethod
|
||||
def set_params_to_ignore(params_to_ignore: Iterable[torch.Tensor]) -> None:
|
||||
"""Sets parameters to be ignored by DDP.
|
||||
This method must be called before initializing ColoDDP.
|
||||
|
||||
Example:
|
||||
>>> params_to_ignore = []
|
||||
>>> for p in module.parameters():
|
||||
>>> if should_ignore(p):
|
||||
>>> params_to_ignore.append(p)
|
||||
>>> ColoDDP.set_params_to_ignore(params_to_ignore)
|
||||
>>> module = ColoDDP(module)
|
||||
|
||||
Args:
|
||||
params_to_ignore (Iterable[torch.Tensor]): A list of parameters to be ignored.
|
||||
"""
|
||||
for p in params_to_ignore:
|
||||
p._ddp_to_ignore = True
|
||||
|
||||
def unwrap(self):
|
||||
# as save/load state dict is overwrited, only return self
|
||||
return self
|
||||
|
||||
def _get_non_persistent_buffers_set(self,
|
||||
module,
|
||||
|
@ -207,7 +289,7 @@ class ZeroDDP(ColoDDP):
|
|||
error_params.append(self.param2name[param])
|
||||
error_str = "\n\t".join(error_params)
|
||||
raise RuntimeError("ZERO DDP error: the synchronization of gradients doesn't exit properly.",
|
||||
"The most possible reason is that the model is not compatible with ZeroDDP.\n",
|
||||
"The most possible reason is that the model is not compatible with GeminiDDP.\n",
|
||||
f"{error_str}")
|
||||
self._setup_grads_ptr()
|
||||
self._logger.debug(
|
||||
|
@ -227,6 +309,7 @@ class ZeroDDP(ColoDDP):
|
|||
self._post_backward()
|
||||
|
||||
def grad_handle(self, p, grad):
|
||||
setattr(p, "_gemini_reduced", True)
|
||||
empty_grad = torch.empty_like(grad)
|
||||
free_storage(empty_grad)
|
||||
with torch._C.DisableTorchFunction():
|
||||
|
@ -533,7 +616,7 @@ class ZeroDDP(ColoDDP):
|
|||
for chunk_32 in chunk_list:
|
||||
chunk_16 = chunk_32.paired_chunk
|
||||
assert chunk_16 is not None
|
||||
chunk_16.optim_update()
|
||||
chunk_16.payload.copy_(chunk_32.payload)
|
||||
|
||||
for name, buf in persistent_buffers.items():
|
||||
if buf is not None:
|
||||
|
@ -557,17 +640,11 @@ class ZeroDDP(ColoDDP):
|
|||
unexpected_keys.append(key)
|
||||
|
||||
def _init_chunks(self, param_order, strict_ddp_mode: bool, cpu_offload: bool, pin_memory: bool):
|
||||
ddp_pg = ColoProcessGroup()
|
||||
dp_world_size = dist.get_world_size(self.dp_process_group)
|
||||
for p in param_order.generate():
|
||||
self._preprocess_param(p)
|
||||
assert type(p) is ColoParameter
|
||||
|
||||
# gather sharded parameters in the strict ddp mode
|
||||
if strict_ddp_mode:
|
||||
if not p.is_replicate():
|
||||
p.set_dist_spec(ReplicaSpec())
|
||||
p.set_process_group(pg=ddp_pg)
|
||||
|
||||
# ignore the parameters with no gradient
|
||||
if not p.requires_grad:
|
||||
self.set_params_to_ignore([p])
|
||||
|
@ -578,38 +655,37 @@ class ZeroDDP(ColoDDP):
|
|||
continue
|
||||
|
||||
# create a fp32 parameter
|
||||
fp32_data = p.data.float()
|
||||
fp32_p = ColoTensor(fp32_data, spec=ColoTensorSpec(p.process_group))
|
||||
fp32_p = p.data.float()
|
||||
# create a fp16 parameter
|
||||
p.data = p.data.to(self.mixed_precision)
|
||||
|
||||
# register the fp16 parameter and fp32 parameter in the chunk manager
|
||||
dp_world_size = p.process_group.dp_world_size()
|
||||
self.chunk_manager.register_tensor(tensor=p,
|
||||
group_type='fp16_param',
|
||||
config_key=dp_world_size,
|
||||
process_group=self.dp_process_group,
|
||||
cpu_offload=cpu_offload,
|
||||
pin_memory=pin_memory)
|
||||
self.chunk_manager.register_tensor(tensor=fp32_p,
|
||||
group_type='fp32_param',
|
||||
config_key=dp_world_size,
|
||||
process_group=self.dp_process_group,
|
||||
cpu_offload=cpu_offload,
|
||||
pin_memory=pin_memory)
|
||||
|
||||
self.fp16_params.append(p)
|
||||
self.fp32_params.append(fp32_p)
|
||||
self.grads_device[p] = self.gemini_manager.default_device
|
||||
|
||||
self.chunk_manager.close_all_groups()
|
||||
|
||||
self.gemini_manager.setup_grads_device(self.fp16_params, self.grads_device)
|
||||
# move master weights to corresponding device and setup paired chunks
|
||||
for p, fp32_p in zip(self.fp16_params, self.fp32_params):
|
||||
chunk_16 = self.chunk_manager.get_chunk(p)
|
||||
chunk_32 = self.chunk_manager.get_chunk(fp32_p)
|
||||
chunk_32.init_pair(chunk_16)
|
||||
|
||||
# keep gathered chunks are in CUDA
|
||||
if chunk_16.keep_gathered:
|
||||
self.grads_device[p] = get_current_device()
|
||||
if chunk_32.device_type != self.grads_device[p].type:
|
||||
self.chunk_manager.move_chunk(chunk_32, self.grads_device[p])
|
||||
|
||||
def _cast_buffers(self):
|
||||
for buffer in self.module.buffers():
|
||||
|
@ -727,67 +803,3 @@ class _StateDictSharder:
|
|||
self.current_block[name] = tensor
|
||||
self.current_block_size += tensor_size
|
||||
return ret_block, ret_block_size
|
||||
|
||||
|
||||
class GeminiDDP(ZeroDDP):
|
||||
|
||||
def __init__(self,
|
||||
module: torch.nn.Module,
|
||||
device: torch.device,
|
||||
placement_policy: str = "cpu",
|
||||
pin_memory: bool = False,
|
||||
force_outputs_fp32: bool = False,
|
||||
strict_ddp_mode: bool = False,
|
||||
scatter_after_inference: bool = True,
|
||||
search_range_m: int = 32,
|
||||
hidden_dim: Optional[int] = None,
|
||||
min_chunk_size_m: float = 32,
|
||||
memstats: Optional[MemStats] = None,
|
||||
mixed_precision: torch.dtype = torch.float16,
|
||||
verbose: bool = False) -> None:
|
||||
"""
|
||||
A torch.Module wrapper using ZeRO-DP and Gemini.
|
||||
ZeRO is for parallel. Gemini is for memory management.
|
||||
WARNING: The class will modify the module inline!
|
||||
|
||||
Example:
|
||||
model is initialized under the context of ColoInitContext
|
||||
>>> model = GeminiDDP(model, torch.cuda.current_device(), "cuda")
|
||||
>>> logits = model(x)
|
||||
>>> loss = criterion(logits, labels)
|
||||
>>> model.backward(loss)
|
||||
|
||||
Args:
|
||||
module (torch.nn.Module): the model to be wrapped.
|
||||
device (torch.device): device to place the model.
|
||||
placement_policy (str, optional): "cpu", "cuda", "auto". Defaults to "cpu".
|
||||
pin_memory (bool, optional): use pin memory on CPU. Defaults to False.
|
||||
force_outputs_fp32 (bool, optional): force outputs are fp32. Defaults to False.
|
||||
search_range_m (int, optional): chunk size searching range divided by 2^20. Defaults to 32.
|
||||
hidden_dim (int, optional): the hidden dimension of DNN.
|
||||
Users can provide this argument to speed up searching.
|
||||
If users do not know this argument before training, it is ok. We will use a default value 1024.
|
||||
min_chunk_size_m (float, optional): the minimum chunk size divided by 2^20.
|
||||
If the aggregate size of parameters is still smaller than the minimum chunk size,
|
||||
all parameters will be compacted into one small chunk.
|
||||
memstats (MemStats, optional) the memory statistics collector by a runtime memory tracer.
|
||||
"""
|
||||
# some ugly hotfix for the compatibility with Lightning
|
||||
if search_range_m is None:
|
||||
search_range_m = 32
|
||||
|
||||
chunk_manager = init_chunk_manager(model=module,
|
||||
init_device=device,
|
||||
hidden_dim=hidden_dim,
|
||||
search_range_m=search_range_m,
|
||||
min_chunk_size_m=min_chunk_size_m,
|
||||
strict_ddp_flag=strict_ddp_mode,
|
||||
verbose=verbose)
|
||||
gemini_manager = GeminiManager(placement_policy, chunk_manager, memstats)
|
||||
super().__init__(module,
|
||||
gemini_manager,
|
||||
pin_memory,
|
||||
force_outputs_fp32,
|
||||
strict_ddp_mode,
|
||||
scatter_after_inference,
|
||||
mixed_precision=mixed_precision)
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
import functools
|
||||
from time import time
|
||||
from typing import List, Optional, Tuple
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
|
@ -26,7 +26,11 @@ class GeminiManager:
|
|||
memstats (MemStats, optional): a mem stats collected by a runtime mem tracer. if None then GeminiManager will collect it during a warmup iteration.
|
||||
"""
|
||||
|
||||
def __init__(self, placement_policy: str, chunk_manager: ChunkManager, memstats: Optional[MemStats] = None) -> None:
|
||||
def __init__(self,
|
||||
placement_policy: str,
|
||||
chunk_manager: ChunkManager,
|
||||
memstats: Optional[MemStats] = None,
|
||||
**placement_kwargs) -> None:
|
||||
|
||||
assert placement_policy in PlacementPolicyFactory.get_policy_names()
|
||||
self.policy_name = placement_policy
|
||||
|
@ -37,7 +41,7 @@ class GeminiManager:
|
|||
self._memstats = memstats
|
||||
self._mem_stats_collector = ChunkMemStatsCollector(chunk_manager,
|
||||
self._memstats) if policy_cls.need_mem_stats else None
|
||||
self._placement_policy = policy_cls(chunk_manager, self._mem_stats_collector)
|
||||
self._placement_policy = policy_cls(chunk_manager, self._mem_stats_collector, **placement_kwargs)
|
||||
self._compute_list: List[Tuple[Chunk, ...]] = []
|
||||
self._compute_idx: int = -1
|
||||
|
||||
|
@ -133,10 +137,6 @@ class GeminiManager:
|
|||
if self._warmup and self._placement_policy.need_mem_stats:
|
||||
self._compute_list.append(chunks)
|
||||
|
||||
@property
|
||||
def default_device(self):
|
||||
return self._placement_policy.get_default_device()
|
||||
|
||||
def sample_overall_data(self):
|
||||
if self._mem_stats_collector:
|
||||
self._mem_stats_collector.sample_overall_data()
|
||||
|
@ -159,6 +159,6 @@ class GeminiManager:
|
|||
def is_cuda_margin_mem_avail(self) -> bool:
|
||||
return self._placement_policy.need_mem_stats
|
||||
|
||||
@staticmethod
|
||||
def get_default_device(policy_name: str) -> torch.device:
|
||||
return PlacementPolicyFactory.get_default_device(policy_name)
|
||||
def setup_grads_device(self, params: List[torch.Tensor], grads_device_map: Dict[torch.Tensor,
|
||||
torch.device]) -> None:
|
||||
self._placement_policy.setup_grads_device(params, grads_device_map)
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
import copy
|
||||
import math
|
||||
import warnings
|
||||
from typing import Any, Dict, Iterator, OrderedDict, Set, Tuple
|
||||
from typing import Any, Dict, Iterator, OrderedDict, Set, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
@ -11,15 +11,16 @@ from torch.optim import Optimizer
|
|||
|
||||
from colossalai.amp.naive_amp.mixed_precision_mixin import BF16MixedPrecisionMixin, FP16MixedPrecisionMixin
|
||||
from colossalai.checkpoint_io.utils import calculate_tensor_size
|
||||
from colossalai.interface import OptimizerWrapper
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.nn.optimizer import ColossalaiOptimizer, CPUAdam, FusedAdam, HybridAdam
|
||||
from colossalai.nn.optimizer import CPUAdam, FusedAdam, HybridAdam
|
||||
from colossalai.tensor.d_tensor import is_distributed_tensor
|
||||
from colossalai.utils import disposable, get_current_device, is_ddp_ignored
|
||||
|
||||
from .chunk import Chunk, ChunkManager
|
||||
from .gemini_ddp import ZeroDDP
|
||||
from .gemini_ddp import GeminiDDP
|
||||
|
||||
__all__ = ['ZeroOptimizer', 'GeminiAdamOptimizer']
|
||||
__all__ = ['GeminiOptimizer', 'GeminiAdamOptimizer']
|
||||
|
||||
_AVAIL_OPTIM_LIST = {FusedAdam, CPUAdam, HybridAdam}
|
||||
|
||||
|
@ -27,7 +28,7 @@ _AVAIL_OPTIM_LIST = {FusedAdam, CPUAdam, HybridAdam}
|
|||
class GeminiFP16MixedPrecisionMixin(FP16MixedPrecisionMixin):
|
||||
|
||||
def __init__(self,
|
||||
module: ZeroDDP,
|
||||
module: GeminiDDP,
|
||||
initial_scale: float = 2**16,
|
||||
min_scale: float = 1,
|
||||
growth_factor: float = 2,
|
||||
|
@ -46,11 +47,11 @@ class GeminiFP16MixedPrecisionMixin(FP16MixedPrecisionMixin):
|
|||
self.module.overflow_counter = 0
|
||||
|
||||
|
||||
class ZeroOptimizer(ColossalaiOptimizer):
|
||||
"""A wrapper for optimizer. ``ZeroDDP`` and ``ZeroOptimizer`` implement Zero Redundancy Optimizer (ZeRO state-3).
|
||||
class GeminiOptimizer(OptimizerWrapper):
|
||||
"""A wrapper for optimizer. ``GeminiDDP`` and ``GeminiOptimizer`` implement Zero Redundancy Optimizer (ZeRO state-3).
|
||||
|
||||
Note:
|
||||
You must use ``ZeroDDP`` with ``ZeroOptimizer``.
|
||||
You must use ``GeminiDDP`` with ``GeminiOptimizer``.
|
||||
|
||||
Note:
|
||||
Make sure you set ``placement_policy`` of ``GeminiManager`` to `"auto"`,
|
||||
|
@ -58,7 +59,7 @@ class ZeroOptimizer(ColossalaiOptimizer):
|
|||
|
||||
Args:
|
||||
optim (Optimizer): An Optimizer instance.
|
||||
module (ZeroDDP): A ``ZeroDDP`` instance.
|
||||
module (GeminiDDP): A ``GeminiDDP`` instance.
|
||||
gpu_margin_mem_ratio (float, optional): The ratio of GPU remaining memory (after the first forward-backward)
|
||||
which will be used when using hybrid CPU optimizer.
|
||||
This argument is meaningless when `placement_policy` of `GeminiManager` is not "auto".
|
||||
|
@ -70,15 +71,15 @@ class ZeroOptimizer(ColossalaiOptimizer):
|
|||
growth_interval (float, optional): Growth_interval used by DynamicGradScaler. Defaults to 1000.
|
||||
hysteresis (float, optional): Hysteresis used by DynamicGradScaler. Defaults to 2.
|
||||
max_scale (int, optional): Max_scale used by DynamicGradScaler. Defaults to 2**32.
|
||||
clipping_norm (float, optional): The norm value used to clip gradient. Defaults to 0.0.
|
||||
max_norm (float, optional): The norm value used to clip gradient. Defaults to 0.0.
|
||||
norm_type (float, optional): The type of norm used for gradient clipping. Currently, only L2-norm (norm_type=2.0)
|
||||
is supported in ZeroOptimizer. Defaults to 2.0.
|
||||
is supported in GeminiOptimizer. Defaults to 2.0.
|
||||
verbose (bool, optional): Whether to print verbose information, including grad overflow info. Defaults to False.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
optim: Optimizer,
|
||||
module: ZeroDDP,
|
||||
module: GeminiDDP,
|
||||
gpu_margin_mem_ratio: float = 0.0,
|
||||
initial_scale: float = 2**32,
|
||||
min_scale: float = 1,
|
||||
|
@ -87,12 +88,12 @@ class ZeroOptimizer(ColossalaiOptimizer):
|
|||
growth_interval: int = 1000,
|
||||
hysteresis: int = 2,
|
||||
max_scale: float = 2**32,
|
||||
clipping_norm: float = 0.0,
|
||||
max_norm: float = 0.0,
|
||||
norm_type: float = 2.0,
|
||||
verbose: bool = False,
|
||||
**defaults: Any):
|
||||
super().__init__(optim)
|
||||
assert isinstance(module, ZeroDDP)
|
||||
assert isinstance(module, GeminiDDP)
|
||||
assert type(optim) in _AVAIL_OPTIM_LIST, "You should use an optimizer in the available list:\n" \
|
||||
f"{_AVAIL_OPTIM_LIST}"
|
||||
self.module = module
|
||||
|
@ -101,8 +102,8 @@ class ZeroOptimizer(ColossalaiOptimizer):
|
|||
self.param_to_range: Dict[Parameter, Tuple[int, int]] = dict()
|
||||
self.param_to_chunk32: Dict[Parameter, Chunk] = dict()
|
||||
self.chunk16_set: Set[Chunk] = set()
|
||||
self.clipping_flag = clipping_norm > 0.0
|
||||
self.max_norm = clipping_norm
|
||||
self.clipping_flag = max_norm > 0.0
|
||||
self.max_norm = max_norm
|
||||
self.verbose = verbose
|
||||
self.param_groups_backup = list()
|
||||
|
||||
|
@ -111,7 +112,7 @@ class ZeroOptimizer(ColossalaiOptimizer):
|
|||
self.id_to_fake_params: Dict[int, Parameter] = dict()
|
||||
|
||||
if self.clipping_flag:
|
||||
assert norm_type == 2.0, "ZeroOptimizer only supports L2 norm now"
|
||||
assert norm_type == 2.0, "GeminiOptimizer only supports L2 norm now"
|
||||
|
||||
ddp_param_list = []
|
||||
for name, param in module.named_parameters():
|
||||
|
@ -735,8 +736,19 @@ class ZeroOptimizer(ColossalaiOptimizer):
|
|||
|
||||
yield current_block, current_block_size
|
||||
|
||||
def clip_grad_by_value(self, clip_value: float, *args, **kwargs) -> None:
|
||||
raise NotImplementedError('Gemini does not support clip_grad_by_value')
|
||||
|
||||
class GeminiAdamOptimizer(ZeroOptimizer):
|
||||
def clip_grad_by_norm(self,
|
||||
max_norm: Union[float, int],
|
||||
norm_type: Union[float, int] = 2,
|
||||
error_if_nonfinite: bool = False,
|
||||
*args,
|
||||
**kwargs) -> torch.Tensor:
|
||||
warnings.warn(f'Gemini controls grad clipping by itself, so you should not use clip_grad_by_norm')
|
||||
|
||||
|
||||
class GeminiAdamOptimizer(GeminiOptimizer):
|
||||
|
||||
def __init__(self, model: torch.nn.Module, **defaults: Any) -> None:
|
||||
optimizer = HybridAdam(model.parameters(), **defaults)
|
||||
|
|
|
@ -9,7 +9,7 @@ class MemStats(object):
|
|||
|
||||
def __init__(self) -> None:
|
||||
"""
|
||||
Store the non model data statistics used for Gemini and ZeroOptimizer.
|
||||
Store the non model data statistics used for Gemini and GeminiOptimizer.
|
||||
"""
|
||||
# (preop_step, List[param])
|
||||
self._step_param_dict = dict()
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
import functools
|
||||
import warnings
|
||||
from abc import ABC, abstractmethod
|
||||
from time import time
|
||||
from typing import Dict, List, Optional, Tuple, Type
|
||||
|
@ -7,6 +8,7 @@ import torch
|
|||
|
||||
from colossalai.utils import get_current_device
|
||||
from colossalai.utils.memory import colo_device_memory_capacity
|
||||
from colossalai.zero.gemini.chunk import Chunk
|
||||
|
||||
from .chunk import Chunk, ChunkManager
|
||||
from .memory_tracer import ChunkMemStatsCollector
|
||||
|
@ -17,7 +19,8 @@ class PlacementPolicy(ABC):
|
|||
|
||||
def __init__(self,
|
||||
chunk_manager: ChunkManager,
|
||||
mem_stats_collector: Optional[ChunkMemStatsCollector] = None) -> None:
|
||||
mem_stats_collector: Optional[ChunkMemStatsCollector] = None,
|
||||
**kwargs) -> None:
|
||||
self.chunk_manager = chunk_manager
|
||||
self.mem_stats_collector: Optional[ChunkMemStatsCollector] = mem_stats_collector
|
||||
|
||||
|
@ -25,57 +28,87 @@ class PlacementPolicy(ABC):
|
|||
def evict_tensors(self, can_evict_chunks: List[Chunk], **kwargs) -> Tuple[int, float]:
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
def get_default_device() -> torch.device:
|
||||
return torch.device('cpu')
|
||||
@abstractmethod
|
||||
def setup_grads_device(self, params: List[torch.Tensor], grads_device_map: Dict[torch.Tensor,
|
||||
torch.device]) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class CPUPlacementPolicy(PlacementPolicy):
|
||||
class StaticPlacementPolicy(PlacementPolicy):
|
||||
|
||||
def __init__(self,
|
||||
chunk_manager: ChunkManager,
|
||||
mem_stats_collector: Optional[ChunkMemStatsCollector] = None) -> None:
|
||||
mem_stats_collector: Optional[ChunkMemStatsCollector] = None,
|
||||
shard_param_frac: float = 1.0,
|
||||
offload_optim_frac: float = 0.0,
|
||||
offload_param_frac: float = 0.0,
|
||||
**kwargs) -> None:
|
||||
super().__init__(chunk_manager, mem_stats_collector=mem_stats_collector)
|
||||
if offload_param_frac > 0.0 and (shard_param_frac != 1.0 or offload_optim_frac != 1.0):
|
||||
warnings.warn('offload_param_frac is ignored when shard_param_frac != 1.0 or offload_optim_frac != 1.0')
|
||||
offload_param_frac = 0.0
|
||||
self.shard_param_frac = shard_param_frac
|
||||
self.offload_optim_frac = offload_optim_frac
|
||||
self.offload_param_frac = offload_param_frac
|
||||
# these should be initialized in setup_grads_device
|
||||
self.keep_gathered_chunk_mem = 0.0
|
||||
self.keep_cuda_chunk_mem = 0.0
|
||||
|
||||
def evict_tensors(self, can_evict_chunks: List[Chunk], **kwargs) -> Tuple[int, float]:
|
||||
volume = 0
|
||||
start = time()
|
||||
can_shard_chunk_mem = sum(chunk.chunk_mem for chunk in can_evict_chunks)
|
||||
can_offload_chunk_mem = can_shard_chunk_mem
|
||||
for chunk in can_evict_chunks:
|
||||
if can_shard_chunk_mem <= self.keep_gathered_chunk_mem:
|
||||
break
|
||||
self.chunk_manager.release_chunk(chunk)
|
||||
# real saved mem is chunk_mem - shard_mem, for simplicity we use chunk_mem
|
||||
can_shard_chunk_mem -= chunk.chunk_mem
|
||||
for chunk in can_evict_chunks:
|
||||
if can_offload_chunk_mem <= self.keep_cuda_chunk_mem:
|
||||
break
|
||||
self.chunk_manager.move_chunk(chunk, torch.device('cpu'))
|
||||
volume += chunk.chunk_mem
|
||||
return volume, time() - start
|
||||
# real saved mem is shard_mem, for simplicity we use chunk_mem
|
||||
can_offload_chunk_mem -= chunk.chunk_mem
|
||||
return 0, 0.0
|
||||
|
||||
def setup_grads_device(self, params: List[torch.Tensor], grads_device_map: Dict[torch.Tensor,
|
||||
torch.device]) -> None:
|
||||
total_chunk_mem = sum(self.chunk_manager.get_chunk(p).chunk_mem for p in params)
|
||||
|
||||
class CUDAPlacementPolicy(PlacementPolicy):
|
||||
|
||||
def __init__(self,
|
||||
chunk_manager: ChunkManager,
|
||||
mem_stats_collector: Optional[ChunkMemStatsCollector] = None) -> None:
|
||||
assert torch.cuda.is_available(), 'Cannot use CUDATensorPlacementPolicy when CUDA is not available'
|
||||
super().__init__(chunk_manager, mem_stats_collector=mem_stats_collector)
|
||||
|
||||
def evict_tensors(self, can_evict_chunks: List[Chunk], **kwargs) -> Tuple[int, float]:
|
||||
return 0, 0
|
||||
|
||||
@staticmethod
|
||||
def get_default_device() -> torch.device:
|
||||
return get_current_device()
|
||||
offload_optim_chunk_mem = total_chunk_mem * self.offload_optim_frac
|
||||
offloaded_optim_chunk_mem = 0
|
||||
chunks = set(self.chunk_manager.get_chunk(p) for p in params)
|
||||
for chunk in chunks:
|
||||
params = chunk.get_tensors()
|
||||
# init offload optim settings
|
||||
# keep gathered chunks are in CUDA
|
||||
if chunk.keep_gathered or offloaded_optim_chunk_mem >= offload_optim_chunk_mem:
|
||||
device = get_current_device()
|
||||
else:
|
||||
device = torch.device('cpu')
|
||||
# real offloaded mem is chunk.shard_mem, for simplicity we use chunk mem here
|
||||
offloaded_optim_chunk_mem += chunk.chunk_mem
|
||||
for p in params:
|
||||
grads_device_map[p] = device
|
||||
self.keep_gathered_chunk_mem = total_chunk_mem * (1 - self.shard_param_frac)
|
||||
self.keep_cuda_chunk_mem = total_chunk_mem * (1 - self.offload_param_frac)
|
||||
|
||||
|
||||
class AutoPlacementPolicy(PlacementPolicy):
|
||||
|
||||
need_mem_stats: bool = True
|
||||
# model data will use 1-_warmup_non_model_data_ratio CUDA memory in warmup phase
|
||||
# you can set them by AutoPlacementPolicy.set_warmup_non_model_data_ratio()
|
||||
# and AutoPlacementPolicy.set_steady_cuda_cap_ratio()
|
||||
_warmup_non_model_data_ratio: float = 0.8
|
||||
_steady_cuda_cap_ratio: float = 0.9
|
||||
|
||||
def __init__(self,
|
||||
chunk_manager: ChunkManager,
|
||||
mem_stats_collector: Optional[ChunkMemStatsCollector] = None) -> None:
|
||||
mem_stats_collector: Optional[ChunkMemStatsCollector] = None,
|
||||
warmup_non_model_data_ratio: float = 0.8,
|
||||
steady_cuda_cap_ratio: float = 0.9,
|
||||
**kwargs) -> None:
|
||||
super().__init__(chunk_manager, mem_stats_collector=mem_stats_collector)
|
||||
# model data will use 1-_warmup_non_model_data_ratio CUDA memory in warmup phase
|
||||
# you can set them by AutoPlacementPolicy.set_warmup_non_model_data_ratio()
|
||||
# and AutoPlacementPolicy.set_steady_cuda_cap_ratio()
|
||||
self._warmup_non_model_data_ratio = warmup_non_model_data_ratio
|
||||
self._steady_cuda_cap_ratio = steady_cuda_cap_ratio
|
||||
|
||||
def evict_tensors(self,
|
||||
can_evict_chunks: List[Chunk],
|
||||
|
@ -105,11 +138,11 @@ class AutoPlacementPolicy(PlacementPolicy):
|
|||
used_cuda_model_data = self.chunk_manager.total_mem['cuda']
|
||||
if warmup:
|
||||
# We designate a part of CUDA memory for model data in warmup iterations.
|
||||
max_cuda_non_model_data_per_period = cuda_capacity * AutoPlacementPolicy._warmup_non_model_data_ratio
|
||||
max_cuda_non_model_data_per_period = cuda_capacity * self._warmup_non_model_data_ratio
|
||||
else:
|
||||
# max non-model-data cuda memory consumption of this sampling moment and the next sampling moment.
|
||||
max_cuda_non_model_data_per_period = self.mem_stats_collector.next_period_non_model_data_usage('cuda')
|
||||
cuda_capacity *= AutoPlacementPolicy._steady_cuda_cap_ratio
|
||||
cuda_capacity *= self._steady_cuda_cap_ratio
|
||||
total_cuda_model_data = cuda_capacity - max_cuda_non_model_data_per_period
|
||||
avail_cuda_model_data = total_cuda_model_data - used_cuda_model_data
|
||||
freed_cuda_model_data = 0
|
||||
|
@ -145,89 +178,22 @@ class AutoPlacementPolicy(PlacementPolicy):
|
|||
next_compute_idx = sorted(next_compute_idx.items(), key=lambda pair: pair[1], reverse=True)
|
||||
return [t for (t, idx) in next_compute_idx]
|
||||
|
||||
@staticmethod
|
||||
def set_warmup_non_model_data_ratio(ratio: float) -> None:
|
||||
ratio = float(ratio)
|
||||
assert 0.0 < ratio < 1.0
|
||||
AutoPlacementPolicy._warmup_non_model_data_ratio = ratio
|
||||
|
||||
@staticmethod
|
||||
def set_steady_cuda_cap_ratio(ratio: float) -> None:
|
||||
ratio = float(ratio)
|
||||
assert 0.0 < ratio < 1.0
|
||||
AutoPlacementPolicy._steady_cuda_cap_ratio = ratio
|
||||
|
||||
|
||||
class ConstPlacementPolicy(PlacementPolicy):
|
||||
|
||||
need_mem_stats: bool = False
|
||||
_accessed_memory_boundary = 512 * 1024**2
|
||||
|
||||
def __init__(self,
|
||||
chunk_manager: ChunkManager,
|
||||
mem_stats_collector: Optional[ChunkMemStatsCollector] = None) -> None:
|
||||
super().__init__(chunk_manager, mem_stats_collector=mem_stats_collector)
|
||||
|
||||
def evict_tensors(self,
|
||||
can_evict_chunks: List[Chunk],
|
||||
cuda_demand: int = 0,
|
||||
warmup: bool = True,
|
||||
compute_list: Optional[List[Tuple[Chunk, ...]]] = None,
|
||||
compute_idx: int = 0,
|
||||
**kwargs) -> Tuple[int, float]:
|
||||
"""
|
||||
See the docstrings in the class `AutoPlacementPolicy`.
|
||||
"""
|
||||
start = time()
|
||||
used_accessed_memory = self.chunk_manager.accessed_mem
|
||||
avail_accessed_memory = ConstPlacementPolicy._accessed_memory_boundary - used_accessed_memory
|
||||
freed_accessed_memory = 0
|
||||
|
||||
if avail_accessed_memory < cuda_demand:
|
||||
to_free_memory = cuda_demand - avail_accessed_memory
|
||||
to_free_chunks = can_evict_chunks
|
||||
|
||||
if not warmup:
|
||||
# sort all chunks
|
||||
to_free_chunks = self._sort_can_evict_chunks(tuple(to_free_chunks), compute_idx, tuple(compute_list))
|
||||
|
||||
for chunk in to_free_chunks:
|
||||
if freed_accessed_memory >= to_free_memory:
|
||||
break
|
||||
|
||||
self.chunk_manager.release_chunk(chunk)
|
||||
self.chunk_manager.move_chunk(chunk, torch.device('cpu'))
|
||||
freed_accessed_memory += chunk.chunk_mem
|
||||
|
||||
if freed_accessed_memory < to_free_memory:
|
||||
raise RuntimeError(f"Adjust layout failed! No enough CUDA memory! "
|
||||
f"Need {to_free_memory}, freed {freed_accessed_memory}")
|
||||
return freed_accessed_memory, time() - start
|
||||
|
||||
@staticmethod
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def _sort_can_evict_chunks(can_evict_chunks: tuple, compute_idx: int, compute_list: tuple) -> list:
|
||||
next_compute_idx = {chunk: len(compute_list) for chunk in can_evict_chunks}
|
||||
for i in range(len(compute_list) - 1, compute_idx, -1):
|
||||
for chunk in compute_list[i]:
|
||||
if chunk in next_compute_idx:
|
||||
next_compute_idx[chunk] = i
|
||||
next_compute_idx = sorted(next_compute_idx.items(), key=lambda pair: pair[1], reverse=True)
|
||||
return [t for (t, idx) in next_compute_idx]
|
||||
|
||||
@staticmethod
|
||||
def set_const_memory_boundary(cuda_memory_mb: int) -> None:
|
||||
boundary = int(cuda_memory_mb * 1024**2)
|
||||
assert boundary > 0
|
||||
ConstPlacementPolicy._accessed_memory_boundary = boundary
|
||||
def setup_grads_device(self, params: List[torch.Tensor], grads_device_map: Dict[torch.Tensor,
|
||||
torch.device]) -> None:
|
||||
for p in params:
|
||||
chunk = self.chunk_manager.get_chunk(p)
|
||||
# init offload optim settings
|
||||
# keep gathered chunks are in CUDA
|
||||
if chunk.keep_gathered:
|
||||
grads_device_map[p] = get_current_device()
|
||||
else:
|
||||
grads_device_map[p] = torch.device('cpu')
|
||||
|
||||
|
||||
class PlacementPolicyFactory:
|
||||
policies: Dict[str, Type[PlacementPolicy]] = {
|
||||
'cpu': CPUPlacementPolicy,
|
||||
'cuda': CUDAPlacementPolicy,
|
||||
'auto': AutoPlacementPolicy,
|
||||
'const': ConstPlacementPolicy
|
||||
'static': StaticPlacementPolicy,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
|
@ -239,8 +205,3 @@ class PlacementPolicyFactory:
|
|||
@staticmethod
|
||||
def get_policy_names():
|
||||
return tuple(PlacementPolicyFactory.policies.keys())
|
||||
|
||||
@staticmethod
|
||||
def get_default_device(policy_name: str) -> torch.device:
|
||||
policy_cls = PlacementPolicyFactory.create(policy_name)
|
||||
return policy_cls.get_default_device()
|
||||
|
|
|
@ -64,13 +64,13 @@ def get_static_torch_model(zero_ddp_model,
|
|||
device=torch.device("cpu"),
|
||||
dtype=torch.float32,
|
||||
only_rank_0=True) -> torch.nn.Module:
|
||||
"""Get a static torch.nn.Module model from the given ZeroDDP module.
|
||||
You should notice that the original ZeroDDP model is not modified.
|
||||
"""Get a static torch.nn.Module model from the given GeminiDDP module.
|
||||
You should notice that the original GeminiDDP model is not modified.
|
||||
Thus, you can use the original model in further training.
|
||||
But you should not use the returned torch model to train, this can cause unexpected errors.
|
||||
|
||||
Args:
|
||||
zero_ddp_model (ZeroDDP): a zero ddp model
|
||||
zero_ddp_model (GeminiDDP): a zero ddp model
|
||||
device (torch.device): the device of the final torch model
|
||||
dtype (torch.dtype): the dtype of the final torch model
|
||||
only_rank_0 (bool): if True, only rank0 has the converted torch model
|
||||
|
@ -78,8 +78,8 @@ def get_static_torch_model(zero_ddp_model,
|
|||
Returns:
|
||||
torch.nn.Module: a static torch model used for saving checkpoints or numeric checks
|
||||
"""
|
||||
from colossalai.zero.gemini.gemini_ddp import ZeroDDP
|
||||
assert isinstance(zero_ddp_model, ZeroDDP)
|
||||
from colossalai.zero.gemini.gemini_ddp import GeminiDDP
|
||||
assert isinstance(zero_ddp_model, GeminiDDP)
|
||||
|
||||
state_dict = zero_ddp_model.state_dict(only_rank_0=only_rank_0)
|
||||
colo_model = zero_ddp_model.module
|
||||
|
|
|
@ -109,6 +109,6 @@ def zero_optim_wrapper(model: nn.Module,
|
|||
config_dict['clip_grad_norm'] = max_norm
|
||||
return LowLevelZeroOptimizer(optimizer, **config_dict, verbose=verbose)
|
||||
else:
|
||||
from colossalai.zero.gemini.gemini_optimizer import ZeroOptimizer
|
||||
from colossalai.zero.gemini.gemini_optimizer import GeminiOptimizer
|
||||
config_dict['clipping_norm'] = max_norm
|
||||
return ZeroOptimizer(optimizer, model, **config_dict, verbose=verbose)
|
||||
return GeminiOptimizer(optimizer, model, **config_dict, verbose=verbose)
|
||||
|
|
|
@ -54,32 +54,38 @@ We also provide a lightweight chunk search mechanism to help users automatically
|
|||
|
||||
We will use `GeminiDDP` to use ZeRO with chunk-based memory management. This is our new torch.Module wrapper which uses ZeRO-DP and Gemini. ZeRO is for parallelism and Gemini is for memory management.
|
||||
|
||||
Also Make sure that your model is initialized under the context of ColoInitContext.
|
||||
Gemini allows LazyInitContext, which can save memory when initializing large models with multi-GPUs.
|
||||
|
||||
If your model has `N` billion parameters and your GPU memory is `M` GB, we recommend you use LazyInitContext when `4N >= M`. Otherwise, LazyInitContext is optional.
|
||||
|
||||
<!--- doc-test-ignore-start -->
|
||||
```python
|
||||
with ColoInitContext(device='cpu', default_dist_spec=default_dist_spec, default_pg=default_pg):
|
||||
with LazyInitContext(default_device=torch.device('cuda')):
|
||||
model = gpt2_medium(checkpoint=True)
|
||||
```
|
||||
<!--- doc-test-ignore-end -->
|
||||
|
||||
Define the model parameters as follows:
|
||||
We've provided `Booster` API which is user-friendly. We recommend you use `Booster` API. But if you still want to use low level API, you can read below content of this section.
|
||||
|
||||
Wrap the model with `GeminiDDP`.
|
||||
|
||||
<!--- doc-test-ignore-start -->
|
||||
```python
|
||||
chunk_manager = init_chunk_manager(model=module,
|
||||
init_device=device,
|
||||
hidden_dim=hidden_dim,
|
||||
search_range_m=search_range_m,
|
||||
min_chunk_size_m=min_chunk_size_m)
|
||||
gemini_manager = GeminiManager(placement_policy, chunk_manager)
|
||||
model = GeminiDDP(model, hidden_dim=hidden_dim, min_chunk_size_m=min_chunk_size_m)
|
||||
```
|
||||
<!--- doc-test-ignore-end -->
|
||||
|
||||
`hidden_dim` is the hidden dimension of DNN. Users can provide this argument to speed up searching. If users do not know this argument before training, it is ok. We will use a default value 1024. `min_chunk_size_m` is a floating point, being the minimum chunk size divided by 2^20 (e.g., if min_chunk_size_m=2.5, then the minimum chunk size should be 2.5*(2^20)).If the aggregate size of parameters is still smaller than the minimum chunk size, all parameters will be compacted into one small chunk.
|
||||
|
||||
Initialization of the optimizer.
|
||||
<!--- doc-test-ignore-start -->
|
||||
```python
|
||||
optimizer = GeminiAdamOptimizer(model, lr=1e-3, initial_scale=2**5)
|
||||
```
|
||||
<!--- doc-test-ignore-start -->
|
||||
|
||||
Training
|
||||
<!--- doc-test-ignore-start -->
|
||||
```python
|
||||
optimizer.zero_grad()
|
||||
outputs = model(input_ids, attn_mask)
|
||||
|
@ -87,6 +93,7 @@ loss = criterion(outputs, input_ids)
|
|||
optimizer.backward(loss)
|
||||
optimizer.step()
|
||||
```
|
||||
<!--- doc-test-ignore-start -->
|
||||
> ⚠️ Note: Please do not use `loss.backward()`, the standard way of writing is `optimizer.backward(loss)`.
|
||||
|
||||
### Train GPT
|
||||
|
@ -142,46 +149,6 @@ class GPTLMLoss(nn.Module):
|
|||
return self.loss_fn(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
|
||||
```
|
||||
|
||||
Define tensor parallel and parameter sharding strategies for tensor parallelism:
|
||||
|
||||
```python
|
||||
def tensor_parallelize(model: torch.nn.Module, pg: ProcessGroup):
|
||||
for mn, module in model.named_modules():
|
||||
for pn, param in module.named_parameters(recurse=False):
|
||||
if hasattr(param, 'visited'):
|
||||
continue
|
||||
param.set_dist_spec(ReplicaSpec())
|
||||
if 'mlp.c_fc' in mn:
|
||||
if 'weight' in pn or 'bias' in pn:
|
||||
split_param_col_tp1d(param, pg)
|
||||
param.compute_spec.set_output_replicate(False)
|
||||
else:
|
||||
param.set_dist_spec(ReplicaSpec())
|
||||
elif 'mlp.c_proj' in mn:
|
||||
if 'weight' in pn:
|
||||
split_param_row_tp1d(param, pg)
|
||||
else:
|
||||
param.set_dist_spec(ReplicaSpec())
|
||||
elif 'wte' in mn or 'wpe' in mn:
|
||||
split_param_col_tp1d(param, pg)
|
||||
elif 'c_attn' in mn or 'c_proj' in mn:
|
||||
split_param_col_tp1d(param, pg)
|
||||
else:
|
||||
param.set_dist_spec(ReplicaSpec())
|
||||
|
||||
param.visited = True
|
||||
def split_param_single_dim_tp1d(dim: int, param: ColoParameter, pg: ProcessGroup):
|
||||
spec = (ShardSpec([dim], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
|
||||
param.set_tensor_spec(*spec)
|
||||
|
||||
|
||||
def split_param_row_tp1d(param: ColoParameter, pg: ProcessGroup):
|
||||
split_param_single_dim_tp1d(0, param, pg)
|
||||
|
||||
|
||||
def split_param_col_tp1d(param: ColoParameter, pg: ProcessGroup):
|
||||
split_param_single_dim_tp1d(-1, param, pg)
|
||||
```
|
||||
|
||||
Write a function to get random inputs:
|
||||
|
||||
|
@ -198,7 +165,7 @@ Finally, we define a model which uses Gemini + ZeRO DDP and define our training
|
|||
from colossalai.nn.optimizer import HybridAdam
|
||||
|
||||
from colossalai.booster import Booster
|
||||
from colossalai.zero import ColoInitContext
|
||||
from colossalai.lazy import LazyInitContext
|
||||
from colossalai.booster.plugin import GeminiPlugin
|
||||
|
||||
def main():
|
||||
|
@ -214,17 +181,13 @@ def main():
|
|||
optimizer = HybridAdam(model.parameters(), lr=0.001)
|
||||
|
||||
torch.manual_seed(123)
|
||||
default_pg = ProcessGroup(tp_degree=args.tp_degree)
|
||||
default_dist_spec = ShardSpec([-1], [args.tp_degree])
|
||||
# build GPT model
|
||||
with ColoInitContext(device='cpu', default_dist_spec=default_dist_spec, default_pg=default_pg):
|
||||
with ColoInitContext(default_device=torch.device('cuda')):
|
||||
model = gpt2_medium(checkpoint=True)
|
||||
pg = default_pg
|
||||
# Tensor Parallelism (TP)
|
||||
tensor_parallelize(model, pg)
|
||||
|
||||
# Gemini + ZeRO DP, Note it must be used after TP
|
||||
plugin = GeminiPlugin(placement_policy='cuda', max_norm=1.0, initial_scale=2**5)
|
||||
|
||||
# Gemini + ZeRO DP
|
||||
plugin = GeminiPlugin(max_norm=1.0, initial_scale=2**5)
|
||||
booster = Booster(plugin=plugin)
|
||||
model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion)
|
||||
|
||||
|
|
|
@ -53,32 +53,37 @@
|
|||
|
||||
我们将运用`GeminiDDP`的方式来使用基于Chunk内存管理的ZeRO。这是我们新包装的torch.Module ,它使用 ZeRO-DP 和 Gemini,其中ZeRO 用于并行,Gemini 用于内存管理。
|
||||
|
||||
同样需要确保你的模型是在 `ColoInitContext` 的上下文中初始化的。
|
||||
Gemini支持惰性初始化, 它可以节省多卡初始化大模型时的显存使用.
|
||||
|
||||
如果你的模型有 `N` billion 个参数,你的 GPU 内存为 `M` GB, 当 `4N >= M` 时,我们推荐使用 LazyInitContext。否则,LazyInitContext 是可选的。
|
||||
|
||||
<!--- doc-test-ignore-start -->
|
||||
```python
|
||||
with ColoInitContext(device='cpu', default_dist_spec=default_dist_spec, default_pg=default_pg):
|
||||
with LazyInitContext(default_device=torch.device('cuda')):
|
||||
model = gpt2_medium(checkpoint=True)
|
||||
```
|
||||
<!--- doc-test-ignore-end -->
|
||||
|
||||
定义模型参数如下:
|
||||
我们提供了 `Booster` API,它用户友好。我们推荐你使用 `Booster` API。如果您仍然想使用底层 API,您可以继续阅读本节其他内容。
|
||||
|
||||
使用 `GeminiDDP` 包装模型。
|
||||
|
||||
<!--- doc-test-ignore-start -->
|
||||
```python
|
||||
chunk_manager = init_chunk_manager(model=module,
|
||||
init_device=device,
|
||||
hidden_dim=hidden_dim,
|
||||
search_range_m=search_range_m,
|
||||
min_chunk_size_m=min_chunk_size_m)
|
||||
gemini_manager = GeminiManager(placement_policy, chunk_manager)
|
||||
model = ZeroDDP(model, gemini_manager)
|
||||
model = GeminiDDP(model, hidden_dim=hidden_dim, min_chunk_size_m=min_chunk_size_m)
|
||||
```
|
||||
<!--- doc-test-ignore-end -->
|
||||
|
||||
`hidden dim`是DNN的隐藏维度。用户可以提供这个参数来加快搜索速度。如果用户在训练前不知道这个参数也可以。 我们将使用默认值 1024。`min_chunk_size_m`是以兆(2^20)为单位的最小块大小。如果参数的总大小仍然小于最小块大小,则所有参数将被压缩为一个小块。
|
||||
|
||||
初始化优化器。
|
||||
<!--- doc-test-ignore-start -->
|
||||
```python
|
||||
optimizer = GeminiAdamOptimizer(model, lr=1e-3, initial_scale=2**5)
|
||||
```
|
||||
<!--- doc-test-ignore-end -->
|
||||
|
||||
<!--- doc-test-ignore-start -->
|
||||
训练
|
||||
```python
|
||||
optimizer.zero_grad()
|
||||
|
@ -87,6 +92,7 @@ loss = criterion(outputs, input_ids)
|
|||
optimizer.backward(loss)
|
||||
optimizer.step()
|
||||
```
|
||||
<!--- doc-test-ignore-end -->
|
||||
> ⚠️ 注意:请不要使用`loss.backward()`,规范写法是`optimizer.backward(loss)`。
|
||||
|
||||
### 训练GPT
|
||||
|
@ -143,47 +149,6 @@ class GPTLMLoss(nn.Module):
|
|||
return self.loss_fn(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
|
||||
```
|
||||
|
||||
定义张量并行和参数分片策略:
|
||||
|
||||
```python
|
||||
def tensor_parallelize(model: torch.nn.Module, pg: ProcessGroup):
|
||||
for mn, module in model.named_modules():
|
||||
for pn, param in module.named_parameters(recurse=False):
|
||||
if hasattr(param, 'visited'):
|
||||
continue
|
||||
param.set_dist_spec(ReplicaSpec())
|
||||
if 'mlp.c_fc' in mn:
|
||||
if 'weight' in pn or 'bias' in pn:
|
||||
split_param_col_tp1d(param, pg)
|
||||
param.compute_spec.set_output_replicate(False)
|
||||
else:
|
||||
param.set_dist_spec(ReplicaSpec())
|
||||
elif 'mlp.c_proj' in mn:
|
||||
if 'weight' in pn:
|
||||
split_param_row_tp1d(param, pg)
|
||||
else:
|
||||
param.set_dist_spec(ReplicaSpec())
|
||||
elif 'wte' in mn or 'wpe' in mn:
|
||||
split_param_col_tp1d(param, pg)
|
||||
elif 'c_attn' in mn or 'c_proj' in mn:
|
||||
split_param_col_tp1d(param, pg)
|
||||
else:
|
||||
param.set_dist_spec(ReplicaSpec())
|
||||
|
||||
param.visited = True
|
||||
def split_param_single_dim_tp1d(dim: int, param: ColoParameter, pg: ProcessGroup):
|
||||
spec = (ShardSpec([dim], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
|
||||
param.set_tensor_spec(*spec)
|
||||
|
||||
|
||||
def split_param_row_tp1d(param: ColoParameter, pg: ProcessGroup):
|
||||
split_param_single_dim_tp1d(0, param, pg)
|
||||
|
||||
|
||||
def split_param_col_tp1d(param: ColoParameter, pg: ProcessGroup):
|
||||
split_param_single_dim_tp1d(-1, param, pg)
|
||||
```
|
||||
|
||||
写一个获得随机输入的函数:
|
||||
|
||||
```python
|
||||
|
@ -200,7 +165,7 @@ def get_data(batch_size, seq_len, vocab_size):
|
|||
from colossalai.nn.optimizer import HybridAdam
|
||||
|
||||
from colossalai.booster import Booster
|
||||
from colossalai.zero import ColoInitContext
|
||||
from colossalai.lazy import LazyInitContext
|
||||
from colossalai.booster.plugin import GeminiPlugin
|
||||
|
||||
def main():
|
||||
|
@ -216,17 +181,13 @@ def main():
|
|||
optimizer = HybridAdam(model.parameters(), lr=0.001)
|
||||
|
||||
torch.manual_seed(123)
|
||||
default_pg = ProcessGroup(tp_degree=args.tp_degree)
|
||||
default_dist_spec = ShardSpec([-1], [args.tp_degree])
|
||||
# build GPT model
|
||||
with ColoInitContext(device='cpu', default_dist_spec=default_dist_spec, default_pg=default_pg):
|
||||
with ColoInitContext(default_device=torch.device('cuda')):
|
||||
model = gpt2_medium(checkpoint=True)
|
||||
pg = default_pg
|
||||
# Tensor Parallelism (TP)
|
||||
tensor_parallelize(model, pg)
|
||||
|
||||
# Gemini + ZeRO DP, Note it must be used after TP
|
||||
plugin = GeminiPlugin(placement_policy='cuda', max_norm=1.0, initial_scale=2**5)
|
||||
|
||||
# Gemini + ZeRO DP
|
||||
plugin = GeminiPlugin(max_norm=1.0, initial_scale=2**5)
|
||||
booster = Booster(plugin=plugin)
|
||||
model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion)
|
||||
|
||||
|
|
|
@ -22,7 +22,7 @@ from colossalai.nn.parallel import GeminiDDP, zero_model_wrapper, zero_optim_wra
|
|||
from colossalai.tensor import ColoParameter, ComputePattern, ComputeSpec, ProcessGroup, ReplicaSpec, ShardSpec
|
||||
from colossalai.utils import get_current_device
|
||||
from colossalai.utils.model.colo_init_context import ColoInitContext
|
||||
from colossalai.zero import ZeroOptimizer
|
||||
from colossalai.zero import GeminiOptimizer
|
||||
|
||||
|
||||
def main():
|
||||
|
@ -46,7 +46,7 @@ def main():
|
|||
args.local_rank = -1
|
||||
args.log_interval = 1
|
||||
else:
|
||||
colossalai.launch_from_torch(config={}) #args.colossal_config
|
||||
colossalai.launch_from_torch(config={}) # args.colossal_config
|
||||
args.local_rank = int(os.environ["LOCAL_RANK"])
|
||||
logger.info(
|
||||
f'launch_from_torch, world size: {torch.distributed.get_world_size()} | ' +
|
||||
|
@ -123,7 +123,8 @@ def main():
|
|||
get_tflops_func = partial(get_tflops, numel, args.train_micro_batch_size_per_gpu, args.max_seq_length)
|
||||
|
||||
# 144003367 is is the length of the entire dataset
|
||||
steps_per_epoch = 144003367 // world_size // args.train_micro_batch_size_per_gpu // args.gradient_accumulation_steps // args.refresh_bucket_size #len(dataloader)
|
||||
# len(dataloader)
|
||||
steps_per_epoch = 144003367 // world_size // args.train_micro_batch_size_per_gpu // args.gradient_accumulation_steps // args.refresh_bucket_size
|
||||
total_steps = steps_per_epoch * args.epoch
|
||||
|
||||
lr_scheduler = get_lr_scheduler(optimizer, total_steps=total_steps, last_epoch=-1)
|
||||
|
|
|
@ -20,6 +20,5 @@ for plugin in "gemini"; do
|
|||
--lr_scheduler="constant" \
|
||||
--lr_warmup_steps=0 \
|
||||
--test_run=True \
|
||||
--num_class_images=200 \
|
||||
--placement="auto" # "cuda"
|
||||
--num_class_images=200
|
||||
done
|
||||
|
|
|
@ -2,9 +2,9 @@ import argparse
|
|||
import hashlib
|
||||
import math
|
||||
import os
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
import shutil
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
@ -19,6 +19,8 @@ from tqdm.auto import tqdm
|
|||
from transformers import AutoTokenizer, PretrainedConfig
|
||||
|
||||
import colossalai
|
||||
from colossalai.booster import Booster
|
||||
from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin
|
||||
from colossalai.context.parallel_mode import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.logging import disable_existing_loggers, get_dist_logger
|
||||
|
@ -26,8 +28,6 @@ from colossalai.nn.optimizer import HybridAdam
|
|||
from colossalai.utils import get_current_device
|
||||
from colossalai.zero import ColoInitContext
|
||||
from colossalai.zero.gemini import get_static_torch_model
|
||||
from colossalai.booster import Booster
|
||||
from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin
|
||||
|
||||
disable_existing_loggers()
|
||||
logger = get_dist_logger()
|
||||
|
@ -138,10 +138,10 @@ def parse_args(input_args=None):
|
|||
" resolution"),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--placement",
|
||||
type=str,
|
||||
default="cpu",
|
||||
help="Placement Policy for Gemini. Valid when using colossalai as dist plan.",
|
||||
"--offload_optim_frac",
|
||||
type=float,
|
||||
default=1.0,
|
||||
help="Fraction of optimizer states to be offloaded. Valid when using colossalai as dist plan.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--center_crop",
|
||||
|
@ -461,18 +461,17 @@ def main(args):
|
|||
revision=args.revision,
|
||||
)
|
||||
|
||||
|
||||
if args.externel_unet_path is None:
|
||||
logger.info(f"Loading UNet2DConditionModel from {args.pretrained_model_name_or_path}", ranks=[0])
|
||||
unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path,
|
||||
subfolder="unet",
|
||||
revision=args.revision,
|
||||
low_cpu_mem_usage=False)
|
||||
subfolder="unet",
|
||||
revision=args.revision,
|
||||
low_cpu_mem_usage=False)
|
||||
else:
|
||||
logger.info(f"Loading UNet2DConditionModel from {args.externel_unet_path}", ranks=[0])
|
||||
unet = UNet2DConditionModel.from_pretrained(args.externel_unet_path,
|
||||
revision=args.revision,
|
||||
low_cpu_mem_usage=False)
|
||||
revision=args.revision,
|
||||
low_cpu_mem_usage=False)
|
||||
|
||||
vae.requires_grad_(False)
|
||||
text_encoder.requires_grad_(False)
|
||||
|
@ -491,30 +490,31 @@ def main(args):
|
|||
if args.plugin.startswith('torch_ddp'):
|
||||
plugin = TorchDDPPlugin()
|
||||
elif args.plugin == 'gemini':
|
||||
plugin = GeminiPlugin(placement_policy=args.placement, strict_ddp_mode=True, initial_scale=2 ** 5)
|
||||
plugin = GeminiPlugin(offload_optim_frac=args.offload_optim_frac, strict_ddp_mode=True, initial_scale=2**5)
|
||||
elif args.plugin == 'low_level_zero':
|
||||
plugin = LowLevelZeroPlugin(initial_scale=2 ** 5)
|
||||
plugin = LowLevelZeroPlugin(initial_scale=2**5)
|
||||
|
||||
booster = Booster(plugin=plugin, **booster_kwargs)
|
||||
|
||||
# config optimizer for colossalai zero
|
||||
optimizer = HybridAdam(unet.parameters(), lr=args.learning_rate, initial_scale=2**5, clipping_norm=args.max_grad_norm)
|
||||
optimizer = HybridAdam(unet.parameters(),
|
||||
lr=args.learning_rate,
|
||||
initial_scale=2**5,
|
||||
clipping_norm=args.max_grad_norm)
|
||||
|
||||
# load noise_scheduler
|
||||
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
|
||||
|
||||
# prepare dataset
|
||||
logger.info(f"Prepare dataset from {args.instance_data_dir}", ranks=[0])
|
||||
train_dataset = DreamBoothDataset(
|
||||
instance_data_root=args.instance_data_dir,
|
||||
instance_prompt=args.instance_prompt,
|
||||
class_data_root=args.class_data_dir if args.with_prior_preservation else None,
|
||||
class_prompt=args.class_prompt,
|
||||
tokenizer=tokenizer,
|
||||
size=args.resolution,
|
||||
center_crop=args.center_crop,
|
||||
test=args.test_run
|
||||
)
|
||||
train_dataset = DreamBoothDataset(instance_data_root=args.instance_data_dir,
|
||||
instance_prompt=args.instance_prompt,
|
||||
class_data_root=args.class_data_dir if args.with_prior_preservation else None,
|
||||
class_prompt=args.class_prompt,
|
||||
tokenizer=tokenizer,
|
||||
size=args.resolution,
|
||||
center_crop=args.center_crop,
|
||||
test=args.test_run)
|
||||
|
||||
def collate_fn(examples):
|
||||
input_ids = [example["instance_prompt_ids"] for example in examples]
|
||||
|
@ -690,6 +690,7 @@ def main(args):
|
|||
if args.push_to_hub:
|
||||
repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
main(args)
|
||||
|
|
|
@ -2,9 +2,9 @@ import argparse
|
|||
import hashlib
|
||||
import math
|
||||
import os
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
import shutil
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
@ -21,6 +21,8 @@ from tqdm.auto import tqdm
|
|||
from transformers import AutoTokenizer, PretrainedConfig
|
||||
|
||||
import colossalai
|
||||
from colossalai.booster import Booster
|
||||
from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin
|
||||
from colossalai.context.parallel_mode import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.logging import disable_existing_loggers, get_dist_logger
|
||||
|
@ -28,8 +30,6 @@ from colossalai.nn.optimizer import HybridAdam
|
|||
from colossalai.utils import get_current_device
|
||||
from colossalai.zero import ColoInitContext, GeminiAdamOptimizer
|
||||
from colossalai.zero.gemini import get_static_torch_model
|
||||
from colossalai.booster import Booster
|
||||
from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin
|
||||
|
||||
disable_existing_loggers()
|
||||
logger = get_dist_logger()
|
||||
|
@ -459,18 +459,17 @@ def main(args):
|
|||
revision=args.revision,
|
||||
)
|
||||
|
||||
|
||||
if args.externel_unet_path is None:
|
||||
logger.info(f"Loading UNet2DConditionModel from {args.pretrained_model_name_or_path}", ranks=[0])
|
||||
unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path,
|
||||
subfolder="unet",
|
||||
revision=args.revision,
|
||||
low_cpu_mem_usage=False)
|
||||
subfolder="unet",
|
||||
revision=args.revision,
|
||||
low_cpu_mem_usage=False)
|
||||
else:
|
||||
logger.info(f"Loading UNet2DConditionModel from {args.externel_unet_path}", ranks=[0])
|
||||
unet = UNet2DConditionModel.from_pretrained(args.externel_unet_path,
|
||||
revision=args.revision,
|
||||
low_cpu_mem_usage=False)
|
||||
revision=args.revision,
|
||||
low_cpu_mem_usage=False)
|
||||
unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path,
|
||||
subfolder="unet",
|
||||
revision=args.revision,
|
||||
|
@ -490,8 +489,7 @@ def main(args):
|
|||
block_id = int(name[len("down_blocks.")])
|
||||
hidden_size = unet.config.block_out_channels[block_id]
|
||||
|
||||
lora_attn_procs[name] = LoRACrossAttnProcessor(hidden_size=hidden_size,
|
||||
cross_attention_dim=cross_attention_dim)
|
||||
lora_attn_procs[name] = LoRACrossAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim)
|
||||
|
||||
unet.set_attn_processor(lora_attn_procs)
|
||||
lora_layers = AttnProcsLayers(unet.attn_processors)
|
||||
|
@ -513,14 +511,17 @@ def main(args):
|
|||
if args.plugin.startswith('torch_ddp'):
|
||||
plugin = TorchDDPPlugin()
|
||||
elif args.plugin == 'gemini':
|
||||
plugin = GeminiPlugin(placement_policy='cuda', strict_ddp_mode=True, initial_scale=2 ** 5)
|
||||
plugin = GeminiPlugin(strict_ddp_mode=True, initial_scale=2**5)
|
||||
elif args.plugin == 'low_level_zero':
|
||||
plugin = LowLevelZeroPlugin(initial_scale=2 ** 5)
|
||||
plugin = LowLevelZeroPlugin(initial_scale=2**5)
|
||||
|
||||
booster = Booster(plugin=plugin, **booster_kwargs)
|
||||
|
||||
# config optimizer for colossalai zero
|
||||
optimizer = HybridAdam(unet.parameters(), lr=args.learning_rate, initial_scale=2**5, clipping_norm=args.max_grad_norm)
|
||||
optimizer = HybridAdam(unet.parameters(),
|
||||
lr=args.learning_rate,
|
||||
initial_scale=2**5,
|
||||
clipping_norm=args.max_grad_norm)
|
||||
|
||||
# load noise_scheduler
|
||||
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
|
||||
|
@ -711,6 +712,7 @@ def main(args):
|
|||
if args.push_to_hub:
|
||||
repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
main(args)
|
||||
|
|
|
@ -49,8 +49,8 @@ python eval.py -c ./ckpt-low_level_zero -e 80
|
|||
|
||||
Expected accuracy performance will be:
|
||||
|
||||
| Model | Single-GPU Baseline FP32 | Booster DDP with FP32 | Booster DDP with FP16 | Booster Low Level Zero |
|
||||
| --------- | ------------------------ | --------------------- | --------------------- | ---------------------- |
|
||||
| ResNet-18 | 85.85% | 84.91% | 85.46% | 84.50% |
|
||||
| Model | Single-GPU Baseline FP32 | Booster DDP with FP32 | Booster DDP with FP16 | Booster Low Level Zero | Booster Gemini |
|
||||
| --------- | ------------------------ | --------------------- | --------------------- | ---------------------- | -------------- |
|
||||
| ResNet-18 | 85.85% | 84.91% | 85.46% | 84.50% | 84.60% |
|
||||
|
||||
**Note: the baseline is adapted from the [script](https://pytorch-tutorial.readthedocs.io/en/latest/tutorial/chapter03_intermediate/3_2_2_cnn_resnet_cifar10/) to use `torchvision.models.resnet18`**
|
||||
|
|
|
@ -104,7 +104,7 @@ def main():
|
|||
'--plugin',
|
||||
type=str,
|
||||
default='torch_ddp',
|
||||
choices=['torch_ddp', 'torch_ddp_fp16', 'low_level_zero'],
|
||||
choices=['torch_ddp', 'torch_ddp_fp16', 'low_level_zero', 'gemini'],
|
||||
help="plugin to use")
|
||||
parser.add_argument('-r', '--resume', type=int, default=-1, help="resume from the epoch's checkpoint")
|
||||
parser.add_argument('-c', '--checkpoint', type=str, default='./checkpoint', help="checkpoint directory")
|
||||
|
@ -141,7 +141,7 @@ def main():
|
|||
if args.plugin.startswith('torch_ddp'):
|
||||
plugin = TorchDDPPlugin()
|
||||
elif args.plugin == 'gemini':
|
||||
plugin = GeminiPlugin(placement_policy='cuda', strict_ddp_mode=True, initial_scale=2**5)
|
||||
plugin = GeminiPlugin(initial_scale=2**5)
|
||||
elif args.plugin == 'low_level_zero':
|
||||
plugin = LowLevelZeroPlugin(initial_scale=2**5)
|
||||
|
||||
|
|
|
@ -1,19 +1,18 @@
|
|||
import time
|
||||
|
||||
import torch
|
||||
import transformers
|
||||
from transformers import ViTConfig, ViTForImageClassification
|
||||
import tqdm
|
||||
import transformers
|
||||
from args import parse_benchmark_args
|
||||
from transformers import ViTConfig, ViTForImageClassification
|
||||
|
||||
import colossalai
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
from colossalai.logging import disable_existing_loggers, get_dist_logger
|
||||
from colossalai.utils import get_current_device
|
||||
from colossalai.booster import Booster
|
||||
from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin
|
||||
from colossalai.cluster import DistCoordinator
|
||||
from colossalai.logging import disable_existing_loggers, get_dist_logger
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
|
||||
from args import parse_benchmark_args
|
||||
|
||||
def format_num(num: int, bytes=False):
|
||||
"""Scale bytes to its proper format, e.g. 1253656 => '1.20MB'"""
|
||||
|
@ -26,8 +25,13 @@ def format_num(num: int, bytes=False):
|
|||
|
||||
|
||||
def get_data(batch_size, num_labels, num_channels=3, height=224, width=224):
|
||||
pixel_values = torch.randn(batch_size, num_channels, height, width, device=torch.cuda.current_device(), dtype=torch.float)
|
||||
labels = torch.randint(0, num_labels, (batch_size, ), device=torch.cuda.current_device(), dtype=torch.int64)
|
||||
pixel_values = torch.randn(batch_size,
|
||||
num_channels,
|
||||
height,
|
||||
width,
|
||||
device=torch.cuda.current_device(),
|
||||
dtype=torch.float)
|
||||
labels = torch.randint(0, num_labels, (batch_size,), device=torch.cuda.current_device(), dtype=torch.int64)
|
||||
return pixel_values, labels
|
||||
|
||||
|
||||
|
@ -55,11 +59,11 @@ def main():
|
|||
transformers.utils.logging.set_verbosity_info()
|
||||
else:
|
||||
transformers.utils.logging.set_verbosity_error()
|
||||
|
||||
|
||||
# Whether to set limit on memory capacity
|
||||
if args.mem_cap > 0:
|
||||
colo_memory_cap(args.mem_cap)
|
||||
|
||||
|
||||
# Build ViT model
|
||||
config = ViTConfig.from_pretrained(args.model_name_or_path)
|
||||
model = ViTForImageClassification(config)
|
||||
|
@ -75,11 +79,7 @@ def main():
|
|||
if args.plugin.startswith('torch_ddp'):
|
||||
plugin = TorchDDPPlugin()
|
||||
elif args.plugin == 'gemini':
|
||||
plugin = GeminiPlugin(device=get_current_device(),
|
||||
placement_policy='cpu',
|
||||
pin_memory=True,
|
||||
strict_ddp_mode=True,
|
||||
initial_scale=2**5)
|
||||
plugin = GeminiPlugin(offload_optim_frac=1.0, pin_memory=True, initial_scale=2**5)
|
||||
elif args.plugin == 'low_level_zero':
|
||||
plugin = LowLevelZeroPlugin(initial_scale=2**5)
|
||||
logger.info(f"Set plugin as {args.plugin}", ranks=[0])
|
||||
|
@ -90,16 +90,15 @@ def main():
|
|||
# Set booster
|
||||
booster = Booster(plugin=plugin, **booster_kwargs)
|
||||
model, optimizer, _, _, _ = booster.boost(model, optimizer)
|
||||
|
||||
|
||||
# Start training.
|
||||
logger.info(f"Start testing", ranks=[0])
|
||||
progress_bar = tqdm.tqdm(total=args.max_train_steps, desc="Training Step", disable=not coordinator.is_master())
|
||||
|
||||
|
||||
torch.cuda.synchronize()
|
||||
model.train()
|
||||
start_time = time.time()
|
||||
|
||||
|
||||
for _ in range(args.max_train_steps):
|
||||
|
||||
pixel_values, labels = get_data(args.batch_size, args.num_labels, 3, 224, 224)
|
||||
|
@ -111,18 +110,19 @@ def main():
|
|||
|
||||
torch.cuda.synchronize()
|
||||
progress_bar.update(1)
|
||||
|
||||
# Compute Statistics
|
||||
|
||||
# Compute Statistics
|
||||
end_time = time.time()
|
||||
throughput = "{:.4f}".format((world_size * args.max_train_steps * args.batch_size) / (end_time - start_time))
|
||||
max_mem = format_num(torch.cuda.max_memory_allocated(device=torch.cuda.current_device()), bytes=True)
|
||||
|
||||
logger.info(f"Testing finished, "
|
||||
f"batch size per gpu: {args.batch_size}, "
|
||||
f"plugin: {args.plugin}, "
|
||||
f"throughput: {throughput}, "
|
||||
f"maximum memory usage per gpu: {max_mem}.",
|
||||
ranks=[0])
|
||||
|
||||
logger.info(
|
||||
f"Testing finished, "
|
||||
f"batch size per gpu: {args.batch_size}, "
|
||||
f"plugin: {args.plugin}, "
|
||||
f"throughput: {throughput}, "
|
||||
f"maximum memory usage per gpu: {max_mem}.",
|
||||
ranks=[0])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -1,20 +1,19 @@
|
|||
import torch
|
||||
import torch.distributed as dist
|
||||
import transformers
|
||||
from transformers import ViTConfig, ViTForImageClassification, ViTImageProcessor
|
||||
from args import parse_demo_args
|
||||
from data import BeansDataset, beans_collator
|
||||
from tqdm import tqdm
|
||||
from transformers import ViTConfig, ViTForImageClassification, ViTImageProcessor
|
||||
|
||||
import colossalai
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
|
||||
from colossalai.logging import disable_existing_loggers, get_dist_logger
|
||||
from colossalai.utils import get_current_device
|
||||
from colossalai.booster import Booster
|
||||
from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin
|
||||
from colossalai.cluster import DistCoordinator
|
||||
|
||||
from args import parse_demo_args
|
||||
from data import BeansDataset, beans_collator
|
||||
from colossalai.logging import disable_existing_loggers, get_dist_logger
|
||||
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
|
||||
def move_to_cuda(batch, device):
|
||||
|
@ -22,12 +21,12 @@ def move_to_cuda(batch, device):
|
|||
|
||||
|
||||
def train_epoch(epoch, model, optimizer, lr_scheduler, dataloader, booster, coordinator):
|
||||
|
||||
|
||||
torch.cuda.synchronize()
|
||||
model.train()
|
||||
|
||||
with tqdm(dataloader, desc=f'Epoch [{epoch + 1}]', disable=not coordinator.is_master()) as pbar:
|
||||
|
||||
|
||||
for batch in pbar:
|
||||
|
||||
# Foward
|
||||
|
@ -47,7 +46,7 @@ def train_epoch(epoch, model, optimizer, lr_scheduler, dataloader, booster, coor
|
|||
|
||||
@torch.no_grad()
|
||||
def evaluate_model(epoch, model, eval_dataloader, num_labels, coordinator):
|
||||
|
||||
|
||||
model.eval()
|
||||
accum_loss = torch.zeros(1, device=get_current_device())
|
||||
total_num = torch.zeros(1, device=get_current_device())
|
||||
|
@ -76,9 +75,7 @@ def evaluate_model(epoch, model, eval_dataloader, num_labels, coordinator):
|
|||
print(f"Evaluation result for epoch {epoch + 1}: \
|
||||
average_loss={avg_loss}, \
|
||||
accuracy={accuracy}.")
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
|
@ -102,14 +99,13 @@ def main():
|
|||
train_dataset = BeansDataset(image_processor, split='train')
|
||||
eval_dataset = BeansDataset(image_processor, split='validation')
|
||||
|
||||
|
||||
# Load pretrained ViT model
|
||||
config = ViTConfig.from_pretrained(args.model_name_or_path)
|
||||
config.num_labels = train_dataset.num_labels
|
||||
config.id2label = {str(i): c for i, c in enumerate(train_dataset.label_names)}
|
||||
config.label2id = {c: str(i) for i, c in enumerate(train_dataset.label_names)}
|
||||
model = ViTForImageClassification.from_pretrained(args.model_name_or_path,
|
||||
config=config,
|
||||
model = ViTForImageClassification.from_pretrained(args.model_name_or_path,
|
||||
config=config,
|
||||
ignore_mismatched_sizes=True)
|
||||
logger.info(f"Finish loading model from {args.model_name_or_path}", ranks=[0])
|
||||
|
||||
|
@ -123,26 +119,22 @@ def main():
|
|||
if args.plugin.startswith('torch_ddp'):
|
||||
plugin = TorchDDPPlugin()
|
||||
elif args.plugin == 'gemini':
|
||||
plugin = GeminiPlugin(device=get_current_device(),
|
||||
placement_policy='cpu',
|
||||
pin_memory=True,
|
||||
strict_ddp_mode=True,
|
||||
initial_scale=2**5)
|
||||
plugin = GeminiPlugin(offload_optim_frac=1.0, pin_memory=True, initial_scale=2**5)
|
||||
elif args.plugin == 'low_level_zero':
|
||||
plugin = LowLevelZeroPlugin(initial_scale=2**5)
|
||||
logger.info(f"Set plugin as {args.plugin}", ranks=[0])
|
||||
|
||||
# Prepare dataloader
|
||||
train_dataloader = plugin.prepare_dataloader(train_dataset,
|
||||
batch_size=args.batch_size,
|
||||
shuffle=True,
|
||||
drop_last=True,
|
||||
collate_fn=beans_collator)
|
||||
batch_size=args.batch_size,
|
||||
shuffle=True,
|
||||
drop_last=True,
|
||||
collate_fn=beans_collator)
|
||||
eval_dataloader = plugin.prepare_dataloader(eval_dataset,
|
||||
batch_size=args.batch_size,
|
||||
shuffle=True,
|
||||
drop_last=True,
|
||||
collate_fn=beans_collator)
|
||||
batch_size=args.batch_size,
|
||||
shuffle=True,
|
||||
drop_last=True,
|
||||
collate_fn=beans_collator)
|
||||
|
||||
# Set optimizer
|
||||
optimizer = HybridAdam(model.parameters(), lr=(args.learning_rate * world_size), weight_decay=args.weight_decay)
|
||||
|
@ -156,11 +148,11 @@ def main():
|
|||
|
||||
# Set booster
|
||||
booster = Booster(plugin=plugin, **booster_kwargs)
|
||||
model, optimizer, _, train_dataloader, lr_scheduler = booster.boost(model=model,
|
||||
optimizer=optimizer,
|
||||
dataloader=train_dataloader,
|
||||
lr_scheduler=lr_scheduler)
|
||||
|
||||
model, optimizer, _, train_dataloader, lr_scheduler = booster.boost(model=model,
|
||||
optimizer=optimizer,
|
||||
dataloader=train_dataloader,
|
||||
lr_scheduler=lr_scheduler)
|
||||
|
||||
# Finetuning
|
||||
logger.info(f"Start finetuning", ranks=[0])
|
||||
for epoch in range(args.num_epoch):
|
||||
|
@ -174,4 +166,4 @@ def main():
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
main()
|
||||
|
|
|
@ -7,6 +7,14 @@ This directory includes two parts: Using the Booster API finetune Huggingface Be
|
|||
bash test_ci.sh
|
||||
```
|
||||
|
||||
### Results on 2-GPU
|
||||
|
||||
| Plugin | Accuracy | F1-score |
|
||||
| -------------- | -------- | -------- |
|
||||
| torch_ddp | 84.4% | 88.6% |
|
||||
| torch_ddp_fp16 | 84.7% | 88.8% |
|
||||
| gemini | 84.0% | 88.4% |
|
||||
|
||||
## Benchmark
|
||||
```
|
||||
bash benchmark.sh
|
||||
|
@ -14,9 +22,9 @@ bash benchmark.sh
|
|||
|
||||
Now include these metrics in benchmark: CUDA mem occupy, throughput and the number of model parameters. If you have custom metrics, you can add them to benchmark_util.
|
||||
|
||||
## Results
|
||||
### Results
|
||||
|
||||
### Bert
|
||||
#### Bert
|
||||
|
||||
| | max cuda mem | throughput(sample/s) | params |
|
||||
| :-----| -----------: | :--------: | :----: |
|
||||
|
@ -25,10 +33,10 @@ Now include these metrics in benchmark: CUDA mem occupy, throughput and the numb
|
|||
| gemini | 11.0 GB | 12.9 | 82M |
|
||||
| low_level_zero | 11.29 G | 14.7 | 82M |
|
||||
|
||||
### AlBert
|
||||
#### AlBert
|
||||
| | max cuda mem | throughput(sample/s) | params |
|
||||
| :-----| -----------: | :--------: | :----: |
|
||||
| ddp | OOM | | |
|
||||
| ddp_fp16 | OOM | | |
|
||||
| gemini | 69.39 G | 1.3 | 208M |
|
||||
| low_level_zero | 56.89 G | 1.4 | 208M |
|
||||
| low_level_zero | 56.89 G | 1.4 | 208M |
|
||||
|
|
|
@ -38,8 +38,8 @@ def move_to_cuda(batch):
|
|||
|
||||
|
||||
@torch.no_grad()
|
||||
def evaluate_model(model: nn.Module, test_dataloader: Union[DataLoader, List[DataLoader]], num_labels: int, task_name: str,
|
||||
eval_splits: List[str], coordinator: DistCoordinator):
|
||||
def evaluate_model(model: nn.Module, test_dataloader: Union[DataLoader, List[DataLoader]], num_labels: int,
|
||||
task_name: str, eval_splits: List[str], coordinator: DistCoordinator):
|
||||
metric = evaluate.load("glue", task_name, process_id=coordinator.rank, num_process=coordinator.world_size)
|
||||
model.eval()
|
||||
|
||||
|
@ -142,7 +142,7 @@ def main():
|
|||
if args.plugin.startswith('torch_ddp'):
|
||||
plugin = TorchDDPPlugin()
|
||||
elif args.plugin == 'gemini':
|
||||
plugin = GeminiPlugin(placement_policy='cuda', strict_ddp_mode=True, initial_scale=2**5)
|
||||
plugin = GeminiPlugin(initial_scale=2**5)
|
||||
elif args.plugin == 'low_level_zero':
|
||||
plugin = LowLevelZeroPlugin(initial_scale=2**5)
|
||||
|
||||
|
@ -208,7 +208,7 @@ def main():
|
|||
train_epoch(epoch, model, optimizer, lr_scheduler, train_dataloader, booster, coordinator)
|
||||
|
||||
results = evaluate_model(model, test_dataloader, data_builder.num_labels, args.task, data_builder.eval_splits,
|
||||
coordinator)
|
||||
coordinator)
|
||||
|
||||
if coordinator.is_master():
|
||||
print(results)
|
||||
|
|
|
@ -4,9 +4,6 @@ export DISTPLAN=${DISTPLAN:-"CAI_Gemini"}
|
|||
|
||||
# The following options only valid when DISTPLAN="colossalai"
|
||||
export GPUNUM=${GPUNUM:-1}
|
||||
export TPDEGREE=${TPDEGREE:-1}
|
||||
export PLACEMENT=${PLACEMENT:-"cpu"}
|
||||
export USE_SHARD_INIT=${USE_SHARD_INIT:-False}
|
||||
export BATCH_SIZE=${BATCH_SIZE:-16}
|
||||
export MODEL_TYPE=${MODEL_TYPE:-"gpt2_medium"}
|
||||
export TRAIN_STEP=${TRAIN_STEP:-10}
|
||||
|
@ -21,11 +18,8 @@ fi
|
|||
mkdir -p gemini_logs
|
||||
|
||||
torchrun --standalone --nproc_per_node=${GPUNUM} ./train_gpt_demo.py \
|
||||
--tp_degree=${TPDEGREE} \
|
||||
--model_type=${MODEL_TYPE} \
|
||||
--batch_size=${BATCH_SIZE} \
|
||||
--placement=${PLACEMENT} \
|
||||
${USE_SHARD_INIT} \
|
||||
--distplan=${DISTPLAN} \
|
||||
--train_step=${TRAIN_STEP} \
|
||||
2>&1 | tee ./gemini_logs/${MODEL_TYPE}_${DISTPLAN}_gpu_${GPUNUM}_bs_${BATCH_SIZE}_tp_${TPDEGREE}_${PLACEMENT}.log
|
||||
|
|
|
@ -6,29 +6,17 @@ for MODEL_TYPE in "gpt2_medium"; do
|
|||
for DISTPLAN in "CAI_Gemini"; do
|
||||
for BATCH_SIZE in 2; do
|
||||
for GPUNUM in 1 4; do
|
||||
for TPDEGREE in 1 2; do
|
||||
if [ ${TPDEGREE} -gt ${GPUNUM} ]; then
|
||||
continue
|
||||
fi
|
||||
for PLACEMENT in "cpu" "auto"; do
|
||||
MODEL_TYPE=${MODEL_TYPE} DISTPLAN=${DISTPLAN} BATCH_SIZE=${BATCH_SIZE} GPUNUM=${GPUNUM} TPDEGREE=${TPDEGREE} PLACEMENT=${PLACEMENT} \
|
||||
bash ./run_gemini.sh
|
||||
done
|
||||
done
|
||||
MODEL_TYPE=${MODEL_TYPE} DISTPLAN=${DISTPLAN} BATCH_SIZE=${BATCH_SIZE} GPUNUM=${GPUNUM} \
|
||||
bash ./run_gemini.sh
|
||||
done
|
||||
done
|
||||
done
|
||||
|
||||
for DISTPLAN in "zero1" "zero2"; do
|
||||
for DISTPLAN in "CAI_ZeRO2" "CAI_ZeRO1"; do
|
||||
for BATCH_SIZE in 2; do
|
||||
for GPUNUM in 1 4; do
|
||||
for TPDEGREE in 1; do
|
||||
if [ ${TPDEGREE} -gt ${GPUNUM} ]; then
|
||||
continue
|
||||
fi
|
||||
MODEL_TYPE=${MODEL_TYPE} DISTPLAN=${DISTPLAN} BATCH_SIZE=${BATCH_SIZE} GPUNUM=${GPUNUM} TPDEGREE=${TPDEGREE}\
|
||||
bash ./run_gemini.sh
|
||||
done
|
||||
MODEL_TYPE=${MODEL_TYPE} DISTPLAN=${DISTPLAN} BATCH_SIZE=${BATCH_SIZE} GPUNUM=${GPUNUM} \
|
||||
bash ./run_gemini.sh
|
||||
done
|
||||
done
|
||||
done
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
import os
|
||||
from contextlib import nullcontext
|
||||
from functools import partial
|
||||
from time import time
|
||||
|
||||
|
@ -13,11 +14,10 @@ from torch.nn.parallel import DistributedDataParallel as DDP
|
|||
import colossalai
|
||||
from colossalai.booster import Booster
|
||||
from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin
|
||||
from colossalai.lazy import LazyInitContext
|
||||
from colossalai.logging import disable_existing_loggers, get_dist_logger
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
from colossalai.tensor import ColoParameter, ComputePattern, ComputeSpec, ProcessGroup, ReplicaSpec, ShardSpec
|
||||
from colossalai.utils import get_current_device
|
||||
from colossalai.zero import ColoInitContext
|
||||
|
||||
CAI_VERSION = colossalai.__version__
|
||||
|
||||
|
@ -30,24 +30,6 @@ def parse_args():
|
|||
default='CAI_Gemini',
|
||||
help="The distributed plan [colossalai, zero1, zero2, torch_ddp, torch_zero].",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tp_degree",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Tensor Parallelism Degree. Valid when using colossalai as dist plan.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--placement",
|
||||
type=str,
|
||||
default='cpu',
|
||||
help="Placement Policy for Gemini. Valid when using colossalai as dist plan.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--shardinit",
|
||||
action='store_true',
|
||||
help=
|
||||
"Shard the tensors when init the model to shrink peak memory size on the assigned device. Valid when using colossalai as dist plan.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--batch_size",
|
||||
type=int,
|
||||
|
@ -71,20 +53,6 @@ def parse_args():
|
|||
return args
|
||||
|
||||
|
||||
# Parameter Sharding Strategies for Tensor Parallelism
|
||||
def split_param_single_dim_tp1d(dim: int, param: ColoParameter, pg: ProcessGroup):
|
||||
spec = (ShardSpec([dim], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
|
||||
param.set_tensor_spec(*spec)
|
||||
|
||||
|
||||
def split_param_row_tp1d(param: ColoParameter, pg: ProcessGroup):
|
||||
split_param_single_dim_tp1d(0, param, pg)
|
||||
|
||||
|
||||
def split_param_col_tp1d(param: ColoParameter, pg: ProcessGroup):
|
||||
split_param_single_dim_tp1d(-1, param, pg)
|
||||
|
||||
|
||||
class GPTLMLoss(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
|
@ -140,47 +108,6 @@ def set_cpu_maximum_parallelism():
|
|||
print(f"environmental variable OMP_NUM_THREADS is set to {max_concurrency}.")
|
||||
|
||||
|
||||
# Tensor Parallel
|
||||
def tensor_parallelize(model: torch.nn.Module, pg: ProcessGroup):
|
||||
"""tensor_parallelize
|
||||
Sharding the Model Parameters.
|
||||
|
||||
Args:
|
||||
model (torch.nn.Module): a torch module to be sharded
|
||||
"""
|
||||
for mn, module in model.named_modules():
|
||||
for pn, param in module.named_parameters(recurse=False):
|
||||
# NOTE() a param maybe shared by two modules
|
||||
if hasattr(param, 'visited'):
|
||||
continue
|
||||
|
||||
# if shard init, then convert param to replica and use the dp-only ProcessGroup
|
||||
param: ColoParameter = param
|
||||
param.set_dist_spec(ReplicaSpec())
|
||||
param.set_process_group(pg)
|
||||
|
||||
# shard it w.r.t tp pattern
|
||||
if 'mlp.c_fc' in mn:
|
||||
if 'weight' in pn or 'bias' in pn:
|
||||
split_param_col_tp1d(param, pg) # column slice
|
||||
# keep the shape of the output from c_fc
|
||||
param.compute_spec.set_output_replicate(False)
|
||||
else:
|
||||
param.set_dist_spec(ReplicaSpec())
|
||||
elif 'mlp.c_proj' in mn:
|
||||
if 'weight' in pn:
|
||||
split_param_row_tp1d(param, pg) # row slice
|
||||
else:
|
||||
param.set_dist_spec(ReplicaSpec())
|
||||
elif 'wte' in mn or 'wpe' in mn:
|
||||
split_param_col_tp1d(param, pg) # column slice
|
||||
elif 'c_attn' in mn or 'c_proj' in mn:
|
||||
split_param_col_tp1d(param, pg) # column slice
|
||||
else:
|
||||
param.set_dist_spec(ReplicaSpec())
|
||||
param.visited = True
|
||||
|
||||
|
||||
def main():
|
||||
# version check
|
||||
# this example is supposed to work for versions greater than 0.2.0
|
||||
|
@ -213,30 +140,13 @@ def main():
|
|||
|
||||
# build criterion
|
||||
criterion = GPTLMLoss()
|
||||
|
||||
torch.manual_seed(123)
|
||||
if args.distplan.startswith("CAI"):
|
||||
# all param must use the same process group.
|
||||
world_size = torch.distributed.get_world_size()
|
||||
shard_pg = ProcessGroup(tp_degree=world_size) if args.shardinit else None
|
||||
default_dist_spec = ShardSpec([-1], [world_size]) if args.shardinit else None
|
||||
|
||||
if args.shardinit and args.distplan != "CAI_Gemini":
|
||||
raise RuntimeError("You can only use shardinit with CAI_Gemini")
|
||||
|
||||
ctx = LazyInitContext(default_device=get_current_device()) if args.distplan == "CAI_Gemini" else nullcontext()
|
||||
# build GPT model
|
||||
with ColoInitContext(device=get_current_device(),
|
||||
dtype=torch.half,
|
||||
default_dist_spec=default_dist_spec,
|
||||
default_pg=shard_pg):
|
||||
with ctx:
|
||||
model = model_builder(args.model_type)(checkpoint=True)
|
||||
|
||||
tp_pg = ProcessGroup(tp_degree=args.tp_degree)
|
||||
# Tensor Parallelism (TP)
|
||||
# You should notice that v0.1.10 is not compatible with TP degree > 1
|
||||
if args.tp_degree > 1:
|
||||
tensor_parallelize(model, tp_pg)
|
||||
|
||||
# assign running configurations
|
||||
if args.distplan == "CAI_ZeRO1":
|
||||
zero_stage = 1
|
||||
|
@ -254,13 +164,7 @@ def main():
|
|||
overlap_communication=True,
|
||||
verbose=True)
|
||||
elif args.distplan == "CAI_Gemini":
|
||||
plugin = GeminiPlugin(device=get_current_device(),
|
||||
placement_policy=args.placement,
|
||||
pin_memory=True,
|
||||
strict_ddp_mode=args.tp_degree == 1,
|
||||
search_range_m=128,
|
||||
hidden_dim=model.config.n_embd,
|
||||
gpu_margin_mem_ratio=0.)
|
||||
plugin = GeminiPlugin(search_range_m=128, hidden_dim=model.config.n_embd)
|
||||
else:
|
||||
raise RuntimeError
|
||||
|
||||
|
|
|
@ -1,22 +1,18 @@
|
|||
import time
|
||||
|
||||
import torch
|
||||
import tqdm
|
||||
import transformers
|
||||
from args import parse_benchmark_args
|
||||
from transformers import AutoConfig, OPTForCausalLM
|
||||
from transformers.utils.versions import require_version
|
||||
import tqdm
|
||||
|
||||
import colossalai
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
from colossalai.logging import disable_existing_loggers, get_dist_logger
|
||||
from colossalai.tensor import ProcessGroup, ShardSpec
|
||||
from colossalai.utils import get_current_device
|
||||
from colossalai.zero import ColoInitContext
|
||||
from colossalai.booster import Booster
|
||||
from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin
|
||||
from colossalai.cluster import DistCoordinator
|
||||
|
||||
from args import parse_benchmark_args
|
||||
from colossalai.logging import disable_existing_loggers, get_dist_logger
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
|
||||
require_version("transformers>=4.20.0", "To fix: pip install -r requirements.txt")
|
||||
|
||||
|
@ -61,11 +57,11 @@ def main():
|
|||
transformers.utils.logging.set_verbosity_info()
|
||||
else:
|
||||
transformers.utils.logging.set_verbosity_error()
|
||||
|
||||
|
||||
# Whether to set limit of memory capacity
|
||||
if args.mem_cap > 0:
|
||||
colo_memory_cap(args.mem_cap)
|
||||
|
||||
|
||||
# Build OPT model
|
||||
config = AutoConfig.from_pretrained(args.model_name_or_path)
|
||||
model = OPTForCausalLM(config=config)
|
||||
|
@ -81,11 +77,7 @@ def main():
|
|||
if args.plugin.startswith('torch_ddp'):
|
||||
plugin = TorchDDPPlugin()
|
||||
elif args.plugin == 'gemini':
|
||||
plugin = GeminiPlugin(device=get_current_device(),
|
||||
placement_policy='cpu',
|
||||
pin_memory=True,
|
||||
strict_ddp_mode=True,
|
||||
initial_scale=2**5)
|
||||
plugin = GeminiPlugin(offload_optim_frac=1.0, pin_memory=True, initial_scale=2**5)
|
||||
elif args.plugin == 'low_level_zero':
|
||||
plugin = LowLevelZeroPlugin(initial_scale=2**5)
|
||||
logger.info(f"Set plugin as {args.plugin}", ranks=[0])
|
||||
|
@ -96,18 +88,18 @@ def main():
|
|||
# Set booster
|
||||
booster = Booster(plugin=plugin, **booster_kwargs)
|
||||
model, optimizer, _, _, _ = booster.boost(model, optimizer)
|
||||
|
||||
|
||||
SEQ_LEN = 1024
|
||||
VOCAB_SIZE = 50257
|
||||
|
||||
# Start training.
|
||||
logger.info(f"Start testing", ranks=[0])
|
||||
progress_bar = tqdm.tqdm(total=args.max_train_steps, desc="Training Step", disable=not coordinator.is_master())
|
||||
|
||||
|
||||
torch.cuda.synchronize()
|
||||
model.train()
|
||||
start_time = time.time()
|
||||
|
||||
|
||||
for _ in range(args.max_train_steps):
|
||||
|
||||
input_ids, attn_mask = get_data(args.batch_size, SEQ_LEN, VOCAB_SIZE)
|
||||
|
@ -119,18 +111,19 @@ def main():
|
|||
|
||||
torch.cuda.synchronize()
|
||||
progress_bar.update(1)
|
||||
|
||||
# Compute Statistics
|
||||
|
||||
# Compute Statistics
|
||||
end_time = time.time()
|
||||
throughput = "{:.4f}".format((world_size * args.max_train_steps * args.batch_size) / (end_time - start_time))
|
||||
max_mem = format_num(torch.cuda.max_memory_allocated(device=torch.cuda.current_device()), bytes=True)
|
||||
|
||||
logger.info(f"Testing finished, "
|
||||
f"batch size per gpu: {args.batch_size}, "
|
||||
f"plugin: {args.plugin}, "
|
||||
f"throughput: {throughput}, "
|
||||
f"maximum memory usage per gpu: {max_mem}.",
|
||||
ranks=[0])
|
||||
|
||||
logger.info(
|
||||
f"Testing finished, "
|
||||
f"batch size per gpu: {args.batch_size}, "
|
||||
f"plugin: {args.plugin}, "
|
||||
f"throughput: {throughput}, "
|
||||
f"maximum memory usage per gpu: {max_mem}.",
|
||||
ranks=[0])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -1,25 +1,20 @@
|
|||
import time
|
||||
|
||||
import torch
|
||||
import datasets
|
||||
import torch
|
||||
import transformers
|
||||
from transformers import AutoConfig, OPTForCausalLM, AutoTokenizer
|
||||
from transformers import get_linear_schedule_with_warmup
|
||||
from transformers.utils.versions import require_version
|
||||
from args import parse_demo_args
|
||||
from data import NetflixDataset, netflix_collator
|
||||
from tqdm import tqdm
|
||||
from transformers import AutoConfig, AutoTokenizer, OPTForCausalLM, get_linear_schedule_with_warmup
|
||||
from transformers.utils.versions import require_version
|
||||
|
||||
import colossalai
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
from colossalai.logging import disable_existing_loggers, get_dist_logger
|
||||
from colossalai.tensor import ProcessGroup, ShardSpec
|
||||
from colossalai.utils import get_current_device
|
||||
from colossalai.zero import ColoInitContext
|
||||
from colossalai.booster import Booster
|
||||
from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin
|
||||
from colossalai.cluster import DistCoordinator
|
||||
|
||||
from args import parse_demo_args
|
||||
from data import NetflixDataset, netflix_collator
|
||||
from colossalai.logging import disable_existing_loggers, get_dist_logger
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r requirements.txt")
|
||||
require_version("transformers>=4.20.0", "To fix: pip install -r requirements.txt")
|
||||
|
@ -30,18 +25,18 @@ def move_to_cuda(batch, device):
|
|||
|
||||
|
||||
def train_epoch(epoch, model, optimizer, lr_scheduler, dataloader, booster, coordinator):
|
||||
|
||||
|
||||
torch.cuda.synchronize()
|
||||
model.train()
|
||||
|
||||
with tqdm(dataloader, desc=f'Epoch [{epoch + 1}]', disable=not coordinator.is_master()) as pbar:
|
||||
|
||||
|
||||
for batch in pbar:
|
||||
|
||||
# Forward
|
||||
optimizer.zero_grad()
|
||||
batch = move_to_cuda(batch, torch.cuda.current_device())
|
||||
|
||||
|
||||
outputs = model(use_cache=False, **batch)
|
||||
loss = outputs['loss']
|
||||
|
||||
|
@ -72,7 +67,7 @@ def main():
|
|||
else:
|
||||
datasets.utils.logging.set_verbosity_error()
|
||||
transformers.utils.logging.set_verbosity_error()
|
||||
|
||||
|
||||
# Build OPT model
|
||||
config = AutoConfig.from_pretrained(args.model_name_or_path)
|
||||
model = OPTForCausalLM.from_pretrained(args.model_name_or_path, config=config)
|
||||
|
@ -88,43 +83,35 @@ def main():
|
|||
if args.plugin.startswith('torch_ddp'):
|
||||
plugin = TorchDDPPlugin()
|
||||
elif args.plugin == 'gemini':
|
||||
plugin = GeminiPlugin(device=get_current_device(),
|
||||
placement_policy='cpu',
|
||||
pin_memory=True,
|
||||
strict_ddp_mode=True,
|
||||
initial_scale=2**5)
|
||||
plugin = GeminiPlugin(offload_optim_frac=1.0, pin_memory=True, initial_scale=2**5)
|
||||
elif args.plugin == 'low_level_zero':
|
||||
plugin = LowLevelZeroPlugin(initial_scale=2**5)
|
||||
logger.info(f"Set plugin as {args.plugin}", ranks=[0])
|
||||
|
||||
# Prepare tokenizer and dataloader
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path)
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path)
|
||||
dataset = NetflixDataset(tokenizer)
|
||||
dataloader = plugin.prepare_dataloader(dataset,
|
||||
batch_size=args.batch_size,
|
||||
shuffle=True,
|
||||
drop_last=True,
|
||||
collate_fn=netflix_collator)
|
||||
|
||||
|
||||
# Set optimizer
|
||||
optimizer = HybridAdam(model.parameters(),
|
||||
lr=(args.learning_rate * world_size),
|
||||
weight_decay=args.weight_decay)
|
||||
optimizer = HybridAdam(model.parameters(), lr=(args.learning_rate * world_size), weight_decay=args.weight_decay)
|
||||
|
||||
# Set lr scheduler
|
||||
total_steps = len(dataloader) * args.num_epoch
|
||||
num_warmup_steps = int(args.warmup_ratio * total_steps)
|
||||
lr_scheduler = get_linear_schedule_with_warmup(
|
||||
optimizer,
|
||||
num_warmup_steps=num_warmup_steps,
|
||||
num_training_steps=len(dataloader) * args.num_epoch
|
||||
)
|
||||
lr_scheduler = get_linear_schedule_with_warmup(optimizer,
|
||||
num_warmup_steps=num_warmup_steps,
|
||||
num_training_steps=len(dataloader) * args.num_epoch)
|
||||
|
||||
# Set booster
|
||||
booster = Booster(plugin=plugin, **booster_kwargs)
|
||||
model, optimizer, _, dataloader, lr_scheduler = booster.boost(model=model,
|
||||
optimizer=optimizer,
|
||||
dataloader=dataloader,
|
||||
model, optimizer, _, dataloader, lr_scheduler = booster.boost(model=model,
|
||||
optimizer=optimizer,
|
||||
dataloader=dataloader,
|
||||
lr_scheduler=lr_scheduler)
|
||||
|
||||
# Start finetuning
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
import gzip
|
||||
import random
|
||||
from contextlib import nullcontext
|
||||
from functools import partial
|
||||
from time import time
|
||||
|
||||
|
@ -8,20 +8,17 @@ import torch
|
|||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
import tqdm
|
||||
from packaging import version
|
||||
|
||||
from colossalai.nn import HybridAdam
|
||||
from palm_pytorch import PaLM
|
||||
from palm_pytorch.autoregressive_wrapper import AutoregressiveWrapper
|
||||
from torch.utils.data import DataLoader, Dataset
|
||||
|
||||
import colossalai
|
||||
from colossalai.logging import disable_existing_loggers, get_dist_logger
|
||||
from colossalai.tensor import ColoParameter, ComputePattern, ComputeSpec, ProcessGroup, ReplicaSpec, ShardSpec
|
||||
from colossalai.utils import MultiTimer, get_current_device
|
||||
from colossalai.zero import ColoInitContext, GeminiAdamOptimizer, ZeroDDP
|
||||
from colossalai.booster import Booster
|
||||
from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin
|
||||
from colossalai.lazy import LazyInitContext
|
||||
from colossalai.logging import disable_existing_loggers, get_dist_logger
|
||||
from colossalai.nn import HybridAdam
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
# constants
|
||||
|
||||
|
@ -44,23 +41,10 @@ def parse_args():
|
|||
help="The distributed plan [colossalai, pytorch].",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tp_degree",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Tensor Parallelism Degree. Valid when using colossalai as dist plan.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--placement",
|
||||
type=str,
|
||||
default='cpu',
|
||||
help="Placement Policy for Gemini. Valid when using colossalai as dist plan.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--shardinit",
|
||||
type=bool,
|
||||
default=False,
|
||||
help=
|
||||
"Shard the tensors when init the model to shrink peak memory size on the assigned device. Valid when using colossalai as dist plan.",
|
||||
"--offload_optim_frac",
|
||||
type=float,
|
||||
default=1.0,
|
||||
help="Fraction of optimizer states to be offloaded. This is only used for gemini.",
|
||||
)
|
||||
parser.add_argument('-p',
|
||||
'--plugin',
|
||||
|
@ -111,51 +95,6 @@ def get_model_size(model: nn.Module):
|
|||
return total_numel
|
||||
|
||||
|
||||
|
||||
|
||||
# Parameter Sharding Strategies for Tensor Parallelism
|
||||
def split_param_single_dim_tp1d(dim: int, param: ColoParameter, pg: ProcessGroup):
|
||||
spec = (ShardSpec([dim], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
|
||||
param.set_tensor_spec(*spec)
|
||||
|
||||
|
||||
def split_param_row_tp1d(param: ColoParameter, pg: ProcessGroup):
|
||||
split_param_single_dim_tp1d(0, param, pg)
|
||||
|
||||
|
||||
def split_param_col_tp1d(param: ColoParameter, pg: ProcessGroup):
|
||||
split_param_single_dim_tp1d(-1, param, pg)
|
||||
|
||||
|
||||
# Tensor Parallel
|
||||
def tensor_parallelize(model: torch.nn.Module, pg: ProcessGroup):
|
||||
"""tensor_parallelize
|
||||
Sharding the Model Parameters.
|
||||
Args:
|
||||
model (torch.nn.Module): a torch module to be sharded
|
||||
"""
|
||||
for mn, module in model.named_modules():
|
||||
for pn, param in module.named_parameters(recurse=False):
|
||||
if hasattr(param, 'visited'):
|
||||
continue
|
||||
param.set_dist_spec(ReplicaSpec())
|
||||
if 'net.0' in mn:
|
||||
split_param_col_tp1d(param, pg) # column slice
|
||||
elif 'to_q' in mn:
|
||||
split_param_col_tp1d(param, pg) # column slice
|
||||
elif 'to_kv' in mn:
|
||||
split_param_row_tp1d(param, pg) # row slice
|
||||
elif 'to_out' in mn:
|
||||
split_param_row_tp1d(param, pg) # row slice
|
||||
elif '1.1' in mn:
|
||||
split_param_col_tp1d(param, pg) # column slice
|
||||
elif '1.2' in mn:
|
||||
split_param_row_tp1d(param, pg) # row slice
|
||||
else:
|
||||
param.set_dist_spec(ReplicaSpec())
|
||||
param.visited = True
|
||||
|
||||
|
||||
args = parse_args()
|
||||
if args.distplan not in ["colossalai", "pytorch"]:
|
||||
raise TypeError(f"{args.distplan} is error")
|
||||
|
@ -212,23 +151,18 @@ if args.distplan == "colossalai":
|
|||
if args.plugin.startswith('torch_ddp'):
|
||||
plugin = TorchDDPPlugin()
|
||||
elif args.plugin == 'gemini':
|
||||
plugin = GeminiPlugin(placement_policy=args.placement, strict_ddp_mode=True, initial_scale=2 ** 5)
|
||||
plugin = GeminiPlugin(offload_optim_frac=args.offload_optim_frac, initial_scale=2**5)
|
||||
elif args.plugin == 'low_level_zero':
|
||||
plugin = LowLevelZeroPlugin(initial_scale=2 ** 5)
|
||||
plugin = LowLevelZeroPlugin(initial_scale=2**5)
|
||||
logger.info(f"plugin: {plugin}")
|
||||
booster = Booster(plugin=plugin, **booster_kwargs)
|
||||
|
||||
default_pg = ProcessGroup(tp_degree=args.tp_degree)
|
||||
default_dist_spec = ShardSpec([-1], [args.tp_degree]) if args.shardinit else None
|
||||
ctx = ColoInitContext(device='cpu', default_dist_spec=default_dist_spec, default_pg=default_pg)
|
||||
ctx = LazyInitContext(default_device=get_current_device()) if args.plugin == 'gemini' else nullcontext()
|
||||
|
||||
with ctx:
|
||||
model = PaLM(num_tokens=50304, dim=4096, depth=64)
|
||||
model = AutoregressiveWrapper(model, max_seq_len=SEQ_LEN)
|
||||
|
||||
pg = default_pg
|
||||
tensor_parallelize(model, pg)
|
||||
|
||||
# optimizer
|
||||
|
||||
optimizer = HybridAdam(model.parameters(), lr=LEARNING_RATE, initial_scale=2**5)
|
||||
|
|
|
@ -3,5 +3,5 @@ torch >= 1.8.1
|
|||
datasets >= 1.8.0
|
||||
sentencepiece != 0.1.92
|
||||
protobuf
|
||||
accelerate == 0.13.2
|
||||
accelerate
|
||||
transformers
|
||||
|
|
|
@ -30,7 +30,7 @@ from itertools import chain
|
|||
import datasets
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import transformers
|
||||
import transformers.utils.logging as logging
|
||||
from accelerate.utils import set_seed
|
||||
from context import barrier_context
|
||||
from datasets import load_dataset
|
||||
|
@ -57,7 +57,7 @@ from colossalai.logging import disable_existing_loggers, get_dist_logger
|
|||
from colossalai.nn.optimizer import HybridAdam
|
||||
from colossalai.tensor import ProcessGroup
|
||||
from colossalai.utils import get_current_device, get_dataloader
|
||||
from colossalai.zero import ColoInitContext, ZeroDDP, ZeroOptimizer
|
||||
from colossalai.zero import GeminiOptimizer
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")
|
||||
|
||||
|
@ -292,10 +292,10 @@ def main():
|
|||
|
||||
if is_main_process:
|
||||
datasets.utils.logging.set_verbosity_warning()
|
||||
transformers.utils.logging.set_verbosity_info()
|
||||
logging.set_verbosity_info()
|
||||
else:
|
||||
datasets.utils.logging.set_verbosity_error()
|
||||
transformers.utils.logging.set_verbosity_error()
|
||||
logging.set_verbosity_error()
|
||||
|
||||
if args.mem_cap > 0:
|
||||
colo_memory_cap(args.mem_cap)
|
||||
|
@ -391,16 +391,28 @@ def main():
|
|||
else:
|
||||
init_dev = get_current_device()
|
||||
|
||||
cai_version = colossalai.__version__
|
||||
logger.info(f'using Colossal-AI version {cai_version}')
|
||||
# build model
|
||||
if version.parse(cai_version) >= version.parse("0.3.1"):
|
||||
from contextlib import nullcontext
|
||||
|
||||
from colossalai.lazy import LazyInitContext
|
||||
ctx = LazyInitContext(
|
||||
default_device=init_dev
|
||||
) if args.model_name_or_path is None or args.model_name_or_path == 'facebook/opt-13b' else nullcontext()
|
||||
else:
|
||||
from colossalai.zero import ColoInitContext
|
||||
ctx = ColoInitContext(device=init_dev)
|
||||
if args.model_name_or_path is None or args.model_name_or_path == 'facebook/opt-13b':
|
||||
# currently, there has a bug in pretrained opt-13b
|
||||
# we can not import it until huggingface fix it
|
||||
logger.info("Train a new model from scratch", ranks=[0])
|
||||
with ColoInitContext(device=init_dev):
|
||||
with ctx:
|
||||
model = OPTForCausalLM(config)
|
||||
else:
|
||||
logger.info("Finetune a pre-trained model", ranks=[0])
|
||||
with ColoInitContext(device=init_dev):
|
||||
with ctx:
|
||||
model = OPTForCausalLM.from_pretrained(args.model_name_or_path,
|
||||
from_tf=bool(".ckpt" in args.model_name_or_path),
|
||||
config=config,
|
||||
|
@ -410,9 +422,10 @@ def main():
|
|||
model.gradient_checkpointing_enable()
|
||||
|
||||
PLACEMENT_POLICY = 'auto'
|
||||
cai_version = colossalai.__version__
|
||||
logger.info(f'using Colossal-AI version {cai_version}')
|
||||
if version.parse(cai_version) > version.parse("0.1.10"):
|
||||
if version.parse(cai_version) >= version.parse("0.3.1"):
|
||||
from colossalai.zero import GeminiDDP
|
||||
model = GeminiDDP(model, offload_optim_frac=1.0, pin_memory=True)
|
||||
elif version.parse(cai_version) > version.parse("0.1.10"):
|
||||
try:
|
||||
from colossalai.nn.parallel import GeminiDDP
|
||||
except ImportError:
|
||||
|
@ -536,7 +549,6 @@ def main():
|
|||
]
|
||||
|
||||
optimizer = HybridAdam(optimizer_grouped_parameters, lr=args.learning_rate)
|
||||
optimizer = ZeroOptimizer(optimizer, model, initial_scale=2**14)
|
||||
|
||||
# Scheduler and math around the number of training steps.
|
||||
overrode_max_train_steps = False
|
||||
|
@ -551,6 +563,7 @@ def main():
|
|||
num_warmup_steps=args.num_warmup_steps,
|
||||
num_training_steps=args.max_train_steps,
|
||||
)
|
||||
optimizer = GeminiOptimizer(optimizer, model, initial_scale=2**14)
|
||||
|
||||
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
|
||||
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
||||
|
|
|
@ -4,9 +4,9 @@ set -xue
|
|||
|
||||
pip install -r requirements.txt
|
||||
|
||||
BS=8
|
||||
BS=4
|
||||
MEMCAP=0
|
||||
GPUNUM=2
|
||||
GPUNUM=4
|
||||
MODLE="facebook/opt-125m"
|
||||
|
||||
torchrun \
|
||||
|
|
|
@ -4,4 +4,5 @@ markers =
|
|||
gpu: tests which requires a single GPU
|
||||
dist: tests which are run in a multi-GPU or multi-machine environment
|
||||
experiment: tests for experimental features
|
||||
addopts = --ignore=tests/test_analyzer --ignore=tests/test_auto_parallel --ignore=tests/test_autochunk --ignore=tests/test_moe
|
||||
addopts = --ignore=tests/test_analyzer --ignore=tests/test_auto_parallel --ignore=tests/test_autochunk --ignore=tests/test_moe --ignore=tests/test_fx
|
||||
|
||||
|
|
|
@ -17,6 +17,13 @@ def data_gen_fn():
|
|||
return dict(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)
|
||||
|
||||
|
||||
def data_gen_for_pretrain():
|
||||
inputs = data_gen_fn()
|
||||
inputs['labels'] = inputs['input_ids'].clone()
|
||||
inputs['sentence_order_label'] = torch.zeros(BATCH_SIZE, dtype=torch.int64)
|
||||
return inputs
|
||||
|
||||
|
||||
output_transform_fn = lambda x: x
|
||||
|
||||
config = transformers.AlbertConfig(embedding_size=128,
|
||||
|
@ -26,14 +33,14 @@ config = transformers.AlbertConfig(embedding_size=128,
|
|||
intermediate_size=256)
|
||||
|
||||
model_zoo.register(name='transformers_albert',
|
||||
model_fn=lambda: transformers.AlbertModel(config),
|
||||
model_fn=lambda: transformers.AlbertModel(config, add_pooling_layer=False),
|
||||
data_gen_fn=data_gen_fn,
|
||||
output_transform_fn=output_transform_fn,
|
||||
model_attribute=ModelAttribute(has_control_flow=True))
|
||||
model_zoo.register(name='transformers_albert_for_pretraining',
|
||||
model_fn=lambda: transformers.AlbertForPreTraining(config),
|
||||
data_gen_fn=data_gen_fn,
|
||||
output_transform_fn=output_transform_fn,
|
||||
data_gen_fn=data_gen_for_pretrain,
|
||||
output_transform_fn=lambda x: dict(loss=x.loss),
|
||||
model_attribute=ModelAttribute(has_control_flow=True))
|
||||
model_zoo.register(name='transformers_albert_for_masked_lm',
|
||||
model_fn=lambda: transformers.AlbertForMaskedLM(config),
|
||||
|
|
|
@ -113,6 +113,7 @@ def data_gen_for_qa():
|
|||
output_transform_fn = lambda x: x
|
||||
|
||||
# define loss funciton
|
||||
|
||||
loss_fn_for_bert_model = lambda x: torch.nn.functional.mse_loss(x.last_hidden_state, torch.ones_like(x.last_hidden_state
|
||||
))
|
||||
loss_fn = lambda x: x.loss
|
||||
|
@ -126,7 +127,7 @@ config = transformers.BertConfig(hidden_size=128,
|
|||
|
||||
# register the BERT variants
|
||||
model_zoo.register(name='transformers_bert',
|
||||
model_fn=lambda: transformers.BertModel(config),
|
||||
model_fn=lambda: transformers.BertModel(config, add_pooling_layer=False),
|
||||
data_gen_fn=data_gen,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn_for_bert_model,
|
||||
|
|
|
@ -57,6 +57,12 @@ def data_gen_for_sequence_classification():
|
|||
return data
|
||||
|
||||
|
||||
def date_gen_for_double_heads():
|
||||
data = data_gen_for_lm()
|
||||
data['mc_labels'] = torch.zeros(data['input_ids'].shape[0], dtype=torch.int64)
|
||||
return data
|
||||
|
||||
|
||||
# define output transform function
|
||||
output_transform_fn = lambda x: x
|
||||
|
||||
|
@ -94,8 +100,8 @@ model_zoo.register(name='transformers_gpt_lm',
|
|||
model_attribute=ModelAttribute(has_control_flow=True))
|
||||
model_zoo.register(name='transformers_gpt_double_heads',
|
||||
model_fn=lambda: transformers.GPT2DoubleHeadsModel(config),
|
||||
data_gen_fn=data_gen_for_lm,
|
||||
output_transform_fn=output_transform_fn,
|
||||
data_gen_fn=date_gen_for_double_heads,
|
||||
output_transform_fn=lambda x: dict(loss=x.loss + x.mc_loss),
|
||||
loss_fn=loss_fn,
|
||||
model_attribute=ModelAttribute(has_control_flow=True))
|
||||
model_zoo.register(name='transformers_gpt_for_question_answering',
|
||||
|
|
|
@ -12,19 +12,16 @@ from colossalai.lazy.lazy_init import LazyInitContext
|
|||
from colossalai.nn.optimizer import HybridAdam
|
||||
from colossalai.tensor.colo_parameter import ColoParameter
|
||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
||||
from colossalai.zero import ColoInitContext
|
||||
from tests.kit.model_zoo import model_zoo
|
||||
|
||||
|
||||
def run_fn(init_method, model_fn, data_gen_fn, output_transform_fn) -> Optional[str]:
|
||||
try:
|
||||
if init_method == 'colo':
|
||||
ctx = ColoInitContext()
|
||||
elif init_method == 'lazy':
|
||||
if init_method == 'lazy':
|
||||
ctx = LazyInitContext()
|
||||
else:
|
||||
ctx = nullcontext()
|
||||
plugin = GeminiPlugin(placement_policy='cuda', strict_ddp_mode=True, max_norm=1.0, initial_scale=2**5)
|
||||
plugin = GeminiPlugin(max_norm=1.0, initial_scale=2**5)
|
||||
booster = Booster(plugin=plugin)
|
||||
with ctx:
|
||||
model = model_fn()
|
||||
|
@ -50,6 +47,7 @@ def run_fn(init_method, model_fn, data_gen_fn, output_transform_fn) -> Optional[
|
|||
optimizer.step()
|
||||
|
||||
except Exception as e:
|
||||
# raise e
|
||||
return repr(e)
|
||||
|
||||
|
||||
|
@ -57,8 +55,9 @@ def run_fn(init_method, model_fn, data_gen_fn, output_transform_fn) -> Optional[
|
|||
# @parameterize('init_method', ['lazy', 'none', 'colo'])
|
||||
|
||||
|
||||
@parameterize('subset', ['torchvision', 'transformers', 'diffusers'])
|
||||
@parameterize('init_method', ['none'])
|
||||
def check_gemini_plugin(init_method: str = 'none', early_stop: bool = True):
|
||||
def check_gemini_plugin(subset: str, init_method: str = 'none', early_stop: bool = True):
|
||||
"""check gemini plugin over model zoo
|
||||
|
||||
Args:
|
||||
|
@ -71,29 +70,23 @@ def check_gemini_plugin(init_method: str = 'none', early_stop: bool = True):
|
|||
passed_models = []
|
||||
failed_info = {} # (model_name, error) pair
|
||||
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, _, _) in model_zoo.items():
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, _, _) in model_zoo.get_sub_registry(subset).items():
|
||||
# These models lead to CUDA error
|
||||
if name in ('diffusers_auto_encoder_kl', 'diffusers_vq_model', 'diffusers_unet2d_model', 'timm_resmlp',
|
||||
'timm_gmixer_12_224', 'timm_gmlp_b16_224', 'timm_mixer_b16_224', 'timm_convnext'):
|
||||
'timm_gmixer_12_224', 'timm_gmlp_b16_224', 'timm_mixer_b16_224', 'timm_convnext',
|
||||
'torchvision_convnext_base'):
|
||||
continue
|
||||
# These models are not compatible with gemini
|
||||
if name in [
|
||||
'diffusers_clip_vision_model', 'timm_resnet', 'timm_beit', 'timm_beitv2', 'timm_eca_nfnet',
|
||||
'timm_efficientformer', 'timm_hrnet_w18_small', 'timm_nf_ecaresnet101', 'timm_nf_regnet_b0',
|
||||
'timm_skresnet18', 'timm_wide_resnet50_2', 'timm_convit', 'timm_dm_nfnet', 'timm_swin_transformer',
|
||||
'torchaudio_conformer', 'torchaudio_deepspeech', 'torchaudio_wavernn', 'torchaudio_tacotron',
|
||||
'deepfm_interactionarch', 'deepfm_simpledeepfmnn', 'dlrm', 'dlrm_interactionarch',
|
||||
'torchvision_googlenet', 'torchvision_inception_v3', 'torchvision_mobilenet_v3_small',
|
||||
'torchvision_resnet18', 'torchvision_resnext50_32x4d', 'torchvision_wide_resnet50_2',
|
||||
'torchvision_vit_b_16', 'torchvision_convnext_base', 'torchvision_swin_s', 'transformers_albert',
|
||||
'transformers_albert_for_pretraining', 'transformers_bert', 'transformers_bert_for_pretraining',
|
||||
'transformers_gpt_double_heads', 'torchaudio_hubert_base', 'torchaudio_wav2vec2_base',
|
||||
'transformers_t5_for_conditional_generation', 'transformers_t5', 'transformers_t5_encoder_model',
|
||||
'transformers_vit', 'transformers_vit_for_masked_image_modeling',
|
||||
'transformers_vit_for_image_classification', 'transformers_chatglm',
|
||||
'transformers_chatglm_for_conditional_generation', 'transformers_blip2',
|
||||
'transformers_blip2_conditional_gerneration', 'transformers_sam', 'transformers_whisper',
|
||||
'transformers_whisper_for_conditional_generation', 'transformers_whisper_for_audio_classification'
|
||||
'timm_convit',
|
||||
'timm_dm_nfnet',
|
||||
'torchvision_vit_b_16',
|
||||
'transformers_t5',
|
||||
'transformers_t5_for_conditional_generation',
|
||||
'transformers_t5_encoder_model', # does not support apex rmsnorm
|
||||
'transformers_chatglm',
|
||||
'transformers_sam',
|
||||
'transformers_vit'
|
||||
]:
|
||||
continue
|
||||
|
||||
|
@ -105,7 +98,6 @@ def check_gemini_plugin(init_method: str = 'none', early_stop: bool = True):
|
|||
]:
|
||||
continue
|
||||
err = run_fn(init_method, model_fn, data_gen_fn, output_transform_fn)
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
if err is None:
|
||||
passed_models.append(name)
|
||||
|
|
|
@ -18,12 +18,45 @@ from colossalai.testing import (
|
|||
)
|
||||
from tests.kit.model_zoo import model_zoo
|
||||
|
||||
MODEL_PLACEMENT_CONFIGS = [
|
||||
{
|
||||
'placement_policy': 'static',
|
||||
'shard_param_frac': 0.0
|
||||
}, # zero2
|
||||
{
|
||||
'placement_policy': 'static',
|
||||
'shard_param_frac': 1.0
|
||||
}, # zero3
|
||||
{
|
||||
'placement_policy': 'static',
|
||||
'shard_param_frac': 0.5
|
||||
}, # zero3-half
|
||||
]
|
||||
|
||||
OPTIM_PLACEMENT_CONFIGS = [
|
||||
{
|
||||
'placement_policy': 'static',
|
||||
'shard_param_frac': 0.0,
|
||||
'offload_optim_frac': 0.0
|
||||
}, # zero2
|
||||
{
|
||||
'placement_policy': 'static',
|
||||
'shard_param_frac': 0.0,
|
||||
'offload_optim_frac': 1.0
|
||||
}, # zero2-offload
|
||||
{
|
||||
'placement_policy': 'static',
|
||||
'shard_param_frac': 0.0,
|
||||
'offload_optim_frac': 0.5
|
||||
}, # zero2-offload-half
|
||||
]
|
||||
|
||||
|
||||
@clear_cache_before_run()
|
||||
@parameterize('placement_policy', ['cuda', 'cpu'])
|
||||
@parameterize('placement_config', MODEL_PLACEMENT_CONFIGS)
|
||||
@parameterize('model_name', ['transformers_bert_for_sequence_classification'])
|
||||
@parameterize('use_safetensors', [False, True])
|
||||
def exam_state_dict_with_origin(placement_policy, model_name, use_safetensors: bool):
|
||||
def exam_state_dict_with_origin(placement_config, model_name, use_safetensors: bool):
|
||||
from transformers import BertForSequenceClassification
|
||||
(model_fn, data_gen_fn, output_transform_fn, _, _) = next(iter(model_zoo.get_sub_registry(model_name).values()))
|
||||
bert_model = model_fn()
|
||||
|
@ -32,7 +65,7 @@ def exam_state_dict_with_origin(placement_policy, model_name, use_safetensors: b
|
|||
pretrained_path = os.path.join(tempdir, 'pretrained')
|
||||
bert_model.config.save_pretrained(save_directory=pretrained_path)
|
||||
|
||||
plugin = GeminiPlugin(placement_policy=placement_policy)
|
||||
plugin = GeminiPlugin(**placement_config)
|
||||
booster = Booster(plugin=plugin)
|
||||
bert_model, _, _, _, _ = booster.boost(bert_model)
|
||||
model_size = sum(p.numel() * p.element_size() for p in bert_model.parameters()) / 1024**2
|
||||
|
@ -46,19 +79,19 @@ def exam_state_dict_with_origin(placement_policy, model_name, use_safetensors: b
|
|||
dist.barrier()
|
||||
|
||||
new_bert_model = BertForSequenceClassification.from_pretrained(pretrained_path)
|
||||
check_state_dict_equal(bert_model.unwrap().state_dict(only_rank_0=False, dtype=torch.float32),
|
||||
check_state_dict_equal(bert_model.state_dict(only_rank_0=False, dtype=torch.float32),
|
||||
new_bert_model.state_dict(), False)
|
||||
|
||||
|
||||
@clear_cache_before_run()
|
||||
@parameterize('placement_policy', ['cuda', 'cpu'])
|
||||
@parameterize('placement_config', OPTIM_PLACEMENT_CONFIGS)
|
||||
@parameterize('shard', [False, True])
|
||||
@parameterize('model_name', ['transformers_gpt'])
|
||||
@parameterize('size_per_shard', [32])
|
||||
def exam_state_dict(placement_policy, shard: bool, model_name: str, size_per_shard: int):
|
||||
def exam_state_dict(placement_config, shard: bool, model_name: str, size_per_shard: int):
|
||||
(model_fn, data_gen_fn, output_transform_fn, _, _) = next(iter(model_zoo.get_sub_registry(model_name).values()))
|
||||
criterion = lambda x: x.mean()
|
||||
plugin = GeminiPlugin(placement_policy=placement_policy, precision="fp16", initial_scale=(2**14))
|
||||
plugin = GeminiPlugin(**placement_config, precision="fp16", initial_scale=(2**14))
|
||||
booster = Booster(plugin=plugin)
|
||||
|
||||
model = model_fn()
|
||||
|
@ -87,12 +120,11 @@ def exam_state_dict(placement_policy, shard: bool, model_name: str, size_per_sha
|
|||
dist.barrier()
|
||||
|
||||
booster.load_model(new_model, model_ckpt_path)
|
||||
check_state_dict_equal(model.unwrap().state_dict(only_rank_0=False),
|
||||
new_model.unwrap().state_dict(only_rank_0=False), False)
|
||||
check_state_dict_equal(model.state_dict(only_rank_0=False), new_model.state_dict(only_rank_0=False), False)
|
||||
|
||||
booster.load_optimizer(new_optimizer, optimizer_ckpt_path)
|
||||
check_state_dict_equal(optimizer.unwrap().state_dict(only_rank_0=False),
|
||||
new_optimizer.unwrap().state_dict(only_rank_0=False), False)
|
||||
check_state_dict_equal(optimizer.state_dict(only_rank_0=False), new_optimizer.state_dict(only_rank_0=False),
|
||||
False)
|
||||
|
||||
# Check the new model/optimizer can successfully run.
|
||||
data = data_gen_fn()
|
||||
|
|
|
@ -60,12 +60,11 @@ def exam_torch_load_from_gemini(shard: bool, model_name: str):
|
|||
new_booster.load_model(new_model, model_ckpt_path, strict=True)
|
||||
|
||||
# Add prefix to get aligned with pytorch parameter names.
|
||||
check_state_dict_equal(
|
||||
model.unwrap().state_dict(only_rank_0=False, prefix='module.module.', dtype=torch.float32),
|
||||
new_model.state_dict(), False)
|
||||
check_state_dict_equal(model.state_dict(only_rank_0=False, prefix='module.module.', dtype=torch.float32),
|
||||
new_model.state_dict(), False)
|
||||
|
||||
new_booster.load_optimizer(new_optimizer, optimizer_ckpt_path)
|
||||
check_state_dict_equal(optimizer.unwrap().state_dict(only_rank_0=False), new_optimizer.state_dict(), False)
|
||||
check_state_dict_equal(optimizer.state_dict(only_rank_0=False), new_optimizer.state_dict(), False)
|
||||
|
||||
# Check the new model/optimizer can successfully run.
|
||||
data = data_gen_fn()
|
||||
|
@ -124,13 +123,12 @@ def exam_gemini_load_from_torch(shard: bool, model_name: str):
|
|||
new_booster.load_model(new_model, model_ckpt_path, strict=True)
|
||||
|
||||
# Add prefix to get aligned with pytorch parameter names.
|
||||
check_state_dict_equal(
|
||||
new_model.unwrap().state_dict(only_rank_0=False, prefix='module.module.', dtype=torch.float32),
|
||||
model.state_dict(), False)
|
||||
check_state_dict_equal(new_model.state_dict(only_rank_0=False, prefix='module.module.', dtype=torch.float32),
|
||||
model.state_dict(), False)
|
||||
|
||||
new_booster.load_optimizer(new_optimizer, optimizer_ckpt_path)
|
||||
old_state_dict = optimizer.state_dict()
|
||||
new_state_dict = new_optimizer.unwrap().state_dict(only_rank_0=False)
|
||||
new_state_dict = new_optimizer.state_dict(only_rank_0=False)
|
||||
|
||||
# Comparison of param_groups needs special care here,
|
||||
# since not all hyperparameters in Adam are used by HybridAdam
|
||||
|
@ -138,7 +136,7 @@ def exam_gemini_load_from_torch(shard: bool, model_name: str):
|
|||
for old_group, new_group in zip(old_state_dict['param_groups'], new_state_dict['param_groups']):
|
||||
for k in hyperparameters_to_examine:
|
||||
assert k in old_group and k in new_group, \
|
||||
f"Old group's keys: {list(old_group.keys())}, New group's keys: {list(new_group.keys())}"
|
||||
f"Old group's keys: {list(old_group.keys())}, New group's keys: {list(new_group.keys())}"
|
||||
assert old_group[k] == new_group[k]
|
||||
check_state_dict_equal(old_state_dict['state'], new_state_dict['state'], False)
|
||||
|
||||
|
|
|
@ -1,104 +0,0 @@
|
|||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from torchvision import transforms
|
||||
from torchvision.datasets import CIFAR10
|
||||
|
||||
import colossalai
|
||||
from colossalai.amp import AMP_TYPE
|
||||
from colossalai.context import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.engine.schedule._pipeline_schedule_v2 import PipelineScheduleV2
|
||||
from colossalai.logging import disable_existing_loggers, get_dist_logger
|
||||
from colossalai.nn import CrossEntropyLoss
|
||||
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
|
||||
from colossalai.pipeline.pipelinable import PipelinableContext
|
||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||
from colossalai.trainer import Trainer, hooks
|
||||
from colossalai.utils import get_dataloader
|
||||
|
||||
disable_existing_loggers()
|
||||
BATCH_SIZE = 4
|
||||
NUM_EPOCHS = 10
|
||||
WARMUP_EPOCHS = 5
|
||||
CONFIG = dict(NUM_MICRO_BATCHES=2,
|
||||
parallel=dict(pipeline=2, tensor=dict(size=1, mode='1d')),
|
||||
fp16=dict(mode=AMP_TYPE.NAIVE),
|
||||
gradient_accumulation=2)
|
||||
|
||||
|
||||
def run_trainer(rank, world_size, port):
|
||||
disable_existing_loggers()
|
||||
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
|
||||
disable_existing_loggers()
|
||||
# get logger
|
||||
logger = get_dist_logger()
|
||||
|
||||
pipelinable = PipelinableContext()
|
||||
try:
|
||||
from titans.model.vit import vit_tiny_patch4_32
|
||||
except ImportError:
|
||||
logger.warning('skip the test_cifar_with_data_pipeline_tensor test because titan is not installed')
|
||||
logger.warning('please install titan from https://github.com/hpcaitech/Titans')
|
||||
return
|
||||
with pipelinable:
|
||||
model = vit_tiny_patch4_32()
|
||||
pipelinable.to_layer_list()
|
||||
pipelinable.policy = "uniform"
|
||||
model = pipelinable.partition(1, gpc.pipeline_parallel_size, gpc.get_local_rank(ParallelMode.PIPELINE))
|
||||
|
||||
# create dataloaders
|
||||
root = Path(os.environ['DATA'])
|
||||
transform_train = transforms.Compose([
|
||||
transforms.RandomCrop(32, padding=4, pad_if_needed=True),
|
||||
transforms.AutoAugment(policy=transforms.AutoAugmentPolicy.CIFAR10),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
|
||||
])
|
||||
train_dataset = CIFAR10(root=root, train=True, download=True, transform=transform_train)
|
||||
train_dataloader = get_dataloader(dataset=train_dataset, shuffle=True, batch_size=BATCH_SIZE, pin_memory=True)
|
||||
|
||||
# create loss function
|
||||
criterion = CrossEntropyLoss(label_smoothing=0.1)
|
||||
|
||||
# create optimizer
|
||||
optimizer = torch.optim.AdamW(model.parameters(), lr=0.001, weight_decay=0)
|
||||
|
||||
# create lr scheduler
|
||||
lr_scheduler = CosineAnnealingWarmupLR(optimizer=optimizer, total_steps=NUM_EPOCHS, warmup_steps=WARMUP_EPOCHS)
|
||||
|
||||
# initialize
|
||||
engine, train_dataloader, *_ = colossalai.initialize(model=model,
|
||||
optimizer=optimizer,
|
||||
criterion=criterion,
|
||||
train_dataloader=train_dataloader)
|
||||
|
||||
engine._schedule = PipelineScheduleV2(num_microbatches=gpc.config.NUM_MICRO_BATCHES)
|
||||
|
||||
logger = get_dist_logger()
|
||||
|
||||
trainer = Trainer(engine=engine, logger=logger)
|
||||
|
||||
hook_list = [
|
||||
hooks.LRSchedulerHook(lr_scheduler=lr_scheduler, by_epoch=False),
|
||||
]
|
||||
|
||||
trainer.fit(train_dataloader=train_dataloader,
|
||||
max_steps=2,
|
||||
epochs=NUM_EPOCHS,
|
||||
hooks=hook_list,
|
||||
display_progress=True)
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_hybrid_parallel():
|
||||
spawn(run_trainer, 2)
|
||||
disable_existing_loggers()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_hybrid_parallel()
|
|
@ -1,92 +0,0 @@
|
|||
import os
|
||||
import random
|
||||
from typing import Callable, Type
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
import colossalai
|
||||
from colossalai.nn.parallel import ColoDDP
|
||||
from colossalai.tensor import ProcessGroup
|
||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||
from colossalai.utils.cuda import get_current_device
|
||||
from colossalai.zero import ColoInitContext, ZeroDDP
|
||||
from colossalai.zero.gemini.chunk import ChunkManager, search_chunk_configuration
|
||||
from colossalai.zero.gemini.gemini_mgr import GeminiManager
|
||||
|
||||
|
||||
def set_seed(seed):
|
||||
random.seed(seed)
|
||||
os.environ['PYTHONHASHSEED'] = str(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
torch.backends.cudnn.deterministic = True
|
||||
|
||||
|
||||
def init_ddp(module: torch.nn.Module) -> ColoDDP:
|
||||
pg = ProcessGroup()
|
||||
return ColoDDP(module, process_group=pg)
|
||||
|
||||
|
||||
def init_ddpv2(module: torch.nn.Module) -> ZeroDDP:
|
||||
chunk_config, *_ = search_chunk_configuration(module, 4, 1024)
|
||||
chunk_manager = ChunkManager(chunk_config)
|
||||
gemini_manager = GeminiManager('cuda', chunk_manager)
|
||||
return ZeroDDP(module, gemini_manager)
|
||||
|
||||
|
||||
class Net(torch.nn.Module):
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.fc1 = torch.nn.Linear(3, 3, bias=False)
|
||||
self.fc2 = torch.nn.Linear(3, 1, bias=False)
|
||||
|
||||
def forward(self, x):
|
||||
return self.fc2(self.fc1(x))
|
||||
|
||||
|
||||
def run_fwd_bwd(ddp_cls: Type[ColoDDP], init_ddp_func: Callable[[torch.nn.Module], ColoDDP]):
|
||||
with ColoInitContext(device=get_current_device()):
|
||||
model = Net().cuda()
|
||||
w1 = model.fc1.weight
|
||||
w2 = model.fc2.weight
|
||||
ddp_cls.set_params_to_ignore([w2])
|
||||
model = init_ddp_func(model)
|
||||
x = torch.rand(2, 3, device=get_current_device())
|
||||
logits = model(x)
|
||||
loss = torch.sum(logits)
|
||||
model.backward(loss)
|
||||
|
||||
if ddp_cls is ZeroDDP:
|
||||
w1s_grad = w1
|
||||
else:
|
||||
w1s_grad = w1.grad
|
||||
|
||||
w1_grads = [torch.empty_like(w1) for _ in range(dist.get_world_size())]
|
||||
dist.all_gather(w1_grads, w1s_grad)
|
||||
assert torch.equal(w1_grads[0], w1_grads[1])
|
||||
w2_grads = [torch.empty_like(w2) for _ in range(dist.get_world_size())]
|
||||
dist.all_gather(w2_grads, w2.grad)
|
||||
assert not torch.equal(w2_grads[0], w2_grads[1])
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
set_seed(dist.get_rank())
|
||||
run_fwd_bwd(ColoDDP, init_ddp)
|
||||
run_fwd_bwd(ZeroDDP, init_ddpv2)
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize('world_size', [2])
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_ddp_ignore_params(world_size):
|
||||
spawn(run_dist, world_size)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_ddp_ignore_params(2)
|
|
@ -1,67 +0,0 @@
|
|||
from collections import OrderedDict
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
import colossalai
|
||||
from colossalai.nn.parallel import ColoDDP
|
||||
from colossalai.tensor import ColoParameter, ProcessGroup
|
||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||
from colossalai.utils.cuda import get_current_device
|
||||
from colossalai.zero import ColoInitContext
|
||||
from tests.components_to_test.registry import non_distributed_component_funcs
|
||||
|
||||
|
||||
def check_state_dict_equal(state_dict: OrderedDict, other_state_dict: OrderedDict):
|
||||
for (k1, t1), (k2, t2) in zip(state_dict.items(), other_state_dict.items()):
|
||||
assert k1 == k2
|
||||
|
||||
if t1.device != t2.device:
|
||||
temp_t2 = t2.to(t1.device)
|
||||
else:
|
||||
temp_t2 = t2
|
||||
|
||||
assert torch.equal(t1, temp_t2), "\t{}\n\t{}".format(t1, temp_t2)
|
||||
|
||||
|
||||
def init_ddp(module: torch.nn.Module) -> ColoDDP:
|
||||
pg = ProcessGroup()
|
||||
return ColoDDP(module, process_group=pg)
|
||||
|
||||
|
||||
def run_ddp_state_dict():
|
||||
get_components_func = non_distributed_component_funcs.get_callable('gpt2')
|
||||
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
|
||||
torch_model = model_builder().cuda()
|
||||
with ColoInitContext(device=get_current_device()):
|
||||
model = model_builder()
|
||||
model = init_ddp(model)
|
||||
torch_state_dict = torch_model.state_dict()
|
||||
|
||||
for param in model.parameters():
|
||||
if isinstance(param, ColoParameter):
|
||||
assert param.get_process_group() is not None
|
||||
model.load_state_dict(torch_state_dict)
|
||||
|
||||
for param in model.parameters():
|
||||
if isinstance(param, ColoParameter):
|
||||
assert param.get_process_group() is not None
|
||||
|
||||
state_dict = model.state_dict()
|
||||
check_state_dict_equal(torch_state_dict, state_dict)
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
run_ddp_state_dict()
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize('world_size', [1, 2])
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_state_dict(world_size):
|
||||
spawn(run_dist, world_size)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_state_dict(2)
|
|
@ -1,47 +0,0 @@
|
|||
from functools import partial
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.distributed.distributed_c10d import _get_default_group
|
||||
|
||||
import colossalai
|
||||
from colossalai.nn.parallel.reducer import Reducer
|
||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||
from colossalai.utils.cuda import get_current_device
|
||||
|
||||
REDUCE_CNT = 0
|
||||
|
||||
|
||||
def check_eq(grad, grad_clone):
|
||||
global REDUCE_CNT
|
||||
print(f'Rank{dist.get_rank()} check {REDUCE_CNT}')
|
||||
REDUCE_CNT += 1
|
||||
assert torch.allclose(grad, grad_clone)
|
||||
|
||||
|
||||
def run_reducer():
|
||||
grads = [torch.rand(64, i + 1, device=get_current_device()) for i in range(10)]
|
||||
grads_clone = [g.clone().detach() for g in grads]
|
||||
for g in grads:
|
||||
dist.all_reduce(g)
|
||||
reducer = Reducer(bucket_size_mb=1)
|
||||
for g, g_clone in zip(grads, grads_clone):
|
||||
reducer.all_reduce_async(g_clone, _get_default_group(), partial(check_eq, g))
|
||||
reducer.flush()
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
run_reducer()
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize('world_size', [1, 2])
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_reducer(world_size):
|
||||
spawn(run_dist, world_size)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_reducer(2)
|
|
@ -1,73 +0,0 @@
|
|||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
import colossalai
|
||||
from colossalai.tensor import ColoTensor, ColoTensorSpec, ProcessGroup
|
||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||
from tests.test_tensor.common_utils import split_param_col_tp1d, split_param_row_tp1d, tensor_equal, tensor_shard_equal
|
||||
|
||||
|
||||
class Conv1D(nn.Module):
|
||||
"""
|
||||
1D-convolutional layer as defined by Radford et al. for OpenAI GPT (and also used in GPT-2).
|
||||
Basically works like a linear layer but the weights are transposed.
|
||||
Args:
|
||||
nf (`int`): The number of output features.
|
||||
nx (`int`): The number of input features.
|
||||
"""
|
||||
|
||||
def __init__(self, nf, nx):
|
||||
super().__init__()
|
||||
self.nf = nf
|
||||
w = torch.empty(nx, nf)
|
||||
nn.init.normal_(w, std=0.02)
|
||||
self.weight = nn.Parameter(w)
|
||||
self.bias = nn.Parameter(torch.ones(nf))
|
||||
|
||||
def forward(self, x):
|
||||
size_out = x.size()[:-1] + (self.nf,)
|
||||
x = torch.addmm(self.bias, x.view(-1, x.size(-1)), self.weight)
|
||||
x = x.view(size_out)
|
||||
return x
|
||||
|
||||
|
||||
def run_with_spec(spec_init_func, split_bias):
|
||||
model = Conv1D(4, 16).cuda()
|
||||
world_size = torch.distributed.get_world_size()
|
||||
pg = ProcessGroup(tp_degree=world_size)
|
||||
|
||||
weight = ColoTensor(torch.nn.Parameter(model.weight.detach()), ColoTensorSpec(pg))
|
||||
bias = ColoTensor(torch.nn.Parameter(model.bias.detach()), ColoTensorSpec(pg))
|
||||
|
||||
spec_init_func(weight, pg)
|
||||
if split_bias:
|
||||
spec_init_func(bias, pg)
|
||||
|
||||
x = torch.rand(2, 16).cuda()
|
||||
out = model(x)
|
||||
colo_out = torch.addmm(bias, x, weight)
|
||||
colo_out = colo_out.to_replicate()
|
||||
assert tensor_equal(out, colo_out)
|
||||
grad = torch.rand_like(out)
|
||||
out.backward(grad)
|
||||
colo_out.backward(grad)
|
||||
tensor_shard_equal(model.weight.grad, weight.grad, pg.tp_local_rank(), pg.tp_world_size())
|
||||
tensor_shard_equal(model.bias.grad, bias.grad, pg.tp_local_rank(), pg.tp_world_size())
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
run_with_spec(spec_init_func=split_param_row_tp1d, split_bias=False)
|
||||
run_with_spec(spec_init_func=split_param_col_tp1d, split_bias=True)
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize('world_size', [1, 4])
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_addmm_1d(world_size):
|
||||
spawn(run_dist, world_size)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_addmm_1d(4)
|
|
@ -1,43 +0,0 @@
|
|||
import pytest
|
||||
import torch
|
||||
from torch.nn import functional as F
|
||||
|
||||
import colossalai
|
||||
from colossalai.tensor import ColoParameter, ColoTensorSpec, ProcessGroup
|
||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||
from tests.test_tensor.common_utils import split_param_col_tp1d, tensor_equal, tensor_shard_equal
|
||||
|
||||
|
||||
def run_with_spec(spec_init_func):
|
||||
pg = ProcessGroup(tp_degree=torch.distributed.get_world_size())
|
||||
model = torch.nn.EmbeddingBag(10, 4).cuda()
|
||||
weight = ColoParameter(model.weight.clone(), True, ColoTensorSpec(pg))
|
||||
|
||||
spec_init_func(weight, pg)
|
||||
|
||||
inputs = torch.tensor([1, 2, 4, 5, 4, 3, 2, 9]).cuda()
|
||||
offsets = torch.tensor([0, 4]).cuda()
|
||||
out = model(inputs, offsets=offsets)
|
||||
colo_out = F.embedding_bag(inputs, weight, offsets=offsets)
|
||||
assert tensor_equal(out, colo_out)
|
||||
grad = torch.rand_like(out)
|
||||
out.backward(grad)
|
||||
colo_out.backward(grad)
|
||||
assert tensor_shard_equal(model.weight.grad, weight.grad, pg.tp_local_rank(), pg.tp_world_size())
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
config = dict(parallel=dict(tensor=dict(mode="1d", size=world_size),))
|
||||
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
run_with_spec(split_param_col_tp1d)
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize('world_size', [1, 4])
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_embedding_bag_1d(world_size):
|
||||
spawn(run_dist, world_size)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_embedding_bag_1d(4)
|
|
@ -1,44 +0,0 @@
|
|||
import pytest
|
||||
import torch
|
||||
from torch.nn import functional as F
|
||||
|
||||
import colossalai
|
||||
from colossalai.tensor import ColoTensor, ColoTensorSpec, ProcessGroup
|
||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||
from tests.test_tensor.common_utils import split_param_col_tp1d, split_param_row_tp1d, tensor_equal, tensor_shard_equal
|
||||
|
||||
|
||||
def run_with_spec(spec_init_func, pg: ProcessGroup):
|
||||
model = torch.nn.Embedding(12, 32).cuda()
|
||||
weight = ColoTensor(torch.nn.Parameter(model.weight.detach()), ColoTensorSpec(pg))
|
||||
|
||||
spec_init_func(weight, pg)
|
||||
|
||||
x = torch.tensor((0, 3, 6, 9)).cuda()
|
||||
out = model(x)
|
||||
colo_out = F.embedding(x, weight)
|
||||
assert tensor_equal(out, colo_out)
|
||||
grad = torch.rand_like(out)
|
||||
out.backward(grad)
|
||||
colo_out.backward(grad)
|
||||
# compare grad inside a TP group
|
||||
assert tensor_shard_equal(model.weight.grad, weight.grad, pg.tp_local_rank(), pg.tp_world_size())
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
# config = dict(parallel=dict(tensor=dict(mode="1d", size=world_size),))
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
pg = ProcessGroup(tp_degree=world_size)
|
||||
run_with_spec(split_param_row_tp1d, pg)
|
||||
run_with_spec(split_param_col_tp1d, pg)
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize('world_size', [1, 4])
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_embedding_1d(world_size):
|
||||
spawn(run_dist, world_size)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_embedding_1d(4)
|
|
@ -1,48 +0,0 @@
|
|||
import pytest
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
import colossalai
|
||||
from colossalai.tensor import ColoTensor, ColoTensorSpec, ProcessGroup
|
||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||
from tests.test_tensor.common_utils import split_param_col_tp1d, split_param_row_tp1d, tensor_equal, tensor_shard_equal
|
||||
|
||||
|
||||
def run_with_spec(spec_init_func, split_bias):
|
||||
pg = ProcessGroup(tp_degree=torch.distributed.get_world_size())
|
||||
model = torch.nn.Linear(4, 8).cuda()
|
||||
weight = ColoTensor(torch.nn.Parameter(model.weight.detach()), ColoTensorSpec(pg))
|
||||
bias = ColoTensor(torch.nn.Parameter(model.bias.detach()), ColoTensorSpec(pg))
|
||||
|
||||
spec_init_func(weight, pg)
|
||||
if split_bias:
|
||||
spec_init_func(bias, pg)
|
||||
|
||||
x = torch.rand(2, 4).cuda()
|
||||
out = model(x)
|
||||
colo_out = F.linear(x, weight, bias)
|
||||
colo_out = colo_out.to_replicate()
|
||||
assert tensor_equal(out, colo_out)
|
||||
grad = torch.rand_like(out)
|
||||
out.backward(grad)
|
||||
colo_out.backward(grad)
|
||||
assert tensor_shard_equal(model.weight.grad, weight.grad, pg.tp_local_rank(), pg.tp_world_size())
|
||||
assert tensor_shard_equal(model.bias.grad, bias.grad, pg.tp_local_rank(), pg.tp_world_size())
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
config = dict(parallel=dict(tensor=dict(mode="1d", size=world_size),))
|
||||
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
run_with_spec(spec_init_func=split_param_col_tp1d, split_bias=False)
|
||||
run_with_spec(spec_init_func=split_param_row_tp1d, split_bias=True)
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize('world_size', [1, 4])
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_linear_1d(world_size):
|
||||
spawn(run_dist, world_size)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_linear_1d(4)
|
|
@ -1,48 +0,0 @@
|
|||
import pytest
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
import colossalai
|
||||
from colossalai.tensor import ColoTensor, ColoTensorSpec, ComputePattern, ComputeSpec, ProcessGroup, ShardSpec
|
||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
|
||||
def check_cross_entropy():
|
||||
input_t = torch.randn(4, 4, device=get_current_device(), requires_grad=True)
|
||||
input_ct = torch.randn(4, 4, device=get_current_device(), requires_grad=True)
|
||||
with torch.no_grad():
|
||||
input_ct.copy_(input_t)
|
||||
|
||||
target = torch.randint(4, (4,), dtype=torch.int64, device=get_current_device())
|
||||
|
||||
world_size = torch.distributed.get_world_size()
|
||||
pg = ProcessGroup(tp_degree=world_size)
|
||||
input_t_colo = ColoTensor.from_torch_tensor(tensor=input_ct, spec=ColoTensorSpec(pg))
|
||||
input_shard = input_t_colo.redistribute(ShardSpec([-1], [pg.tp_world_size()]))
|
||||
input_shard.set_tensor_spec(dist_spec=None, compute_spec=ComputeSpec(ComputePattern.TP1D))
|
||||
|
||||
output = F.cross_entropy(input_t, target)
|
||||
output_colo = F.cross_entropy(input_shard, target)
|
||||
assert torch.allclose(output_colo, output)
|
||||
|
||||
output.backward()
|
||||
output_colo.backward()
|
||||
|
||||
assert torch.allclose(input_t.grad, input_ct.grad)
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
check_cross_entropy()
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize('world_size', [1, 2])
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_loss_func(world_size):
|
||||
spawn(run_dist, world_size)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_loss_func(1)
|
|
@ -1,87 +0,0 @@
|
|||
import pytest
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.nn import Parameter
|
||||
|
||||
import colossalai
|
||||
from colossalai.tensor import ColoTensor, ColoTensorSpec, ProcessGroup, ShardSpec
|
||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
|
||||
def _run_layer_norm():
|
||||
ln_op = torch.nn.LayerNorm(2, 3, device=get_current_device())
|
||||
|
||||
input_t = torch.randn(3, 2, device=get_current_device())
|
||||
|
||||
pg = ProcessGroup(tp_degree=torch.distributed.get_world_size())
|
||||
input_t_colo = ColoTensor.from_torch_tensor(input_t.clone().detach(), ColoTensorSpec(pg))
|
||||
|
||||
# prepare colossalai LN
|
||||
weight = ColoTensor(Parameter(ln_op.weight.detach()), ColoTensorSpec(pg))
|
||||
bias = ColoTensor(Parameter(ln_op.bias.detach()), ColoTensorSpec(pg))
|
||||
|
||||
output = ln_op(input_t)
|
||||
output_colo = F.layer_norm(input_t_colo, ln_op.normalized_shape, weight, bias, ln_op.eps)
|
||||
|
||||
assert torch.allclose(output_colo, output)
|
||||
|
||||
torch.mean(output).backward()
|
||||
torch.mean(output_colo).backward()
|
||||
|
||||
assert torch.allclose(ln_op.weight.grad, weight.grad)
|
||||
|
||||
|
||||
def check_spec_eq(tensor, other):
|
||||
assert isinstance(tensor, ColoTensor) and isinstance(other, ColoTensor)
|
||||
for k in dir(tensor.dist_spec):
|
||||
if not k.startswith('__'):
|
||||
assert hasattr(other.dist_spec, k), f"{k}"
|
||||
assert getattr(tensor.dist_spec, k) == getattr(other.dist_spec, k)
|
||||
|
||||
|
||||
def check_element_wise_ops():
|
||||
world_size = torch.distributed.get_world_size()
|
||||
pg = ProcessGroup(tp_degree=world_size)
|
||||
t = torch.rand(2, 2)
|
||||
x = ColoTensor(t, spec=ColoTensorSpec(pg, ShardSpec([0], [pg.tp_world_size()])))
|
||||
|
||||
check_spec_eq(x, x.cuda())
|
||||
assert torch.equal(x.cuda(), t.cuda())
|
||||
check_spec_eq(x, torch.abs(x))
|
||||
assert torch.equal(torch.abs(x), torch.abs(t))
|
||||
check_spec_eq(x, F.sigmoid(x))
|
||||
assert torch.equal(F.sigmoid(x), F.sigmoid(t))
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
check_element_wise_ops()
|
||||
_run_layer_norm()
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize('world_size', [2])
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_element_wise_ops(world_size):
|
||||
spawn(run_dist, world_size)
|
||||
|
||||
|
||||
def run_dist2(rank, world_size, port):
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
_run_layer_norm()
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize('world_size', [1])
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_ln(world_size):
|
||||
spawn(run_dist2, world_size)
|
||||
|
||||
|
||||
def check_all():
|
||||
test_element_wise_ops(2)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
check_all()
|
|
@ -1,97 +0,0 @@
|
|||
import pytest
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
import colossalai
|
||||
from colossalai.tensor import ColoTensor, ColoTensorSpec, ProcessGroup, ShardSpec
|
||||
from colossalai.tensor.distspec import DistPlacementPattern
|
||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||
from colossalai.utils import get_current_device
|
||||
from tests.test_tensor.common_utils import debug_print, split_param_col_tp1d, split_param_row_tp1d
|
||||
|
||||
|
||||
def exam_view_core(pg):
|
||||
# the case of replicated ColoTensors
|
||||
x = torch.randn(4, 4).cuda()
|
||||
x_colo = ColoTensor(x, ColoTensorSpec(pg))
|
||||
|
||||
y = x.view(2, -1, 2)
|
||||
y_colo = x_colo.view(2, -1, 2)
|
||||
|
||||
assert torch.all(y == y_colo)
|
||||
assert y_colo.dist_spec.placement == DistPlacementPattern.REPLICATE
|
||||
# the perfect case of col-sliced ColoTensors
|
||||
split_param_col_tp1d(x_colo, pg)
|
||||
|
||||
z = x.view(torch.Size((2, 1, 2, -1)))
|
||||
z_colo = x_colo.view(torch.Size((2, 1, 2, -1)))
|
||||
if dist.get_rank() == 0:
|
||||
z = z[:, :, :, 0:2]
|
||||
else:
|
||||
z = z[:, :, :, 2:]
|
||||
assert torch.all(z == z_colo)
|
||||
assert z_colo.dist_spec == x_colo.dist_spec
|
||||
# the perfect case of row-sliced ColoTensors
|
||||
split_param_row_tp1d(x_colo, pg)
|
||||
|
||||
z = x.view(torch.Size((-1, 2, 2)))
|
||||
z_colo = x_colo.view(torch.Size((-1, 2, 2)))
|
||||
if dist.get_rank() == 0:
|
||||
z = z[0:2, :, :]
|
||||
else:
|
||||
z = z[2:, :, :]
|
||||
assert torch.all(z == z_colo)
|
||||
assert z_colo.dist_spec == x_colo.dist_spec
|
||||
# the normal case of row-sliced ColoTensors
|
||||
z = x.view(-1, 2, 2, 2)
|
||||
z_colo = x_colo.view(-1, 2, 2, 2)
|
||||
assert torch.all(z == z_colo)
|
||||
assert y_colo.dist_spec.placement == DistPlacementPattern.REPLICATE
|
||||
|
||||
|
||||
def exam_view_autograd(pg):
|
||||
x = torch.randn(8, 2, device=get_current_device(), requires_grad=True)
|
||||
y = torch.randn(8, 2, device=get_current_device(), requires_grad=True)
|
||||
with torch.no_grad():
|
||||
y.copy_(x)
|
||||
y = ColoTensor(y, ColoTensorSpec(pg))
|
||||
y_slice = y.redistribute(ShardSpec([-1], [pg.tp_world_size()]))
|
||||
|
||||
xx = x.view(2, 2, -1)
|
||||
yy_slice = y_slice.view(2, 2, -1)
|
||||
yy = yy_slice.to_replicate()
|
||||
grad = torch.randn(2, 2, 4, device=get_current_device())
|
||||
|
||||
xx.backward(grad)
|
||||
yy.backward(grad)
|
||||
assert torch.all(x.grad == y.grad)
|
||||
|
||||
|
||||
def exam_view_errors(pg):
|
||||
x = torch.randn(8, 2, device=get_current_device())
|
||||
x = ColoTensor(x, ColoTensorSpec(pg))
|
||||
split_param_row_tp1d(x, pg)
|
||||
|
||||
x.view('a', 'b', 'c')
|
||||
x.view(8, -1)
|
||||
x.view([-2, -2, -2])
|
||||
x.view((-1, -1, -1))
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
colossalai.launch(config=dict(), rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
pg = ProcessGroup(tp_degree=torch.distributed.get_world_size())
|
||||
exam_view_core(pg)
|
||||
exam_view_autograd(pg)
|
||||
# exam_view_errors(pg)
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize('world_size', [2])
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_view(world_size):
|
||||
spawn(run_dist, world_size)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_view(2)
|
|
@ -1,3 +1,4 @@
|
|||
import pytest
|
||||
import torch
|
||||
|
||||
from colossalai.pipeline.pipelinable import PipelinableContext
|
||||
|
@ -48,6 +49,7 @@ def run_pipelinable(rank, world_size, port):
|
|||
assert layers_count_in_part_0 + layers_count_in_part_1 == pipelinable.layers_count
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="this is useless")
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_pipelinable():
|
||||
spawn(run_pipelinable, 1)
|
||||
|
|
|
@ -127,6 +127,10 @@ def check_gpt2(rank, world_size, port):
|
|||
run_gpt2_test()
|
||||
|
||||
|
||||
# TODO(ver217): fix this
|
||||
|
||||
|
||||
@pytest.mark.skip("this will stuck in CI")
|
||||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
@clear_cache_before_run()
|
||||
|
|
|
@ -1,153 +0,0 @@
|
|||
import pytest
|
||||
import torch
|
||||
from numpy import allclose
|
||||
|
||||
import colossalai
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.tensor import ColoTensor, ColoTensorSpec, ProcessGroup, ReplicaSpec, ShardSpec, distspec
|
||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||
|
||||
|
||||
def _run_tensor_indexing():
|
||||
pg = ProcessGroup()
|
||||
torch_t = torch.randn(2, 3)
|
||||
colo_t = ColoTensor(torch_t, ColoTensorSpec(pg))
|
||||
assert allclose(torch_t[:, 1], colo_t[:, 1])
|
||||
|
||||
|
||||
def _run_wrapped_tensor_func():
|
||||
pg = ProcessGroup()
|
||||
t_ref = torch.randn(4, 5)
|
||||
t = ColoTensor.from_torch_tensor(t_ref.clone(), ColoTensorSpec(pg))
|
||||
|
||||
# non-func attr
|
||||
assert t.is_cuda == t_ref.is_cuda
|
||||
|
||||
# return 1 torch.Tensor
|
||||
t_abs = t.abs()
|
||||
assert isinstance(t_abs, ColoTensor) and torch.equal(t_abs, t_ref.abs())
|
||||
|
||||
# return 1 non-torch.Tensor
|
||||
assert t.dim() == t_ref.dim()
|
||||
|
||||
# return >1 torch.Tensor
|
||||
assert isinstance(t, ColoTensor)
|
||||
t_split1, t_split2 = t.split(2)
|
||||
assert isinstance(t_split1, ColoTensor) and isinstance(t_split2, ColoTensor), f"{type(t_split1)} {type(t_split2)}"
|
||||
|
||||
|
||||
def _run_operand(world_size):
|
||||
pg = ProcessGroup()
|
||||
t_ref = torch.randn(4, 5)
|
||||
t = ColoTensor.from_torch_tensor(t_ref.clone(), ColoTensorSpec(pg))
|
||||
|
||||
t_ref_res = t_ref + t_ref
|
||||
t_res = t + t
|
||||
|
||||
assert isinstance(t_res, ColoTensor)
|
||||
assert torch.allclose(t_ref_res, t_res)
|
||||
|
||||
pg = ProcessGroup(tp_degree=world_size)
|
||||
t = ColoTensor.from_torch_tensor(t_ref.clone(), ColoTensorSpec(pg))
|
||||
t.set_dist_spec(ShardSpec([0], [world_size]))
|
||||
t_new = torch.zeros_like(t)
|
||||
assert isinstance(t_new, ColoTensor)
|
||||
assert t_new.is_sharded()
|
||||
|
||||
|
||||
#### Test Distributed init a Colotensor
|
||||
|
||||
|
||||
def _run_view(world_size):
|
||||
t_ref = torch.randn(4, 5)
|
||||
rank = gpc.get_global_rank()
|
||||
pg = ProcessGroup(rank, list(range(world_size)), tp_degree=world_size)
|
||||
t = ColoTensor.from_torch_tensor(
|
||||
t_ref, ColoTensorSpec(pg, dist_attr=ShardSpec(dims=[0], num_partitions=[pg.tp_world_size()])))
|
||||
|
||||
assert t.size_global()[0] == 4 * world_size
|
||||
assert t.size_global(1) == 5
|
||||
assert t.size_global() == torch.Size([4 * world_size, 5])
|
||||
|
||||
t = t.view(4 * 5 * world_size)
|
||||
assert t.shape == torch.Size([4 * 5 * world_size])
|
||||
|
||||
|
||||
def _run_tensor_shard_init(world_size):
|
||||
t_ref = torch.randn(4, 5)
|
||||
pg = ProcessGroup(tp_degree=world_size)
|
||||
shard_attr = ShardSpec(dims=[0], num_partitions=[pg.tp_world_size()])
|
||||
tensor_spec = ColoTensorSpec(pg, dist_attr=shard_attr)
|
||||
t = ColoTensor.from_torch_tensor(t_ref.clone(), tensor_spec)
|
||||
t.set_dist_spec(ReplicaSpec())
|
||||
|
||||
assert t.shape == torch.Size((4 * world_size, 5)), f"{t.shape} vs ({4 * world_size, 5})"
|
||||
|
||||
|
||||
def _run_tensor_replicated_init(world_size):
|
||||
t_ref = torch.randn(4 * world_size, 5)
|
||||
pg = ProcessGroup()
|
||||
spec = ColoTensorSpec(pg)
|
||||
t = ColoTensor.from_torch_tensor(t_ref.clone(), spec)
|
||||
|
||||
assert t.shape == torch.Size((4 * world_size, 5)), f"{t.shape}"
|
||||
|
||||
|
||||
def _run_process_group(world_size):
|
||||
pg1 = ProcessGroup()
|
||||
pg2 = ProcessGroup()
|
||||
assert pg1 == pg2
|
||||
|
||||
|
||||
def _run_redistributed(world_size):
|
||||
if world_size != 4:
|
||||
return
|
||||
pg1 = ProcessGroup(tp_degree=2, dp_degree=2)
|
||||
pg2 = ProcessGroup(tp_degree=4, dp_degree=1)
|
||||
|
||||
spec1 = ColoTensorSpec(pg1)
|
||||
t1 = ColoTensor.from_torch_tensor(torch.randn(2, 3, 4), spec1)
|
||||
t1 = t1.redistribute(ShardSpec([0], [pg1.tp_world_size()]))
|
||||
assert t1.is_sharded()
|
||||
t1 = t1.redistribute(ShardSpec([-1], [pg2.tp_world_size()]), pg2)
|
||||
assert t1.is_sharded()
|
||||
pg3 = ProcessGroup(tp_degree=1, dp_degree=4)
|
||||
t1 = t1.redistribute(ReplicaSpec(), pg3)
|
||||
assert t1.is_replicate()
|
||||
|
||||
|
||||
def _run_set_tensor_spec(world_size):
|
||||
if world_size != 4:
|
||||
return
|
||||
pg = ProcessGroup(tp_degree=2, dp_degree=2)
|
||||
spec1 = ColoTensorSpec(pg)
|
||||
t1 = ColoTensor.from_torch_tensor(torch.randn(2, 3, 4), spec1)
|
||||
|
||||
dist_spec2 = ShardSpec([-1], [pg.tp_world_size()])
|
||||
assert t1.is_replicate()
|
||||
t1.set_dist_spec(dist_spec2)
|
||||
assert t1.is_shard_1dcol()
|
||||
|
||||
|
||||
def run_dist_tests(rank, world_size, port):
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
_run_tensor_shard_init(world_size)
|
||||
_run_tensor_replicated_init(world_size)
|
||||
_run_view(world_size)
|
||||
_run_process_group(world_size)
|
||||
_run_tensor_indexing()
|
||||
_run_operand(world_size)
|
||||
_run_wrapped_tensor_func()
|
||||
_run_redistributed(world_size)
|
||||
_run_set_tensor_spec(world_size)
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize('world_size', [1, 2])
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_dist_cases(world_size):
|
||||
spawn(run_dist_tests, world_size)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_dist_cases(4)
|
|
@ -1,148 +0,0 @@
|
|||
import pytest
|
||||
import torch
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
|
||||
import colossalai
|
||||
from colossalai.nn.parallel.data_parallel import ColoDDP
|
||||
from colossalai.tensor import ColoTensor, ColoTensorSpec, ComputePattern, ComputeSpec, ProcessGroup, ShardSpec
|
||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||
from colossalai.utils.cuda import get_current_device
|
||||
from colossalai.zero import ColoInitContext
|
||||
from tests.components_to_test.registry import non_distributed_component_funcs
|
||||
from tests.test_tensor.common_utils import (
|
||||
debug_print,
|
||||
set_seed,
|
||||
split_param_col_tp1d,
|
||||
split_param_row_tp1d,
|
||||
tensor_equal,
|
||||
tensor_shard_equal,
|
||||
)
|
||||
|
||||
|
||||
def init_1d_row_spec(model, pg: ProcessGroup):
|
||||
tensor_spec = (ShardSpec([0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
|
||||
for n, p in model.named_parameters():
|
||||
p.set_process_group(pg)
|
||||
if 'weight' in n and 'ln' not in n:
|
||||
p.set_tensor_spec(*tensor_spec)
|
||||
|
||||
|
||||
def init_1d_col_spec(model, pg: ProcessGroup):
|
||||
spec = (ShardSpec([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
|
||||
|
||||
for n, p in model.named_parameters():
|
||||
p.set_process_group(pg)
|
||||
if 'ln' not in n and ('weight' in n or 'bias' in n):
|
||||
p.set_tensor_spec(*spec)
|
||||
|
||||
|
||||
def init_megatron_spec(model, pg: ProcessGroup):
|
||||
for mn, module in model.named_modules():
|
||||
# debug_print([0], mn)
|
||||
for pn, param in module.named_parameters(recurse=False):
|
||||
# debug_print([0], '\t', pn, param.compute_spec, param.shape)
|
||||
param.set_process_group(pg)
|
||||
|
||||
if 'mlp.c_fc' in mn:
|
||||
if 'weight' in pn or 'bias' in pn:
|
||||
split_param_col_tp1d(param, pg)
|
||||
param.compute_spec.set_output_replicate(False)
|
||||
else:
|
||||
raise RuntimeError
|
||||
elif 'mlp.c_proj' in mn:
|
||||
if 'weight' in pn:
|
||||
split_param_row_tp1d(param, pg)
|
||||
else:
|
||||
assert 'bias' in pn
|
||||
elif 'wte' in mn or 'wpe' in mn:
|
||||
assert 'weight' in pn
|
||||
split_param_col_tp1d(param, pg)
|
||||
elif 'c_attn' in mn or 'c_proj' in mn:
|
||||
split_param_col_tp1d(param, pg)
|
||||
# debug_print([0], '\t', param.compute_spec, param.shape)
|
||||
|
||||
|
||||
def check_param_equal(model, torch_model, pg: ProcessGroup):
|
||||
for p, torch_p in zip(model.parameters(), torch_model.parameters()):
|
||||
assert pg.tp_local_rank() is not None, f"{pg.rank()} {pg.tp_world_size()} {pg._tp_degree} {pg.tp_local_rank()}1"
|
||||
assert pg.tp_world_size() is not None
|
||||
assert tensor_shard_equal(torch_p, p, pg.tp_local_rank(), pg.tp_world_size())
|
||||
|
||||
|
||||
def check_grad_equal(model, torch_model, pg: ProcessGroup):
|
||||
for p, torch_p in zip(model.parameters(), torch_model.parameters()):
|
||||
assert tensor_shard_equal(torch_p.grad, p.grad, pg.tp_local_rank(), pg.tp_world_size())
|
||||
|
||||
|
||||
def run_gpt(init_spec_func, use_ddp):
|
||||
world_size = torch.distributed.get_world_size()
|
||||
|
||||
# build a PG with TP and DP hybrid
|
||||
pg = ProcessGroup(dp_degree=(2 if (use_ddp and world_size >= 2) else 1))
|
||||
|
||||
# set seed make processes of the same tp group use the same seed
|
||||
# set_seed(pg.tp_local_rank())
|
||||
|
||||
get_components_func = non_distributed_component_funcs.get_callable('gpt2')
|
||||
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
|
||||
|
||||
# make sure torch_model and model has the same parameter values
|
||||
with ColoInitContext(device=get_current_device()):
|
||||
model = model_builder()
|
||||
model = model.cuda()
|
||||
torch_model = model_builder().cuda()
|
||||
|
||||
if use_ddp:
|
||||
torch_model = DDP(torch_model, device_ids=[pg.rank()], process_group=pg.dp_process_group())
|
||||
model = ColoDDP(model, process_group=pg)
|
||||
|
||||
for torch_p, p in zip(torch_model.parameters(), model.parameters()):
|
||||
torch_p.data.copy_(p)
|
||||
|
||||
init_spec_func(model, pg)
|
||||
|
||||
check_param_equal(model, torch_model, pg)
|
||||
|
||||
# close the dropout in eval mode
|
||||
model.eval()
|
||||
torch_model.eval()
|
||||
set_seed(pg.dp_local_rank())
|
||||
torch.distributed.barrier()
|
||||
for i, (input_ids, label) in enumerate(train_dataloader):
|
||||
colo_input = ColoTensor.from_torch_tensor(input_ids, ColoTensorSpec(pg))
|
||||
logits = model(colo_input)
|
||||
torch_logits = torch_model(input_ids)
|
||||
assert tensor_equal(torch_logits, logits), f"{torch_logits - logits}"
|
||||
loss = criterion(logits, input_ids)
|
||||
torch_loss = criterion(torch_logits, input_ids)
|
||||
if use_ddp:
|
||||
model.backward(loss)
|
||||
else:
|
||||
loss.backward()
|
||||
torch_loss.backward()
|
||||
check_grad_equal(model, torch_model, pg)
|
||||
if i > 0:
|
||||
break
|
||||
set_seed(313)
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port, use_ddp):
|
||||
if use_ddp and world_size == 1:
|
||||
return
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
# Comments below tests for speed concern
|
||||
# run_gpt(init_1d_row_spec, use_ddp)
|
||||
# run_gpt(init_1d_col_spec, use_ddp)
|
||||
run_gpt(init_megatron_spec, use_ddp)
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize('world_size', [1, 4])
|
||||
@pytest.mark.parametrize('use_ddp', [False, True])
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_gpt(world_size, use_ddp):
|
||||
spawn(run_dist, world_size, use_ddp=use_ddp)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_gpt(4, use_ddp=False)
|
|
@ -1,334 +0,0 @@
|
|||
import pytest
|
||||
import torch
|
||||
|
||||
import colossalai
|
||||
from colossalai.nn.optimizer import ColossalaiOptimizer
|
||||
from colossalai.tensor import ColoTensor, ProcessGroup
|
||||
from colossalai.tensor.colo_parameter import ColoParameter
|
||||
from colossalai.testing import free_port, rerun_if_address_is_in_use, spawn
|
||||
from colossalai.utils.cuda import get_current_device
|
||||
from colossalai.zero import ColoInitContext
|
||||
from tests.components_to_test.registry import non_distributed_component_funcs
|
||||
from tests.test_tensor.common_utils import (
|
||||
check_equal,
|
||||
set_seed,
|
||||
split_param_col_tp1d,
|
||||
split_param_row_tp1d,
|
||||
tensor_shard_equal,
|
||||
)
|
||||
|
||||
|
||||
def run_1d_hybrid_tp(model_name):
|
||||
# A simple net with two stacked nn.Linear
|
||||
get_components_func = non_distributed_component_funcs.get_callable(model_name)
|
||||
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
|
||||
|
||||
rank = torch.distributed.get_rank()
|
||||
world_size = torch.distributed.get_world_size()
|
||||
|
||||
set_seed(1)
|
||||
with ColoInitContext(device=get_current_device()):
|
||||
model = model_builder(checkpoint=True)
|
||||
|
||||
if rank == 0:
|
||||
model_torch = model_builder(checkpoint=True)
|
||||
model_torch = model_torch.cuda()
|
||||
|
||||
optimizer_torch = ColossalaiOptimizer(torch.optim.SGD(model_torch.parameters(), lr=0.1))
|
||||
|
||||
# Make two models have the same init params
|
||||
for p1, p2 in zip(model.parameters(), model_torch.parameters()):
|
||||
p2.data.copy_(p1.data)
|
||||
else:
|
||||
model_torch = None
|
||||
optimizer_torch = None
|
||||
|
||||
pg = ProcessGroup(tp_degree=world_size)
|
||||
if 'bert' == model_name:
|
||||
for name, p in model.named_parameters():
|
||||
if not isinstance(p, ColoTensor):
|
||||
continue
|
||||
|
||||
# num_class = type_vocab_size = 2 | (8, 2)
|
||||
if 'classifier' in name and 'weight' in name:
|
||||
split_param_col_tp1d(p, pg)
|
||||
# num_class = vocab_size = 30524 | (30524, 8)
|
||||
elif 'word_embeddings' in name and 'weight' in name:
|
||||
split_param_row_tp1d(p, pg)
|
||||
# num_class = seq_len = 512 | (512, 8)
|
||||
elif 'position_embeddings' in name and 'weight' in name:
|
||||
split_param_row_tp1d(p, pg)
|
||||
# num_class = type_vocab_size = 2 | (2, 8)
|
||||
elif 'token_type_embeddings' in name and 'weight' in name:
|
||||
split_param_col_tp1d(p, pg)
|
||||
|
||||
elif "simple_net" == model_name:
|
||||
# A naive way to set spec for all weights in Linear
|
||||
for name, p in model.named_parameters():
|
||||
if not isinstance(p, ColoTensor):
|
||||
continue
|
||||
if 'embed' in name and 'weight' in name:
|
||||
split_param_col_tp1d(p, pg)
|
||||
if 'proj1' in name and ('weight' in name or 'bias' in name):
|
||||
split_param_row_tp1d(p, pg)
|
||||
if 'proj2' in name and 'weight' in name:
|
||||
split_param_col_tp1d(p, pg)
|
||||
if 'classifier' in name and ('weight' in name or 'bias' in name):
|
||||
split_param_row_tp1d(p, pg)
|
||||
|
||||
model = model.cuda()
|
||||
model.eval()
|
||||
if rank == 0:
|
||||
model_torch.eval()
|
||||
|
||||
colo_optimizer = ColossalaiOptimizer(torch.optim.SGD(model.parameters(), lr=0.1))
|
||||
|
||||
for i, (data, label) in enumerate(train_dataloader):
|
||||
|
||||
# Zero grad
|
||||
colo_optimizer.zero_grad()
|
||||
if rank == 0:
|
||||
optimizer_torch.zero_grad()
|
||||
torch.distributed.barrier()
|
||||
|
||||
data = data.to(get_current_device())
|
||||
label = label.to(get_current_device())
|
||||
|
||||
torch.distributed.broadcast(data, 0, group=pg.tp_process_group())
|
||||
torch.distributed.broadcast(label, 0, group=pg.tp_process_group())
|
||||
|
||||
# Bcast rank0 data to all processes
|
||||
if criterion:
|
||||
output = model(data)
|
||||
loss = criterion(output, label)
|
||||
else:
|
||||
output = model(data, label)
|
||||
loss = output
|
||||
|
||||
# Test output
|
||||
if rank == 0:
|
||||
if criterion:
|
||||
output_torch = model_torch(data)
|
||||
loss_torch = criterion(output_torch, label)
|
||||
else:
|
||||
output_torch = model_torch(data, label)
|
||||
loss_torch = output_torch
|
||||
assert torch.allclose(loss, loss_torch, rtol=1e-2), f"model_name {model_name} failed"
|
||||
torch.distributed.barrier()
|
||||
|
||||
loss.backward()
|
||||
colo_optimizer.step()
|
||||
|
||||
if rank == 0:
|
||||
loss_torch.backward()
|
||||
optimizer_torch.step()
|
||||
|
||||
with torch.no_grad():
|
||||
# check param
|
||||
for p, torch_p in zip(model.parameters(), model_torch.parameters()):
|
||||
assert tensor_shard_equal(torch_p, p, pg.tp_local_rank(), pg.tp_world_size())
|
||||
torch.distributed.barrier()
|
||||
if i > 5:
|
||||
break
|
||||
|
||||
|
||||
# Test the overrided parameters() and named_parameters() member functions
|
||||
def test_model_parameters():
|
||||
colossalai.launch(config={}, rank=0, world_size=1, host='localhost', port=free_port(), backend='nccl')
|
||||
|
||||
# build a module with 2 Linear, 4 parameters in total.
|
||||
class Net(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.fcs = torch.nn.Sequential(torch.nn.Linear(2, 3), torch.nn.Linear(3, 2))
|
||||
self.extra_param = torch.nn.Parameter(torch.randn(2))
|
||||
|
||||
with ColoInitContext(device=get_current_device()):
|
||||
model = Net()
|
||||
|
||||
param_cnt = 0
|
||||
for name, p in model.named_parameters():
|
||||
param_cnt += 1
|
||||
assert param_cnt == 5
|
||||
|
||||
for name, colo_p in model.named_parameters():
|
||||
assert colo_p.is_model_data()
|
||||
|
||||
param_cnt = 0
|
||||
for name, p in model.named_parameters(recurse=False):
|
||||
param_cnt += 1
|
||||
assert param_cnt == 1
|
||||
|
||||
param_cnt = 0
|
||||
for p in model.fcs[0].parameters(recurse=False):
|
||||
param_cnt += 1
|
||||
assert param_cnt == 2
|
||||
|
||||
|
||||
def test_colo_optimizer():
|
||||
get_components_func = non_distributed_component_funcs.get_callable('simple_net')
|
||||
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
|
||||
set_seed(1)
|
||||
with ColoInitContext(device=get_current_device()):
|
||||
model = model_builder(checkpoint=True)
|
||||
|
||||
colo_optimizer = ColossalaiOptimizer(torch.optim.SGD(model.parameters(), lr=0.1))
|
||||
for i, (data, label) in enumerate(train_dataloader):
|
||||
colo_optimizer.zero_grad()
|
||||
data = data.to(get_current_device())
|
||||
label = label.to(get_current_device())
|
||||
|
||||
# Bcast rank0 data to all processes
|
||||
if criterion:
|
||||
output = model(data)
|
||||
loss = criterion(output, label)
|
||||
else:
|
||||
output = model(data, label)
|
||||
loss = output
|
||||
|
||||
loss.backward()
|
||||
colo_optimizer.step()
|
||||
|
||||
if i > 5:
|
||||
break
|
||||
|
||||
|
||||
def run_1d_row_tp(model_name: str):
|
||||
# A simple net with two stacked nn.Linear
|
||||
get_components_func = non_distributed_component_funcs.get_callable(model_name)
|
||||
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
|
||||
rank = torch.distributed.get_rank()
|
||||
|
||||
set_seed(1)
|
||||
with ColoInitContext(device=get_current_device()):
|
||||
model = model_builder(checkpoint=True)
|
||||
|
||||
world_size = torch.distributed.get_world_size()
|
||||
pg = ProcessGroup(tp_degree=world_size)
|
||||
|
||||
set_seed(1)
|
||||
if rank == 0:
|
||||
model_torch = model_builder(checkpoint=True)
|
||||
model_torch = model_torch.cuda()
|
||||
|
||||
# A naive way to set spec for all weights in Linear
|
||||
for mo_name, module in model.named_modules():
|
||||
# print(mo_name)
|
||||
for pa_name, param in module.named_parameters(recurse=False):
|
||||
# print('\t', pa_name, param.shape)
|
||||
if not isinstance(param, ColoTensor):
|
||||
continue
|
||||
if 'weight' in pa_name:
|
||||
if 'embed' in mo_name and 'token' not in mo_name and 'LayerNorm' not in mo_name:
|
||||
split_param_row_tp1d(param, pg)
|
||||
elif 'LayerNorm' not in mo_name and 'ln' not in mo_name:
|
||||
split_param_col_tp1d(param, pg)
|
||||
|
||||
model = model.cuda()
|
||||
|
||||
for i, (data, label) in enumerate(train_dataloader):
|
||||
data = data.to(get_current_device())
|
||||
label = label.to(get_current_device())
|
||||
|
||||
torch.distributed.broadcast(data, 0, group=pg.tp_process_group())
|
||||
torch.distributed.broadcast(label, 0, group=pg.tp_process_group())
|
||||
|
||||
# Bcast rank0 data to all processes
|
||||
if criterion:
|
||||
output = model(data)
|
||||
loss = criterion(output, label)
|
||||
else:
|
||||
output = model(data, label)
|
||||
loss = output
|
||||
|
||||
# For reference
|
||||
if rank == 0:
|
||||
if criterion:
|
||||
output_torch = model_torch(data)
|
||||
loss_torch = criterion(output_torch, label)
|
||||
else:
|
||||
output_torch = model_torch(data, label)
|
||||
loss_torch = output_torch
|
||||
assert torch.allclose(loss, loss_torch, rtol=1e-2)
|
||||
torch.distributed.barrier()
|
||||
|
||||
loss.backward()
|
||||
|
||||
if rank == 0:
|
||||
loss_torch.backward()
|
||||
torch.distributed.barrier()
|
||||
|
||||
if i > 5:
|
||||
break
|
||||
|
||||
|
||||
def _run_pretrain_load():
|
||||
from transformers import BertForMaskedLM
|
||||
set_seed(1)
|
||||
model_pretrained = BertForMaskedLM.from_pretrained('bert-base-uncased')
|
||||
with ColoInitContext(device=get_current_device()):
|
||||
model = BertForMaskedLM.from_pretrained('bert-base-uncased')
|
||||
|
||||
model_pretrained = model_pretrained.cuda()
|
||||
model = model.cuda()
|
||||
|
||||
dict_pretrained = {}
|
||||
dict_col = {}
|
||||
c_ref = 0
|
||||
for name, param in model_pretrained.named_parameters():
|
||||
dict_pretrained[name] = param
|
||||
c_ref += 1
|
||||
c1 = 0
|
||||
c2 = 0
|
||||
for name, param in model.named_parameters():
|
||||
if isinstance(param, ColoParameter):
|
||||
c1 += 1
|
||||
else:
|
||||
c2 += 1
|
||||
dict_col[name] = param
|
||||
assert c_ref == c1
|
||||
assert c2 == 0
|
||||
if model_pretrained.cls.predictions.decoder.bias is model_pretrained.cls.predictions.bias:
|
||||
assert model.cls.predictions.decoder.bias is model.cls.predictions.bias
|
||||
|
||||
for name, param in dict_pretrained.items():
|
||||
check_equal(param, dict_col[name])
|
||||
|
||||
|
||||
def run_model_dist(rank, world_size, port):
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
# Comment below test for speed consideration
|
||||
# for name in ['bert', 'simple_net']:
|
||||
# run_1d_row_tp(name)
|
||||
for name in ['bert', 'simple_net']:
|
||||
run_1d_hybrid_tp(name)
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize('world_size', [1, 4])
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_model(world_size):
|
||||
spawn(run_model_dist, world_size)
|
||||
|
||||
|
||||
def run_pretrain_load_dist(rank, world_size, port):
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
_run_pretrain_load()
|
||||
|
||||
|
||||
# The test case has to download huggingface pretrained models from the internet
|
||||
# So we manually trigger the test.
|
||||
@pytest.mark.skip
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize('world_size', [1, 4])
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_pretrain_load(world_size):
|
||||
spawn(run_pretrain_load_dist, world_size)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# test_model_parameters()
|
||||
# test_colo_optimizer()
|
||||
test_model(4)
|
||||
# test_pretrain_load(4)
|
|
@ -1,227 +0,0 @@
|
|||
from copy import deepcopy
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
import colossalai
|
||||
from colossalai.nn.parallel.layers import check_colo_module, init_colo_module
|
||||
from colossalai.tensor import (
|
||||
ColoTensor,
|
||||
ColoTensorSpec,
|
||||
ComputePattern,
|
||||
ComputeSpec,
|
||||
ProcessGroup,
|
||||
ReplicaSpec,
|
||||
ShardSpec,
|
||||
distspec,
|
||||
)
|
||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||
from colossalai.utils.cuda import get_current_device
|
||||
from colossalai.zero import ColoInitContext
|
||||
from tests.components_to_test.registry import non_distributed_component_funcs
|
||||
from tests.test_tensor.common_utils import set_seed, tensor_equal, tensor_shard_equal
|
||||
|
||||
|
||||
def run_model_with_spec(mode, model_name):
|
||||
get_components_func = non_distributed_component_funcs.get_callable(model_name)
|
||||
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
|
||||
world_size = torch.distributed.get_world_size()
|
||||
pg = ProcessGroup(tp_degree=world_size)
|
||||
rank = pg.rank()
|
||||
|
||||
set_seed(1)
|
||||
with ColoInitContext(device=get_current_device()):
|
||||
model = model_builder(checkpoint=False)
|
||||
|
||||
if rank == 0:
|
||||
model_seq = model_builder(checkpoint=False)
|
||||
model_seq = model_seq.cuda()
|
||||
|
||||
# Make two models have the same init params
|
||||
for p1, p2 in zip(model.parameters(), model_seq.parameters()):
|
||||
p2.data.copy_(p1.data)
|
||||
|
||||
compute_spec = ComputeSpec(ComputePattern.TP1D)
|
||||
# Not all layers in Bert can be mod by 4.
|
||||
# e.g. row shard for all layers is invalid because the first dim of some layer is the classification type size 2.
|
||||
if 'bert' == model_name:
|
||||
if 'col' == mode:
|
||||
init_colo_module(model.bert.embeddings, compute_spec, pg=pg, recursive=True, mode=mode)
|
||||
init_colo_module(model.bert.encoder, compute_spec, pg=pg, recursive=True, mode=mode)
|
||||
init_colo_module(model.classifier, compute_spec, pg=pg, recursive=True, mode='row')
|
||||
elif 'row' == mode:
|
||||
init_colo_module(model.bert.embeddings, compute_spec, pg=pg, recursive=True, mode='col')
|
||||
init_colo_module(model.bert.encoder, compute_spec, pg=pg, recursive=True, mode=mode)
|
||||
init_colo_module(model.classifier, compute_spec, pg=pg, recursive=True, mode=mode)
|
||||
elif 'simple_net' == model_name:
|
||||
init_colo_module(model, compute_spec, pg=pg, recursive=True, mode=mode)
|
||||
|
||||
model = model.cuda()
|
||||
for i, (data, label) in enumerate(train_dataloader):
|
||||
data = data.to(get_current_device())
|
||||
label = label.to(get_current_device())
|
||||
|
||||
torch.distributed.broadcast(data, 0, group=pg.tp_process_group())
|
||||
torch.distributed.broadcast(label, 0, group=pg.tp_process_group())
|
||||
|
||||
if criterion:
|
||||
output = model(data)
|
||||
loss = criterion(output, label)
|
||||
else:
|
||||
output = model(data, label)
|
||||
loss = output
|
||||
|
||||
# For reference
|
||||
if rank == 0:
|
||||
if criterion:
|
||||
output_seq = model_seq(data)
|
||||
loss_seq = criterion(output_seq, label)
|
||||
else:
|
||||
output_seq = model_seq(data, label)
|
||||
loss_seq = output_seq
|
||||
|
||||
if rank == 0:
|
||||
with torch.no_grad():
|
||||
assert torch.allclose(loss, loss_seq, rtol=1e-2)
|
||||
|
||||
loss.backward()
|
||||
|
||||
if rank == 0:
|
||||
loss_seq.backward()
|
||||
|
||||
with torch.no_grad():
|
||||
# check param
|
||||
for p1, p2 in zip(model.parameters(), model_seq.parameters()):
|
||||
if p1.size() == p2.size():
|
||||
assert torch.allclose(p1, p2)
|
||||
else:
|
||||
if p1.size(-1) < p2.size(-1): # col
|
||||
world_size = p2.size(-1) // p1.size(-1)
|
||||
split_p2 = torch.chunk(p2, world_size, dim=-1)[0]
|
||||
|
||||
elif p1.size(0) < p2.size(0): # row
|
||||
world_size = p2.size(0) // p1.size(0)
|
||||
split_p2 = torch.chunk(p2, world_size, dim=0)[0]
|
||||
|
||||
assert torch.allclose(p1, split_p2)
|
||||
|
||||
if i > 3:
|
||||
break
|
||||
|
||||
|
||||
def run_linear_with_spec(mode):
|
||||
with ColoInitContext(device=get_current_device()):
|
||||
model = torch.nn.Linear(4, 8)
|
||||
|
||||
model_handy = deepcopy(model)
|
||||
world_size = torch.distributed.get_world_size()
|
||||
pg = ProcessGroup(tp_degree=world_size)
|
||||
compute_spec = ComputeSpec(ComputePattern.TP1D)
|
||||
init_colo_module(model, compute_spec, pg=pg, recursive=True, mode=mode)
|
||||
|
||||
x = torch.rand(2, 4).cuda()
|
||||
colo_x = ColoTensor.from_torch_tensor(x, ColoTensorSpec(pg))
|
||||
|
||||
out = model(x)
|
||||
colo_out = model_handy(colo_x)
|
||||
assert tensor_equal(out, colo_out)
|
||||
|
||||
grad = torch.rand_like(out)
|
||||
out.backward(grad)
|
||||
colo_out.backward(grad)
|
||||
|
||||
assert tensor_shard_equal(model_handy.weight.grad, model.weight.grad, pg.tp_local_rank(), pg.tp_world_size())
|
||||
assert tensor_shard_equal(model_handy.bias.grad, model.bias.grad, pg.tp_local_rank(), pg.tp_world_size())
|
||||
|
||||
|
||||
def run_check_shared_param():
|
||||
from transformers import BertConfig, BertForMaskedLM
|
||||
hidden_dim = 8
|
||||
num_head = 4
|
||||
sequence_length = 12
|
||||
num_layer = 2
|
||||
vocab_size = 24
|
||||
|
||||
world_size = torch.distributed.get_world_size()
|
||||
pg = ProcessGroup(tp_degree=world_size)
|
||||
rank = pg.rank()
|
||||
|
||||
config = BertConfig(vocab_size=vocab_size,
|
||||
hidden_size=hidden_dim,
|
||||
intermediate_size=hidden_dim * 4,
|
||||
num_attention_heads=num_head,
|
||||
max_position_embeddings=sequence_length,
|
||||
num_hidden_layers=num_layer,
|
||||
hidden_dropout_prob=0.,
|
||||
attention_probs_dropout_prob=0.)
|
||||
with ColoInitContext(device=get_current_device()):
|
||||
model = BertForMaskedLM(config)
|
||||
|
||||
model = model.cuda()
|
||||
compute_spec = ComputeSpec(ComputePattern.TP1D)
|
||||
# model.cls.predictions.decoder and model.cls.predictions share the bias, so they should have the same spec
|
||||
assert len(model.cls.predictions.decoder.bias.shared_param_modules) == 2
|
||||
# They are all Linear, so both row is allowed. This should pass check.
|
||||
init_colo_module(model, compute_spec, pg=pg, recursive=True, mode='row')
|
||||
# This should be detected by check because you can not set weight as row while set bias as col.
|
||||
col_spec = (ShardSpec([0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
|
||||
|
||||
# TODO(jiaruifang) optimize this line
|
||||
if not model.cls.predictions.bias.has_initialized:
|
||||
model.cls.predictions.bias.pg = pg
|
||||
model.cls.predictions.bias.dist_spec = ReplicaSpec()
|
||||
model.cls.predictions.bias.has_initialized = True
|
||||
model.cls.predictions.bias.set_tensor_spec(*col_spec)
|
||||
try:
|
||||
check_colo_module(model.cls.predictions.decoder, pg=pg, recursive=False)
|
||||
except Exception as e:
|
||||
assert 'incorrectly sharded' in str(e)
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
config = dict(parallel=dict(tensor=dict(mode="1d", size=world_size),))
|
||||
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
run_linear_with_spec('col')
|
||||
run_linear_with_spec('row')
|
||||
|
||||
|
||||
def run_dist_model(rank, world_size, port):
|
||||
config = dict(parallel=dict(tensor=dict(mode="1d", size=world_size),))
|
||||
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
for model_name in ['simple_net', 'bert']:
|
||||
run_model_with_spec('col', model_name)
|
||||
run_model_with_spec('row', model_name)
|
||||
|
||||
|
||||
def run_dist_check(rank, world_size, port):
|
||||
config = dict(parallel=dict(tensor=dict(mode="1d", size=world_size),))
|
||||
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
run_check_shared_param()
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize('world_size', [1, 4])
|
||||
@pytest.mark.skip("for higher testing speed")
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_module_linear_1d(world_size):
|
||||
spawn(run_dist, world_size)
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize('world_size', [1, 4])
|
||||
@pytest.mark.skip("for higher testing speed")
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_module_model(world_size):
|
||||
spawn(run_dist_model, world_size)
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize('world_size', [1, 2])
|
||||
@pytest.mark.skip("for higher testing speed")
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_module_check(world_size):
|
||||
spawn(run_dist_check, world_size)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_module_linear_1d(4)
|
|
@ -1,41 +0,0 @@
|
|||
import pytest
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
import colossalai
|
||||
from colossalai.tensor import ColoTensor, ColoTensorSpec, ComputePattern, ComputeSpec, ProcessGroup, ShardSpec
|
||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||
from colossalai.utils.checkpoint.utils import gather_tensor, scatter_tensor
|
||||
from tests.test_tensor.common_utils import tensor_shard_equal
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port, dp_degree, tp_degree):
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
pg = ProcessGroup(dp_degree=dp_degree, tp_degree=tp_degree)
|
||||
x = torch.randn(4, 4)
|
||||
param = ColoTensor(torch.nn.Parameter(x), spec=ColoTensorSpec(pg))
|
||||
spec = ShardSpec([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D)
|
||||
param.set_tensor_spec(*spec)
|
||||
|
||||
gather_tensor(param)
|
||||
if dist.get_rank() == 0:
|
||||
assert torch.all(x == param)
|
||||
else:
|
||||
assert tensor_shard_equal(x, param.data, pg.tp_local_rank(), pg.tp_world_size())
|
||||
dist.barrier()
|
||||
|
||||
scatter_tensor(param, spec[0])
|
||||
assert tensor_shard_equal(x, param.data, pg.tp_local_rank(), pg.tp_world_size())
|
||||
assert param.requires_grad is True
|
||||
dist.barrier()
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize('world_size', [4])
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_checkpoint(world_size):
|
||||
spawn(run_dist, world_size, dp_degree=2, tp_degree=world_size // 2)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_checkpoint(world_size=4)
|
|
@ -1,64 +0,0 @@
|
|||
import pytest
|
||||
import torch
|
||||
|
||||
import colossalai
|
||||
from colossalai.tensor import (
|
||||
ColoParameter,
|
||||
ColoTensorSpec,
|
||||
ComputePattern,
|
||||
ComputeSpec,
|
||||
ProcessGroup,
|
||||
ReplicaSpec,
|
||||
ShardSpec,
|
||||
)
|
||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||
from colossalai.utils.cuda import get_current_device
|
||||
from colossalai.zero import ColoInitContext
|
||||
from tests.components_to_test.registry import non_distributed_component_funcs
|
||||
from tests.test_tensor.common_utils import set_seed
|
||||
|
||||
|
||||
def run_colo_init_context(rank: int, world_size: int, port: int):
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
|
||||
# make sure seed of each process is the same, so the params are consistent among processes and the params are exactly replicated.
|
||||
set_seed(42)
|
||||
get_components_func = non_distributed_component_funcs.get_callable('gpt2')
|
||||
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
|
||||
|
||||
# keep parameters replicated during init
|
||||
with ColoInitContext(device=get_current_device()):
|
||||
model1 = model_builder()
|
||||
|
||||
# shard the parameters during init
|
||||
set_seed(42)
|
||||
shard_spec = ReplicaSpec()
|
||||
|
||||
# If using ShardSpec, the assertations will failed.
|
||||
# But it is not a bug, the initialized values are not consist with the original one.
|
||||
# shard_spec = ShardSpec(dims=[0], num_partitions=[world_size])
|
||||
default_pg = ProcessGroup(tp_degree=world_size)
|
||||
with ColoInitContext(device=get_current_device(), default_pg=default_pg, default_dist_spec=shard_spec):
|
||||
model2 = model_builder()
|
||||
|
||||
# reshard both models
|
||||
new_shard = ShardSpec(dims=[-1], num_partitions=[world_size])
|
||||
for p1, p2 in zip(model1.parameters(), model2.parameters()):
|
||||
p1: ColoParameter = p1
|
||||
p1.set_process_group(ProcessGroup(tp_degree=world_size))
|
||||
p1.set_dist_spec(new_shard)
|
||||
p2.set_dist_spec(new_shard)
|
||||
|
||||
for p1, p2 in zip(model1.parameters(), model2.parameters()):
|
||||
assert (torch.allclose(p1, p2))
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize('world_size', [1, 4])
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_colo_init_context(world_size):
|
||||
spawn(run_colo_init_context, world_size)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_colo_init_context(2)
|
|
@ -1,232 +0,0 @@
|
|||
import pytest
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
import colossalai
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.nn._ops._utils import gather_forward_split_backward
|
||||
from colossalai.tensor import ColoParameter, ColoTensor, ProcessGroup
|
||||
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
config = {}
|
||||
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
|
||||
# create mlp vars
|
||||
x = ColoTensor.from_torch_tensor(torch.rand(4, 4, 8, requires_grad=True)).cuda()
|
||||
w = ColoParameter.from_torch_tensor(torch.rand(16, 8, requires_grad=True)).cuda()
|
||||
b = ColoParameter.from_torch_tensor(torch.rand(16, requires_grad=True)).cuda()
|
||||
|
||||
# run normal forward
|
||||
out = F.linear(x, w, b)
|
||||
|
||||
# create mesh meta
|
||||
# the mesh is in the following topo
|
||||
# [[0, 1],
|
||||
# [2, 3]]
|
||||
physical_mesh_id = torch.arange(0, 4)
|
||||
mesh_shape = (2, 2)
|
||||
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
|
||||
row_id = rank // 2
|
||||
column_id = rank % 2
|
||||
|
||||
# create pg
|
||||
row_process_group = None
|
||||
col_process_group = None
|
||||
row_to_ranks = {0: [0, 1], 1: [2, 3]}
|
||||
col_to_ranks = {0: [0, 2], 1: [1, 3]}
|
||||
|
||||
for idx in range(2):
|
||||
# row ranks
|
||||
row_ranks = row_to_ranks[idx]
|
||||
row_pg = ProcessGroup(ranks=row_ranks, tp_degree=2)
|
||||
|
||||
# col ranks
|
||||
col_ranks = col_to_ranks[idx]
|
||||
col_pg = ProcessGroup(ranks=col_ranks, tp_degree=2)
|
||||
|
||||
if rank in row_ranks:
|
||||
row_process_group = row_pg
|
||||
|
||||
if rank in col_ranks:
|
||||
col_process_group = col_pg
|
||||
|
||||
########################
|
||||
# RRR x RS0 -> RRS0 #
|
||||
########################
|
||||
# w will be transposed in F.linear
|
||||
x_replica = x.detach().clone()
|
||||
w_shard = torch.chunk(w.detach().clone(), chunks=2, dim=0)[row_id]
|
||||
b_shard = torch.chunk(b.detach().clone(), chunks=2, dim=0)[row_id]
|
||||
|
||||
# adding sharding spec
|
||||
x_replica.sharding_spec = ShardingSpec(device_mesh, x.shape, dim_partition_dict={})
|
||||
w_shard.sharding_spec = ShardingSpec(device_mesh, w.shape, dim_partition_dict={0: [0]})
|
||||
b_shard.sharding_spec = ShardingSpec(device_mesh, b.shape, dim_partition_dict={0: [0]})
|
||||
|
||||
# check sharding spec
|
||||
assert str(x_replica.sharding_spec.sharding_sequence) == "[R, R, R]"
|
||||
assert str(w_shard.sharding_spec.sharding_sequence) == "[S0, R]"
|
||||
assert str(b_shard.sharding_spec.sharding_sequence) == "[S0]"
|
||||
|
||||
w_shard.pg_axis0 = col_process_group
|
||||
w_shard.pg_axis1 = row_process_group
|
||||
|
||||
out_shard = F.linear(x_replica, w_shard, b_shard)
|
||||
assert str(out_shard.sharding_spec.sharding_sequence) == "[R, R, S0]"
|
||||
|
||||
# each row only has a mini-batch
|
||||
expected_out_shard = torch.chunk(out, chunks=2, dim=2)[row_id]
|
||||
assert torch.allclose(out_shard, expected_out_shard)
|
||||
|
||||
########################
|
||||
# S0RR x RS1 -> S0RS1 #
|
||||
########################
|
||||
# w will be transposed in F.linear
|
||||
x_shard = torch.chunk(x.detach().clone(), chunks=2, dim=0)[row_id]
|
||||
w_shard = torch.chunk(w.detach().clone(), chunks=2, dim=0)[column_id]
|
||||
b_shard = torch.chunk(b.detach().clone(), chunks=2, dim=0)[column_id]
|
||||
|
||||
# adding sharding spec
|
||||
x_shard.sharding_spec = ShardingSpec(device_mesh, x.shape, dim_partition_dict={0: [0]})
|
||||
w_shard.sharding_spec = ShardingSpec(device_mesh, w.shape, dim_partition_dict={0: [1]})
|
||||
b_shard.sharding_spec = ShardingSpec(device_mesh, b.shape, dim_partition_dict={0: [1]})
|
||||
|
||||
# check sharding spec
|
||||
assert str(x_shard.sharding_spec.sharding_sequence) == "[S0, R, R]"
|
||||
assert str(w_shard.sharding_spec.sharding_sequence) == "[S1, R]"
|
||||
assert str(b_shard.sharding_spec.sharding_sequence) == "[S1]"
|
||||
|
||||
w_shard.pg_axis0 = col_process_group
|
||||
w_shard.pg_axis1 = row_process_group
|
||||
|
||||
out_shard = F.linear(x_shard, w_shard, b_shard)
|
||||
|
||||
# each row only has a mini-batch
|
||||
expected_out_shard = torch.chunk(out, chunks=2, dim=0)[row_id]
|
||||
expected_out_shard = torch.chunk(expected_out_shard, chunks=2, dim=2)[column_id]
|
||||
assert torch.allclose(out_shard, expected_out_shard)
|
||||
|
||||
########################
|
||||
# S0RS1 x S1R -> S0RR #
|
||||
########################
|
||||
# w will be transposed in F.linear
|
||||
x_shard = torch.chunk(x.clone(), chunks=2, dim=0)[row_id]
|
||||
x_shard = torch.chunk(x_shard, chunks=2, dim=2)[column_id]
|
||||
w_shard = torch.chunk(w.clone(), chunks=2, dim=1)[column_id]
|
||||
b_replica = b.clone()
|
||||
|
||||
# adding sharding spec
|
||||
x_shard.sharding_spec = ShardingSpec(device_mesh, x.shape, dim_partition_dict={0: [0], 2: [1]})
|
||||
w_shard.sharding_spec = ShardingSpec(device_mesh, w.shape, dim_partition_dict={1: [1]})
|
||||
b_replica.sharding_spec = ShardingSpec(device_mesh, b.shape, dim_partition_dict={})
|
||||
|
||||
# check sharding spec
|
||||
assert str(x_shard.sharding_spec.sharding_sequence) == "[S0, R, S1]"
|
||||
assert str(w_shard.sharding_spec.sharding_sequence) == "[R, S1]"
|
||||
assert str(b_replica.sharding_spec.sharding_sequence) == "[R]"
|
||||
|
||||
w_shard.pg_axis0 = col_process_group
|
||||
w_shard.pg_axis1 = row_process_group
|
||||
|
||||
out_shard = F.linear(x_shard, w_shard, b_replica)
|
||||
|
||||
# each row only has a mini-batch
|
||||
expected_out_shard = torch.chunk(out, chunks=2, dim=0)[row_id]
|
||||
assert torch.allclose(out_shard, expected_out_shard)
|
||||
|
||||
########################
|
||||
# RRS0 x S0R -> RRR #
|
||||
########################
|
||||
# w will be transposed in F.linear
|
||||
x_shard = torch.chunk(x.clone(), chunks=2, dim=2)[row_id]
|
||||
w_shard = torch.chunk(w.clone(), chunks=2, dim=1)[row_id]
|
||||
b_replica = b.clone()
|
||||
|
||||
# adding sharding spec
|
||||
x_shard.sharding_spec = ShardingSpec(device_mesh, x.shape, dim_partition_dict={2: [0]})
|
||||
w_shard.sharding_spec = ShardingSpec(device_mesh, w.shape, dim_partition_dict={1: [0]})
|
||||
b_replica.sharding_spec = ShardingSpec(device_mesh, b.shape, dim_partition_dict={})
|
||||
|
||||
# check sharding spec
|
||||
assert str(x_shard.sharding_spec.sharding_sequence) == "[R, R, S0]"
|
||||
assert str(w_shard.sharding_spec.sharding_sequence) == "[R, S0]"
|
||||
assert str(b_replica.sharding_spec.sharding_sequence) == "[R]"
|
||||
|
||||
w_shard.pg_axis0 = col_process_group
|
||||
w_shard.pg_axis1 = row_process_group
|
||||
|
||||
out_shard = F.linear(x_shard, w_shard, b_replica)
|
||||
|
||||
# each row only has a mini-batch
|
||||
expected_out_shard = out
|
||||
assert torch.allclose(out_shard, expected_out_shard)
|
||||
|
||||
########################
|
||||
# RS0S1 x S1R -> RS0R #
|
||||
########################
|
||||
# w will be transposed in F.linear
|
||||
x_shard = torch.chunk(x.clone(), chunks=2, dim=1)[row_id]
|
||||
x_shard = torch.chunk(x_shard, chunks=2, dim=2)[column_id]
|
||||
w_shard = torch.chunk(w.clone(), chunks=2, dim=1)[column_id]
|
||||
b_replica = b.clone()
|
||||
|
||||
# adding sharding spec
|
||||
x_shard.sharding_spec = ShardingSpec(device_mesh, x.shape, dim_partition_dict={1: [0], 2: [1]})
|
||||
w_shard.sharding_spec = ShardingSpec(device_mesh, w.shape, dim_partition_dict={1: [1]})
|
||||
b_replica.sharding_spec = ShardingSpec(device_mesh, b.shape, dim_partition_dict={})
|
||||
|
||||
# check sharding spec
|
||||
assert str(x_shard.sharding_spec.sharding_sequence) == "[R, S0, S1]"
|
||||
assert str(w_shard.sharding_spec.sharding_sequence) == "[R, S1]"
|
||||
assert str(b_replica.sharding_spec.sharding_sequence) == "[R]"
|
||||
|
||||
w_shard.pg_axis0 = col_process_group
|
||||
w_shard.pg_axis1 = row_process_group
|
||||
|
||||
out_shard = F.linear(x_shard, w_shard, b_replica)
|
||||
|
||||
# each row only has a mini-batch
|
||||
expected_out_shard = torch.chunk(out, chunks=2, dim=1)[row_id]
|
||||
assert torch.allclose(out_shard, expected_out_shard)
|
||||
|
||||
########################
|
||||
# RRS0 x S0S1 -> RRS1 #
|
||||
########################
|
||||
# w will be transposed in F.linear
|
||||
x_shard = torch.chunk(x.clone(), chunks=2, dim=2)[row_id]
|
||||
w_shard = torch.chunk(w.clone(), chunks=2, dim=1)[row_id]
|
||||
w_shard = torch.chunk(w_shard, chunks=2, dim=0)[column_id]
|
||||
b_shard = torch.chunk(b.clone(), chunks=2, dim=0)[column_id]
|
||||
|
||||
# adding sharding spec
|
||||
x_shard.sharding_spec = ShardingSpec(device_mesh, x.shape, dim_partition_dict={2: [0]})
|
||||
w_shard.sharding_spec = ShardingSpec(device_mesh, w.shape, dim_partition_dict={0: [1], 1: [0]})
|
||||
b_shard.sharding_spec = ShardingSpec(device_mesh, b.shape, dim_partition_dict={0: [1]})
|
||||
|
||||
# check sharding spec
|
||||
assert str(x_shard.sharding_spec.sharding_sequence) == "[R, R, S0]"
|
||||
assert str(w_shard.sharding_spec.sharding_sequence) == "[S1, S0]"
|
||||
assert str(b_shard.sharding_spec.sharding_sequence) == "[S1]"
|
||||
|
||||
w_shard.pg_axis0 = col_process_group
|
||||
w_shard.pg_axis1 = row_process_group
|
||||
|
||||
out_shard = F.linear(x_shard, w_shard, b_shard)
|
||||
|
||||
# each row only has a mini-batch
|
||||
expected_out_shard = torch.chunk(out, chunks=2, dim=2)[column_id]
|
||||
assert torch.allclose(out_shard, expected_out_shard)
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize('world_size', [4])
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_sharded_mlp(world_size):
|
||||
spawn(run_dist, world_size)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_sharded_mlp(4)
|
|
@ -1,143 +0,0 @@
|
|||
import pytest
|
||||
import torch
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
|
||||
import colossalai
|
||||
from colossalai.amp import convert_to_apex_amp
|
||||
from colossalai.tensor import ColoTensor, ColoTensorSpec, ComputePattern, ComputeSpec, ProcessGroup, ShardSpec
|
||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
||||
from colossalai.utils.cuda import get_current_device
|
||||
from colossalai.zero import ColoInitContext, GeminiAdamOptimizer, GeminiDDP, ZeroDDP
|
||||
from colossalai.zero.gemini import search_chunk_configuration
|
||||
from tests.components_to_test.registry import non_distributed_component_funcs
|
||||
from tests.test_tensor.common_utils import set_seed, tensor_shard_equal
|
||||
from tests.test_tensor.model.test_gpt2 import init_megatron_spec
|
||||
|
||||
|
||||
def check_param(model: ZeroDDP, torch_model: torch.nn.Module, pg: ProcessGroup):
|
||||
zero_dict = model.state_dict(only_rank_0=False)
|
||||
torch_dict = torch_model.state_dict()
|
||||
|
||||
for key, value in torch_dict.items():
|
||||
# key is 'module.model.PARAMETER', so we truncate it
|
||||
key = key[7:]
|
||||
assert key in zero_dict, "{} not in ZeRO dictionary.".format(key)
|
||||
temp_zero_value = zero_dict[key].to(device=value.device, dtype=value.dtype)
|
||||
# debug_print([0], "max range: ", key, torch.max(torch.abs(value - temp_zero_value)))
|
||||
assert tensor_shard_equal(value, temp_zero_value, pg.tp_local_rank(), pg.tp_world_size()), \
|
||||
"parameter '{}' has problem.".format(key)
|
||||
|
||||
|
||||
def run_fwd_bwd(model, criterion, optimizer, input_ids):
|
||||
optimizer.zero_grad()
|
||||
logits = model(input_ids)
|
||||
logits = logits.float()
|
||||
loss = criterion(logits, input_ids)
|
||||
optimizer.backward(loss)
|
||||
return logits
|
||||
|
||||
|
||||
def init_1d_row_spec(model, pg: ProcessGroup):
|
||||
spec = (ShardSpec([0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
|
||||
for n, p in model.named_parameters():
|
||||
p.set_process_group(pg)
|
||||
if 'weight' in n and 'ln' not in n:
|
||||
p.set_tensor_spec(*spec)
|
||||
|
||||
|
||||
def init_1d_col_spec(model, pg: ProcessGroup):
|
||||
spec = (ShardSpec([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
|
||||
for n, p in model.named_parameters():
|
||||
p.set_process_group(pg)
|
||||
if 'ln' not in n and ('weight' in n or 'bias' in n):
|
||||
p.set_tensor_spec(*spec)
|
||||
|
||||
|
||||
@parameterize('placement_policy', ['cuda', 'cpu'])
|
||||
def run_gpt(placement_policy, tp_init_spec_func=None):
|
||||
set_seed(42)
|
||||
get_components_func = non_distributed_component_funcs.get_callable('gpt2')
|
||||
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
|
||||
|
||||
with ColoInitContext(device=get_current_device()):
|
||||
model = model_builder()
|
||||
model = model.cuda()
|
||||
torch_model = model_builder().cuda()
|
||||
|
||||
for torch_p, p in zip(torch_model.parameters(), model.parameters()):
|
||||
torch_p.data.copy_(p.data)
|
||||
|
||||
world_size = torch.distributed.get_world_size()
|
||||
|
||||
# world size, dp = 2, tp =2, construct a hybrid parallelism.
|
||||
if world_size == 4:
|
||||
pg = ProcessGroup(tp_degree=2)
|
||||
else:
|
||||
pg = ProcessGroup(tp_degree=world_size)
|
||||
|
||||
if tp_init_spec_func:
|
||||
tp_init_spec_func(model, pg)
|
||||
|
||||
dp_world_size = pg.dp_world_size()
|
||||
config_dict, *_ = search_chunk_configuration(model, search_range_m=1, search_interval=100)
|
||||
config_dict[dp_world_size]['chunk_size'] = 5000
|
||||
config_dict[dp_world_size]['keep_gathered'] = False
|
||||
if placement_policy != 'cuda':
|
||||
init_device = torch.device('cpu')
|
||||
else:
|
||||
init_device = None
|
||||
|
||||
model = GeminiDDP(model, init_device, placement_policy, True, False)
|
||||
# The same as the following 3 lines
|
||||
# chunk_manager = ChunkManager(config_dict, init_device=init_device)
|
||||
# gemini_manager = GeminiManager(placement_policy, chunk_manager)
|
||||
# model = ZeroDDP(model, gemini_manager, pin_memory=True)
|
||||
|
||||
zero_optim = GeminiAdamOptimizer(model, lr=1e-3, initial_scale=1)
|
||||
# The same as the following 2 lines
|
||||
# optimizer = HybridAdam(model.parameters(), lr=1e-3)
|
||||
# zero_optim = ZeroOptimizer(optimizer, model, initial_scale=1)
|
||||
|
||||
amp_config = dict(opt_level='O2', keep_batchnorm_fp32=False, loss_scale=1)
|
||||
torch_optim = torch.optim.Adam(torch_model.parameters(), lr=1e-3)
|
||||
torch_model, torch_optim = convert_to_apex_amp(torch_model, torch_optim, amp_config)
|
||||
torch_model = DDP(torch_model, device_ids=[pg.rank()], process_group=pg.dp_process_group())
|
||||
|
||||
check_param(model, torch_model, pg)
|
||||
|
||||
model.eval()
|
||||
torch_model.eval()
|
||||
|
||||
set_seed(pg.dp_local_rank())
|
||||
for i, (input_ids, label) in enumerate(train_dataloader):
|
||||
if i > 2:
|
||||
break
|
||||
input_ids_colo = ColoTensor.from_torch_tensor(input_ids, ColoTensorSpec(pg))
|
||||
zero_logits = run_fwd_bwd(model, criterion, zero_optim, input_ids_colo)
|
||||
torch_logits = run_fwd_bwd(torch_model, criterion, torch_optim, input_ids)
|
||||
assert torch.allclose(zero_logits, torch_logits, rtol=1e-3, atol=1e-2)
|
||||
|
||||
zero_optim.step()
|
||||
torch_optim.step()
|
||||
check_param(model, torch_model, pg)
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
config = {}
|
||||
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
if world_size == 4:
|
||||
run_gpt(tp_init_spec_func=init_megatron_spec)
|
||||
else:
|
||||
run_gpt(tp_init_spec_func=init_1d_col_spec)
|
||||
run_gpt(tp_init_spec_func=init_1d_row_spec)
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize('world_size', [1, 4])
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_gpt(world_size):
|
||||
spawn(run_dist, world_size)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_gpt(4)
|
|
@ -1,206 +0,0 @@
|
|||
import os
|
||||
import shutil
|
||||
from copy import deepcopy
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.optim.lr_scheduler import CosineAnnealingLR, MultiplicativeLR
|
||||
|
||||
import colossalai
|
||||
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
|
||||
from colossalai.nn.optimizer import ColossalaiOptimizer
|
||||
from colossalai.tensor import ColoTensor, ComputePattern, ComputeSpec, ProcessGroup, ShardSpec
|
||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||
from colossalai.utils.checkpoint import load_checkpoint, save_checkpoint
|
||||
from colossalai.utils.cuda import get_current_device
|
||||
from colossalai.zero import ColoInitContext
|
||||
from tests.components_to_test.registry import non_distributed_component_funcs
|
||||
|
||||
|
||||
def init_1d_row_linear(weight: ColoTensor, pg: ProcessGroup):
|
||||
spec = (ShardSpec([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
|
||||
weight.set_process_group(pg)
|
||||
weight.set_tensor_spec(*spec)
|
||||
|
||||
|
||||
def init_1d_col_linear(weight, pg):
|
||||
spec = (ShardSpec([0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
|
||||
weight.set_process_group(pg)
|
||||
weight.set_tensor_spec(*spec)
|
||||
|
||||
|
||||
def init_1d_row_embedding(weight, pg):
|
||||
spec = (ShardSpec([0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
|
||||
weight.set_process_group(pg)
|
||||
weight.set_tensor_spec(*spec)
|
||||
|
||||
|
||||
def init_1d_col_embedding(weight, pg):
|
||||
spec = (ShardSpec([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
|
||||
weight.set_process_group(pg)
|
||||
weight.set_tensor_spec(*spec)
|
||||
|
||||
|
||||
def init_1d_row_for_linear_weight_spec(model, pg: ProcessGroup):
|
||||
spec = (ShardSpec([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
|
||||
for name, p in model.named_parameters():
|
||||
if not isinstance(p, ColoTensor):
|
||||
continue
|
||||
if 'embed' in name and 'weight' in name:
|
||||
init_1d_col_embedding(p, pg)
|
||||
if 'proj1' in name and ('weight' in name or 'bias' in name):
|
||||
init_1d_col_linear(p, pg)
|
||||
if 'proj2' in name and 'weight' in name:
|
||||
init_1d_row_linear(p, pg)
|
||||
if 'classifier' in name and ('weight' in name or 'bias' in name):
|
||||
init_1d_col_linear(p, pg)
|
||||
|
||||
|
||||
def check_param_equal(model, torch_model):
|
||||
for (n, p), (tn, tp) in zip(model.named_parameters(), torch_model.named_parameters()):
|
||||
assert torch.all(p.data == tp.data), "{} went wrong.\n {} vs {}\n{}".format(n, p, tp, p.shape)
|
||||
|
||||
|
||||
def remove(path):
|
||||
""" param <path> could either be relative or absolute. """
|
||||
if os.path.isfile(path) or os.path.islink(path):
|
||||
os.remove(path)
|
||||
elif os.path.isdir(path):
|
||||
shutil.rmtree(path)
|
||||
else:
|
||||
raise ValueError("file {} is not a file or dir.".format(path))
|
||||
|
||||
|
||||
def compare_optims(optim1, optim2):
|
||||
state1 = optim1.state_dict()['state']
|
||||
state2 = optim2.state_dict()['state']
|
||||
for k, p1 in state1.items():
|
||||
if k not in state2:
|
||||
continue
|
||||
p2 = state2[k]
|
||||
for n, t1 in p1.items():
|
||||
if n not in p2:
|
||||
continue
|
||||
t2 = p2[n]
|
||||
if isinstance(t1, ColoTensor):
|
||||
assert isinstance(t2, ColoTensor)
|
||||
assert torch.allclose(t1, t2, rtol=0, atol=0)
|
||||
|
||||
|
||||
def _run_checkpoint(model_name, init_spec_func, use_ddp, use_mp_reload, test_scheduler, pg):
|
||||
get_components_func = non_distributed_component_funcs.get_callable(model_name)
|
||||
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
|
||||
|
||||
rank = torch.distributed.get_rank()
|
||||
world_size = torch.distributed.get_world_size()
|
||||
|
||||
# set_seed(1)
|
||||
with ColoInitContext(device=get_current_device()):
|
||||
model = model_builder(checkpoint=True)
|
||||
|
||||
if use_mp_reload:
|
||||
if 'bert' == model_name:
|
||||
for name, p in model.named_parameters():
|
||||
if not isinstance(p, ColoTensor):
|
||||
continue
|
||||
# num_class = type_vocab_size = 2 | (8, 2)
|
||||
if 'classifier' in name and 'weight' in name:
|
||||
init_1d_row_linear(p, pg)
|
||||
# num_class = vocab_size = 30524 | (30524, 8)
|
||||
elif 'word_embeddings' in name and 'weight' in name:
|
||||
init_1d_row_embedding(p, pg)
|
||||
# num_class = seq_len = 512 | (512, 8)
|
||||
elif 'position_embeddings' in name and 'weight' in name:
|
||||
init_1d_row_embedding(p, pg)
|
||||
# num_class = type_vocab_size = 2 | (2, 8)
|
||||
elif 'token_type_embeddings' in name and 'weight' in name:
|
||||
init_1d_col_embedding(p, pg)
|
||||
elif p.process_group.tp_world_size() == 1:
|
||||
p.set_process_group(pg)
|
||||
elif "simple_net" == model_name:
|
||||
init_spec_func(model, pg)
|
||||
|
||||
model_reload = deepcopy(model)
|
||||
model = model.cuda()
|
||||
model.eval()
|
||||
|
||||
model_reload = model_reload.cuda()
|
||||
model_reload.eval()
|
||||
|
||||
opt_class = torch.optim.Adam
|
||||
colo_optimizer = ColossalaiOptimizer(opt_class(model.parameters(), lr=0.1))
|
||||
colo_optimizer_reload = ColossalaiOptimizer(opt_class(model_reload.parameters(), lr=0.1))
|
||||
|
||||
for i, (data, label) in enumerate(train_dataloader):
|
||||
|
||||
# Zero grad
|
||||
colo_optimizer.zero_grad()
|
||||
colo_optimizer_reload.zero_grad()
|
||||
|
||||
data = data.to(get_current_device())
|
||||
label = label.to(get_current_device())
|
||||
|
||||
dist.broadcast(data, pg.tp_rank_list()[0], pg.tp_process_group())
|
||||
dist.broadcast(label, pg.tp_rank_list()[0], pg.tp_process_group())
|
||||
|
||||
# Bcast rank0 data to all processes
|
||||
if criterion:
|
||||
output = model(data)
|
||||
output_reload = model_reload(data)
|
||||
loss = criterion(output, label)
|
||||
loss_reload = criterion(output_reload, label)
|
||||
else:
|
||||
loss = model(data, label)
|
||||
loss_reload = model_reload(data, label)
|
||||
|
||||
loss.backward()
|
||||
loss_reload.backward()
|
||||
|
||||
colo_optimizer.step()
|
||||
colo_optimizer_reload.step()
|
||||
|
||||
if i > 2:
|
||||
break
|
||||
|
||||
if not os.path.isdir('./checkpoint') and rank == 0:
|
||||
os.mkdir('./checkpoint')
|
||||
dist.barrier()
|
||||
|
||||
save_checkpoint('./checkpoint', 0, model, colo_optimizer, None)
|
||||
load_checkpoint('./checkpoint', 0, model_reload, colo_optimizer_reload, None)
|
||||
|
||||
check_param_equal(model, model_reload)
|
||||
compare_optims(colo_optimizer, colo_optimizer_reload)
|
||||
|
||||
if rank == 0:
|
||||
remove('./checkpoint')
|
||||
dist.barrier()
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port, use_ddp, use_mp_reload, test_scheduler):
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
pg = ProcessGroup(tp_degree=world_size)
|
||||
|
||||
# the data loader of BERT is in DDP mode, causing the input data is not replicated in the TP context
|
||||
for model_name in ['bert']:
|
||||
_run_checkpoint(model_name,
|
||||
init_1d_row_for_linear_weight_spec,
|
||||
use_ddp,
|
||||
use_mp_reload,
|
||||
test_scheduler=test_scheduler,
|
||||
pg=pg)
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize('world_size', [1, 2])
|
||||
@pytest.mark.parametrize('use_ddp', [False])
|
||||
@pytest.mark.parametrize('use_mp_reload', [True, False])
|
||||
# @pytest.mark.parametrize('test_scheduler', ['colossalai_cosine_warmup', 'torch_cosine', 'torch_lambda'])
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_checkpoint(world_size, use_ddp, use_mp_reload, test_scheduler=None):
|
||||
spawn(run_dist, world_size, use_ddp=use_ddp, use_mp_reload=use_mp_reload, test_scheduler=test_scheduler)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_checkpoint(2, use_ddp=False, use_mp_reload=True, test_scheduler="torch_cosine")
|
|
@ -66,6 +66,7 @@ def run_dist(rank, world_size, port):
|
|||
run_grad_clip_norm(world_size=world_size)
|
||||
|
||||
|
||||
@pytest.mark.skip("this need to be updated")
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize('world_size', [1, 2])
|
||||
@rerun_if_address_is_in_use()
|
||||
|
|
|
@ -1,8 +1,9 @@
|
|||
import pytest
|
||||
import torch
|
||||
from torch.distributed.distributed_c10d import _get_default_group
|
||||
|
||||
import colossalai
|
||||
from colossalai.tensor import ColoTensor, ColoTensorSpec, ProcessGroup
|
||||
from colossalai.tensor import ColoTensor
|
||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
||||
from colossalai.zero.gemini.chunk import ChunkManager
|
||||
from tests.test_tensor.common_utils import debug_print
|
||||
|
@ -15,19 +16,18 @@ CPU_MEM = {True: {True: 0, False: 0}, False: {True: 512, False: 0}}
|
|||
@parameterize('keep_gathered', [True, False])
|
||||
@parameterize('pin_memory', [True, False])
|
||||
def exam_chunk_memory(keep_gathered, pin_memory):
|
||||
pg = ProcessGroup()
|
||||
|
||||
debug_print([0], "keep_gathered: {}, pin_memory: {}".format(keep_gathered, pin_memory))
|
||||
|
||||
params = [ColoTensor(torch.rand(8, 8), spec=ColoTensorSpec(pg)) for _ in range(3)]
|
||||
params = [ColoTensor(torch.rand(8, 8)) for _ in range(3)]
|
||||
config = {2: dict(chunk_size=128, keep_gathered=keep_gathered)}
|
||||
|
||||
chunk_manager = ChunkManager(config)
|
||||
assert chunk_manager.total_mem['cpu'] == 0
|
||||
assert chunk_manager.total_mem['cuda'] == 0
|
||||
|
||||
process_group = _get_default_group()
|
||||
for p in params:
|
||||
chunk_manager.register_tensor(p, 'param', 2, pin_memory=pin_memory)
|
||||
chunk_manager.register_tensor(p, 'param', 2, process_group, pin_memory=pin_memory)
|
||||
chunk_manager.close_all_groups()
|
||||
assert chunk_manager.total_mem['cpu'] == CPU_MEM[keep_gathered][pin_memory]
|
||||
assert chunk_manager.total_mem['cuda'] == CUDA_MEM_0[keep_gathered]
|
||||
|
|
|
@ -1,10 +1,10 @@
|
|||
import pytest
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.distributed.distributed_c10d import _get_default_group
|
||||
|
||||
import colossalai
|
||||
from colossalai.tensor import ColoParameter
|
||||
from colossalai.tensor import ProcessGroup as ColoProcessGroup
|
||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
||||
from colossalai.utils import get_current_device
|
||||
from colossalai.zero.gemini import TensorState
|
||||
|
@ -36,7 +36,7 @@ def check_equal(param, param_cp):
|
|||
@parameterize('pin_memory', [True, False])
|
||||
def exam_chunk_basic(init_device, keep_gathered, pin_memory):
|
||||
world_size = torch.distributed.get_world_size()
|
||||
pg = ColoProcessGroup()
|
||||
pg = _get_default_group()
|
||||
my_chunk = Chunk(chunk_size=1024,
|
||||
process_group=pg,
|
||||
dtype=torch.float32,
|
||||
|
|
|
@ -1,23 +1,40 @@
|
|||
import pytest
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.testing import assert_close
|
||||
|
||||
import colossalai
|
||||
from colossalai.amp import convert_to_apex_amp
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
from colossalai.tensor import ProcessGroup
|
||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
||||
from colossalai.utils.cuda import get_current_device
|
||||
from colossalai.zero import ColoInitContext, ZeroDDP, ZeroOptimizer
|
||||
from colossalai.zero.gemini.chunk import ChunkManager, search_chunk_configuration
|
||||
from colossalai.zero.gemini.gemini_mgr import GeminiManager
|
||||
from tests.components_to_test import run_fwd, run_fwd_bwd
|
||||
from colossalai.zero import GeminiDDP, GeminiOptimizer
|
||||
from colossalai.zero.gemini.chunk import search_chunk_configuration
|
||||
from tests.components_to_test import run_fwd_bwd
|
||||
from tests.components_to_test.registry import non_distributed_component_funcs
|
||||
from tests.test_tensor.common_utils import set_seed
|
||||
|
||||
PLACEMENT_CONFIGS = [
|
||||
{
|
||||
'placement_policy': 'static',
|
||||
'shard_param_frac': 0.0
|
||||
}, # zero2
|
||||
{
|
||||
'placement_policy': 'static',
|
||||
'shard_param_frac': 1.0
|
||||
}, # zero3
|
||||
{
|
||||
'placement_policy': 'static',
|
||||
'shard_param_frac': 0.5
|
||||
}, # zero3-half
|
||||
{
|
||||
'placement_policy': 'auto'
|
||||
}
|
||||
]
|
||||
|
||||
def check_grad(model: ZeroDDP, torch_model: torch.nn.Module):
|
||||
|
||||
def check_grad(model: GeminiDDP, torch_model: torch.nn.Module):
|
||||
chunk_manager = model.chunk_manager
|
||||
param_list = [p for p in model.parameters()]
|
||||
chunk_list = chunk_manager.get_chunks(param_list)
|
||||
|
@ -28,12 +45,12 @@ def check_grad(model: ZeroDDP, torch_model: torch.nn.Module):
|
|||
assert_close(p0, p1.grad, rtol=1e-3, atol=5e-5)
|
||||
|
||||
|
||||
@parameterize('placement_policy', ['cuda', 'cpu', 'auto', 'const'])
|
||||
@parameterize('placement_config', PLACEMENT_CONFIGS)
|
||||
@parameterize('keep_gather', [False, True])
|
||||
@parameterize('model_name', ['gpt2', 'bert', 'albert'])
|
||||
@parameterize('use_grad_checkpoint', [False, True])
|
||||
def exam_gpt_fwd_bwd(
|
||||
placement_policy,
|
||||
placement_config,
|
||||
keep_gather,
|
||||
model_name: str,
|
||||
use_grad_checkpoint: bool = False,
|
||||
|
@ -43,8 +60,7 @@ def exam_gpt_fwd_bwd(
|
|||
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
|
||||
|
||||
set_seed(42)
|
||||
with ColoInitContext(device=init_device):
|
||||
model = model_builder(use_grad_checkpoint)
|
||||
model = model_builder(use_grad_checkpoint)
|
||||
|
||||
set_seed(42)
|
||||
torch_model = model_builder(use_grad_checkpoint).cuda()
|
||||
|
@ -55,19 +71,17 @@ def exam_gpt_fwd_bwd(
|
|||
config_dict, *_ = search_chunk_configuration(model, search_range_m=1, search_interval=100)
|
||||
config_dict[world_size]['chunk_size'] = 5000
|
||||
config_dict[world_size]['keep_gathered'] = keep_gather
|
||||
chunk_manager = ChunkManager(config_dict)
|
||||
gemini_manager = GeminiManager(placement_policy, chunk_manager)
|
||||
model = ZeroDDP(model, gemini_manager, pin_memory=True)
|
||||
model = GeminiDDP(model, config_dict, init_device, pin_memory=True, **placement_config)
|
||||
optimizer = HybridAdam(model.parameters(), lr=1e-3)
|
||||
zero_optim = ZeroOptimizer(optimizer, model, initial_scale=1)
|
||||
zero_optim = GeminiOptimizer(optimizer, model, initial_scale=1)
|
||||
|
||||
pg = ProcessGroup()
|
||||
rank = dist.get_rank()
|
||||
amp_config = dict(opt_level='O2', keep_batchnorm_fp32=False, loss_scale=1)
|
||||
torch_optim = torch.optim.Adam(torch_model.parameters(), lr=1e-3)
|
||||
torch_model, torch_optim = convert_to_apex_amp(torch_model, torch_optim, amp_config)
|
||||
torch_model = DDP(torch_model, device_ids=[pg.rank()], process_group=pg.dp_process_group())
|
||||
torch_model = DDP(torch_model, device_ids=[rank])
|
||||
|
||||
set_seed(pg.dp_local_rank())
|
||||
set_seed(rank)
|
||||
for i, (input_ids, label) in enumerate(train_dataloader):
|
||||
# you can only test a single fwd + bwd.
|
||||
# after bwd param is grad for Gemini, due to the chunk reuse optimization.
|
||||
|
@ -89,65 +103,10 @@ def exam_gpt_fwd_bwd(
|
|||
check_grad(model, torch_model)
|
||||
|
||||
|
||||
@parameterize('placement_policy', ['cuda', 'cpu'])
|
||||
@parameterize('keep_gather', [False, True])
|
||||
@parameterize('model_name', ['gpt2', 'bert', 'albert'])
|
||||
@parameterize('scatter_after_inference', [False, True])
|
||||
def exam_gpt_inference(
|
||||
placement_policy,
|
||||
keep_gather,
|
||||
model_name: str,
|
||||
scatter_after_inference: bool = False,
|
||||
):
|
||||
init_device = get_current_device()
|
||||
get_components_func = non_distributed_component_funcs.get_callable(model_name)
|
||||
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
|
||||
|
||||
set_seed(42)
|
||||
with ColoInitContext(device=init_device):
|
||||
model = model_builder()
|
||||
|
||||
set_seed(42)
|
||||
torch_model = model_builder().cuda()
|
||||
for torch_p, p in zip(torch_model.parameters(), model.parameters()):
|
||||
torch_p.data.copy_(p.data)
|
||||
|
||||
world_size = torch.distributed.get_world_size()
|
||||
config_dict, *_ = search_chunk_configuration(model, search_range_m=1, search_interval=100)
|
||||
config_dict[world_size]['chunk_size'] = 5000
|
||||
config_dict[world_size]['keep_gathered'] = keep_gather
|
||||
chunk_manager = ChunkManager(config_dict)
|
||||
gemini_manager = GeminiManager(placement_policy, chunk_manager)
|
||||
model = ZeroDDP(model, gemini_manager, pin_memory=True, scatter_after_inference=scatter_after_inference)
|
||||
|
||||
pg = ProcessGroup()
|
||||
amp_config = dict(opt_level='O2', keep_batchnorm_fp32=False, loss_scale=1)
|
||||
torch_optim = torch.optim.Adam(torch_model.parameters(), lr=1e-3)
|
||||
torch_model, torch_optim = convert_to_apex_amp(torch_model, torch_optim, amp_config)
|
||||
torch_model = DDP(torch_model, device_ids=[pg.rank()], process_group=pg.dp_process_group())
|
||||
|
||||
set_seed(pg.dp_local_rank())
|
||||
model.eval()
|
||||
torch_model.eval()
|
||||
for i, (input_ids, label) in enumerate(train_dataloader):
|
||||
# you can only test a single fwd + bwd.
|
||||
# after bwd param is grad for Gemini, due to the chunk reuse optimization.
|
||||
if i > 0:
|
||||
break
|
||||
with torch.no_grad():
|
||||
input_ids, label = input_ids.cuda(), label.cuda()
|
||||
|
||||
torch_loss = run_fwd(torch_model, input_ids, label, criterion)
|
||||
loss = run_fwd(model, input_ids, label, criterion)
|
||||
|
||||
assert torch.equal(torch_loss, loss)
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
config = {}
|
||||
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
exam_gpt_fwd_bwd()
|
||||
exam_gpt_inference()
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
|
|
|
@ -1,12 +1,11 @@
|
|||
import pytest
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
import colossalai
|
||||
from colossalai.tensor import ProcessGroup
|
||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
||||
from colossalai.zero import ColoInitContext, ZeroDDP
|
||||
from colossalai.zero.gemini.chunk import ChunkManager, search_chunk_configuration
|
||||
from colossalai.zero.gemini.gemini_mgr import GeminiManager
|
||||
from colossalai.zero import GeminiDDP
|
||||
from colossalai.zero.gemini.chunk import search_chunk_configuration
|
||||
from colossalai.zero.gemini.memory_tracer.runtime_mem_tracer import RuntimeMemTracer
|
||||
from tests.components_to_test import run_fwd_bwd
|
||||
from tests.components_to_test.registry import non_distributed_component_funcs
|
||||
|
@ -24,8 +23,7 @@ def run_gemini_use_rmt(placement_policy, keep_gather, model_name: str, use_grad_
|
|||
get_components_func = non_distributed_component_funcs.get_callable(model_name)
|
||||
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
|
||||
|
||||
with ColoInitContext(device='cpu'):
|
||||
model = model_builder(use_grad_checkpoint)
|
||||
model = model_builder(use_grad_checkpoint).cuda()
|
||||
|
||||
print(f'model_name {model_name}')
|
||||
runtime_mem_tracer = RuntimeMemTracer(model)
|
||||
|
@ -59,12 +57,13 @@ def run_gemini_use_rmt(placement_policy, keep_gather, model_name: str, use_grad_
|
|||
config_dict, *_ = search_chunk_configuration(model, search_range_m=1, search_interval=100)
|
||||
config_dict[world_size]['chunk_size'] = 5000
|
||||
config_dict[world_size]['keep_gathered'] = keep_gather
|
||||
chunk_manager = ChunkManager(config_dict)
|
||||
gemini_manager = GeminiManager(placement_policy, chunk_manager, memstats)
|
||||
model = ZeroDDP(model, gemini_manager, pin_memory=True)
|
||||
model = GeminiDDP(model,
|
||||
chunk_config_dict=config_dict,
|
||||
placement_policy=placement_policy,
|
||||
pin_memory=True,
|
||||
memstats=memstats)
|
||||
|
||||
pg = ProcessGroup()
|
||||
set_seed(pg.dp_local_rank())
|
||||
set_seed(dist.get_rank())
|
||||
for i, (input_ids, label) in enumerate(train_dataloader):
|
||||
# you can only test a single fwd + bwd.
|
||||
# after bwd param is grad for Gemini, due to the chunk reuse optimization.
|
||||
|
@ -76,7 +75,7 @@ def run_gemini_use_rmt(placement_policy, keep_gather, model_name: str, use_grad_
|
|||
set_seed(42)
|
||||
loss = run_fwd_bwd(model, input_ids, label, criterion, model)
|
||||
|
||||
gemini_non_model_data = gemini_manager._mem_stats_collector._memstats.non_model_data_list('cuda')
|
||||
gemini_non_model_data = model.gemini_manager._mem_stats_collector._memstats.non_model_data_list('cuda')
|
||||
|
||||
# print('gemini non model data:', gemini_non_model_data)
|
||||
|
||||
|
@ -90,6 +89,7 @@ def run_dist(rank, world_size, port):
|
|||
run_gemini_use_rmt()
|
||||
|
||||
|
||||
@pytest.mark.skip("this is not used")
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize('world_size', [1, 4])
|
||||
@rerun_if_address_is_in_use()
|
||||
|
|
|
@ -1,52 +0,0 @@
|
|||
import pytest
|
||||
import torch
|
||||
|
||||
import colossalai
|
||||
from colossalai.tensor import ColoParameter
|
||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
||||
from colossalai.utils.cuda import get_current_device
|
||||
from colossalai.zero import ColoInitContext, GeminiDDP
|
||||
from colossalai.zero.gemini.utils import get_static_torch_model
|
||||
from tests.components_to_test.registry import non_distributed_component_funcs
|
||||
|
||||
|
||||
@parameterize('model_name', ['hanging_param_model', 'resnet18', 'gpt2'])
|
||||
def run_convert_torch_module(model_name: str):
|
||||
get_components_func = non_distributed_component_funcs.get_callable(model_name)
|
||||
model_builder, _, _, _, _ = get_components_func()
|
||||
|
||||
with ColoInitContext(device=torch.device("cpu")):
|
||||
model = model_builder(checkpoint=False)
|
||||
model = GeminiDDP(model, device=get_current_device(), placement_policy='auto', pin_memory=True)
|
||||
pytorch_model = get_static_torch_model(model, only_rank_0=False)
|
||||
|
||||
for n, p in pytorch_model.named_parameters():
|
||||
assert type(p) == torch.nn.Parameter, f"type error: {n} is a {type(p)}"
|
||||
|
||||
# get the static model should not change the original model
|
||||
for n, p in model.named_parameters():
|
||||
assert isinstance(p, ColoParameter)
|
||||
|
||||
for (pn, pm), (cn, cm) in zip(pytorch_model.named_modules(), model.named_modules()):
|
||||
assert pn == cn
|
||||
assert id(pm) != id(cm)
|
||||
for pp, cp in zip(pm.parameters(recurse=False), cm.parameters(recurse=False)):
|
||||
assert id(pp) != id(cp)
|
||||
assert pp.shape == cp.shape
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
config = {}
|
||||
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
run_convert_torch_module()
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize('world_size', [1, 4])
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_convert_torch_module(world_size):
|
||||
spawn(run_dist, world_size)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_convert_torch_module(2)
|
|
@ -8,16 +8,38 @@ import colossalai
|
|||
from colossalai.amp import convert_to_apex_amp
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
||||
from colossalai.utils.cuda import get_current_device
|
||||
from colossalai.zero import ColoInitContext, ZeroDDP, ZeroOptimizer
|
||||
from colossalai.zero.gemini.chunk import ChunkManager, search_chunk_configuration
|
||||
from colossalai.zero.gemini.gemini_mgr import GeminiManager
|
||||
from colossalai.zero import GeminiDDP, GeminiOptimizer
|
||||
from colossalai.zero.gemini.chunk import search_chunk_configuration
|
||||
from tests.components_to_test import run_fwd_bwd
|
||||
from tests.components_to_test.registry import non_distributed_component_funcs
|
||||
from tests.test_tensor.common_utils import set_seed
|
||||
|
||||
PLACEMENT_CONFIGS = [
|
||||
{
|
||||
'placement_policy': 'static',
|
||||
'shard_param_frac': 0.0,
|
||||
'offload_optim_frac': 0.0,
|
||||
'offload_param_frac': 0.0
|
||||
}, # zero2
|
||||
{
|
||||
'placement_policy': 'static',
|
||||
'shard_param_frac': 0.0,
|
||||
'offload_optim_frac': 1.0,
|
||||
'offload_param_frac': 0.0
|
||||
}, # zero2-offload
|
||||
{
|
||||
'placement_policy': 'static',
|
||||
'shard_param_frac': 0.0,
|
||||
'offload_optim_frac': 0.5,
|
||||
'offload_param_frac': 0.0
|
||||
}, # zero2-offload-half
|
||||
{
|
||||
'placement_policy': 'auto'
|
||||
}
|
||||
]
|
||||
|
||||
def check_param(model: ZeroDDP, torch_model: torch.nn.Module):
|
||||
|
||||
def check_param(model: GeminiDDP, torch_model: torch.nn.Module):
|
||||
zero_dict = model.state_dict(only_rank_0=False)
|
||||
torch_dict = torch_model.state_dict()
|
||||
|
||||
|
@ -30,9 +52,9 @@ def check_param(model: ZeroDDP, torch_model: torch.nn.Module):
|
|||
assert_close(value, temp_zero_value, rtol=1e-3, atol=4e-3)
|
||||
|
||||
|
||||
@parameterize('placement_policy', ['cuda', 'cpu', 'auto', 'const'])
|
||||
@parameterize('placement_config', PLACEMENT_CONFIGS)
|
||||
@parameterize('model_name', ['gpt2'])
|
||||
def exam_grad_clipping(placement_policy, model_name: str):
|
||||
def exam_grad_clipping(placement_config, model_name: str):
|
||||
set_seed(1912)
|
||||
get_components_func = non_distributed_component_funcs.get_callable(model_name)
|
||||
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
|
||||
|
@ -43,9 +65,7 @@ def exam_grad_clipping(placement_policy, model_name: str):
|
|||
torch_model, torch_optim = convert_to_apex_amp(torch_model, torch_optim, amp_config)
|
||||
torch_model = DDP(torch_model, device_ids=[dist.get_rank()])
|
||||
|
||||
init_dev = get_current_device()
|
||||
with ColoInitContext(device=init_dev):
|
||||
model = model_builder()
|
||||
model = model_builder()
|
||||
|
||||
for torch_p, p in zip(torch_model.parameters(), model.parameters()):
|
||||
p.data.copy_(torch_p.data)
|
||||
|
@ -54,16 +74,19 @@ def exam_grad_clipping(placement_policy, model_name: str):
|
|||
config_dict, *_ = search_chunk_configuration(model, search_range_m=1, search_interval=100)
|
||||
config_dict[world_size]['chunk_size'] = 5000
|
||||
config_dict[world_size]['keep_gathered'] = False
|
||||
if placement_policy != 'cuda':
|
||||
if placement_config['placement_policy'] != 'cuda':
|
||||
init_device = torch.device('cpu')
|
||||
else:
|
||||
init_device = None
|
||||
chunk_manager = ChunkManager(config_dict, init_device=init_device)
|
||||
gemini_manager = GeminiManager(placement_policy, chunk_manager)
|
||||
model = ZeroDDP(model, gemini_manager, pin_memory=True)
|
||||
|
||||
model = GeminiDDP(model,
|
||||
chunk_config_dict=config_dict,
|
||||
chunk_init_device=init_device,
|
||||
pin_memory=True,
|
||||
**placement_config)
|
||||
|
||||
optimizer = HybridAdam(model.parameters(), lr=1e-3)
|
||||
zero_optim = ZeroOptimizer(optimizer, model, initial_scale=32, clipping_norm=1.0)
|
||||
zero_optim = GeminiOptimizer(optimizer, model, initial_scale=32, clipping_norm=1.0)
|
||||
|
||||
model.train()
|
||||
torch_model.train()
|
||||
|
|
|
@ -11,15 +11,32 @@ from colossalai.amp import convert_to_apex_amp
|
|||
from colossalai.nn.optimizer import HybridAdam
|
||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
||||
from colossalai.utils.cuda import get_current_device
|
||||
from colossalai.zero import ColoInitContext, ZeroDDP, ZeroOptimizer, post_process_colo_init_ctx, zero_model_wrapper
|
||||
from colossalai.zero.gemini.chunk import ChunkManager, init_chunk_manager, search_chunk_configuration
|
||||
from colossalai.zero.gemini.gemini_mgr import GeminiManager
|
||||
from colossalai.zero import GeminiDDP, GeminiOptimizer
|
||||
from colossalai.zero.gemini.chunk import search_chunk_configuration
|
||||
from tests.components_to_test import run_fwd_bwd
|
||||
from tests.components_to_test.registry import non_distributed_component_funcs
|
||||
from tests.test_tensor.common_utils import debug_print, set_seed
|
||||
from tests.test_tensor.common_utils import set_seed
|
||||
|
||||
PLACEMENT_CONFIGS = [
|
||||
{
|
||||
'placement_policy': 'static',
|
||||
'shard_param_frac': 0.0
|
||||
}, # zero2
|
||||
{
|
||||
'placement_policy': 'static',
|
||||
'shard_param_frac': 1.0
|
||||
}, # zero3
|
||||
{
|
||||
'placement_policy': 'static',
|
||||
'shard_param_frac': 0.5
|
||||
}, # zero3-half
|
||||
{
|
||||
'placement_policy': 'auto'
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
def check_param(model: ZeroDDP, torch_model: torch.nn.Module):
|
||||
def check_param(model: GeminiDDP, torch_model: torch.nn.Module):
|
||||
zero_dict = model.state_dict(only_rank_0=False)
|
||||
torch_dict = torch_model.state_dict()
|
||||
|
||||
|
@ -32,35 +49,24 @@ def check_param(model: ZeroDDP, torch_model: torch.nn.Module):
|
|||
assert_close(value, temp_zero_value, rtol=1e-3, atol=4e-3)
|
||||
|
||||
|
||||
def multi_chunk_init(model: torch.nn.Module, placement_policy: str):
|
||||
def multi_chunk_init(model: torch.nn.Module, placement_config: dict):
|
||||
world_size = dist.get_world_size()
|
||||
config_dict, *_ = search_chunk_configuration(model, search_range_m=1, search_interval=100)
|
||||
config_dict[world_size]['chunk_size'] = 5000
|
||||
config_dict[world_size]['keep_gathered'] = False
|
||||
if placement_policy != 'cuda':
|
||||
init_device = torch.device('cpu')
|
||||
else:
|
||||
init_device = None
|
||||
chunk_manager = ChunkManager(config_dict, init_device=init_device)
|
||||
gemini_manager = GeminiManager(placement_policy, chunk_manager)
|
||||
model = ZeroDDP(model, gemini_manager, pin_memory=True)
|
||||
model = GeminiDDP(model, config_dict, pin_memory=True, **placement_config)
|
||||
return model
|
||||
|
||||
|
||||
def single_chunk_init(model: torch.nn.Module, placement_policy: str):
|
||||
gemini_config = dict(
|
||||
device=get_current_device(),
|
||||
placement_policy=placement_policy,
|
||||
pin_memory=True,
|
||||
)
|
||||
model = zero_model_wrapper(model=model, zero_stage=3, gemini_config=gemini_config)
|
||||
def single_chunk_init(model: torch.nn.Module, placement_config: dict):
|
||||
model = GeminiDDP(model, chunk_init_device=get_current_device(), pin_memory=True, **placement_config)
|
||||
return model
|
||||
|
||||
|
||||
@parameterize('placement_policy', ['cuda', 'cpu', 'auto', 'const'])
|
||||
@parameterize('placement_config', PLACEMENT_CONFIGS)
|
||||
@parameterize('model_name', ['gpt2'])
|
||||
@parameterize('model_init_func', [single_chunk_init, multi_chunk_init])
|
||||
def exam_inference(placement_policy: str, model_name: str, model_init_func: Callable):
|
||||
def exam_inference(placement_config: dict, model_name: str, model_init_func: Callable):
|
||||
set_seed(19360226)
|
||||
get_components_func = non_distributed_component_funcs.get_callable(model_name)
|
||||
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
|
||||
|
@ -70,17 +76,15 @@ def exam_inference(placement_policy: str, model_name: str, model_init_func: Call
|
|||
torch_optim = torch.optim.Adam(torch_model.parameters(), lr=1e-3)
|
||||
torch_model, torch_optim = convert_to_apex_amp(torch_model, torch_optim, amp_config)
|
||||
torch_model = DDP(torch_model, device_ids=[dist.get_rank()])
|
||||
|
||||
init_dev = get_current_device()
|
||||
with ColoInitContext(device=init_dev):
|
||||
model = model_builder()
|
||||
model = model_builder().to(init_dev)
|
||||
|
||||
for torch_p, p in zip(torch_model.parameters(), model.parameters()):
|
||||
p.data.copy_(torch_p.data)
|
||||
|
||||
model = model_init_func(model, placement_policy)
|
||||
model = model_init_func(model, placement_config)
|
||||
optimizer = HybridAdam(model.parameters(), lr=1e-3)
|
||||
zero_optim = ZeroOptimizer(optimizer, model, initial_scale=128)
|
||||
zero_optim = GeminiOptimizer(optimizer, model, initial_scale=128)
|
||||
|
||||
model.eval()
|
||||
torch_model.eval()
|
||||
|
@ -95,7 +99,7 @@ def exam_inference(placement_policy: str, model_name: str, model_init_func: Call
|
|||
torch_optim.zero_grad()
|
||||
torch_loss = run_fwd_bwd(torch_model, input_ids, label, criterion, torch_optim)
|
||||
loss = run_fwd_bwd(model, input_ids, label, criterion, zero_optim)
|
||||
assert_close(torch_loss, loss)
|
||||
assert_close(torch_loss, loss, rtol=1e-5, atol=1e-5)
|
||||
zero_optim.step()
|
||||
torch_optim.step()
|
||||
check_param(model, torch_model)
|
||||
|
|
|
@ -9,12 +9,46 @@ from colossalai.amp import convert_to_apex_amp
|
|||
from colossalai.nn.optimizer import HybridAdam
|
||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
||||
from colossalai.utils.cuda import get_current_device
|
||||
from colossalai.zero import ColoInitContext, ZeroDDP, ZeroOptimizer, post_process_colo_init_ctx
|
||||
from colossalai.zero.gemini.chunk import ChunkManager, init_chunk_manager, search_chunk_configuration
|
||||
from colossalai.zero.gemini.gemini_mgr import GeminiManager
|
||||
from colossalai.zero import GeminiDDP, GeminiOptimizer
|
||||
from colossalai.zero.gemini.chunk import search_chunk_configuration
|
||||
from tests.components_to_test import run_fwd_bwd
|
||||
from tests.components_to_test.registry import non_distributed_component_funcs
|
||||
from tests.test_tensor.common_utils import debug_print, set_seed
|
||||
from tests.test_tensor.common_utils import set_seed
|
||||
|
||||
PLACEMENT_CONFIGS = [
|
||||
{
|
||||
'placement_policy': 'static',
|
||||
'shard_param_frac': 0.0,
|
||||
'offload_optim_frac': 0.0
|
||||
}, # zero2
|
||||
{
|
||||
'placement_policy': 'static',
|
||||
'shard_param_frac': 0.0,
|
||||
'offload_optim_frac': 1.0
|
||||
}, # zero2-offload
|
||||
{
|
||||
'placement_policy': 'static',
|
||||
'shard_param_frac': 0.0,
|
||||
'offload_optim_frac': 0.5
|
||||
}, # zero2-offload-half
|
||||
{
|
||||
'placement_policy': 'static',
|
||||
'shard_param_frac': 1.0
|
||||
}, # zero3
|
||||
{
|
||||
'placement_policy': 'static',
|
||||
'shard_param_frac': 0.5
|
||||
}, # zero3-half
|
||||
{
|
||||
'placement_policy': 'static',
|
||||
'shard_param_frac': 1.0,
|
||||
'offload_optim_frac': 1.0,
|
||||
'offload_param_frac': 1.0
|
||||
}, # zero3-offload-all
|
||||
{
|
||||
'placement_policy': 'auto'
|
||||
}
|
||||
]
|
||||
|
||||
# this model is large enough to slice to chunks
|
||||
TEST_MODELS = ['gpt2']
|
||||
|
@ -29,7 +63,7 @@ BF16_IGNORED_KEYS = [
|
|||
]
|
||||
|
||||
|
||||
def check_param(model: ZeroDDP, torch_model: torch.nn.Module, dtype: torch.dtype):
|
||||
def check_param(model: GeminiDDP, torch_model: torch.nn.Module, dtype: torch.dtype):
|
||||
zero_dict = model.state_dict(only_rank_0=False, dtype=dtype)
|
||||
torch_dict = torch_model.state_dict()
|
||||
|
||||
|
@ -51,10 +85,10 @@ def check_param(model: ZeroDDP, torch_model: torch.nn.Module, dtype: torch.dtype
|
|||
msg=lambda s: s + f'\n{key}\n{temp_zero_value.dtype}')
|
||||
|
||||
|
||||
@parameterize('placement_policy', ['cuda', 'cpu', 'auto', 'const'])
|
||||
@parameterize('placement_config', PLACEMENT_CONFIGS)
|
||||
@parameterize('model_name', TEST_MODELS)
|
||||
@parameterize('mixed_precision', [torch.half, torch.bfloat16])
|
||||
def exam_model_step(placement_policy, model_name: str, mixed_precision: torch.dtype):
|
||||
def exam_model_step(placement_config, model_name: str, mixed_precision: torch.dtype):
|
||||
set_seed(42)
|
||||
get_components_func = non_distributed_component_funcs.get_callable(model_name)
|
||||
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
|
||||
|
@ -65,9 +99,7 @@ def exam_model_step(placement_policy, model_name: str, mixed_precision: torch.dt
|
|||
torch_model, torch_optim = convert_to_apex_amp(torch_model, torch_optim, amp_config)
|
||||
torch_model = DDP(torch_model, device_ids=[dist.get_rank()])
|
||||
|
||||
init_dev = get_current_device()
|
||||
with ColoInitContext(device=init_dev):
|
||||
model = model_builder()
|
||||
model = model_builder().cuda()
|
||||
|
||||
for torch_p, p in zip(torch_model.parameters(), model.parameters()):
|
||||
p.data.copy_(torch_p.data)
|
||||
|
@ -76,16 +108,10 @@ def exam_model_step(placement_policy, model_name: str, mixed_precision: torch.dt
|
|||
config_dict, *_ = search_chunk_configuration(model, search_range_m=1, search_interval=100)
|
||||
config_dict[world_size]['chunk_size'] = 5000
|
||||
config_dict[world_size]['keep_gathered'] = False
|
||||
if placement_policy != 'cuda':
|
||||
init_device = torch.device('cpu')
|
||||
else:
|
||||
init_device = None
|
||||
chunk_manager = ChunkManager(config_dict, init_device=init_device)
|
||||
gemini_manager = GeminiManager(placement_policy, chunk_manager)
|
||||
model = ZeroDDP(model, gemini_manager, pin_memory=True, mixed_precision=mixed_precision)
|
||||
model = GeminiDDP(model, config_dict, **placement_config, mixed_precision=mixed_precision)
|
||||
|
||||
optimizer = HybridAdam(model.parameters(), lr=1e-3)
|
||||
zero_optim = ZeroOptimizer(optimizer, model, initial_scale=128)
|
||||
zero_optim = GeminiOptimizer(optimizer, model, initial_scale=128)
|
||||
|
||||
model.eval()
|
||||
torch_model.eval()
|
||||
|
@ -109,10 +135,10 @@ def exam_model_step(placement_policy, model_name: str, mixed_precision: torch.dt
|
|||
check_param(model, torch_model, mixed_precision)
|
||||
|
||||
|
||||
@parameterize('placement_policy', ['cuda', 'cpu', 'auto', 'const'])
|
||||
@parameterize('placement_config', PLACEMENT_CONFIGS)
|
||||
@parameterize('model_name', EXAMPLE_MODELS)
|
||||
@parameterize('mixed_precision', [torch.half, torch.bfloat16])
|
||||
def exam_tiny_example(placement_policy, model_name: str, mixed_precision: torch.dtype):
|
||||
def exam_tiny_example(placement_config, model_name: str, mixed_precision: torch.dtype):
|
||||
set_seed(2008)
|
||||
get_components_func = non_distributed_component_funcs.get_callable(model_name)
|
||||
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
|
||||
|
@ -123,18 +149,19 @@ def exam_tiny_example(placement_policy, model_name: str, mixed_precision: torch.
|
|||
torch_model, torch_optim = convert_to_apex_amp(torch_model, torch_optim, amp_config)
|
||||
torch_model = DDP(torch_model, device_ids=[dist.get_rank()])
|
||||
|
||||
init_dev = get_current_device()
|
||||
with ColoInitContext(device=init_dev):
|
||||
model = model_builder()
|
||||
model = model_builder().cuda()
|
||||
|
||||
for torch_p, p in zip(torch_model.parameters(), model.parameters()):
|
||||
p.data.copy_(torch_p.data)
|
||||
|
||||
chunk_manager = init_chunk_manager(model=model, init_device=get_current_device(), search_range_m=1)
|
||||
gemini_manager = GeminiManager(placement_policy, chunk_manager)
|
||||
model = ZeroDDP(model, gemini_manager, pin_memory=True, mixed_precision=mixed_precision)
|
||||
model = GeminiDDP(model,
|
||||
chunk_init_device=get_current_device(),
|
||||
search_range_m=1,
|
||||
pin_memory=True,
|
||||
mixed_precision=mixed_precision,
|
||||
**placement_config)
|
||||
optimizer = HybridAdam(model.parameters(), lr=1e-3)
|
||||
zero_optim = ZeroOptimizer(optimizer, model, initial_scale=2)
|
||||
zero_optim = GeminiOptimizer(optimizer, model, initial_scale=2)
|
||||
|
||||
model.eval()
|
||||
torch_model.eval()
|
||||
|
|
|
@ -1,15 +1,16 @@
|
|||
from copy import deepcopy
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from colossalai.testing import clear_cache_before_run
|
||||
from colossalai.zero import ColoInitContext
|
||||
from colossalai.zero.gemini.memory_tracer.runtime_mem_tracer import RuntimeMemTracer
|
||||
from tests.components_to_test import run_fwd_bwd
|
||||
from tests.components_to_test.registry import non_distributed_component_funcs
|
||||
|
||||
|
||||
@pytest.mark.skip("this is not used")
|
||||
@clear_cache_before_run()
|
||||
def test_runtime_mem_tracer():
|
||||
test_models = ['gpt2', 'bert', 'simple_net', 'repeated_computed_layers', 'nested_model', 'albert']
|
||||
|
@ -18,8 +19,7 @@ def test_runtime_mem_tracer():
|
|||
get_components_func = non_distributed_component_funcs.get_callable(model_name)
|
||||
model_builder, train_dataloader, _, _, criterion = get_components_func()
|
||||
|
||||
with ColoInitContext(device='cpu'):
|
||||
model = model_builder(checkpoint=False)
|
||||
model = model_builder(checkpoint=False).cuda()
|
||||
|
||||
model_bk = deepcopy(model)
|
||||
runtime_mem_tracer = RuntimeMemTracer(model)
|
||||
|
|
|
@ -2,33 +2,20 @@ import pytest
|
|||
import torch
|
||||
|
||||
import colossalai
|
||||
from colossalai.tensor import ComputePattern, ComputeSpec, ProcessGroup, ShardSpec
|
||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||
from colossalai.utils import get_current_device
|
||||
from colossalai.zero import ColoInitContext
|
||||
from colossalai.zero.gemini.chunk import init_chunk_manager, search_chunk_configuration
|
||||
from tests.components_to_test.registry import non_distributed_component_funcs
|
||||
|
||||
|
||||
def init_1d_row_spec(model, pg: ProcessGroup):
|
||||
tensor_spec = (ShardSpec([0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
|
||||
for n, p in model.named_parameters():
|
||||
if 'weight' in n and 'ln' not in n:
|
||||
p.set_process_group(pg)
|
||||
p.set_tensor_spec(*tensor_spec)
|
||||
|
||||
|
||||
def exam_search_chunk_size():
|
||||
world_size = torch.distributed.get_world_size()
|
||||
pg_tp = ProcessGroup(tp_degree=world_size)
|
||||
|
||||
get_components_func = non_distributed_component_funcs.get_callable('gpt2')
|
||||
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
|
||||
|
||||
# make sure torch_model and model has the same parameter values
|
||||
with ColoInitContext(device=get_current_device()):
|
||||
model = model_builder()
|
||||
init_1d_row_spec(model, pg_tp)
|
||||
model = model_builder()
|
||||
config_dict, *_ = search_chunk_configuration(model,
|
||||
search_range_m=1,
|
||||
search_interval=16,
|
||||
|
@ -37,57 +24,19 @@ def exam_search_chunk_size():
|
|||
|
||||
for key in config_dict:
|
||||
chunk_size = config_dict[key]['chunk_size']
|
||||
if world_size == 1:
|
||||
if world_size == 1 or True:
|
||||
assert chunk_size == 31616
|
||||
else:
|
||||
assert chunk_size == 1024
|
||||
|
||||
|
||||
def exam_search_strict_ddp():
|
||||
world_size = torch.distributed.get_world_size()
|
||||
default_shard_pg = ProcessGroup(tp_degree=world_size)
|
||||
default_shard_spec = ShardSpec([-1], [world_size])
|
||||
|
||||
get_components_func = non_distributed_component_funcs.get_callable('gpt2')
|
||||
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
|
||||
# get the chunk configuration over replicated models
|
||||
with ColoInitContext(device=get_current_device()):
|
||||
ddp_model = model_builder()
|
||||
re_dict, re_total, re_wasted = search_chunk_configuration(ddp_model,
|
||||
search_range_m=1,
|
||||
search_interval=16,
|
||||
min_chunk_size_m=0,
|
||||
filter_exlarge_params=True,
|
||||
strict_ddp_flag=False)
|
||||
# get the chunk configuration over sharded ddp models
|
||||
with ColoInitContext(device=get_current_device(), default_pg=default_shard_pg,
|
||||
default_dist_spec=default_shard_spec):
|
||||
sharded_ddp_model = model_builder()
|
||||
sh_dict, sh_total, sh_wasted = search_chunk_configuration(sharded_ddp_model,
|
||||
search_range_m=1,
|
||||
search_interval=16,
|
||||
min_chunk_size_m=0,
|
||||
filter_exlarge_params=True,
|
||||
strict_ddp_flag=True)
|
||||
assert re_dict == sh_dict
|
||||
for key in re_dict:
|
||||
assert re_dict[key] == sh_dict[key]
|
||||
|
||||
assert re_total == sh_total
|
||||
assert re_wasted == sh_wasted
|
||||
|
||||
|
||||
def exam_chunk_manager():
|
||||
world_size = torch.distributed.get_world_size()
|
||||
default_shard_pg = ProcessGroup(tp_degree=world_size)
|
||||
default_shard_spec = ShardSpec([-1], [world_size])
|
||||
|
||||
get_components_func = non_distributed_component_funcs.get_callable('gpt2')
|
||||
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
|
||||
|
||||
with ColoInitContext(device=get_current_device(), default_pg=default_shard_pg,
|
||||
default_dist_spec=default_shard_spec):
|
||||
sharded_ddp_model = model_builder()
|
||||
sharded_ddp_model = model_builder()
|
||||
chunk_manager = init_chunk_manager(sharded_ddp_model,
|
||||
get_current_device(),
|
||||
hidden_dim=16,
|
||||
|
@ -103,7 +52,6 @@ def exam_chunk_manager():
|
|||
def run_dist(rank, world_size, port):
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
exam_search_chunk_size()
|
||||
exam_search_strict_ddp()
|
||||
exam_chunk_manager()
|
||||
|
||||
|
||||
|
|
|
@ -4,31 +4,46 @@ from torch.testing import assert_close
|
|||
|
||||
import colossalai
|
||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
||||
from colossalai.utils.cuda import get_current_device
|
||||
from colossalai.zero import ColoInitContext, ZeroDDP
|
||||
from colossalai.zero.gemini.chunk import ChunkManager, search_chunk_configuration
|
||||
from colossalai.zero.gemini.gemini_mgr import GeminiManager
|
||||
from colossalai.zero import GeminiDDP
|
||||
from colossalai.zero.gemini.chunk import search_chunk_configuration
|
||||
from tests.components_to_test.registry import non_distributed_component_funcs
|
||||
from tests.test_tensor.common_utils import debug_print, set_seed
|
||||
from tests.test_tensor.common_utils import set_seed
|
||||
|
||||
PLACEMENT_CONFIGS = [
|
||||
{
|
||||
'placement_policy': 'static',
|
||||
'shard_param_frac': 0.0
|
||||
}, # zero2
|
||||
{
|
||||
'placement_policy': 'static',
|
||||
'shard_param_frac': 1.0
|
||||
}, # zero3
|
||||
{
|
||||
'placement_policy': 'static',
|
||||
'shard_param_frac': 0.5
|
||||
}, # zero3-half
|
||||
{
|
||||
'placement_policy': 'auto'
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
def ignore_the_first_parameter(model: torch.nn.Module):
|
||||
for name, param in model.named_parameters():
|
||||
print(f"parameter `{name}` is set ignored")
|
||||
ZeroDDP.set_params_to_ignore([param])
|
||||
GeminiDDP.set_params_to_ignore([param])
|
||||
return
|
||||
|
||||
|
||||
@parameterize('placement_policy', ['cuda', 'cpu', 'auto'])
|
||||
@parameterize('placement_config', PLACEMENT_CONFIGS)
|
||||
@parameterize('keep_gathered', [True, False])
|
||||
@parameterize('model_name', ['gpt2', 'bert'])
|
||||
def exam_state_dict(placement_policy, keep_gathered, model_name: str):
|
||||
def exam_state_dict(placement_config, keep_gathered, model_name: str):
|
||||
set_seed(431)
|
||||
get_components_func = non_distributed_component_funcs.get_callable(model_name)
|
||||
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
|
||||
|
||||
with ColoInitContext(device=get_current_device()):
|
||||
model = model_builder()
|
||||
model = model_builder()
|
||||
|
||||
torch_model = model_builder()
|
||||
for torch_p, p in zip(torch_model.parameters(), model.parameters()):
|
||||
|
@ -38,9 +53,7 @@ def exam_state_dict(placement_policy, keep_gathered, model_name: str):
|
|||
config_dict, *_ = search_chunk_configuration(model, search_range_m=1, search_interval=100)
|
||||
config_dict[world_size]['chunk_size'] = 5000
|
||||
config_dict[world_size]['keep_gathered'] = keep_gathered
|
||||
chunk_manager = ChunkManager(config_dict)
|
||||
gemini_manager = GeminiManager(placement_policy, chunk_manager)
|
||||
model = ZeroDDP(model, gemini_manager, pin_memory=True)
|
||||
model = GeminiDDP(model, config_dict, **placement_config, pin_memory=True)
|
||||
model.train()
|
||||
|
||||
zero_dict = model.state_dict(only_rank_0=False)
|
||||
|
@ -52,16 +65,15 @@ def exam_state_dict(placement_policy, keep_gathered, model_name: str):
|
|||
assert_close(value, temp_zero_value, rtol=1e-3, atol=1e-5)
|
||||
|
||||
|
||||
@parameterize('placement_policy', ['cuda', 'cpu', 'auto'])
|
||||
@parameterize('placement_config', PLACEMENT_CONFIGS)
|
||||
@parameterize('keep_gathered', [True, False])
|
||||
@parameterize('model_name', ['gpt2', 'bert'])
|
||||
def exam_load_state_dict(placement_policy, keep_gathered, model_name: str):
|
||||
def exam_load_state_dict(placement_config, keep_gathered, model_name: str):
|
||||
set_seed(431)
|
||||
get_components_func = non_distributed_component_funcs.get_callable(model_name)
|
||||
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
|
||||
|
||||
with ColoInitContext(device=get_current_device()):
|
||||
model = model_builder()
|
||||
model = model_builder()
|
||||
|
||||
set_seed(451)
|
||||
torch_model = model_builder() # get a different model
|
||||
|
@ -71,13 +83,7 @@ def exam_load_state_dict(placement_policy, keep_gathered, model_name: str):
|
|||
config_dict[world_size]['chunk_size'] = 5000
|
||||
config_dict[world_size]['keep_gathered'] = keep_gathered
|
||||
|
||||
if placement_policy != 'cuda':
|
||||
init_device = torch.device('cpu')
|
||||
else:
|
||||
init_device = None
|
||||
chunk_manager = ChunkManager(config_dict, init_device=init_device)
|
||||
gemini_manager = GeminiManager(placement_policy, chunk_manager)
|
||||
model = ZeroDDP(model, gemini_manager, pin_memory=True)
|
||||
model = GeminiDDP(model, config_dict, **placement_config, pin_memory=True)
|
||||
|
||||
torch_dict = torch_model.state_dict()
|
||||
model.load_state_dict(torch_dict, strict=False)
|
||||
|
@ -89,11 +95,37 @@ def exam_load_state_dict(placement_policy, keep_gathered, model_name: str):
|
|||
assert_close(value, temp_zero_value, rtol=1e-3, atol=1e-5)
|
||||
|
||||
|
||||
@parameterize('placement_config', PLACEMENT_CONFIGS)
|
||||
@parameterize('model_name', ['gpt2', 'bert'])
|
||||
def exam_state_dict_shard(placement_config, model_name: str):
|
||||
get_components_func = non_distributed_component_funcs.get_callable(model_name)
|
||||
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
|
||||
|
||||
model = model_builder()
|
||||
|
||||
model_size = sum(p.numel() * p.element_size() for p in model.parameters()) / 1024**2
|
||||
|
||||
config_dict, *_ = search_chunk_configuration(model, search_range_m=1, search_interval=100)
|
||||
model = GeminiDDP(model, config_dict, **placement_config)
|
||||
model.train()
|
||||
|
||||
zero_dict = model.state_dict(only_rank_0=False)
|
||||
accumulated_keys = set()
|
||||
# ensure number of shards > 1
|
||||
for shard, _ in model.state_dict_shard(max_shard_size=(model_size / 3), only_rank_0=False):
|
||||
for key, value in shard.items():
|
||||
assert key not in accumulated_keys, f"key `{key}` is duplicated."
|
||||
accumulated_keys.add(key)
|
||||
assert key in zero_dict, f"{key} not in ZeRO dictionary."
|
||||
assert torch.equal(value, zero_dict[key]), f"{key} not equal."
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
config = {}
|
||||
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
exam_state_dict()
|
||||
exam_load_state_dict()
|
||||
exam_state_dict_shard()
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
|
|
|
@ -1,56 +0,0 @@
|
|||
import pytest
|
||||
import torch
|
||||
from torch.testing import assert_close
|
||||
|
||||
import colossalai
|
||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
||||
from colossalai.utils.cuda import get_current_device
|
||||
from colossalai.zero import ColoInitContext, ZeroDDP
|
||||
from colossalai.zero.gemini.chunk import ChunkManager, search_chunk_configuration
|
||||
from colossalai.zero.gemini.gemini_mgr import GeminiManager
|
||||
from tests.components_to_test.registry import non_distributed_component_funcs
|
||||
|
||||
|
||||
@parameterize('placement_policy', ['cuda', 'cpu'])
|
||||
@parameterize('model_name', ['gpt2', 'bert'])
|
||||
def exam_state_dict(placement_policy, model_name: str):
|
||||
get_components_func = non_distributed_component_funcs.get_callable(model_name)
|
||||
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
|
||||
|
||||
with ColoInitContext(device=get_current_device()):
|
||||
model = model_builder()
|
||||
|
||||
model_size = sum(p.numel() * p.element_size() for p in model.parameters()) / 1024**2
|
||||
|
||||
config_dict, *_ = search_chunk_configuration(model, search_range_m=1, search_interval=100)
|
||||
chunk_manager = ChunkManager(config_dict)
|
||||
gemini_manager = GeminiManager(placement_policy, chunk_manager)
|
||||
model = ZeroDDP(model, gemini_manager)
|
||||
model.train()
|
||||
|
||||
zero_dict = model.state_dict(only_rank_0=False)
|
||||
accumulated_keys = set()
|
||||
# ensure number of shards > 1
|
||||
for shard, _ in model.state_dict_shard(max_shard_size=(model_size / 3), only_rank_0=False):
|
||||
for key, value in shard.items():
|
||||
assert key not in accumulated_keys, f"key `{key}` is duplicated."
|
||||
accumulated_keys.add(key)
|
||||
assert key in zero_dict, f"{key} not in ZeRO dictionary."
|
||||
assert torch.equal(value, zero_dict[key]), f"{key} not equal."
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
config = {}
|
||||
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
exam_state_dict()
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize('world_size', [1, 4])
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_zero_ddp_state_dict_shard(world_size):
|
||||
spawn(run_dist, world_size)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_zero_ddp_state_dict_shard(1)
|
|
@ -5,42 +5,53 @@ import torch.distributed as dist
|
|||
import colossalai
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
||||
from colossalai.utils.cuda import get_current_device
|
||||
from colossalai.zero import ColoInitContext, ZeroDDP, ZeroOptimizer
|
||||
from colossalai.zero.gemini.chunk import ChunkManager, search_chunk_configuration
|
||||
from colossalai.zero.gemini.gemini_mgr import GeminiManager
|
||||
from colossalai.zero import GeminiDDP, GeminiOptimizer
|
||||
from colossalai.zero.gemini.chunk import search_chunk_configuration
|
||||
from tests.components_to_test.registry import non_distributed_component_funcs
|
||||
from tests.test_tensor.common_utils import debug_print, set_seed
|
||||
from tests.test_tensor.common_utils import set_seed
|
||||
|
||||
PLACEMENT_CONFIGS = [
|
||||
{
|
||||
'placement_policy': 'static',
|
||||
'shard_param_frac': 0.0,
|
||||
'offload_optim_frac': 0.0
|
||||
}, # zero2
|
||||
{
|
||||
'placement_policy': 'static',
|
||||
'shard_param_frac': 0.0,
|
||||
'offload_optim_frac': 1.0
|
||||
}, # zero2-offload
|
||||
{
|
||||
'placement_policy': 'static',
|
||||
'shard_param_frac': 0.0,
|
||||
'offload_optim_frac': 0.5
|
||||
}, # zero2-offload-half
|
||||
{
|
||||
'placement_policy': 'auto'
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
@parameterize('placement_policy', ['cuda', 'cpu', 'auto'])
|
||||
@parameterize('placement_config', PLACEMENT_CONFIGS)
|
||||
@parameterize('keep_gathered', [True, False])
|
||||
def exam_zero_optim_state_dict(placement_policy, keep_gathered):
|
||||
def exam_zero_optim_state_dict(placement_config, keep_gathered):
|
||||
set_seed(431)
|
||||
get_components_func = non_distributed_component_funcs.get_callable('gpt2')
|
||||
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
|
||||
|
||||
with ColoInitContext(device=get_current_device()):
|
||||
model = model_builder()
|
||||
model = model_builder()
|
||||
|
||||
set_seed(451)
|
||||
torch_model = model_builder() # get a different model
|
||||
|
||||
world_size = torch.distributed.get_world_size()
|
||||
config_dict, *_ = search_chunk_configuration(model, search_range_m=1, search_interval=100)
|
||||
config_dict[world_size]['chunk_size'] = 5000
|
||||
config_dict[world_size]['keep_gathered'] = keep_gathered
|
||||
|
||||
if placement_policy != 'cuda':
|
||||
init_device = torch.device('cpu')
|
||||
else:
|
||||
init_device = None
|
||||
chunk_manager = ChunkManager(config_dict, init_device=init_device)
|
||||
gemini_manager = GeminiManager(placement_policy, chunk_manager)
|
||||
model = ZeroDDP(model, gemini_manager, pin_memory=True)
|
||||
model = GeminiDDP(model, config_dict, **placement_config, pin_memory=True)
|
||||
|
||||
optimizer = HybridAdam(model.parameters())
|
||||
optim = ZeroOptimizer(optimizer, model, initial_scale=32) # initialize the link between chunk16 and chunk32
|
||||
optim = GeminiOptimizer(optimizer, model, initial_scale=32) # initialize the link between chunk16 and chunk32
|
||||
|
||||
set_seed(dist.get_rank() * 3 + 128)
|
||||
model.train()
|
||||
|
|
|
@ -1,55 +0,0 @@
|
|||
import pytest
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
|
||||
import colossalai
|
||||
from colossalai.tensor import ProcessGroup
|
||||
from colossalai.testing import spawn
|
||||
from colossalai.utils import get_current_device
|
||||
from colossalai.zero import ColoInitContext, LowLevelZeroOptimizer
|
||||
|
||||
|
||||
class MlpModel(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super(MlpModel, self).__init__()
|
||||
self.linear1 = nn.Linear(128, 256)
|
||||
self.linear2 = nn.Linear(256, 512)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.linear1(x)
|
||||
x = self.linear2(x)
|
||||
return x
|
||||
|
||||
|
||||
def exam_zero_init():
|
||||
dp_2_tp_2_pg = ProcessGroup(dp_degree=2, tp_degree=2)
|
||||
model1 = MlpModel().cuda()
|
||||
with ColoInitContext(device=get_current_device(), default_pg=dp_2_tp_2_pg):
|
||||
model2 = MlpModel()
|
||||
optimizer1 = LowLevelZeroOptimizer(torch.optim.Adam(model1.parameters(), lr=1))
|
||||
optimizer2 = LowLevelZeroOptimizer(torch.optim.Adam(model2.parameters(), lr=1))
|
||||
|
||||
assert optimizer1._local_rank == optimizer2._local_rank
|
||||
assert optimizer1._world_size == optimizer2._world_size
|
||||
|
||||
mp_group1 = optimizer1.tp_pg
|
||||
mp_group2 = optimizer2.tp_pg
|
||||
assert dist.get_world_size(mp_group1) == dist.get_world_size(mp_group2)
|
||||
assert dist.get_rank(mp_group1) == dist.get_rank(mp_group2)
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
config_dict = dict(parallel=dict(data=2, tensor=dict(size=2, mode='1d')))
|
||||
colossalai.launch(config=config_dict, rank=rank, world_size=world_size, port=port, host='localhost')
|
||||
exam_zero_init()
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
def test_zero_init():
|
||||
spawn(run_dist, 4)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_zero_init()
|
|
@ -85,6 +85,7 @@ def run_dist(rank, world_size, port):
|
|||
exam_zero_with_tp()
|
||||
|
||||
|
||||
@pytest.mark.skip('this will be rewritten by shardformer')
|
||||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_zero_with_tp():
|
||||
|
|
Loading…
Reference in New Issue