[gemini] improve compatibility and add static placement policy (#4479)

* [gemini] remove distributed-related part from colotensor (#4379)

* [gemini] remove process group dependency

* [gemini] remove tp part from colo tensor

* [gemini] patch inplace op

* [gemini] fix param op hook and update tests

* [test] remove useless tests

* [test] remove useless tests

* [misc] fix requirements

* [test] fix model zoo

* [test] fix model zoo

* [test] fix model zoo

* [test] fix model zoo

* [test] fix model zoo

* [misc] update requirements

* [gemini] refactor gemini optimizer and gemini ddp (#4398)

* [gemini] update optimizer interface

* [gemini] renaming gemini optimizer

* [gemini] refactor gemini ddp class

* [example] update gemini related example

* [example] update gemini related example

* [plugin] fix gemini plugin args

* [test] update gemini ckpt tests

* [gemini] fix checkpoint io

* [example] fix opt example requirements

* [example] fix opt example

* [example] fix opt example

* [example] fix opt example

* [gemini] add static placement policy (#4443)

* [gemini] add static placement policy

* [gemini] fix param offload

* [test] update gemini tests

* [plugin] update gemini plugin

* [plugin] update gemini plugin docstr

* [misc] fix flash attn requirement

* [test] fix gemini checkpoint io test

* [example] update resnet example result (#4457)

* [example] update bert example result (#4458)

* [doc] update gemini doc (#4468)

* [example] update gemini related examples (#4473)

* [example] update gpt example

* [example] update dreambooth example

* [example] update vit

* [example] update opt

* [example] update palm

* [example] update vit and opt benchmark

* [hotfix] fix bert in model zoo (#4480)

* [hotfix] fix bert in model zoo

* [test] remove chatglm gemini test

* [test] remove sam gemini test

* [test] remove vit gemini test

* [hotfix] fix opt tutorial example (#4497)

* [hotfix] fix opt tutorial example

* [hotfix] fix opt tutorial example
pull/4504/head
Hongxin Liu 2023-08-24 09:29:25 +08:00 committed by GitHub
parent 285fe7ba71
commit 27061426f7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
82 changed files with 1008 additions and 4036 deletions

View File

@ -1,13 +1,11 @@
import gc
import logging
import os
import warnings
from pathlib import Path
from typing import Callable, Iterator, List, Optional, Tuple, Union
from typing import Callable, Iterator, List, Optional, Tuple
import torch
import torch.nn as nn
from torch import Tensor
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
from torch.utils.data import DataLoader
@ -16,7 +14,6 @@ from colossalai.checkpoint_io import CheckpointIndexFile, CheckpointIO, GeneralC
from colossalai.checkpoint_io.utils import (
get_model_base_filenames,
get_optimizer_base_filenames,
get_shard_filename,
load_shard_state_dict,
save_state_dict,
save_state_dict_shards,
@ -24,8 +21,7 @@ from colossalai.checkpoint_io.utils import (
from colossalai.cluster import DistCoordinator
from colossalai.interface import ModelWrapper, OptimizerWrapper
from colossalai.utils import get_current_device
from colossalai.zero import GeminiDDP, zero_model_wrapper, zero_optim_wrapper
from colossalai.zero.gemini import ZeroOptimizer
from colossalai.zero import GeminiDDP, GeminiOptimizer
from colossalai.zero.gemini.memory_tracer import MemStats
from .dp_plugin_base import DPPluginBase
@ -132,11 +128,7 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
As there is communication when getting state dict, this must be called on all processes.
"""
# If optimizer is wrapped, unwrap it.
if isinstance(optimizer, OptimizerWrapper):
optimizer = optimizer.unwrap()
assert isinstance(optimizer, ZeroOptimizer)
assert isinstance(optimizer, GeminiOptimizer)
if os.path.isfile(checkpoint):
logging.error(f"Provided path ({checkpoint}) should be a directory, not a file")
@ -183,11 +175,7 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
if not os.path.isfile(checkpoint_index_file):
logging.error(f"Provided path ({checkpoint_index_file}) should be a file")
# If optimizer is wrapped, unwrap it.
if isinstance(optimizer, OptimizerWrapper):
optimizer = optimizer.unwrap()
assert isinstance(optimizer, ZeroOptimizer)
assert isinstance(optimizer, GeminiOptimizer)
# Read checkpoint index file.
ckpt_index_file = CheckpointIndexFile.from_file(checkpoint_index_file)
@ -220,47 +208,6 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
super().save_lr_scheduler(lr_scheduler, checkpoint)
class GeminiModel(ModelWrapper):
def __init__(self, module: nn.Module, gemini_config: dict, verbose: bool = False) -> None:
super().__init__(module)
self.module = zero_model_wrapper(module, zero_stage=3, gemini_config=gemini_config, verbose=verbose)
def unwrap(self):
# as save/load state dict is coupled with the GeminiDDP, we only return GeminiDDP model
return self.module
class GeminiOptimizer(OptimizerWrapper):
def __init__(self,
module: GeminiDDP,
optimizer: Optimizer,
zero_optim_config: dict,
optim_kwargs: dict,
verbose: bool = False) -> None:
optimizer = zero_optim_wrapper(module,
optimizer,
optim_config=zero_optim_config,
**optim_kwargs,
verbose=verbose)
super().__init__(optimizer)
def backward(self, loss: Tensor, *args, **kwargs):
self.optim.backward(loss)
def clip_grad_by_norm(self,
max_norm: Union[float, int],
norm_type: Union[float, int] = 2,
error_if_nonfinite: bool = False,
*args,
**kwargs) -> Tensor:
warnings.warn(f'Gemini controls grad clipping by itself, so you should not use clip_grad_by_norm')
def clip_grad_by_value(self, clip_value: float, *args, **kwargs) -> None:
raise NotImplementedError('Gemini does not support clip_grad_by_value')
class GeminiPlugin(DPPluginBase):
"""
Plugin for Gemini.
@ -277,8 +224,20 @@ class GeminiPlugin(DPPluginBase):
>>> model, optimizer, train_dataloader, criterion = booster.boost(model, optimizer, train_dataloader, criterion)
Args:
device (torch.device): device to place the model.
placement_policy (str, optional): "cpu", "cuda", "auto". Defaults to "cpu".
chunk_config_dict (dict, optional): chunk configuration dictionary.
chunk_init_device (torch.device, optional): device to initialize the chunk.
placement_policy (str, optional): "static" and "auto". Defaults to "static".
shard_param_frac (float, optional): fraction of parameters to be sharded. Only for "static" placement.
If `shard_param_frac` is 1.0, it's equal to zero-3. If `shard_param_frac` is 0.0, it's equal to zero-2. Defaults to 1.0.
offload_optim_frac (float, optional): fraction of optimizer states to be offloaded. Only for "static" placement.
If `shard_param_frac` is 1.0 and `offload_optim_frac` is 0.0, it's equal to old "cuda" placement. Defaults to 0.0.
offload_param_frac (float, optional): fraction of parameters to be offloaded. Only for "static" placement.
For efficiency, this argument is useful only when `shard_param_frac` is 1.0 and `offload_optim_frac` is 1.0.
If `shard_param_frac` is 1.0, `offload_optim_frac` is 1.0 and `offload_param_frac` is 1.0, it's equal to old "cpu" placement.
When using static placement, we recommend users to tune `shard_param_frac` first and then `offload_optim_frac`.
Defaults to 0.0.
warmup_non_model_data_ratio (float, optional): ratio of expected non-model data memory during warmup. Only for "auto" placement. Defaults to 0.8.
steady_cuda_cap_ratio (float, optional): ratio of allowed cuda capacity for model data during steady state. Only for "auto" placement. Defaults to 0.9.
precision (str, optional): precision. Support 'fp16' and 'bf16'. Defaults to 'fp16'.
pin_memory (bool, optional): use pin memory on CPU. Defaults to False.
force_outputs_fp32 (bool, optional): force outputs are fp32. Defaults to False.
@ -310,8 +269,14 @@ class GeminiPlugin(DPPluginBase):
def __init__(
self,
device: Optional[torch.device] = None,
placement_policy: str = "cpu",
chunk_config_dict: Optional[dict] = None,
chunk_init_device: Optional[torch.device] = None,
placement_policy: str = "static",
shard_param_frac: float = 1.0, # only for static placement
offload_optim_frac: float = 0.0, # only for static placement
offload_param_frac: float = 0.0, # only for static placement
warmup_non_model_data_ratio: float = 0.8, # only for auto placement
steady_cuda_cap_ratio: float = 0.9, # only for auto placement
precision: str = "fp16",
pin_memory: bool = False,
force_outputs_fp32: bool = False,
@ -335,8 +300,14 @@ class GeminiPlugin(DPPluginBase):
super().__init__()
assert precision in SUPPORTED_PRECISION, f'precision {precision} is not supported'
self.gemini_config = dict(
device=(device or get_current_device()),
chunk_config_dict=chunk_config_dict,
chunk_init_device=(chunk_init_device or get_current_device()),
placement_policy=placement_policy,
shard_param_frac=shard_param_frac,
offload_optim_frac=offload_optim_frac,
offload_param_frac=offload_param_frac,
warmup_non_model_data_ratio=warmup_non_model_data_ratio,
steady_cuda_cap_ratio=steady_cuda_cap_ratio,
pin_memory=pin_memory,
force_outputs_fp32=force_outputs_fp32,
strict_ddp_mode=strict_ddp_mode,
@ -393,12 +364,15 @@ class GeminiPlugin(DPPluginBase):
# model = nn.SyncBatchNorm.convert_sync_batchnorm(model, None)
# wrap the model with Gemini
model = GeminiModel(model, self.gemini_config, self.verbose)
model = GeminiDDP(model, **self.gemini_config, verbose=self.verbose)
if optimizer is not None and \
not isinstance(optimizer, OptimizerWrapper):
optimizer = GeminiOptimizer(model.unwrap(), optimizer, self.zero_optim_config, self.optim_kwargs,
self.verbose)
optimizer = GeminiOptimizer(optimizer,
model.unwrap(),
**self.zero_optim_config,
**self.optim_kwargs,
verbose=self.verbose)
return model, optimizer, criterion, dataloader, lr_scheduler

View File

@ -3,9 +3,15 @@ from typing import Optional
import torch
from colossalai.tensor.colo_tensor import ColoTensor
from colossalai.tensor.const import TensorType
from colossalai.tensor.param_op_hook import ColoParamOpHookManager
from colossalai.tensor.tensor_spec import ColoTensorSpec
from .colo_tensor import _convert_output
WHITE_LIST_FUNCS = {torch.Tensor.__getitem__}
def is_no_hook_op(func) -> bool:
return func.__name__.startswith('__') and func not in WHITE_LIST_FUNCS
def filter_colo_parameters(*args, **kwargs):
@ -41,53 +47,25 @@ class ColoParameter(ColoTensor, torch.nn.Parameter):
"""
def __new__(cls,
data: Optional[torch.Tensor] = None,
requires_grad: bool = True,
spec: ColoTensorSpec = None) -> 'ColoParameter':
def __new__(cls, data: Optional[torch.Tensor] = None, requires_grad: bool = True) -> 'ColoParameter':
if data is None:
data = torch.empty(0)
return torch.Tensor._make_subclass(cls, data, requires_grad)
def __init__(self,
data: Optional[torch.Tensor] = None,
requires_grad: bool = True,
spec: ColoTensorSpec = None) -> None:
ColoTensor.__init__(self, data, spec)
self._type = TensorType.MODEL
# a list contains modules sharing this ColoParameter with others.
self._shared_param_modules = []
@property
def shared_param_modules(self):
return self._shared_param_modules
@staticmethod
def from_torch_tensor(tensor: torch.Tensor,
requires_grad: bool = True,
spec: ColoTensorSpec = None) -> 'ColoParameter':
tensor = tensor.as_subclass(ColoParameter)
tensor.__init__(tensor, requires_grad=requires_grad, spec=spec)
return tensor
def __repr__(self):
return super(ColoParameter, self).__repr__()
@classmethod
def __torch_function__(cls, func, types, args=..., kwargs=None):
if ColoParamOpHookManager.has_hook():
if not func.__name__.startswith('__'):
if kwargs is None:
kwargs = {}
params = filter_colo_parameters(*args, **kwargs)
if len(params) > 0:
with torch._C.DisableTorchFunction():
new_args = ColoParamOpHookManager.pre_op(params, *args, *kwargs.values())
args, kwargs = replace_args(args, kwargs, new_args)
ret = super().__torch_function__(func, types, args, kwargs)
with torch._C.DisableTorchFunction():
ret = ColoParamOpHookManager.post_op(params, ret)
return ret
if kwargs is None:
kwargs = {}
if ColoParamOpHookManager.has_hook() and not is_no_hook_op(func):
params = filter_colo_parameters(*args, **kwargs)
if len(params) > 0:
with torch._C.DisableTorchFunction():
new_args = ColoParamOpHookManager.pre_op(params, *args, *kwargs.values())
args, kwargs = replace_args(args, kwargs, new_args)
ret = super().__torch_function__(func, types, args, kwargs)
with torch._C.DisableTorchFunction():
ret = ColoParamOpHookManager.post_op(params, ret)
return _convert_output(ret, func)
return super().__torch_function__(func, types, args, kwargs)
def __deepcopy__(self, memo):
@ -96,9 +74,7 @@ class ColoParameter(ColoTensor, torch.nn.Parameter):
else:
with torch._C.DisableTorchFunction():
data = self.data.clone()
tensor = ColoParameter(data,
self.requires_grad,
spec=ColoTensorSpec(self.get_process_group(), self.dist_spec, self.compute_spec))
tensor = ColoParameter(data, self.requires_grad)
memo[id(self)] = tensor
return tensor

View File

@ -1,17 +1,14 @@
import operator
from copy import copy
from functools import lru_cache, reduce
from typing import Callable, Optional, Set
from functools import lru_cache
from typing import Callable, Set
import torch
from colossalai.tensor.dist_spec_mgr import DistSpecManager
from colossalai.tensor.distspec import DistPlacementPattern, ReplicaSpec, _DistSpec
from colossalai.tensor.process_group import ProcessGroup
from colossalai.tensor.tensor_spec import ColoTensorSpec
from .const import TensorType
from .op_wrapper import _COLOSSAL_OPS
INPALCE_MAPPING = {
torch.Tensor.add_: torch.Tensor.add,
torch.Tensor.sub_: torch.Tensor.sub,
torch.Tensor.mul_: torch.Tensor.mul,
torch.Tensor.div_: torch.Tensor.div
}
@lru_cache(None)
@ -25,61 +22,37 @@ def _get_my_nowrap_functions() -> Set[Callable]:
}
def _convert_output(output, colo_spec: ColoTensorSpec):
if type(output) == torch.Tensor:
return ColoTensor.from_torch_tensor(output, colo_spec)
def _convert(output):
if isinstance(output, torch.Tensor) and not isinstance(output, ColoTensor):
output.__class__ = ColoTensor
elif isinstance(output, (list, tuple)):
return type(output)(_convert_output(o, colo_spec) for o in output)
else:
output = type(output)(_convert(o) for o in output)
return output
def _convert_output(output, func):
if func in _get_my_nowrap_functions():
return output
def _get_spec_from_args(args, kwargs) -> ColoTensorSpec:
for elem in args:
if isinstance(elem, ColoTensor):
pg = elem.get_process_group()
dp = elem.dist_spec
return ColoTensorSpec(pg, dp)
elif isinstance(elem, (list, tuple)):
spec = _get_spec_from_args(elem, {})
if spec is not None:
return spec
for k, v in kwargs.items():
if isinstance(v, ColoTensor):
pg = v.get_process_group()
dp = v.dist_spec
return ColoTensorSpec(pg, dp)
return None
return _convert(output)
class ColoTensor(torch.Tensor):
""" Data Structure for Tensor in Colossal-AI. It is a subclass of torch.Tensor.
The Colotensor can be initialized with a PyTorch tensor in the following ways.
>>> pg = ProcessGroup()
>>> colo_t1 = ColoTensor(torch.randn(2,3), spec = ColoTensorSpec(pg, ReplicaSpec()))
>>> # The tensor passed in is a tensor after sharding but not a global tensor.
>>> shard_spec = ShardSpec(process_group=ProcessGroup(tp=world_size),
>>> dims=[0],
>>> num_partitions=[world_size])
>>> tensor_spec = ColoTensorSpec(pg, shard_spec)
>>> colo_t2 = ColoTensor.from_torch_tensor(t_ref.clone(), tensor_spec)
It is only used to trigger the torch function hook.
Args:
data (torch.Tensor): a torch tensor used as the payload the colotensor.
spec (ColoTensorSpec, optional): the tensor spec of initialization. Defaults to ColoTensorSpec(ReplicaSpec()).
"""
torch_major = int(torch.__version__.split('.')[0])
torch_minor = int(torch.__version__.split('.')[1])
def __new__(cls, data: torch.Tensor, spec: ColoTensorSpec) -> 'ColoTensor':
def __new__(cls, data: torch.Tensor) -> 'ColoTensor':
"""
The signature of the __new__ has to be consistent with the torch.Tensor.
Args:
data (torch.Tensor): a torch tensor used as the payload the colotensor.
spec (TensorSpec, optional): the tensor spec of initialization.
Returns:
ColoTensor: a ColoTensor wrappers the data.
@ -88,86 +61,6 @@ class ColoTensor(torch.Tensor):
data = torch.empty(0)
return torch.Tensor._make_subclass(cls, data, data.requires_grad)
def __init__(self, data: torch.Tensor, spec: Optional[ColoTensorSpec] = None) -> None:
# If not set spec, use a DP process group and replicate dist spec
if spec is None:
self.has_initialized = False
self.dist_spec = ReplicaSpec()
self.compute_spec = None
self.process_group = ProcessGroup()
else:
self.has_initialized = True
self.dist_spec = spec.dist_attr
self.compute_spec = spec.compute_attr
if spec.pg is None:
self.process_group = ProcessGroup()
else:
self.process_group = spec.pg
self._type = TensorType.NONMODEL
def has_compute_spec(self) -> bool:
return self.compute_spec is not None
def is_model_data(self) -> bool:
return self._type == TensorType.MODEL
def get_process_group(self) -> 'ProcessGroup':
return self.process_group
def set_process_group(self, pg: ProcessGroup):
"""set_process_group
change the pg of the ColoTensor. Note that the valid use cases is limited.
It works for the target pg is DP and TP only and current dist spec of the Tensor is Replica.
Args:
pg (ProcessGroup): target pg
"""
assert isinstance(pg, ProcessGroup), f"pg as type {type(pg)} is invalid"
# if the new pg is the same as the old pg, just returns
if self.process_group == pg:
return
assert self.process_group.tp_world_size() == 1 or self.process_group.dp_world_size() == 1, \
"Can not set_process_group on a ColoTensor whose process_group is both tp > 1 and world group > 1"
assert self.dist_spec.placement.value == 'r', \
"Can not set_process_group on a ColoTensor whose dist spec is not Replica"
self.process_group = pg
def get_tp_world_size(self) -> int:
return self.process_group.tp_world_size()
def get_dp_world_size(self) -> int:
"""get_dp_world_size
get the dp world size of the tensor.
Returns:
int: dp world size
"""
return self.process_group.dp_world_size()
def set_dist_spec(self, dist_spec: _DistSpec):
"""set_dist_spec
set dist spec and change the payloads.
Args:
dist_spec (_DistSpec): target dist spec.
"""
assert isinstance(dist_spec, _DistSpec)
assert self.process_group is not None
self._redistribute(dist_spec)
def set_tensor_spec(self, dist_spec, compute_spec):
if dist_spec is not None:
assert isinstance(dist_spec, _DistSpec), f"{type(dist_spec)}"
self.set_dist_spec(dist_spec)
if compute_spec is not None:
self.compute_spec = compute_spec
def has_compute_pattern(self, compute_pattern):
return self.compute_spec.compute_pattern == compute_pattern
@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
if kwargs is None:
@ -175,9 +68,6 @@ class ColoTensor(torch.Tensor):
if not all(issubclass(cls, t) for t in types):
return NotImplemented
global _COLOSSAL_OPS
if func in _COLOSSAL_OPS:
func = _COLOSSAL_OPS[func]
if cls.torch_major > 1 or (cls.torch_major == 1 and cls.torch_minor >= 12):
# in order to trigger pre-op hook in the forward of checkpoint module
@ -189,94 +79,16 @@ class ColoTensor(torch.Tensor):
tensor_kwargs = {k: torch.Tensor(v) if torch.is_tensor(v) else v for k, v in kwargs.items()}
return backward_tensor.backward(**tensor_kwargs)
# replace the in-place function
if func in INPALCE_MAPPING:
func = INPALCE_MAPPING[func]
# set the 'inplace' kwargs to False
if 'inplace' in kwargs:
kwargs['inplace'] = False
with torch._C.DisableTorchFunction():
ret = func(*args, **kwargs)
if func in _get_my_nowrap_functions():
return ret
else:
colo_spec = _get_spec_from_args(args, kwargs)
return _convert_output(ret, colo_spec)
def __repr__(self):
output_list = [super(ColoTensor, self).__repr__()]
output_list.append(str(self.process_group))
output_list.append(str(self.dist_spec))
if self.compute_spec is not None:
output_list.append(str(self.compute_spec))
return "\n".join(output_list)
def _redistribute(self, dist_spec: _DistSpec) -> None:
"""_redistribute
Note the function will not handle the logic of backward propagation!
It is used during model tensor initializations as an internal function.
Args:
dist_spec (_DistSpec): the target dist. spec.
"""
assert self.grad_fn is None, "Current tensor has grad_fn and it can't get converted"
with DistSpecManager.no_grad():
self.data = DistSpecManager.handle_trans_spec(self.data, self.dist_spec, dist_spec, self.process_group)
self.dist_spec = dist_spec
def redistribute(self, dist_spec: _DistSpec, pg: Optional[ProcessGroup] = None) -> 'ColoTensor':
"""redistribute
Redistribute the tensor among processes. The rule is like this:
1. If the pg is None, then redistribute the tensor payload among the TP process group. Keep the
DP process group not changed.
2. If the pg is not not None and not equal to the current process group.
First, convert the tensor as replicated among the TP process group.
Second, reset the process group to the new pg.
Third, convert the tensor (new replicated both among the tp process group) to the new dist_spec.
Args:
dist_spec (_DistSpec): the new dist spec.
pg (Optional[ProcessGroup], optional): the new process group . Defaults to None.
Returns:
ColoTensor: a redistributed colotensor
"""
if pg is not None and pg != self.get_process_group():
# if the pg is not equal, convert the current tensor to replicated
handled = self.redistribute(ReplicaSpec())
else:
handled = self
pg = self.process_group
ret = DistSpecManager.handle_trans_spec(handled, handled.dist_spec, dist_spec, pg)
return ColoTensor.from_torch_tensor(ret, ColoTensorSpec(pg=pg, dist_attr=dist_spec))
def to_replicate_(self):
"""to_replicate_
an inline member function, converting dist spec of the tensor to REPLICATE
"""
self._redistribute(dist_spec=ReplicaSpec())
def to_replicate(self) -> 'ColoTensor':
"""to_replicate
converting dist spec of the tensor to ReplicaSpec()
"""
return self.redistribute(ReplicaSpec())
@staticmethod
def from_torch_tensor(tensor: torch.Tensor, spec: Optional[ColoTensorSpec] = None) -> 'ColoTensor':
"""from_torch_tensor
A static method builds a `ColoTensor` from a PyTorch Tensor.
Args:
tensor (torch.Tensor): the pytorch tensor, which is a local tensor for this rank not a global tensor.
spec (Optional[ColoTensorSpec], optional): tensor spec. Defaults to None.
Returns:
ColoTensor: a ColoTensor
"""
tensor = tensor.as_subclass(ColoTensor)
tensor.__init__(tensor, spec=spec)
return tensor
return _convert_output(ret, func)
def __deepcopy__(self, memo):
if id(self) in memo:
@ -284,60 +96,6 @@ class ColoTensor(torch.Tensor):
else:
with torch._C.DisableTorchFunction():
data = self.data.clone()
tensor = ColoTensor(data, spec=copy(ColoTensorSpec(self.process_group, self.dist_spec, self.compute_spec)))
tensor = ColoTensor(data)
memo[id(self)] = tensor
return tensor
# override builtin functions which must use tensor in replicate placement #
def size_local(self, *args) -> torch.Size:
with torch._C.DisableTorchFunction():
return super().size(*args)
def size_global(self, *args) -> torch.Size:
"""size_global
override the torch building size()
the shape passed in must be in a replicate placement.
Returns:
torch.Size: the global tensor shape
"""
if self.is_replicate():
return self.size_local(*args)
spec = self.dist_spec
dims = spec.dims
num_partitions = spec.num_partitions
# import inspect
# print(*['{:40}| {}:{}\n'.format(x.function, x.filename, x.lineno) for x in inspect.stack()])
size_list = list(self.size_local())
for dim, num_partition in zip(dims, num_partitions):
size_list[dim] *= num_partition
if args == ():
return torch.Size(size_list)
else:
return size_list[args[0]]
def numel_global(self):
"""Returns the number of elements in the tensor when it's replicated.
"""
return reduce(operator.mul, self.size_global(), 1)
# Some API for dist spec check
def is_replicate(self):
return self.dist_spec.placement == DistPlacementPattern.REPLICATE \
or (len(self.dist_spec.num_partitions) == 1
and self.dist_spec.num_partitions[0] == 1) \
or (self.process_group.tp_world_size() == 1)
def is_shard_1dcol(self):
return self.dist_spec.placement == DistPlacementPattern.SHARD \
and len(self.dist_spec.dims) == 1 and self.dist_spec.dims[0] == -1
def is_shard_1drow(self):
return self.dist_spec.placement == DistPlacementPattern.SHARD \
and len(self.dist_spec.dims) == 1 and self.dist_spec.dims[0] == 0
def is_sharded(self):
return self.dist_spec.placement == DistPlacementPattern.SHARD

View File

@ -3,9 +3,7 @@ from contextlib import contextmanager
from typing import Any, List, Tuple
import torch
from colossalai.tensor.colo_tensor import ColoTensor
from colossalai.tensor.tensor_spec import ColoTensorSpec
from torch.utils._pytree import TreeSpec, tree_flatten, tree_unflatten
class ColoParamOpHook(ABC):
@ -82,26 +80,18 @@ class ColoParamOpHookManager:
@staticmethod
def pre_op(params: List[torch.Tensor], *args: Any) -> list:
ColoParamOpHookManager._trigger_pre_forward(params)
grad_args, rear_args = _get_grad_args(*args)
colo_info = _get_colo_tensors_info(*grad_args)
rets = PreFwdPostBwd.apply(params, *grad_args)
update_args = _update_colo_tensors(colo_info, *rets)
if rear_args is None:
return update_args
else:
arg_zero = (tuple(update_args),)
return arg_zero + rear_args
# auto grad function can only recognize torch.Tensor, thus we have to flatten the input
# if one of the input requires grad, all the output will be treated as requires grad
# and will have grad fn even the corresponding input does not require grad
# we have to extract tensors requiring grad into flat list and then merge them back
grad_args, other_args, grad_flags, spec = _flatten_grad_args(args)
new_grad_args = PreFwdPostBwd.apply(params, *grad_args)
return _merge_args(new_grad_args, other_args, grad_flags, spec)
@staticmethod
def post_op(params: List[torch.Tensor], arg: Any) -> Any:
ColoParamOpHookManager._trigger_post_forward(params)
colo_info = _get_colo_tensors_info(arg)
ret = PostFwdPreBwd.apply(params, arg)
res = _update_colo_tensors(colo_info, ret)
if len(res) == 1:
return res[0]
else:
return res
return PostFwdPreBwd.apply(params, arg)
@staticmethod
def has_hook() -> bool:
@ -141,57 +131,24 @@ def _is_grad_tensor(obj) -> bool:
return False
def _has_grad_tensor(obj) -> bool:
if isinstance(obj, tuple) or isinstance(obj, list):
for x in obj:
if _has_grad_tensor(x):
return True
return False
elif isinstance(obj, dict):
for x in obj.values():
if _has_grad_tensor(x):
return True
return False
else:
return _is_grad_tensor(obj)
def _get_grad_args(*args):
# if there is no grad tensors, do nothing
if not _has_grad_tensor(args):
return args, None
# returns the identical args if there is a grad tensor
for obj in args:
if _is_grad_tensor(obj):
return args, None
# otherwise, the first argument should be a tuple of grad tensors
# if there is no grad tensor, the backward of PreFwdPostBwd can't be triggered
arg_zero = args[0]
if not isinstance(arg_zero, tuple):
raise NotImplementedError("Some torch function is incompatible because of its complicated inputs.")
check_grad_flag = False
for obj in arg_zero:
check_grad_flag |= _is_grad_tensor(obj)
if not check_grad_flag:
raise NotImplementedError("Some torch function is incompatible because of its complicated inputs.")
return arg_zero, args[1:]
def _get_colo_tensors_info(*args) -> list:
info = []
for arg in args:
if isinstance(arg, ColoTensor):
info.append((arg.__class__, ColoTensorSpec(arg.get_process_group(), arg.dist_spec, arg.compute_spec)))
def _flatten_grad_args(args) -> Tuple[list, list, List[bool], TreeSpec]:
flat_args, spec = tree_flatten(args)
grad_args = []
other_args = []
grad_flags = []
for arg in flat_args:
flag = _is_grad_tensor(arg)
grad_flags.append(flag)
if flag:
grad_args.append(arg)
else:
info.append(None)
return info
other_args.append(arg)
assert len(grad_args) > 0
return grad_args, other_args, grad_flags, spec
def _update_colo_tensors(info, *args) -> list:
ret = []
for t_info, arg in zip(info, args):
if t_info is not None:
t_cls, spec = t_info
arg = t_cls.from_torch_tensor(arg, spec=spec)
ret.append(arg)
return ret
def _merge_args(grad_args, other_args, grad_flags, spec):
grad_iter = iter(grad_args)
other_iter = iter(other_args)
flat_args = [next(grad_iter) if flag else next(other_iter) for flag in grad_flags]
return tree_unflatten(flat_args, spec)

View File

@ -2,8 +2,7 @@ from .gemini import (
ColoInitContext,
GeminiAdamOptimizer,
GeminiDDP,
ZeroDDP,
ZeroOptimizer,
GeminiOptimizer,
get_static_torch_model,
post_process_colo_init_ctx,
)
@ -11,6 +10,6 @@ from .low_level import LowLevelZeroOptimizer
from .wrapper import zero_model_wrapper, zero_optim_wrapper
__all__ = [
'ZeroDDP', 'GeminiDDP', 'ZeroOptimizer', 'GeminiAdamOptimizer', 'zero_model_wrapper', 'zero_optim_wrapper',
'GeminiDDP', 'GeminiOptimizer', 'GeminiAdamOptimizer', 'zero_model_wrapper', 'zero_optim_wrapper',
'LowLevelZeroOptimizer', 'ColoInitContext', 'post_process_colo_init_ctx', 'get_static_torch_model'
]

View File

@ -1,11 +1,11 @@
from .chunk import ChunkManager, TensorInfo, TensorState, search_chunk_configuration
from .colo_init_context import ColoInitContext, post_process_colo_init_ctx
from .gemini_ddp import GeminiDDP, ZeroDDP
from .gemini_ddp import GeminiDDP
from .gemini_mgr import GeminiManager
from .gemini_optimizer import GeminiAdamOptimizer, ZeroOptimizer
from .gemini_optimizer import GeminiAdamOptimizer, GeminiOptimizer
from .utils import get_static_torch_model
__all__ = [
'GeminiManager', 'TensorInfo', 'TensorState', 'ChunkManager', 'search_chunk_configuration', 'ZeroDDP', 'GeminiDDP',
'get_static_torch_model', 'GeminiAdamOptimizer', 'ZeroOptimizer', 'ColoInitContext', 'post_process_colo_init_ctx'
'GeminiManager', 'TensorInfo', 'TensorState', 'ChunkManager', 'search_chunk_configuration', 'GeminiDDP',
'get_static_torch_model', 'GeminiAdamOptimizer', 'GeminiOptimizer', 'ColoInitContext', 'post_process_colo_init_ctx'
]

View File

@ -4,8 +4,8 @@ from typing import Dict, List, Optional
import torch
import torch.distributed as dist
from torch.distributed import ProcessGroup
from colossalai.tensor import ProcessGroup as ColoProcessGroup
from colossalai.utils import get_current_device
@ -55,7 +55,7 @@ class Chunk:
def __init__(self,
chunk_size: int,
process_group: ColoProcessGroup,
process_group: ProcessGroup,
dtype: torch.dtype,
init_device: Optional[torch.device] = None,
cpu_shard_init: bool = False,
@ -69,7 +69,7 @@ class Chunk:
Args:
chunk_size (int): the number of elements in the chunk
process_group (ColoProcessGroup): the process group of this chunk
process_group (ProcessGroup): the process group of this chunk
dtype (torch.dtype): the data type of the chunk
init_device (torch.device): optional, During the chunk construction process, where the tensor is stored.
The default value is None, which is the current GPU
@ -83,7 +83,7 @@ class Chunk:
self.chunk_size = chunk_size
self.utilized_size = 0
self.torch_pg = process_group.dp_process_group()
self.torch_pg = process_group
self.pg_size = dist.get_world_size(self.torch_pg)
self.pg_rank = dist.get_rank(self.torch_pg)
@ -218,7 +218,7 @@ class Chunk:
return False
else:
return self.tensor_state_cnter[TensorState.HOLD] + \
self.tensor_state_cnter[TensorState.HOLD_AFTER_BWD] == self.num_tensors
self.tensor_state_cnter[TensorState.HOLD_AFTER_BWD] == self.num_tensors
@property
def can_reduce(self):

View File

@ -2,8 +2,9 @@ from collections import deque
from typing import Deque, Dict, Iterable, List, Optional, Set, Tuple
import torch
import torch.distributed as dist
from torch.distributed import ProcessGroup
from colossalai.tensor import ColoTensor
from colossalai.utils import get_current_device
from .chunk import Chunk, ChunkFullError, TensorState
@ -27,16 +28,17 @@ class ChunkManager:
self.dp_degree_chunk_size_dict[k] = v.pop('chunk_size')
v['init_device'] = self.device
self.chunk_groups: Dict[str, Deque] = dict()
self.chunk_groups: Dict[str, Deque[Chunk]] = dict()
self.tensor_chunk_map: Dict[torch.Tensor, Chunk] = dict()
self.accessed_chunks: Set[Chunk] = set()
self.accessed_mem: int = 0
self.total_mem: Dict[str, int] = {'cpu': 0, 'cuda': 0}
def register_tensor(self,
tensor: ColoTensor,
tensor: torch.Tensor,
group_type: str,
config_key: int,
process_group: ProcessGroup,
cpu_offload: bool = False,
pin_memory: bool = False) -> None:
"""
@ -51,7 +53,7 @@ class ChunkManager:
pin_memory: whether the chunk is pinned in the cpu memory
"""
assert tensor not in self.tensor_chunk_map
assert isinstance(tensor, ColoTensor), "Please feed ColoTensor to this ChunkManager"
assert isinstance(tensor, torch.Tensor), "Please feed Tensor to this ChunkManager"
assert config_key in self.dp_degree_chunk_size_dict
chunk_size = self.dp_degree_chunk_size_dict[config_key]
@ -73,12 +75,12 @@ class ChunkManager:
if tensor.numel() > chunk_size:
chunk_size = tensor.numel()
dp_size = tensor.get_dp_world_size()
dp_size = dist.get_world_size(process_group)
chunk_size = chunk_size + (-chunk_size % dp_size)
chunk = Chunk(
chunk_size=chunk_size,
process_group=tensor.process_group,
process_group=process_group,
dtype=tensor.dtype,
cpu_shard_init=cpu_offload,
pin_memory=pin_memory,
@ -220,7 +222,7 @@ class ChunkManager:
msg.append(f'[{i}] {chunk}\n')
return ''.join(msg)
def __get_chunk_group(self, group_name: str) -> Deque:
def __get_chunk_group(self, group_name: str) -> Deque[Chunk]:
"""Register a chunk group.
"""
if group_name not in self.chunk_groups:

View File

@ -4,6 +4,7 @@ from typing import Dict, List, Optional, Tuple
import numpy as np
import torch.distributed as dist
import torch.nn as nn
from torch.distributed import ProcessGroup
from colossalai.tensor import ColoParameter
from colossalai.utils import is_ddp_ignored
@ -59,7 +60,7 @@ def _get_unused_byte(size_list: List[int], chunk_size: int) -> int:
return left + acc
def _tensor_numel(local_param: ColoParameter, strict_ddp_flag: bool) -> int:
def _tensor_numel(local_param: ColoParameter) -> int:
"""_tensor_numel
Get the number of elements of a tensor.
@ -71,15 +72,12 @@ def _tensor_numel(local_param: ColoParameter, strict_ddp_flag: bool) -> int:
Returns:
int: the number of elements.
"""
if strict_ddp_flag and type(local_param) is ColoParameter:
return local_param.numel_global()
else:
# if local_param is not ColoParameter, we assume it's replicated
return local_param.numel()
# TODO(ver217): support dtensor here
return local_param.numel()
def classify_params_by_dp_degree(param_order: OrderedParamGenerator,
strict_ddp_flag: bool = False) -> Dict[int, List[ColoParameter]]:
process_group: ProcessGroup) -> Dict[int, List[ColoParameter]]:
"""classify_params_by_dp_degree
Classify the parameters by their dp degree
@ -97,13 +95,7 @@ def classify_params_by_dp_degree(param_order: OrderedParamGenerator,
# assert isinstance(param, ColoParameter), "please init model in the ColoInitContext"
if is_ddp_ignored(param):
continue
if strict_ddp_flag or type(param) is not ColoParameter:
# if model is not initialized with ColoInitContext, we assume it's replicated
# TODO(ver217): integrate DTensor
param_key = dist.get_world_size()
else:
param_key = param.process_group.dp_world_size()
param_key = dist.get_world_size(process_group)
if param_key not in params_dict:
params_dict[param_key] = []
@ -119,6 +111,7 @@ def search_chunk_configuration(
min_chunk_size_m: float = 32,
filter_exlarge_params: bool = True,
strict_ddp_flag: bool = False,
process_group: Optional[ProcessGroup] = None,
memstas: Optional[MemStats] = None) -> Tuple[Dict, int, int]:
"""search_chunk_configuration
@ -149,7 +142,7 @@ def search_chunk_configuration(
min_chunk_size = round(min_chunk_size_m * 1024**2)
assert search_range >= 0
params_dict = classify_params_by_dp_degree(param_order, strict_ddp_flag)
params_dict = classify_params_by_dp_degree(param_order, process_group)
size_lcm = np.lcm.reduce(list(params_dict.keys()))
config_dict: Dict[int, Dict] = dict()
total_param_size = 0
@ -157,7 +150,7 @@ def search_chunk_configuration(
size_dict: Dict[int, List[int]] = dict()
for dp_degree in params_dict:
params_list = params_dict[dp_degree]
size_list = [_tensor_numel(p, strict_ddp_flag) for p in params_list]
size_list = [_tensor_numel(p) for p in params_list]
group_acc_size = sum(size_list)
total_param_size += group_acc_size

View File

@ -2,19 +2,20 @@ import itertools
from collections import OrderedDict
from contextlib import nullcontext
from functools import partial
from typing import Dict, Iterator, List, Optional, Set, Tuple, Union
from typing import Dict, Iterable, Iterator, List, Optional, Set, Tuple, Union
import torch
import torch.distributed as dist
import torch.nn as nn
from torch.distributed import ProcessGroup
from torch.distributed.distributed_c10d import _get_default_group
from colossalai.checkpoint_io.utils import calculate_tensor_size
from colossalai.interface import ModelWrapper
from colossalai.lazy import LazyTensor
from colossalai.logging import get_dist_logger
from colossalai.nn.parallel.data_parallel import ColoDDP, _cast_float, free_storage
from colossalai.tensor import ProcessGroup as ColoProcessGroup
from colossalai.tensor import ReplicaSpec
from colossalai.tensor.colo_parameter import ColoParameter, ColoTensor, ColoTensorSpec
from colossalai.nn.parallel.data_parallel import _cast_float, free_storage
from colossalai.tensor.colo_parameter import ColoParameter
from colossalai.tensor.param_op_hook import ColoParamOpHookManager
from colossalai.utils import get_current_device, is_ddp_ignored
@ -30,14 +31,13 @@ except ImportError:
_EXTRA_STATE_KEY_SUFFIX = '_extra_state'
__all__ = [
'ZeroDDP',
'GeminiDDP',
]
class ZeroDDP(ColoDDP):
"""ZeRO DDP for ColoTensor.
Warning: Nested ZeroDDP is not supported now.
class GeminiDDP(ModelWrapper):
"""ZeRO DDP.
Warning: Nested GeminiDDP is not supported now.
It is designed to be used with ChunkManager and GeminiManager.
For more details, see the API reference of ``ChunkManager`` and ``GeminiManager``.
@ -54,20 +54,54 @@ class ZeroDDP(ColoDDP):
mixed_precision (torch.dtype): If set to torch.float16, the model will be trained in fp16. Otherwise, the model will be trained in bf16. Defaults to torch.float16.
"""
def __init__(self,
module: torch.nn.Module,
gemini_manager: GeminiManager,
pin_memory: bool = False,
force_outputs_fp32: bool = False,
strict_ddp_mode: bool = False,
scatter_after_inference: bool = True,
mixed_precision: torch.dtype = torch.float16) -> None:
def __init__(
self,
module: torch.nn.Module,
chunk_config_dict: Optional[dict] = None,
chunk_init_device: torch.device = torch.device('cpu'),
placement_policy: str = "static",
shard_param_frac: float = 1.0, # only for static placement
offload_optim_frac: float = 0.0, # only for static placement
offload_param_frac: float = 0.0, # only for static placement
warmup_non_model_data_ratio: float = 0.8, # only for auto placement
steady_cuda_cap_ratio: float = 0.9, # only for auto placement
search_range_m: int = 32, # chunk search options
hidden_dim: Optional[int] = None, # chunk search options
min_chunk_size_m: float = 32, # chunk search options
pin_memory: bool = False,
force_outputs_fp32: bool = False,
strict_ddp_mode: bool = False,
scatter_after_inference: bool = True,
mixed_precision: torch.dtype = torch.float16,
process_group: Optional[ProcessGroup] = None,
memstats: Optional[MemStats] = None, # genimi memory stats
verbose: bool = False) -> None:
assert mixed_precision in (torch.float16, torch.bfloat16)
self.gemini_manager = gemini_manager
self.chunk_manager: ChunkManager = gemini_manager.chunk_manager
if chunk_config_dict is not None:
self.chunk_manager = ChunkManager(chunk_config_dict, chunk_init_device)
else:
# some ugly hotfix for the compatibility with Lightning
if search_range_m is None:
search_range_m = 32
self.chunk_manager = init_chunk_manager(model=module,
init_device=chunk_init_device,
hidden_dim=hidden_dim,
search_range_m=search_range_m,
min_chunk_size_m=min_chunk_size_m,
strict_ddp_flag=strict_ddp_mode,
process_group=process_group,
verbose=verbose)
self.gemini_manager = GeminiManager(placement_policy,
self.chunk_manager,
memstats,
shard_param_frac=shard_param_frac,
offload_optim_frac=offload_optim_frac,
offload_param_frac=offload_param_frac,
warmup_non_model_data_ratio=warmup_non_model_data_ratio,
steady_cuda_cap_ratio=steady_cuda_cap_ratio)
self.force_outputs_fp32 = force_outputs_fp32
self.param_op_hook = GeminiZeROHook(gemini_manager)
self.fp32_params: List[ColoTensor] = list()
self.param_op_hook = GeminiZeROHook(self.gemini_manager)
self.fp32_params: List[torch.Tensor] = list()
self.fp16_params: List[ColoParameter] = list()
self.overflow_counter = 0
self.grads_device: Dict[torch.Tensor, torch.device] = dict()
@ -75,6 +109,7 @@ class ZeroDDP(ColoDDP):
self.name2param: Dict[str, nn.Parameter] = dict()
self.scatter_after_inference = scatter_after_inference
self.mixed_precision = mixed_precision
self.dp_process_group = process_group or _get_default_group()
self._logger = get_dist_logger()
@ -88,20 +123,67 @@ class ZeroDDP(ColoDDP):
for p in module.parameters():
param_order.append(p)
self._init_chunks(param_order=param_order,
strict_ddp_mode=strict_ddp_mode,
cpu_offload=self.gemini_manager.policy_name != 'cuda',
pin_memory=pin_memory)
for name, param in module.named_parameters():
self.param2name[param] = name
for m_name, m_var in module.named_modules():
for p_name, p_var in m_var.named_parameters(recurse=False):
param_name = m_name + '.' + p_name if m_name else p_name
self.name2param[param_name] = p_var
super().__init__(module, process_group=ColoProcessGroup())
self._init_chunks(param_order=param_order,
strict_ddp_mode=strict_ddp_mode,
cpu_offload=self.gemini_manager.policy_name != 'cuda',
pin_memory=pin_memory)
super().__init__(module)
self._non_persistent_buffers_set = self._get_non_persistent_buffers_set(module)
self._cast_buffers()
# register grad hook
for p in module.parameters():
if is_ddp_ignored(p):
continue
if p.requires_grad:
p.register_hook(partial(self.grad_handle, p))
def parameters(self, recurse: bool = True):
return self.module.parameters(recurse)
def named_parameters(self, prefix: str = '', recurse: bool = True):
return self.module.named_parameters(prefix, recurse)
def named_buffers(self, prefix: str = '', recurse: bool = True):
return self.module.named_buffers(prefix, recurse)
def named_children(self):
return self.module.named_children()
def named_modules(self,
memo: Optional[Set[torch.nn.Module]] = None,
prefix: str = '',
remove_duplicate: bool = True):
return self.module.named_modules(memo, prefix, remove_duplicate)
@staticmethod
def set_params_to_ignore(params_to_ignore: Iterable[torch.Tensor]) -> None:
"""Sets parameters to be ignored by DDP.
This method must be called before initializing ColoDDP.
Example:
>>> params_to_ignore = []
>>> for p in module.parameters():
>>> if should_ignore(p):
>>> params_to_ignore.append(p)
>>> ColoDDP.set_params_to_ignore(params_to_ignore)
>>> module = ColoDDP(module)
Args:
params_to_ignore (Iterable[torch.Tensor]): A list of parameters to be ignored.
"""
for p in params_to_ignore:
p._ddp_to_ignore = True
def unwrap(self):
# as save/load state dict is overwrited, only return self
return self
def _get_non_persistent_buffers_set(self,
module,
@ -207,7 +289,7 @@ class ZeroDDP(ColoDDP):
error_params.append(self.param2name[param])
error_str = "\n\t".join(error_params)
raise RuntimeError("ZERO DDP error: the synchronization of gradients doesn't exit properly.",
"The most possible reason is that the model is not compatible with ZeroDDP.\n",
"The most possible reason is that the model is not compatible with GeminiDDP.\n",
f"{error_str}")
self._setup_grads_ptr()
self._logger.debug(
@ -227,6 +309,7 @@ class ZeroDDP(ColoDDP):
self._post_backward()
def grad_handle(self, p, grad):
setattr(p, "_gemini_reduced", True)
empty_grad = torch.empty_like(grad)
free_storage(empty_grad)
with torch._C.DisableTorchFunction():
@ -533,7 +616,7 @@ class ZeroDDP(ColoDDP):
for chunk_32 in chunk_list:
chunk_16 = chunk_32.paired_chunk
assert chunk_16 is not None
chunk_16.optim_update()
chunk_16.payload.copy_(chunk_32.payload)
for name, buf in persistent_buffers.items():
if buf is not None:
@ -557,17 +640,11 @@ class ZeroDDP(ColoDDP):
unexpected_keys.append(key)
def _init_chunks(self, param_order, strict_ddp_mode: bool, cpu_offload: bool, pin_memory: bool):
ddp_pg = ColoProcessGroup()
dp_world_size = dist.get_world_size(self.dp_process_group)
for p in param_order.generate():
self._preprocess_param(p)
assert type(p) is ColoParameter
# gather sharded parameters in the strict ddp mode
if strict_ddp_mode:
if not p.is_replicate():
p.set_dist_spec(ReplicaSpec())
p.set_process_group(pg=ddp_pg)
# ignore the parameters with no gradient
if not p.requires_grad:
self.set_params_to_ignore([p])
@ -578,38 +655,37 @@ class ZeroDDP(ColoDDP):
continue
# create a fp32 parameter
fp32_data = p.data.float()
fp32_p = ColoTensor(fp32_data, spec=ColoTensorSpec(p.process_group))
fp32_p = p.data.float()
# create a fp16 parameter
p.data = p.data.to(self.mixed_precision)
# register the fp16 parameter and fp32 parameter in the chunk manager
dp_world_size = p.process_group.dp_world_size()
self.chunk_manager.register_tensor(tensor=p,
group_type='fp16_param',
config_key=dp_world_size,
process_group=self.dp_process_group,
cpu_offload=cpu_offload,
pin_memory=pin_memory)
self.chunk_manager.register_tensor(tensor=fp32_p,
group_type='fp32_param',
config_key=dp_world_size,
process_group=self.dp_process_group,
cpu_offload=cpu_offload,
pin_memory=pin_memory)
self.fp16_params.append(p)
self.fp32_params.append(fp32_p)
self.grads_device[p] = self.gemini_manager.default_device
self.chunk_manager.close_all_groups()
self.gemini_manager.setup_grads_device(self.fp16_params, self.grads_device)
# move master weights to corresponding device and setup paired chunks
for p, fp32_p in zip(self.fp16_params, self.fp32_params):
chunk_16 = self.chunk_manager.get_chunk(p)
chunk_32 = self.chunk_manager.get_chunk(fp32_p)
chunk_32.init_pair(chunk_16)
# keep gathered chunks are in CUDA
if chunk_16.keep_gathered:
self.grads_device[p] = get_current_device()
if chunk_32.device_type != self.grads_device[p].type:
self.chunk_manager.move_chunk(chunk_32, self.grads_device[p])
def _cast_buffers(self):
for buffer in self.module.buffers():
@ -727,67 +803,3 @@ class _StateDictSharder:
self.current_block[name] = tensor
self.current_block_size += tensor_size
return ret_block, ret_block_size
class GeminiDDP(ZeroDDP):
def __init__(self,
module: torch.nn.Module,
device: torch.device,
placement_policy: str = "cpu",
pin_memory: bool = False,
force_outputs_fp32: bool = False,
strict_ddp_mode: bool = False,
scatter_after_inference: bool = True,
search_range_m: int = 32,
hidden_dim: Optional[int] = None,
min_chunk_size_m: float = 32,
memstats: Optional[MemStats] = None,
mixed_precision: torch.dtype = torch.float16,
verbose: bool = False) -> None:
"""
A torch.Module wrapper using ZeRO-DP and Gemini.
ZeRO is for parallel. Gemini is for memory management.
WARNING: The class will modify the module inline!
Example:
model is initialized under the context of ColoInitContext
>>> model = GeminiDDP(model, torch.cuda.current_device(), "cuda")
>>> logits = model(x)
>>> loss = criterion(logits, labels)
>>> model.backward(loss)
Args:
module (torch.nn.Module): the model to be wrapped.
device (torch.device): device to place the model.
placement_policy (str, optional): "cpu", "cuda", "auto". Defaults to "cpu".
pin_memory (bool, optional): use pin memory on CPU. Defaults to False.
force_outputs_fp32 (bool, optional): force outputs are fp32. Defaults to False.
search_range_m (int, optional): chunk size searching range divided by 2^20. Defaults to 32.
hidden_dim (int, optional): the hidden dimension of DNN.
Users can provide this argument to speed up searching.
If users do not know this argument before training, it is ok. We will use a default value 1024.
min_chunk_size_m (float, optional): the minimum chunk size divided by 2^20.
If the aggregate size of parameters is still smaller than the minimum chunk size,
all parameters will be compacted into one small chunk.
memstats (MemStats, optional) the memory statistics collector by a runtime memory tracer.
"""
# some ugly hotfix for the compatibility with Lightning
if search_range_m is None:
search_range_m = 32
chunk_manager = init_chunk_manager(model=module,
init_device=device,
hidden_dim=hidden_dim,
search_range_m=search_range_m,
min_chunk_size_m=min_chunk_size_m,
strict_ddp_flag=strict_ddp_mode,
verbose=verbose)
gemini_manager = GeminiManager(placement_policy, chunk_manager, memstats)
super().__init__(module,
gemini_manager,
pin_memory,
force_outputs_fp32,
strict_ddp_mode,
scatter_after_inference,
mixed_precision=mixed_precision)

View File

@ -1,6 +1,6 @@
import functools
from time import time
from typing import List, Optional, Tuple
from typing import Dict, List, Optional, Tuple
import torch
@ -26,7 +26,11 @@ class GeminiManager:
memstats (MemStats, optional): a mem stats collected by a runtime mem tracer. if None then GeminiManager will collect it during a warmup iteration.
"""
def __init__(self, placement_policy: str, chunk_manager: ChunkManager, memstats: Optional[MemStats] = None) -> None:
def __init__(self,
placement_policy: str,
chunk_manager: ChunkManager,
memstats: Optional[MemStats] = None,
**placement_kwargs) -> None:
assert placement_policy in PlacementPolicyFactory.get_policy_names()
self.policy_name = placement_policy
@ -37,7 +41,7 @@ class GeminiManager:
self._memstats = memstats
self._mem_stats_collector = ChunkMemStatsCollector(chunk_manager,
self._memstats) if policy_cls.need_mem_stats else None
self._placement_policy = policy_cls(chunk_manager, self._mem_stats_collector)
self._placement_policy = policy_cls(chunk_manager, self._mem_stats_collector, **placement_kwargs)
self._compute_list: List[Tuple[Chunk, ...]] = []
self._compute_idx: int = -1
@ -133,10 +137,6 @@ class GeminiManager:
if self._warmup and self._placement_policy.need_mem_stats:
self._compute_list.append(chunks)
@property
def default_device(self):
return self._placement_policy.get_default_device()
def sample_overall_data(self):
if self._mem_stats_collector:
self._mem_stats_collector.sample_overall_data()
@ -159,6 +159,6 @@ class GeminiManager:
def is_cuda_margin_mem_avail(self) -> bool:
return self._placement_policy.need_mem_stats
@staticmethod
def get_default_device(policy_name: str) -> torch.device:
return PlacementPolicyFactory.get_default_device(policy_name)
def setup_grads_device(self, params: List[torch.Tensor], grads_device_map: Dict[torch.Tensor,
torch.device]) -> None:
self._placement_policy.setup_grads_device(params, grads_device_map)

View File

@ -2,7 +2,7 @@
import copy
import math
import warnings
from typing import Any, Dict, Iterator, OrderedDict, Set, Tuple
from typing import Any, Dict, Iterator, OrderedDict, Set, Tuple, Union
import torch
import torch.distributed as dist
@ -11,15 +11,16 @@ from torch.optim import Optimizer
from colossalai.amp.naive_amp.mixed_precision_mixin import BF16MixedPrecisionMixin, FP16MixedPrecisionMixin
from colossalai.checkpoint_io.utils import calculate_tensor_size
from colossalai.interface import OptimizerWrapper
from colossalai.logging import get_dist_logger
from colossalai.nn.optimizer import ColossalaiOptimizer, CPUAdam, FusedAdam, HybridAdam
from colossalai.nn.optimizer import CPUAdam, FusedAdam, HybridAdam
from colossalai.tensor.d_tensor import is_distributed_tensor
from colossalai.utils import disposable, get_current_device, is_ddp_ignored
from .chunk import Chunk, ChunkManager
from .gemini_ddp import ZeroDDP
from .gemini_ddp import GeminiDDP
__all__ = ['ZeroOptimizer', 'GeminiAdamOptimizer']
__all__ = ['GeminiOptimizer', 'GeminiAdamOptimizer']
_AVAIL_OPTIM_LIST = {FusedAdam, CPUAdam, HybridAdam}
@ -27,7 +28,7 @@ _AVAIL_OPTIM_LIST = {FusedAdam, CPUAdam, HybridAdam}
class GeminiFP16MixedPrecisionMixin(FP16MixedPrecisionMixin):
def __init__(self,
module: ZeroDDP,
module: GeminiDDP,
initial_scale: float = 2**16,
min_scale: float = 1,
growth_factor: float = 2,
@ -46,11 +47,11 @@ class GeminiFP16MixedPrecisionMixin(FP16MixedPrecisionMixin):
self.module.overflow_counter = 0
class ZeroOptimizer(ColossalaiOptimizer):
"""A wrapper for optimizer. ``ZeroDDP`` and ``ZeroOptimizer`` implement Zero Redundancy Optimizer (ZeRO state-3).
class GeminiOptimizer(OptimizerWrapper):
"""A wrapper for optimizer. ``GeminiDDP`` and ``GeminiOptimizer`` implement Zero Redundancy Optimizer (ZeRO state-3).
Note:
You must use ``ZeroDDP`` with ``ZeroOptimizer``.
You must use ``GeminiDDP`` with ``GeminiOptimizer``.
Note:
Make sure you set ``placement_policy`` of ``GeminiManager`` to `"auto"`,
@ -58,7 +59,7 @@ class ZeroOptimizer(ColossalaiOptimizer):
Args:
optim (Optimizer): An Optimizer instance.
module (ZeroDDP): A ``ZeroDDP`` instance.
module (GeminiDDP): A ``GeminiDDP`` instance.
gpu_margin_mem_ratio (float, optional): The ratio of GPU remaining memory (after the first forward-backward)
which will be used when using hybrid CPU optimizer.
This argument is meaningless when `placement_policy` of `GeminiManager` is not "auto".
@ -70,15 +71,15 @@ class ZeroOptimizer(ColossalaiOptimizer):
growth_interval (float, optional): Growth_interval used by DynamicGradScaler. Defaults to 1000.
hysteresis (float, optional): Hysteresis used by DynamicGradScaler. Defaults to 2.
max_scale (int, optional): Max_scale used by DynamicGradScaler. Defaults to 2**32.
clipping_norm (float, optional): The norm value used to clip gradient. Defaults to 0.0.
max_norm (float, optional): The norm value used to clip gradient. Defaults to 0.0.
norm_type (float, optional): The type of norm used for gradient clipping. Currently, only L2-norm (norm_type=2.0)
is supported in ZeroOptimizer. Defaults to 2.0.
is supported in GeminiOptimizer. Defaults to 2.0.
verbose (bool, optional): Whether to print verbose information, including grad overflow info. Defaults to False.
"""
def __init__(self,
optim: Optimizer,
module: ZeroDDP,
module: GeminiDDP,
gpu_margin_mem_ratio: float = 0.0,
initial_scale: float = 2**32,
min_scale: float = 1,
@ -87,12 +88,12 @@ class ZeroOptimizer(ColossalaiOptimizer):
growth_interval: int = 1000,
hysteresis: int = 2,
max_scale: float = 2**32,
clipping_norm: float = 0.0,
max_norm: float = 0.0,
norm_type: float = 2.0,
verbose: bool = False,
**defaults: Any):
super().__init__(optim)
assert isinstance(module, ZeroDDP)
assert isinstance(module, GeminiDDP)
assert type(optim) in _AVAIL_OPTIM_LIST, "You should use an optimizer in the available list:\n" \
f"{_AVAIL_OPTIM_LIST}"
self.module = module
@ -101,8 +102,8 @@ class ZeroOptimizer(ColossalaiOptimizer):
self.param_to_range: Dict[Parameter, Tuple[int, int]] = dict()
self.param_to_chunk32: Dict[Parameter, Chunk] = dict()
self.chunk16_set: Set[Chunk] = set()
self.clipping_flag = clipping_norm > 0.0
self.max_norm = clipping_norm
self.clipping_flag = max_norm > 0.0
self.max_norm = max_norm
self.verbose = verbose
self.param_groups_backup = list()
@ -111,7 +112,7 @@ class ZeroOptimizer(ColossalaiOptimizer):
self.id_to_fake_params: Dict[int, Parameter] = dict()
if self.clipping_flag:
assert norm_type == 2.0, "ZeroOptimizer only supports L2 norm now"
assert norm_type == 2.0, "GeminiOptimizer only supports L2 norm now"
ddp_param_list = []
for name, param in module.named_parameters():
@ -735,8 +736,19 @@ class ZeroOptimizer(ColossalaiOptimizer):
yield current_block, current_block_size
def clip_grad_by_value(self, clip_value: float, *args, **kwargs) -> None:
raise NotImplementedError('Gemini does not support clip_grad_by_value')
class GeminiAdamOptimizer(ZeroOptimizer):
def clip_grad_by_norm(self,
max_norm: Union[float, int],
norm_type: Union[float, int] = 2,
error_if_nonfinite: bool = False,
*args,
**kwargs) -> torch.Tensor:
warnings.warn(f'Gemini controls grad clipping by itself, so you should not use clip_grad_by_norm')
class GeminiAdamOptimizer(GeminiOptimizer):
def __init__(self, model: torch.nn.Module, **defaults: Any) -> None:
optimizer = HybridAdam(model.parameters(), **defaults)

View File

@ -9,7 +9,7 @@ class MemStats(object):
def __init__(self) -> None:
"""
Store the non model data statistics used for Gemini and ZeroOptimizer.
Store the non model data statistics used for Gemini and GeminiOptimizer.
"""
# (preop_step, List[param])
self._step_param_dict = dict()

View File

@ -1,4 +1,5 @@
import functools
import warnings
from abc import ABC, abstractmethod
from time import time
from typing import Dict, List, Optional, Tuple, Type
@ -7,6 +8,7 @@ import torch
from colossalai.utils import get_current_device
from colossalai.utils.memory import colo_device_memory_capacity
from colossalai.zero.gemini.chunk import Chunk
from .chunk import Chunk, ChunkManager
from .memory_tracer import ChunkMemStatsCollector
@ -17,7 +19,8 @@ class PlacementPolicy(ABC):
def __init__(self,
chunk_manager: ChunkManager,
mem_stats_collector: Optional[ChunkMemStatsCollector] = None) -> None:
mem_stats_collector: Optional[ChunkMemStatsCollector] = None,
**kwargs) -> None:
self.chunk_manager = chunk_manager
self.mem_stats_collector: Optional[ChunkMemStatsCollector] = mem_stats_collector
@ -25,57 +28,87 @@ class PlacementPolicy(ABC):
def evict_tensors(self, can_evict_chunks: List[Chunk], **kwargs) -> Tuple[int, float]:
raise NotImplementedError
@staticmethod
def get_default_device() -> torch.device:
return torch.device('cpu')
@abstractmethod
def setup_grads_device(self, params: List[torch.Tensor], grads_device_map: Dict[torch.Tensor,
torch.device]) -> None:
raise NotImplementedError
class CPUPlacementPolicy(PlacementPolicy):
class StaticPlacementPolicy(PlacementPolicy):
def __init__(self,
chunk_manager: ChunkManager,
mem_stats_collector: Optional[ChunkMemStatsCollector] = None) -> None:
mem_stats_collector: Optional[ChunkMemStatsCollector] = None,
shard_param_frac: float = 1.0,
offload_optim_frac: float = 0.0,
offload_param_frac: float = 0.0,
**kwargs) -> None:
super().__init__(chunk_manager, mem_stats_collector=mem_stats_collector)
if offload_param_frac > 0.0 and (shard_param_frac != 1.0 or offload_optim_frac != 1.0):
warnings.warn('offload_param_frac is ignored when shard_param_frac != 1.0 or offload_optim_frac != 1.0')
offload_param_frac = 0.0
self.shard_param_frac = shard_param_frac
self.offload_optim_frac = offload_optim_frac
self.offload_param_frac = offload_param_frac
# these should be initialized in setup_grads_device
self.keep_gathered_chunk_mem = 0.0
self.keep_cuda_chunk_mem = 0.0
def evict_tensors(self, can_evict_chunks: List[Chunk], **kwargs) -> Tuple[int, float]:
volume = 0
start = time()
can_shard_chunk_mem = sum(chunk.chunk_mem for chunk in can_evict_chunks)
can_offload_chunk_mem = can_shard_chunk_mem
for chunk in can_evict_chunks:
if can_shard_chunk_mem <= self.keep_gathered_chunk_mem:
break
self.chunk_manager.release_chunk(chunk)
# real saved mem is chunk_mem - shard_mem, for simplicity we use chunk_mem
can_shard_chunk_mem -= chunk.chunk_mem
for chunk in can_evict_chunks:
if can_offload_chunk_mem <= self.keep_cuda_chunk_mem:
break
self.chunk_manager.move_chunk(chunk, torch.device('cpu'))
volume += chunk.chunk_mem
return volume, time() - start
# real saved mem is shard_mem, for simplicity we use chunk_mem
can_offload_chunk_mem -= chunk.chunk_mem
return 0, 0.0
def setup_grads_device(self, params: List[torch.Tensor], grads_device_map: Dict[torch.Tensor,
torch.device]) -> None:
total_chunk_mem = sum(self.chunk_manager.get_chunk(p).chunk_mem for p in params)
class CUDAPlacementPolicy(PlacementPolicy):
def __init__(self,
chunk_manager: ChunkManager,
mem_stats_collector: Optional[ChunkMemStatsCollector] = None) -> None:
assert torch.cuda.is_available(), 'Cannot use CUDATensorPlacementPolicy when CUDA is not available'
super().__init__(chunk_manager, mem_stats_collector=mem_stats_collector)
def evict_tensors(self, can_evict_chunks: List[Chunk], **kwargs) -> Tuple[int, float]:
return 0, 0
@staticmethod
def get_default_device() -> torch.device:
return get_current_device()
offload_optim_chunk_mem = total_chunk_mem * self.offload_optim_frac
offloaded_optim_chunk_mem = 0
chunks = set(self.chunk_manager.get_chunk(p) for p in params)
for chunk in chunks:
params = chunk.get_tensors()
# init offload optim settings
# keep gathered chunks are in CUDA
if chunk.keep_gathered or offloaded_optim_chunk_mem >= offload_optim_chunk_mem:
device = get_current_device()
else:
device = torch.device('cpu')
# real offloaded mem is chunk.shard_mem, for simplicity we use chunk mem here
offloaded_optim_chunk_mem += chunk.chunk_mem
for p in params:
grads_device_map[p] = device
self.keep_gathered_chunk_mem = total_chunk_mem * (1 - self.shard_param_frac)
self.keep_cuda_chunk_mem = total_chunk_mem * (1 - self.offload_param_frac)
class AutoPlacementPolicy(PlacementPolicy):
need_mem_stats: bool = True
# model data will use 1-_warmup_non_model_data_ratio CUDA memory in warmup phase
# you can set them by AutoPlacementPolicy.set_warmup_non_model_data_ratio()
# and AutoPlacementPolicy.set_steady_cuda_cap_ratio()
_warmup_non_model_data_ratio: float = 0.8
_steady_cuda_cap_ratio: float = 0.9
def __init__(self,
chunk_manager: ChunkManager,
mem_stats_collector: Optional[ChunkMemStatsCollector] = None) -> None:
mem_stats_collector: Optional[ChunkMemStatsCollector] = None,
warmup_non_model_data_ratio: float = 0.8,
steady_cuda_cap_ratio: float = 0.9,
**kwargs) -> None:
super().__init__(chunk_manager, mem_stats_collector=mem_stats_collector)
# model data will use 1-_warmup_non_model_data_ratio CUDA memory in warmup phase
# you can set them by AutoPlacementPolicy.set_warmup_non_model_data_ratio()
# and AutoPlacementPolicy.set_steady_cuda_cap_ratio()
self._warmup_non_model_data_ratio = warmup_non_model_data_ratio
self._steady_cuda_cap_ratio = steady_cuda_cap_ratio
def evict_tensors(self,
can_evict_chunks: List[Chunk],
@ -105,11 +138,11 @@ class AutoPlacementPolicy(PlacementPolicy):
used_cuda_model_data = self.chunk_manager.total_mem['cuda']
if warmup:
# We designate a part of CUDA memory for model data in warmup iterations.
max_cuda_non_model_data_per_period = cuda_capacity * AutoPlacementPolicy._warmup_non_model_data_ratio
max_cuda_non_model_data_per_period = cuda_capacity * self._warmup_non_model_data_ratio
else:
# max non-model-data cuda memory consumption of this sampling moment and the next sampling moment.
max_cuda_non_model_data_per_period = self.mem_stats_collector.next_period_non_model_data_usage('cuda')
cuda_capacity *= AutoPlacementPolicy._steady_cuda_cap_ratio
cuda_capacity *= self._steady_cuda_cap_ratio
total_cuda_model_data = cuda_capacity - max_cuda_non_model_data_per_period
avail_cuda_model_data = total_cuda_model_data - used_cuda_model_data
freed_cuda_model_data = 0
@ -145,89 +178,22 @@ class AutoPlacementPolicy(PlacementPolicy):
next_compute_idx = sorted(next_compute_idx.items(), key=lambda pair: pair[1], reverse=True)
return [t for (t, idx) in next_compute_idx]
@staticmethod
def set_warmup_non_model_data_ratio(ratio: float) -> None:
ratio = float(ratio)
assert 0.0 < ratio < 1.0
AutoPlacementPolicy._warmup_non_model_data_ratio = ratio
@staticmethod
def set_steady_cuda_cap_ratio(ratio: float) -> None:
ratio = float(ratio)
assert 0.0 < ratio < 1.0
AutoPlacementPolicy._steady_cuda_cap_ratio = ratio
class ConstPlacementPolicy(PlacementPolicy):
need_mem_stats: bool = False
_accessed_memory_boundary = 512 * 1024**2
def __init__(self,
chunk_manager: ChunkManager,
mem_stats_collector: Optional[ChunkMemStatsCollector] = None) -> None:
super().__init__(chunk_manager, mem_stats_collector=mem_stats_collector)
def evict_tensors(self,
can_evict_chunks: List[Chunk],
cuda_demand: int = 0,
warmup: bool = True,
compute_list: Optional[List[Tuple[Chunk, ...]]] = None,
compute_idx: int = 0,
**kwargs) -> Tuple[int, float]:
"""
See the docstrings in the class `AutoPlacementPolicy`.
"""
start = time()
used_accessed_memory = self.chunk_manager.accessed_mem
avail_accessed_memory = ConstPlacementPolicy._accessed_memory_boundary - used_accessed_memory
freed_accessed_memory = 0
if avail_accessed_memory < cuda_demand:
to_free_memory = cuda_demand - avail_accessed_memory
to_free_chunks = can_evict_chunks
if not warmup:
# sort all chunks
to_free_chunks = self._sort_can_evict_chunks(tuple(to_free_chunks), compute_idx, tuple(compute_list))
for chunk in to_free_chunks:
if freed_accessed_memory >= to_free_memory:
break
self.chunk_manager.release_chunk(chunk)
self.chunk_manager.move_chunk(chunk, torch.device('cpu'))
freed_accessed_memory += chunk.chunk_mem
if freed_accessed_memory < to_free_memory:
raise RuntimeError(f"Adjust layout failed! No enough CUDA memory! "
f"Need {to_free_memory}, freed {freed_accessed_memory}")
return freed_accessed_memory, time() - start
@staticmethod
@functools.lru_cache(maxsize=None)
def _sort_can_evict_chunks(can_evict_chunks: tuple, compute_idx: int, compute_list: tuple) -> list:
next_compute_idx = {chunk: len(compute_list) for chunk in can_evict_chunks}
for i in range(len(compute_list) - 1, compute_idx, -1):
for chunk in compute_list[i]:
if chunk in next_compute_idx:
next_compute_idx[chunk] = i
next_compute_idx = sorted(next_compute_idx.items(), key=lambda pair: pair[1], reverse=True)
return [t for (t, idx) in next_compute_idx]
@staticmethod
def set_const_memory_boundary(cuda_memory_mb: int) -> None:
boundary = int(cuda_memory_mb * 1024**2)
assert boundary > 0
ConstPlacementPolicy._accessed_memory_boundary = boundary
def setup_grads_device(self, params: List[torch.Tensor], grads_device_map: Dict[torch.Tensor,
torch.device]) -> None:
for p in params:
chunk = self.chunk_manager.get_chunk(p)
# init offload optim settings
# keep gathered chunks are in CUDA
if chunk.keep_gathered:
grads_device_map[p] = get_current_device()
else:
grads_device_map[p] = torch.device('cpu')
class PlacementPolicyFactory:
policies: Dict[str, Type[PlacementPolicy]] = {
'cpu': CPUPlacementPolicy,
'cuda': CUDAPlacementPolicy,
'auto': AutoPlacementPolicy,
'const': ConstPlacementPolicy
'static': StaticPlacementPolicy,
}
@staticmethod
@ -239,8 +205,3 @@ class PlacementPolicyFactory:
@staticmethod
def get_policy_names():
return tuple(PlacementPolicyFactory.policies.keys())
@staticmethod
def get_default_device(policy_name: str) -> torch.device:
policy_cls = PlacementPolicyFactory.create(policy_name)
return policy_cls.get_default_device()

View File

@ -64,13 +64,13 @@ def get_static_torch_model(zero_ddp_model,
device=torch.device("cpu"),
dtype=torch.float32,
only_rank_0=True) -> torch.nn.Module:
"""Get a static torch.nn.Module model from the given ZeroDDP module.
You should notice that the original ZeroDDP model is not modified.
"""Get a static torch.nn.Module model from the given GeminiDDP module.
You should notice that the original GeminiDDP model is not modified.
Thus, you can use the original model in further training.
But you should not use the returned torch model to train, this can cause unexpected errors.
Args:
zero_ddp_model (ZeroDDP): a zero ddp model
zero_ddp_model (GeminiDDP): a zero ddp model
device (torch.device): the device of the final torch model
dtype (torch.dtype): the dtype of the final torch model
only_rank_0 (bool): if True, only rank0 has the converted torch model
@ -78,8 +78,8 @@ def get_static_torch_model(zero_ddp_model,
Returns:
torch.nn.Module: a static torch model used for saving checkpoints or numeric checks
"""
from colossalai.zero.gemini.gemini_ddp import ZeroDDP
assert isinstance(zero_ddp_model, ZeroDDP)
from colossalai.zero.gemini.gemini_ddp import GeminiDDP
assert isinstance(zero_ddp_model, GeminiDDP)
state_dict = zero_ddp_model.state_dict(only_rank_0=only_rank_0)
colo_model = zero_ddp_model.module

View File

@ -109,6 +109,6 @@ def zero_optim_wrapper(model: nn.Module,
config_dict['clip_grad_norm'] = max_norm
return LowLevelZeroOptimizer(optimizer, **config_dict, verbose=verbose)
else:
from colossalai.zero.gemini.gemini_optimizer import ZeroOptimizer
from colossalai.zero.gemini.gemini_optimizer import GeminiOptimizer
config_dict['clipping_norm'] = max_norm
return ZeroOptimizer(optimizer, model, **config_dict, verbose=verbose)
return GeminiOptimizer(optimizer, model, **config_dict, verbose=verbose)

View File

@ -54,32 +54,38 @@ We also provide a lightweight chunk search mechanism to help users automatically
We will use `GeminiDDP` to use ZeRO with chunk-based memory management. This is our new torch.Module wrapper which uses ZeRO-DP and Gemini. ZeRO is for parallelism and Gemini is for memory management.
Also Make sure that your model is initialized under the context of ColoInitContext.
Gemini allows LazyInitContext, which can save memory when initializing large models with multi-GPUs.
If your model has `N` billion parameters and your GPU memory is `M` GB, we recommend you use LazyInitContext when `4N >= M`. Otherwise, LazyInitContext is optional.
<!--- doc-test-ignore-start -->
```python
with ColoInitContext(device='cpu', default_dist_spec=default_dist_spec, default_pg=default_pg):
with LazyInitContext(default_device=torch.device('cuda')):
model = gpt2_medium(checkpoint=True)
```
<!--- doc-test-ignore-end -->
Define the model parameters as follows:
We've provided `Booster` API which is user-friendly. We recommend you use `Booster` API. But if you still want to use low level API, you can read below content of this section.
Wrap the model with `GeminiDDP`.
<!--- doc-test-ignore-start -->
```python
chunk_manager = init_chunk_manager(model=module,
init_device=device,
hidden_dim=hidden_dim,
search_range_m=search_range_m,
min_chunk_size_m=min_chunk_size_m)
gemini_manager = GeminiManager(placement_policy, chunk_manager)
model = GeminiDDP(model, hidden_dim=hidden_dim, min_chunk_size_m=min_chunk_size_m)
```
<!--- doc-test-ignore-end -->
`hidden_dim` is the hidden dimension of DNN. Users can provide this argument to speed up searching. If users do not know this argument before training, it is ok. We will use a default value 1024. `min_chunk_size_m` is a floating point, being the minimum chunk size divided by 2^20 (e.g., if min_chunk_size_m=2.5, then the minimum chunk size should be 2.5*(2^20)).If the aggregate size of parameters is still smaller than the minimum chunk size, all parameters will be compacted into one small chunk.
Initialization of the optimizer.
<!--- doc-test-ignore-start -->
```python
optimizer = GeminiAdamOptimizer(model, lr=1e-3, initial_scale=2**5)
```
<!--- doc-test-ignore-start -->
Training
<!--- doc-test-ignore-start -->
```python
optimizer.zero_grad()
outputs = model(input_ids, attn_mask)
@ -87,6 +93,7 @@ loss = criterion(outputs, input_ids)
optimizer.backward(loss)
optimizer.step()
```
<!--- doc-test-ignore-start -->
> ⚠️ Note: Please do not use `loss.backward()`, the standard way of writing is `optimizer.backward(loss)`.
### Train GPT
@ -142,46 +149,6 @@ class GPTLMLoss(nn.Module):
return self.loss_fn(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
```
Define tensor parallel and parameter sharding strategies for tensor parallelism:
```python
def tensor_parallelize(model: torch.nn.Module, pg: ProcessGroup):
for mn, module in model.named_modules():
for pn, param in module.named_parameters(recurse=False):
if hasattr(param, 'visited'):
continue
param.set_dist_spec(ReplicaSpec())
if 'mlp.c_fc' in mn:
if 'weight' in pn or 'bias' in pn:
split_param_col_tp1d(param, pg)
param.compute_spec.set_output_replicate(False)
else:
param.set_dist_spec(ReplicaSpec())
elif 'mlp.c_proj' in mn:
if 'weight' in pn:
split_param_row_tp1d(param, pg)
else:
param.set_dist_spec(ReplicaSpec())
elif 'wte' in mn or 'wpe' in mn:
split_param_col_tp1d(param, pg)
elif 'c_attn' in mn or 'c_proj' in mn:
split_param_col_tp1d(param, pg)
else:
param.set_dist_spec(ReplicaSpec())
param.visited = True
def split_param_single_dim_tp1d(dim: int, param: ColoParameter, pg: ProcessGroup):
spec = (ShardSpec([dim], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
param.set_tensor_spec(*spec)
def split_param_row_tp1d(param: ColoParameter, pg: ProcessGroup):
split_param_single_dim_tp1d(0, param, pg)
def split_param_col_tp1d(param: ColoParameter, pg: ProcessGroup):
split_param_single_dim_tp1d(-1, param, pg)
```
Write a function to get random inputs:
@ -198,7 +165,7 @@ Finally, we define a model which uses Gemini + ZeRO DDP and define our training
from colossalai.nn.optimizer import HybridAdam
from colossalai.booster import Booster
from colossalai.zero import ColoInitContext
from colossalai.lazy import LazyInitContext
from colossalai.booster.plugin import GeminiPlugin
def main():
@ -214,17 +181,13 @@ def main():
optimizer = HybridAdam(model.parameters(), lr=0.001)
torch.manual_seed(123)
default_pg = ProcessGroup(tp_degree=args.tp_degree)
default_dist_spec = ShardSpec([-1], [args.tp_degree])
# build GPT model
with ColoInitContext(device='cpu', default_dist_spec=default_dist_spec, default_pg=default_pg):
with ColoInitContext(default_device=torch.device('cuda')):
model = gpt2_medium(checkpoint=True)
pg = default_pg
# Tensor Parallelism (TP)
tensor_parallelize(model, pg)
# Gemini + ZeRO DP, Note it must be used after TP
plugin = GeminiPlugin(placement_policy='cuda', max_norm=1.0, initial_scale=2**5)
# Gemini + ZeRO DP
plugin = GeminiPlugin(max_norm=1.0, initial_scale=2**5)
booster = Booster(plugin=plugin)
model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion)

View File

@ -53,32 +53,37 @@
我们将运用`GeminiDDP`的方式来使用基于Chunk内存管理的ZeRO。这是我们新包装的torch.Module ,它使用 ZeRO-DP 和 Gemini其中ZeRO 用于并行Gemini 用于内存管理。
同样需要确保你的模型是在 `ColoInitContext` 的上下文中初始化的。
Gemini支持惰性初始化, 它可以节省多卡初始化大模型时的显存使用.
如果你的模型有 `N` billion 个参数,你的 GPU 内存为 `M` GB, 当 `4N >= M` 时,我们推荐使用 LazyInitContext。否则LazyInitContext 是可选的。
<!--- doc-test-ignore-start -->
```python
with ColoInitContext(device='cpu', default_dist_spec=default_dist_spec, default_pg=default_pg):
with LazyInitContext(default_device=torch.device('cuda')):
model = gpt2_medium(checkpoint=True)
```
<!--- doc-test-ignore-end -->
定义模型参数如下:
我们提供了 `Booster` API它用户友好。我们推荐你使用 `Booster` API。如果您仍然想使用底层 API您可以继续阅读本节其他内容。
使用 `GeminiDDP` 包装模型。
<!--- doc-test-ignore-start -->
```python
chunk_manager = init_chunk_manager(model=module,
init_device=device,
hidden_dim=hidden_dim,
search_range_m=search_range_m,
min_chunk_size_m=min_chunk_size_m)
gemini_manager = GeminiManager(placement_policy, chunk_manager)
model = ZeroDDP(model, gemini_manager)
model = GeminiDDP(model, hidden_dim=hidden_dim, min_chunk_size_m=min_chunk_size_m)
```
<!--- doc-test-ignore-end -->
`hidden dim`是DNN的隐藏维度。用户可以提供这个参数来加快搜索速度。如果用户在训练前不知道这个参数也可以。 我们将使用默认值 1024。`min_chunk_size_m`是以兆2^20为单位的最小块大小。如果参数的总大小仍然小于最小块大小则所有参数将被压缩为一个小块。
初始化优化器。
<!--- doc-test-ignore-start -->
```python
optimizer = GeminiAdamOptimizer(model, lr=1e-3, initial_scale=2**5)
```
<!--- doc-test-ignore-end -->
<!--- doc-test-ignore-start -->
训练
```python
optimizer.zero_grad()
@ -87,6 +92,7 @@ loss = criterion(outputs, input_ids)
optimizer.backward(loss)
optimizer.step()
```
<!--- doc-test-ignore-end -->
> ⚠️ 注意:请不要使用`loss.backward()`,规范写法是`optimizer.backward(loss)`。
### 训练GPT
@ -143,47 +149,6 @@ class GPTLMLoss(nn.Module):
return self.loss_fn(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
```
定义张量并行和参数分片策略:
```python
def tensor_parallelize(model: torch.nn.Module, pg: ProcessGroup):
for mn, module in model.named_modules():
for pn, param in module.named_parameters(recurse=False):
if hasattr(param, 'visited'):
continue
param.set_dist_spec(ReplicaSpec())
if 'mlp.c_fc' in mn:
if 'weight' in pn or 'bias' in pn:
split_param_col_tp1d(param, pg)
param.compute_spec.set_output_replicate(False)
else:
param.set_dist_spec(ReplicaSpec())
elif 'mlp.c_proj' in mn:
if 'weight' in pn:
split_param_row_tp1d(param, pg)
else:
param.set_dist_spec(ReplicaSpec())
elif 'wte' in mn or 'wpe' in mn:
split_param_col_tp1d(param, pg)
elif 'c_attn' in mn or 'c_proj' in mn:
split_param_col_tp1d(param, pg)
else:
param.set_dist_spec(ReplicaSpec())
param.visited = True
def split_param_single_dim_tp1d(dim: int, param: ColoParameter, pg: ProcessGroup):
spec = (ShardSpec([dim], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
param.set_tensor_spec(*spec)
def split_param_row_tp1d(param: ColoParameter, pg: ProcessGroup):
split_param_single_dim_tp1d(0, param, pg)
def split_param_col_tp1d(param: ColoParameter, pg: ProcessGroup):
split_param_single_dim_tp1d(-1, param, pg)
```
写一个获得随机输入的函数:
```python
@ -200,7 +165,7 @@ def get_data(batch_size, seq_len, vocab_size):
from colossalai.nn.optimizer import HybridAdam
from colossalai.booster import Booster
from colossalai.zero import ColoInitContext
from colossalai.lazy import LazyInitContext
from colossalai.booster.plugin import GeminiPlugin
def main():
@ -216,17 +181,13 @@ def main():
optimizer = HybridAdam(model.parameters(), lr=0.001)
torch.manual_seed(123)
default_pg = ProcessGroup(tp_degree=args.tp_degree)
default_dist_spec = ShardSpec([-1], [args.tp_degree])
# build GPT model
with ColoInitContext(device='cpu', default_dist_spec=default_dist_spec, default_pg=default_pg):
with ColoInitContext(default_device=torch.device('cuda')):
model = gpt2_medium(checkpoint=True)
pg = default_pg
# Tensor Parallelism (TP)
tensor_parallelize(model, pg)
# Gemini + ZeRO DP, Note it must be used after TP
plugin = GeminiPlugin(placement_policy='cuda', max_norm=1.0, initial_scale=2**5)
# Gemini + ZeRO DP
plugin = GeminiPlugin(max_norm=1.0, initial_scale=2**5)
booster = Booster(plugin=plugin)
model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion)

View File

@ -22,7 +22,7 @@ from colossalai.nn.parallel import GeminiDDP, zero_model_wrapper, zero_optim_wra
from colossalai.tensor import ColoParameter, ComputePattern, ComputeSpec, ProcessGroup, ReplicaSpec, ShardSpec
from colossalai.utils import get_current_device
from colossalai.utils.model.colo_init_context import ColoInitContext
from colossalai.zero import ZeroOptimizer
from colossalai.zero import GeminiOptimizer
def main():
@ -46,7 +46,7 @@ def main():
args.local_rank = -1
args.log_interval = 1
else:
colossalai.launch_from_torch(config={}) #args.colossal_config
colossalai.launch_from_torch(config={}) # args.colossal_config
args.local_rank = int(os.environ["LOCAL_RANK"])
logger.info(
f'launch_from_torch, world size: {torch.distributed.get_world_size()} | ' +
@ -123,7 +123,8 @@ def main():
get_tflops_func = partial(get_tflops, numel, args.train_micro_batch_size_per_gpu, args.max_seq_length)
# 144003367 is is the length of the entire dataset
steps_per_epoch = 144003367 // world_size // args.train_micro_batch_size_per_gpu // args.gradient_accumulation_steps // args.refresh_bucket_size #len(dataloader)
# len(dataloader)
steps_per_epoch = 144003367 // world_size // args.train_micro_batch_size_per_gpu // args.gradient_accumulation_steps // args.refresh_bucket_size
total_steps = steps_per_epoch * args.epoch
lr_scheduler = get_lr_scheduler(optimizer, total_steps=total_steps, last_epoch=-1)

View File

@ -20,6 +20,5 @@ for plugin in "gemini"; do
--lr_scheduler="constant" \
--lr_warmup_steps=0 \
--test_run=True \
--num_class_images=200 \
--placement="auto" # "cuda"
--num_class_images=200
done

View File

@ -2,9 +2,9 @@ import argparse
import hashlib
import math
import os
import shutil
from pathlib import Path
from typing import Optional
import shutil
import torch
import torch.nn.functional as F
@ -19,6 +19,8 @@ from tqdm.auto import tqdm
from transformers import AutoTokenizer, PretrainedConfig
import colossalai
from colossalai.booster import Booster
from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.logging import disable_existing_loggers, get_dist_logger
@ -26,8 +28,6 @@ from colossalai.nn.optimizer import HybridAdam
from colossalai.utils import get_current_device
from colossalai.zero import ColoInitContext
from colossalai.zero.gemini import get_static_torch_model
from colossalai.booster import Booster
from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin
disable_existing_loggers()
logger = get_dist_logger()
@ -138,10 +138,10 @@ def parse_args(input_args=None):
" resolution"),
)
parser.add_argument(
"--placement",
type=str,
default="cpu",
help="Placement Policy for Gemini. Valid when using colossalai as dist plan.",
"--offload_optim_frac",
type=float,
default=1.0,
help="Fraction of optimizer states to be offloaded. Valid when using colossalai as dist plan.",
)
parser.add_argument(
"--center_crop",
@ -461,18 +461,17 @@ def main(args):
revision=args.revision,
)
if args.externel_unet_path is None:
logger.info(f"Loading UNet2DConditionModel from {args.pretrained_model_name_or_path}", ranks=[0])
unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path,
subfolder="unet",
revision=args.revision,
low_cpu_mem_usage=False)
subfolder="unet",
revision=args.revision,
low_cpu_mem_usage=False)
else:
logger.info(f"Loading UNet2DConditionModel from {args.externel_unet_path}", ranks=[0])
unet = UNet2DConditionModel.from_pretrained(args.externel_unet_path,
revision=args.revision,
low_cpu_mem_usage=False)
revision=args.revision,
low_cpu_mem_usage=False)
vae.requires_grad_(False)
text_encoder.requires_grad_(False)
@ -491,30 +490,31 @@ def main(args):
if args.plugin.startswith('torch_ddp'):
plugin = TorchDDPPlugin()
elif args.plugin == 'gemini':
plugin = GeminiPlugin(placement_policy=args.placement, strict_ddp_mode=True, initial_scale=2 ** 5)
plugin = GeminiPlugin(offload_optim_frac=args.offload_optim_frac, strict_ddp_mode=True, initial_scale=2**5)
elif args.plugin == 'low_level_zero':
plugin = LowLevelZeroPlugin(initial_scale=2 ** 5)
plugin = LowLevelZeroPlugin(initial_scale=2**5)
booster = Booster(plugin=plugin, **booster_kwargs)
# config optimizer for colossalai zero
optimizer = HybridAdam(unet.parameters(), lr=args.learning_rate, initial_scale=2**5, clipping_norm=args.max_grad_norm)
optimizer = HybridAdam(unet.parameters(),
lr=args.learning_rate,
initial_scale=2**5,
clipping_norm=args.max_grad_norm)
# load noise_scheduler
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
# prepare dataset
logger.info(f"Prepare dataset from {args.instance_data_dir}", ranks=[0])
train_dataset = DreamBoothDataset(
instance_data_root=args.instance_data_dir,
instance_prompt=args.instance_prompt,
class_data_root=args.class_data_dir if args.with_prior_preservation else None,
class_prompt=args.class_prompt,
tokenizer=tokenizer,
size=args.resolution,
center_crop=args.center_crop,
test=args.test_run
)
train_dataset = DreamBoothDataset(instance_data_root=args.instance_data_dir,
instance_prompt=args.instance_prompt,
class_data_root=args.class_data_dir if args.with_prior_preservation else None,
class_prompt=args.class_prompt,
tokenizer=tokenizer,
size=args.resolution,
center_crop=args.center_crop,
test=args.test_run)
def collate_fn(examples):
input_ids = [example["instance_prompt_ids"] for example in examples]
@ -690,6 +690,7 @@ def main(args):
if args.push_to_hub:
repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True)
if __name__ == "__main__":
args = parse_args()
main(args)

View File

@ -2,9 +2,9 @@ import argparse
import hashlib
import math
import os
import shutil
from pathlib import Path
from typing import Optional
import shutil
import torch
import torch.nn.functional as F
@ -21,6 +21,8 @@ from tqdm.auto import tqdm
from transformers import AutoTokenizer, PretrainedConfig
import colossalai
from colossalai.booster import Booster
from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.logging import disable_existing_loggers, get_dist_logger
@ -28,8 +30,6 @@ from colossalai.nn.optimizer import HybridAdam
from colossalai.utils import get_current_device
from colossalai.zero import ColoInitContext, GeminiAdamOptimizer
from colossalai.zero.gemini import get_static_torch_model
from colossalai.booster import Booster
from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin
disable_existing_loggers()
logger = get_dist_logger()
@ -459,18 +459,17 @@ def main(args):
revision=args.revision,
)
if args.externel_unet_path is None:
logger.info(f"Loading UNet2DConditionModel from {args.pretrained_model_name_or_path}", ranks=[0])
unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path,
subfolder="unet",
revision=args.revision,
low_cpu_mem_usage=False)
subfolder="unet",
revision=args.revision,
low_cpu_mem_usage=False)
else:
logger.info(f"Loading UNet2DConditionModel from {args.externel_unet_path}", ranks=[0])
unet = UNet2DConditionModel.from_pretrained(args.externel_unet_path,
revision=args.revision,
low_cpu_mem_usage=False)
revision=args.revision,
low_cpu_mem_usage=False)
unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path,
subfolder="unet",
revision=args.revision,
@ -490,8 +489,7 @@ def main(args):
block_id = int(name[len("down_blocks.")])
hidden_size = unet.config.block_out_channels[block_id]
lora_attn_procs[name] = LoRACrossAttnProcessor(hidden_size=hidden_size,
cross_attention_dim=cross_attention_dim)
lora_attn_procs[name] = LoRACrossAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim)
unet.set_attn_processor(lora_attn_procs)
lora_layers = AttnProcsLayers(unet.attn_processors)
@ -513,14 +511,17 @@ def main(args):
if args.plugin.startswith('torch_ddp'):
plugin = TorchDDPPlugin()
elif args.plugin == 'gemini':
plugin = GeminiPlugin(placement_policy='cuda', strict_ddp_mode=True, initial_scale=2 ** 5)
plugin = GeminiPlugin(strict_ddp_mode=True, initial_scale=2**5)
elif args.plugin == 'low_level_zero':
plugin = LowLevelZeroPlugin(initial_scale=2 ** 5)
plugin = LowLevelZeroPlugin(initial_scale=2**5)
booster = Booster(plugin=plugin, **booster_kwargs)
# config optimizer for colossalai zero
optimizer = HybridAdam(unet.parameters(), lr=args.learning_rate, initial_scale=2**5, clipping_norm=args.max_grad_norm)
optimizer = HybridAdam(unet.parameters(),
lr=args.learning_rate,
initial_scale=2**5,
clipping_norm=args.max_grad_norm)
# load noise_scheduler
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
@ -711,6 +712,7 @@ def main(args):
if args.push_to_hub:
repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True)
if __name__ == "__main__":
args = parse_args()
main(args)

View File

@ -49,8 +49,8 @@ python eval.py -c ./ckpt-low_level_zero -e 80
Expected accuracy performance will be:
| Model | Single-GPU Baseline FP32 | Booster DDP with FP32 | Booster DDP with FP16 | Booster Low Level Zero |
| --------- | ------------------------ | --------------------- | --------------------- | ---------------------- |
| ResNet-18 | 85.85% | 84.91% | 85.46% | 84.50% |
| Model | Single-GPU Baseline FP32 | Booster DDP with FP32 | Booster DDP with FP16 | Booster Low Level Zero | Booster Gemini |
| --------- | ------------------------ | --------------------- | --------------------- | ---------------------- | -------------- |
| ResNet-18 | 85.85% | 84.91% | 85.46% | 84.50% | 84.60% |
**Note: the baseline is adapted from the [script](https://pytorch-tutorial.readthedocs.io/en/latest/tutorial/chapter03_intermediate/3_2_2_cnn_resnet_cifar10/) to use `torchvision.models.resnet18`**

View File

@ -104,7 +104,7 @@ def main():
'--plugin',
type=str,
default='torch_ddp',
choices=['torch_ddp', 'torch_ddp_fp16', 'low_level_zero'],
choices=['torch_ddp', 'torch_ddp_fp16', 'low_level_zero', 'gemini'],
help="plugin to use")
parser.add_argument('-r', '--resume', type=int, default=-1, help="resume from the epoch's checkpoint")
parser.add_argument('-c', '--checkpoint', type=str, default='./checkpoint', help="checkpoint directory")
@ -141,7 +141,7 @@ def main():
if args.plugin.startswith('torch_ddp'):
plugin = TorchDDPPlugin()
elif args.plugin == 'gemini':
plugin = GeminiPlugin(placement_policy='cuda', strict_ddp_mode=True, initial_scale=2**5)
plugin = GeminiPlugin(initial_scale=2**5)
elif args.plugin == 'low_level_zero':
plugin = LowLevelZeroPlugin(initial_scale=2**5)

View File

@ -1,19 +1,18 @@
import time
import torch
import transformers
from transformers import ViTConfig, ViTForImageClassification
import tqdm
import transformers
from args import parse_benchmark_args
from transformers import ViTConfig, ViTForImageClassification
import colossalai
from colossalai.nn.optimizer import HybridAdam
from colossalai.logging import disable_existing_loggers, get_dist_logger
from colossalai.utils import get_current_device
from colossalai.booster import Booster
from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin
from colossalai.cluster import DistCoordinator
from colossalai.logging import disable_existing_loggers, get_dist_logger
from colossalai.nn.optimizer import HybridAdam
from args import parse_benchmark_args
def format_num(num: int, bytes=False):
"""Scale bytes to its proper format, e.g. 1253656 => '1.20MB'"""
@ -26,8 +25,13 @@ def format_num(num: int, bytes=False):
def get_data(batch_size, num_labels, num_channels=3, height=224, width=224):
pixel_values = torch.randn(batch_size, num_channels, height, width, device=torch.cuda.current_device(), dtype=torch.float)
labels = torch.randint(0, num_labels, (batch_size, ), device=torch.cuda.current_device(), dtype=torch.int64)
pixel_values = torch.randn(batch_size,
num_channels,
height,
width,
device=torch.cuda.current_device(),
dtype=torch.float)
labels = torch.randint(0, num_labels, (batch_size,), device=torch.cuda.current_device(), dtype=torch.int64)
return pixel_values, labels
@ -55,11 +59,11 @@ def main():
transformers.utils.logging.set_verbosity_info()
else:
transformers.utils.logging.set_verbosity_error()
# Whether to set limit on memory capacity
if args.mem_cap > 0:
colo_memory_cap(args.mem_cap)
# Build ViT model
config = ViTConfig.from_pretrained(args.model_name_or_path)
model = ViTForImageClassification(config)
@ -75,11 +79,7 @@ def main():
if args.plugin.startswith('torch_ddp'):
plugin = TorchDDPPlugin()
elif args.plugin == 'gemini':
plugin = GeminiPlugin(device=get_current_device(),
placement_policy='cpu',
pin_memory=True,
strict_ddp_mode=True,
initial_scale=2**5)
plugin = GeminiPlugin(offload_optim_frac=1.0, pin_memory=True, initial_scale=2**5)
elif args.plugin == 'low_level_zero':
plugin = LowLevelZeroPlugin(initial_scale=2**5)
logger.info(f"Set plugin as {args.plugin}", ranks=[0])
@ -90,16 +90,15 @@ def main():
# Set booster
booster = Booster(plugin=plugin, **booster_kwargs)
model, optimizer, _, _, _ = booster.boost(model, optimizer)
# Start training.
logger.info(f"Start testing", ranks=[0])
progress_bar = tqdm.tqdm(total=args.max_train_steps, desc="Training Step", disable=not coordinator.is_master())
torch.cuda.synchronize()
model.train()
start_time = time.time()
for _ in range(args.max_train_steps):
pixel_values, labels = get_data(args.batch_size, args.num_labels, 3, 224, 224)
@ -111,18 +110,19 @@ def main():
torch.cuda.synchronize()
progress_bar.update(1)
# Compute Statistics
# Compute Statistics
end_time = time.time()
throughput = "{:.4f}".format((world_size * args.max_train_steps * args.batch_size) / (end_time - start_time))
max_mem = format_num(torch.cuda.max_memory_allocated(device=torch.cuda.current_device()), bytes=True)
logger.info(f"Testing finished, "
f"batch size per gpu: {args.batch_size}, "
f"plugin: {args.plugin}, "
f"throughput: {throughput}, "
f"maximum memory usage per gpu: {max_mem}.",
ranks=[0])
logger.info(
f"Testing finished, "
f"batch size per gpu: {args.batch_size}, "
f"plugin: {args.plugin}, "
f"throughput: {throughput}, "
f"maximum memory usage per gpu: {max_mem}.",
ranks=[0])
if __name__ == "__main__":

View File

@ -1,20 +1,19 @@
import torch
import torch.distributed as dist
import transformers
from transformers import ViTConfig, ViTForImageClassification, ViTImageProcessor
from args import parse_demo_args
from data import BeansDataset, beans_collator
from tqdm import tqdm
from transformers import ViTConfig, ViTForImageClassification, ViTImageProcessor
import colossalai
from colossalai.nn.optimizer import HybridAdam
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
from colossalai.logging import disable_existing_loggers, get_dist_logger
from colossalai.utils import get_current_device
from colossalai.booster import Booster
from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin
from colossalai.cluster import DistCoordinator
from args import parse_demo_args
from data import BeansDataset, beans_collator
from colossalai.logging import disable_existing_loggers, get_dist_logger
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
from colossalai.nn.optimizer import HybridAdam
from colossalai.utils import get_current_device
def move_to_cuda(batch, device):
@ -22,12 +21,12 @@ def move_to_cuda(batch, device):
def train_epoch(epoch, model, optimizer, lr_scheduler, dataloader, booster, coordinator):
torch.cuda.synchronize()
model.train()
with tqdm(dataloader, desc=f'Epoch [{epoch + 1}]', disable=not coordinator.is_master()) as pbar:
for batch in pbar:
# Foward
@ -47,7 +46,7 @@ def train_epoch(epoch, model, optimizer, lr_scheduler, dataloader, booster, coor
@torch.no_grad()
def evaluate_model(epoch, model, eval_dataloader, num_labels, coordinator):
model.eval()
accum_loss = torch.zeros(1, device=get_current_device())
total_num = torch.zeros(1, device=get_current_device())
@ -76,9 +75,7 @@ def evaluate_model(epoch, model, eval_dataloader, num_labels, coordinator):
print(f"Evaluation result for epoch {epoch + 1}: \
average_loss={avg_loss}, \
accuracy={accuracy}.")
def main():
@ -102,14 +99,13 @@ def main():
train_dataset = BeansDataset(image_processor, split='train')
eval_dataset = BeansDataset(image_processor, split='validation')
# Load pretrained ViT model
config = ViTConfig.from_pretrained(args.model_name_or_path)
config.num_labels = train_dataset.num_labels
config.id2label = {str(i): c for i, c in enumerate(train_dataset.label_names)}
config.label2id = {c: str(i) for i, c in enumerate(train_dataset.label_names)}
model = ViTForImageClassification.from_pretrained(args.model_name_or_path,
config=config,
model = ViTForImageClassification.from_pretrained(args.model_name_or_path,
config=config,
ignore_mismatched_sizes=True)
logger.info(f"Finish loading model from {args.model_name_or_path}", ranks=[0])
@ -123,26 +119,22 @@ def main():
if args.plugin.startswith('torch_ddp'):
plugin = TorchDDPPlugin()
elif args.plugin == 'gemini':
plugin = GeminiPlugin(device=get_current_device(),
placement_policy='cpu',
pin_memory=True,
strict_ddp_mode=True,
initial_scale=2**5)
plugin = GeminiPlugin(offload_optim_frac=1.0, pin_memory=True, initial_scale=2**5)
elif args.plugin == 'low_level_zero':
plugin = LowLevelZeroPlugin(initial_scale=2**5)
logger.info(f"Set plugin as {args.plugin}", ranks=[0])
# Prepare dataloader
train_dataloader = plugin.prepare_dataloader(train_dataset,
batch_size=args.batch_size,
shuffle=True,
drop_last=True,
collate_fn=beans_collator)
batch_size=args.batch_size,
shuffle=True,
drop_last=True,
collate_fn=beans_collator)
eval_dataloader = plugin.prepare_dataloader(eval_dataset,
batch_size=args.batch_size,
shuffle=True,
drop_last=True,
collate_fn=beans_collator)
batch_size=args.batch_size,
shuffle=True,
drop_last=True,
collate_fn=beans_collator)
# Set optimizer
optimizer = HybridAdam(model.parameters(), lr=(args.learning_rate * world_size), weight_decay=args.weight_decay)
@ -156,11 +148,11 @@ def main():
# Set booster
booster = Booster(plugin=plugin, **booster_kwargs)
model, optimizer, _, train_dataloader, lr_scheduler = booster.boost(model=model,
optimizer=optimizer,
dataloader=train_dataloader,
lr_scheduler=lr_scheduler)
model, optimizer, _, train_dataloader, lr_scheduler = booster.boost(model=model,
optimizer=optimizer,
dataloader=train_dataloader,
lr_scheduler=lr_scheduler)
# Finetuning
logger.info(f"Start finetuning", ranks=[0])
for epoch in range(args.num_epoch):
@ -174,4 +166,4 @@ def main():
if __name__ == "__main__":
main()
main()

View File

@ -7,6 +7,14 @@ This directory includes two parts: Using the Booster API finetune Huggingface Be
bash test_ci.sh
```
### Results on 2-GPU
| Plugin | Accuracy | F1-score |
| -------------- | -------- | -------- |
| torch_ddp | 84.4% | 88.6% |
| torch_ddp_fp16 | 84.7% | 88.8% |
| gemini | 84.0% | 88.4% |
## Benchmark
```
bash benchmark.sh
@ -14,9 +22,9 @@ bash benchmark.sh
Now include these metrics in benchmark: CUDA mem occupy, throughput and the number of model parameters. If you have custom metrics, you can add them to benchmark_util.
## Results
### Results
### Bert
#### Bert
| | max cuda mem | throughput(sample/s) | params |
| :-----| -----------: | :--------: | :----: |
@ -25,10 +33,10 @@ Now include these metrics in benchmark: CUDA mem occupy, throughput and the numb
| gemini | 11.0 GB | 12.9 | 82M |
| low_level_zero | 11.29 G | 14.7 | 82M |
### AlBert
#### AlBert
| | max cuda mem | throughput(sample/s) | params |
| :-----| -----------: | :--------: | :----: |
| ddp | OOM | | |
| ddp_fp16 | OOM | | |
| gemini | 69.39 G | 1.3 | 208M |
| low_level_zero | 56.89 G | 1.4 | 208M |
| low_level_zero | 56.89 G | 1.4 | 208M |

View File

@ -38,8 +38,8 @@ def move_to_cuda(batch):
@torch.no_grad()
def evaluate_model(model: nn.Module, test_dataloader: Union[DataLoader, List[DataLoader]], num_labels: int, task_name: str,
eval_splits: List[str], coordinator: DistCoordinator):
def evaluate_model(model: nn.Module, test_dataloader: Union[DataLoader, List[DataLoader]], num_labels: int,
task_name: str, eval_splits: List[str], coordinator: DistCoordinator):
metric = evaluate.load("glue", task_name, process_id=coordinator.rank, num_process=coordinator.world_size)
model.eval()
@ -142,7 +142,7 @@ def main():
if args.plugin.startswith('torch_ddp'):
plugin = TorchDDPPlugin()
elif args.plugin == 'gemini':
plugin = GeminiPlugin(placement_policy='cuda', strict_ddp_mode=True, initial_scale=2**5)
plugin = GeminiPlugin(initial_scale=2**5)
elif args.plugin == 'low_level_zero':
plugin = LowLevelZeroPlugin(initial_scale=2**5)
@ -208,7 +208,7 @@ def main():
train_epoch(epoch, model, optimizer, lr_scheduler, train_dataloader, booster, coordinator)
results = evaluate_model(model, test_dataloader, data_builder.num_labels, args.task, data_builder.eval_splits,
coordinator)
coordinator)
if coordinator.is_master():
print(results)

View File

@ -4,9 +4,6 @@ export DISTPLAN=${DISTPLAN:-"CAI_Gemini"}
# The following options only valid when DISTPLAN="colossalai"
export GPUNUM=${GPUNUM:-1}
export TPDEGREE=${TPDEGREE:-1}
export PLACEMENT=${PLACEMENT:-"cpu"}
export USE_SHARD_INIT=${USE_SHARD_INIT:-False}
export BATCH_SIZE=${BATCH_SIZE:-16}
export MODEL_TYPE=${MODEL_TYPE:-"gpt2_medium"}
export TRAIN_STEP=${TRAIN_STEP:-10}
@ -21,11 +18,8 @@ fi
mkdir -p gemini_logs
torchrun --standalone --nproc_per_node=${GPUNUM} ./train_gpt_demo.py \
--tp_degree=${TPDEGREE} \
--model_type=${MODEL_TYPE} \
--batch_size=${BATCH_SIZE} \
--placement=${PLACEMENT} \
${USE_SHARD_INIT} \
--distplan=${DISTPLAN} \
--train_step=${TRAIN_STEP} \
2>&1 | tee ./gemini_logs/${MODEL_TYPE}_${DISTPLAN}_gpu_${GPUNUM}_bs_${BATCH_SIZE}_tp_${TPDEGREE}_${PLACEMENT}.log

View File

@ -6,29 +6,17 @@ for MODEL_TYPE in "gpt2_medium"; do
for DISTPLAN in "CAI_Gemini"; do
for BATCH_SIZE in 2; do
for GPUNUM in 1 4; do
for TPDEGREE in 1 2; do
if [ ${TPDEGREE} -gt ${GPUNUM} ]; then
continue
fi
for PLACEMENT in "cpu" "auto"; do
MODEL_TYPE=${MODEL_TYPE} DISTPLAN=${DISTPLAN} BATCH_SIZE=${BATCH_SIZE} GPUNUM=${GPUNUM} TPDEGREE=${TPDEGREE} PLACEMENT=${PLACEMENT} \
bash ./run_gemini.sh
done
done
MODEL_TYPE=${MODEL_TYPE} DISTPLAN=${DISTPLAN} BATCH_SIZE=${BATCH_SIZE} GPUNUM=${GPUNUM} \
bash ./run_gemini.sh
done
done
done
for DISTPLAN in "zero1" "zero2"; do
for DISTPLAN in "CAI_ZeRO2" "CAI_ZeRO1"; do
for BATCH_SIZE in 2; do
for GPUNUM in 1 4; do
for TPDEGREE in 1; do
if [ ${TPDEGREE} -gt ${GPUNUM} ]; then
continue
fi
MODEL_TYPE=${MODEL_TYPE} DISTPLAN=${DISTPLAN} BATCH_SIZE=${BATCH_SIZE} GPUNUM=${GPUNUM} TPDEGREE=${TPDEGREE}\
bash ./run_gemini.sh
done
MODEL_TYPE=${MODEL_TYPE} DISTPLAN=${DISTPLAN} BATCH_SIZE=${BATCH_SIZE} GPUNUM=${GPUNUM} \
bash ./run_gemini.sh
done
done
done

View File

@ -1,4 +1,5 @@
import os
from contextlib import nullcontext
from functools import partial
from time import time
@ -13,11 +14,10 @@ from torch.nn.parallel import DistributedDataParallel as DDP
import colossalai
from colossalai.booster import Booster
from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin
from colossalai.lazy import LazyInitContext
from colossalai.logging import disable_existing_loggers, get_dist_logger
from colossalai.nn.optimizer import HybridAdam
from colossalai.tensor import ColoParameter, ComputePattern, ComputeSpec, ProcessGroup, ReplicaSpec, ShardSpec
from colossalai.utils import get_current_device
from colossalai.zero import ColoInitContext
CAI_VERSION = colossalai.__version__
@ -30,24 +30,6 @@ def parse_args():
default='CAI_Gemini',
help="The distributed plan [colossalai, zero1, zero2, torch_ddp, torch_zero].",
)
parser.add_argument(
"--tp_degree",
type=int,
default=1,
help="Tensor Parallelism Degree. Valid when using colossalai as dist plan.",
)
parser.add_argument(
"--placement",
type=str,
default='cpu',
help="Placement Policy for Gemini. Valid when using colossalai as dist plan.",
)
parser.add_argument(
"--shardinit",
action='store_true',
help=
"Shard the tensors when init the model to shrink peak memory size on the assigned device. Valid when using colossalai as dist plan.",
)
parser.add_argument(
"--batch_size",
type=int,
@ -71,20 +53,6 @@ def parse_args():
return args
# Parameter Sharding Strategies for Tensor Parallelism
def split_param_single_dim_tp1d(dim: int, param: ColoParameter, pg: ProcessGroup):
spec = (ShardSpec([dim], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
param.set_tensor_spec(*spec)
def split_param_row_tp1d(param: ColoParameter, pg: ProcessGroup):
split_param_single_dim_tp1d(0, param, pg)
def split_param_col_tp1d(param: ColoParameter, pg: ProcessGroup):
split_param_single_dim_tp1d(-1, param, pg)
class GPTLMLoss(nn.Module):
def __init__(self):
@ -140,47 +108,6 @@ def set_cpu_maximum_parallelism():
print(f"environmental variable OMP_NUM_THREADS is set to {max_concurrency}.")
# Tensor Parallel
def tensor_parallelize(model: torch.nn.Module, pg: ProcessGroup):
"""tensor_parallelize
Sharding the Model Parameters.
Args:
model (torch.nn.Module): a torch module to be sharded
"""
for mn, module in model.named_modules():
for pn, param in module.named_parameters(recurse=False):
# NOTE() a param maybe shared by two modules
if hasattr(param, 'visited'):
continue
# if shard init, then convert param to replica and use the dp-only ProcessGroup
param: ColoParameter = param
param.set_dist_spec(ReplicaSpec())
param.set_process_group(pg)
# shard it w.r.t tp pattern
if 'mlp.c_fc' in mn:
if 'weight' in pn or 'bias' in pn:
split_param_col_tp1d(param, pg) # column slice
# keep the shape of the output from c_fc
param.compute_spec.set_output_replicate(False)
else:
param.set_dist_spec(ReplicaSpec())
elif 'mlp.c_proj' in mn:
if 'weight' in pn:
split_param_row_tp1d(param, pg) # row slice
else:
param.set_dist_spec(ReplicaSpec())
elif 'wte' in mn or 'wpe' in mn:
split_param_col_tp1d(param, pg) # column slice
elif 'c_attn' in mn or 'c_proj' in mn:
split_param_col_tp1d(param, pg) # column slice
else:
param.set_dist_spec(ReplicaSpec())
param.visited = True
def main():
# version check
# this example is supposed to work for versions greater than 0.2.0
@ -213,30 +140,13 @@ def main():
# build criterion
criterion = GPTLMLoss()
torch.manual_seed(123)
if args.distplan.startswith("CAI"):
# all param must use the same process group.
world_size = torch.distributed.get_world_size()
shard_pg = ProcessGroup(tp_degree=world_size) if args.shardinit else None
default_dist_spec = ShardSpec([-1], [world_size]) if args.shardinit else None
if args.shardinit and args.distplan != "CAI_Gemini":
raise RuntimeError("You can only use shardinit with CAI_Gemini")
ctx = LazyInitContext(default_device=get_current_device()) if args.distplan == "CAI_Gemini" else nullcontext()
# build GPT model
with ColoInitContext(device=get_current_device(),
dtype=torch.half,
default_dist_spec=default_dist_spec,
default_pg=shard_pg):
with ctx:
model = model_builder(args.model_type)(checkpoint=True)
tp_pg = ProcessGroup(tp_degree=args.tp_degree)
# Tensor Parallelism (TP)
# You should notice that v0.1.10 is not compatible with TP degree > 1
if args.tp_degree > 1:
tensor_parallelize(model, tp_pg)
# assign running configurations
if args.distplan == "CAI_ZeRO1":
zero_stage = 1
@ -254,13 +164,7 @@ def main():
overlap_communication=True,
verbose=True)
elif args.distplan == "CAI_Gemini":
plugin = GeminiPlugin(device=get_current_device(),
placement_policy=args.placement,
pin_memory=True,
strict_ddp_mode=args.tp_degree == 1,
search_range_m=128,
hidden_dim=model.config.n_embd,
gpu_margin_mem_ratio=0.)
plugin = GeminiPlugin(search_range_m=128, hidden_dim=model.config.n_embd)
else:
raise RuntimeError

View File

@ -1,22 +1,18 @@
import time
import torch
import tqdm
import transformers
from args import parse_benchmark_args
from transformers import AutoConfig, OPTForCausalLM
from transformers.utils.versions import require_version
import tqdm
import colossalai
from colossalai.nn.optimizer import HybridAdam
from colossalai.logging import disable_existing_loggers, get_dist_logger
from colossalai.tensor import ProcessGroup, ShardSpec
from colossalai.utils import get_current_device
from colossalai.zero import ColoInitContext
from colossalai.booster import Booster
from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin
from colossalai.cluster import DistCoordinator
from args import parse_benchmark_args
from colossalai.logging import disable_existing_loggers, get_dist_logger
from colossalai.nn.optimizer import HybridAdam
require_version("transformers>=4.20.0", "To fix: pip install -r requirements.txt")
@ -61,11 +57,11 @@ def main():
transformers.utils.logging.set_verbosity_info()
else:
transformers.utils.logging.set_verbosity_error()
# Whether to set limit of memory capacity
if args.mem_cap > 0:
colo_memory_cap(args.mem_cap)
# Build OPT model
config = AutoConfig.from_pretrained(args.model_name_or_path)
model = OPTForCausalLM(config=config)
@ -81,11 +77,7 @@ def main():
if args.plugin.startswith('torch_ddp'):
plugin = TorchDDPPlugin()
elif args.plugin == 'gemini':
plugin = GeminiPlugin(device=get_current_device(),
placement_policy='cpu',
pin_memory=True,
strict_ddp_mode=True,
initial_scale=2**5)
plugin = GeminiPlugin(offload_optim_frac=1.0, pin_memory=True, initial_scale=2**5)
elif args.plugin == 'low_level_zero':
plugin = LowLevelZeroPlugin(initial_scale=2**5)
logger.info(f"Set plugin as {args.plugin}", ranks=[0])
@ -96,18 +88,18 @@ def main():
# Set booster
booster = Booster(plugin=plugin, **booster_kwargs)
model, optimizer, _, _, _ = booster.boost(model, optimizer)
SEQ_LEN = 1024
VOCAB_SIZE = 50257
# Start training.
logger.info(f"Start testing", ranks=[0])
progress_bar = tqdm.tqdm(total=args.max_train_steps, desc="Training Step", disable=not coordinator.is_master())
torch.cuda.synchronize()
model.train()
start_time = time.time()
for _ in range(args.max_train_steps):
input_ids, attn_mask = get_data(args.batch_size, SEQ_LEN, VOCAB_SIZE)
@ -119,18 +111,19 @@ def main():
torch.cuda.synchronize()
progress_bar.update(1)
# Compute Statistics
# Compute Statistics
end_time = time.time()
throughput = "{:.4f}".format((world_size * args.max_train_steps * args.batch_size) / (end_time - start_time))
max_mem = format_num(torch.cuda.max_memory_allocated(device=torch.cuda.current_device()), bytes=True)
logger.info(f"Testing finished, "
f"batch size per gpu: {args.batch_size}, "
f"plugin: {args.plugin}, "
f"throughput: {throughput}, "
f"maximum memory usage per gpu: {max_mem}.",
ranks=[0])
logger.info(
f"Testing finished, "
f"batch size per gpu: {args.batch_size}, "
f"plugin: {args.plugin}, "
f"throughput: {throughput}, "
f"maximum memory usage per gpu: {max_mem}.",
ranks=[0])
if __name__ == "__main__":

View File

@ -1,25 +1,20 @@
import time
import torch
import datasets
import torch
import transformers
from transformers import AutoConfig, OPTForCausalLM, AutoTokenizer
from transformers import get_linear_schedule_with_warmup
from transformers.utils.versions import require_version
from args import parse_demo_args
from data import NetflixDataset, netflix_collator
from tqdm import tqdm
from transformers import AutoConfig, AutoTokenizer, OPTForCausalLM, get_linear_schedule_with_warmup
from transformers.utils.versions import require_version
import colossalai
from colossalai.nn.optimizer import HybridAdam
from colossalai.logging import disable_existing_loggers, get_dist_logger
from colossalai.tensor import ProcessGroup, ShardSpec
from colossalai.utils import get_current_device
from colossalai.zero import ColoInitContext
from colossalai.booster import Booster
from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin
from colossalai.cluster import DistCoordinator
from args import parse_demo_args
from data import NetflixDataset, netflix_collator
from colossalai.logging import disable_existing_loggers, get_dist_logger
from colossalai.nn.optimizer import HybridAdam
require_version("datasets>=1.8.0", "To fix: pip install -r requirements.txt")
require_version("transformers>=4.20.0", "To fix: pip install -r requirements.txt")
@ -30,18 +25,18 @@ def move_to_cuda(batch, device):
def train_epoch(epoch, model, optimizer, lr_scheduler, dataloader, booster, coordinator):
torch.cuda.synchronize()
model.train()
with tqdm(dataloader, desc=f'Epoch [{epoch + 1}]', disable=not coordinator.is_master()) as pbar:
for batch in pbar:
# Forward
optimizer.zero_grad()
batch = move_to_cuda(batch, torch.cuda.current_device())
outputs = model(use_cache=False, **batch)
loss = outputs['loss']
@ -72,7 +67,7 @@ def main():
else:
datasets.utils.logging.set_verbosity_error()
transformers.utils.logging.set_verbosity_error()
# Build OPT model
config = AutoConfig.from_pretrained(args.model_name_or_path)
model = OPTForCausalLM.from_pretrained(args.model_name_or_path, config=config)
@ -88,43 +83,35 @@ def main():
if args.plugin.startswith('torch_ddp'):
plugin = TorchDDPPlugin()
elif args.plugin == 'gemini':
plugin = GeminiPlugin(device=get_current_device(),
placement_policy='cpu',
pin_memory=True,
strict_ddp_mode=True,
initial_scale=2**5)
plugin = GeminiPlugin(offload_optim_frac=1.0, pin_memory=True, initial_scale=2**5)
elif args.plugin == 'low_level_zero':
plugin = LowLevelZeroPlugin(initial_scale=2**5)
logger.info(f"Set plugin as {args.plugin}", ranks=[0])
# Prepare tokenizer and dataloader
tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path)
tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path)
dataset = NetflixDataset(tokenizer)
dataloader = plugin.prepare_dataloader(dataset,
batch_size=args.batch_size,
shuffle=True,
drop_last=True,
collate_fn=netflix_collator)
# Set optimizer
optimizer = HybridAdam(model.parameters(),
lr=(args.learning_rate * world_size),
weight_decay=args.weight_decay)
optimizer = HybridAdam(model.parameters(), lr=(args.learning_rate * world_size), weight_decay=args.weight_decay)
# Set lr scheduler
total_steps = len(dataloader) * args.num_epoch
num_warmup_steps = int(args.warmup_ratio * total_steps)
lr_scheduler = get_linear_schedule_with_warmup(
optimizer,
num_warmup_steps=num_warmup_steps,
num_training_steps=len(dataloader) * args.num_epoch
)
lr_scheduler = get_linear_schedule_with_warmup(optimizer,
num_warmup_steps=num_warmup_steps,
num_training_steps=len(dataloader) * args.num_epoch)
# Set booster
booster = Booster(plugin=plugin, **booster_kwargs)
model, optimizer, _, dataloader, lr_scheduler = booster.boost(model=model,
optimizer=optimizer,
dataloader=dataloader,
model, optimizer, _, dataloader, lr_scheduler = booster.boost(model=model,
optimizer=optimizer,
dataloader=dataloader,
lr_scheduler=lr_scheduler)
# Start finetuning

View File

@ -1,5 +1,5 @@
import gzip
import random
from contextlib import nullcontext
from functools import partial
from time import time
@ -8,20 +8,17 @@ import torch
import torch.nn as nn
import torch.optim as optim
import tqdm
from packaging import version
from colossalai.nn import HybridAdam
from palm_pytorch import PaLM
from palm_pytorch.autoregressive_wrapper import AutoregressiveWrapper
from torch.utils.data import DataLoader, Dataset
import colossalai
from colossalai.logging import disable_existing_loggers, get_dist_logger
from colossalai.tensor import ColoParameter, ComputePattern, ComputeSpec, ProcessGroup, ReplicaSpec, ShardSpec
from colossalai.utils import MultiTimer, get_current_device
from colossalai.zero import ColoInitContext, GeminiAdamOptimizer, ZeroDDP
from colossalai.booster import Booster
from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin
from colossalai.lazy import LazyInitContext
from colossalai.logging import disable_existing_loggers, get_dist_logger
from colossalai.nn import HybridAdam
from colossalai.utils import get_current_device
# constants
@ -44,23 +41,10 @@ def parse_args():
help="The distributed plan [colossalai, pytorch].",
)
parser.add_argument(
"--tp_degree",
type=int,
default=1,
help="Tensor Parallelism Degree. Valid when using colossalai as dist plan.",
)
parser.add_argument(
"--placement",
type=str,
default='cpu',
help="Placement Policy for Gemini. Valid when using colossalai as dist plan.",
)
parser.add_argument(
"--shardinit",
type=bool,
default=False,
help=
"Shard the tensors when init the model to shrink peak memory size on the assigned device. Valid when using colossalai as dist plan.",
"--offload_optim_frac",
type=float,
default=1.0,
help="Fraction of optimizer states to be offloaded. This is only used for gemini.",
)
parser.add_argument('-p',
'--plugin',
@ -111,51 +95,6 @@ def get_model_size(model: nn.Module):
return total_numel
# Parameter Sharding Strategies for Tensor Parallelism
def split_param_single_dim_tp1d(dim: int, param: ColoParameter, pg: ProcessGroup):
spec = (ShardSpec([dim], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
param.set_tensor_spec(*spec)
def split_param_row_tp1d(param: ColoParameter, pg: ProcessGroup):
split_param_single_dim_tp1d(0, param, pg)
def split_param_col_tp1d(param: ColoParameter, pg: ProcessGroup):
split_param_single_dim_tp1d(-1, param, pg)
# Tensor Parallel
def tensor_parallelize(model: torch.nn.Module, pg: ProcessGroup):
"""tensor_parallelize
Sharding the Model Parameters.
Args:
model (torch.nn.Module): a torch module to be sharded
"""
for mn, module in model.named_modules():
for pn, param in module.named_parameters(recurse=False):
if hasattr(param, 'visited'):
continue
param.set_dist_spec(ReplicaSpec())
if 'net.0' in mn:
split_param_col_tp1d(param, pg) # column slice
elif 'to_q' in mn:
split_param_col_tp1d(param, pg) # column slice
elif 'to_kv' in mn:
split_param_row_tp1d(param, pg) # row slice
elif 'to_out' in mn:
split_param_row_tp1d(param, pg) # row slice
elif '1.1' in mn:
split_param_col_tp1d(param, pg) # column slice
elif '1.2' in mn:
split_param_row_tp1d(param, pg) # row slice
else:
param.set_dist_spec(ReplicaSpec())
param.visited = True
args = parse_args()
if args.distplan not in ["colossalai", "pytorch"]:
raise TypeError(f"{args.distplan} is error")
@ -212,23 +151,18 @@ if args.distplan == "colossalai":
if args.plugin.startswith('torch_ddp'):
plugin = TorchDDPPlugin()
elif args.plugin == 'gemini':
plugin = GeminiPlugin(placement_policy=args.placement, strict_ddp_mode=True, initial_scale=2 ** 5)
plugin = GeminiPlugin(offload_optim_frac=args.offload_optim_frac, initial_scale=2**5)
elif args.plugin == 'low_level_zero':
plugin = LowLevelZeroPlugin(initial_scale=2 ** 5)
plugin = LowLevelZeroPlugin(initial_scale=2**5)
logger.info(f"plugin: {plugin}")
booster = Booster(plugin=plugin, **booster_kwargs)
default_pg = ProcessGroup(tp_degree=args.tp_degree)
default_dist_spec = ShardSpec([-1], [args.tp_degree]) if args.shardinit else None
ctx = ColoInitContext(device='cpu', default_dist_spec=default_dist_spec, default_pg=default_pg)
ctx = LazyInitContext(default_device=get_current_device()) if args.plugin == 'gemini' else nullcontext()
with ctx:
model = PaLM(num_tokens=50304, dim=4096, depth=64)
model = AutoregressiveWrapper(model, max_seq_len=SEQ_LEN)
pg = default_pg
tensor_parallelize(model, pg)
# optimizer
optimizer = HybridAdam(model.parameters(), lr=LEARNING_RATE, initial_scale=2**5)

View File

@ -3,5 +3,5 @@ torch >= 1.8.1
datasets >= 1.8.0
sentencepiece != 0.1.92
protobuf
accelerate == 0.13.2
accelerate
transformers

View File

@ -30,7 +30,7 @@ from itertools import chain
import datasets
import torch
import torch.distributed as dist
import transformers
import transformers.utils.logging as logging
from accelerate.utils import set_seed
from context import barrier_context
from datasets import load_dataset
@ -57,7 +57,7 @@ from colossalai.logging import disable_existing_loggers, get_dist_logger
from colossalai.nn.optimizer import HybridAdam
from colossalai.tensor import ProcessGroup
from colossalai.utils import get_current_device, get_dataloader
from colossalai.zero import ColoInitContext, ZeroDDP, ZeroOptimizer
from colossalai.zero import GeminiOptimizer
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")
@ -292,10 +292,10 @@ def main():
if is_main_process:
datasets.utils.logging.set_verbosity_warning()
transformers.utils.logging.set_verbosity_info()
logging.set_verbosity_info()
else:
datasets.utils.logging.set_verbosity_error()
transformers.utils.logging.set_verbosity_error()
logging.set_verbosity_error()
if args.mem_cap > 0:
colo_memory_cap(args.mem_cap)
@ -391,16 +391,28 @@ def main():
else:
init_dev = get_current_device()
cai_version = colossalai.__version__
logger.info(f'using Colossal-AI version {cai_version}')
# build model
if version.parse(cai_version) >= version.parse("0.3.1"):
from contextlib import nullcontext
from colossalai.lazy import LazyInitContext
ctx = LazyInitContext(
default_device=init_dev
) if args.model_name_or_path is None or args.model_name_or_path == 'facebook/opt-13b' else nullcontext()
else:
from colossalai.zero import ColoInitContext
ctx = ColoInitContext(device=init_dev)
if args.model_name_or_path is None or args.model_name_or_path == 'facebook/opt-13b':
# currently, there has a bug in pretrained opt-13b
# we can not import it until huggingface fix it
logger.info("Train a new model from scratch", ranks=[0])
with ColoInitContext(device=init_dev):
with ctx:
model = OPTForCausalLM(config)
else:
logger.info("Finetune a pre-trained model", ranks=[0])
with ColoInitContext(device=init_dev):
with ctx:
model = OPTForCausalLM.from_pretrained(args.model_name_or_path,
from_tf=bool(".ckpt" in args.model_name_or_path),
config=config,
@ -410,9 +422,10 @@ def main():
model.gradient_checkpointing_enable()
PLACEMENT_POLICY = 'auto'
cai_version = colossalai.__version__
logger.info(f'using Colossal-AI version {cai_version}')
if version.parse(cai_version) > version.parse("0.1.10"):
if version.parse(cai_version) >= version.parse("0.3.1"):
from colossalai.zero import GeminiDDP
model = GeminiDDP(model, offload_optim_frac=1.0, pin_memory=True)
elif version.parse(cai_version) > version.parse("0.1.10"):
try:
from colossalai.nn.parallel import GeminiDDP
except ImportError:
@ -536,7 +549,6 @@ def main():
]
optimizer = HybridAdam(optimizer_grouped_parameters, lr=args.learning_rate)
optimizer = ZeroOptimizer(optimizer, model, initial_scale=2**14)
# Scheduler and math around the number of training steps.
overrode_max_train_steps = False
@ -551,6 +563,7 @@ def main():
num_warmup_steps=args.num_warmup_steps,
num_training_steps=args.max_train_steps,
)
optimizer = GeminiOptimizer(optimizer, model, initial_scale=2**14)
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)

View File

@ -4,9 +4,9 @@ set -xue
pip install -r requirements.txt
BS=8
BS=4
MEMCAP=0
GPUNUM=2
GPUNUM=4
MODLE="facebook/opt-125m"
torchrun \

View File

@ -4,4 +4,5 @@ markers =
gpu: tests which requires a single GPU
dist: tests which are run in a multi-GPU or multi-machine environment
experiment: tests for experimental features
addopts = --ignore=tests/test_analyzer --ignore=tests/test_auto_parallel --ignore=tests/test_autochunk --ignore=tests/test_moe
addopts = --ignore=tests/test_analyzer --ignore=tests/test_auto_parallel --ignore=tests/test_autochunk --ignore=tests/test_moe --ignore=tests/test_fx

View File

@ -17,6 +17,13 @@ def data_gen_fn():
return dict(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)
def data_gen_for_pretrain():
inputs = data_gen_fn()
inputs['labels'] = inputs['input_ids'].clone()
inputs['sentence_order_label'] = torch.zeros(BATCH_SIZE, dtype=torch.int64)
return inputs
output_transform_fn = lambda x: x
config = transformers.AlbertConfig(embedding_size=128,
@ -26,14 +33,14 @@ config = transformers.AlbertConfig(embedding_size=128,
intermediate_size=256)
model_zoo.register(name='transformers_albert',
model_fn=lambda: transformers.AlbertModel(config),
model_fn=lambda: transformers.AlbertModel(config, add_pooling_layer=False),
data_gen_fn=data_gen_fn,
output_transform_fn=output_transform_fn,
model_attribute=ModelAttribute(has_control_flow=True))
model_zoo.register(name='transformers_albert_for_pretraining',
model_fn=lambda: transformers.AlbertForPreTraining(config),
data_gen_fn=data_gen_fn,
output_transform_fn=output_transform_fn,
data_gen_fn=data_gen_for_pretrain,
output_transform_fn=lambda x: dict(loss=x.loss),
model_attribute=ModelAttribute(has_control_flow=True))
model_zoo.register(name='transformers_albert_for_masked_lm',
model_fn=lambda: transformers.AlbertForMaskedLM(config),

View File

@ -113,6 +113,7 @@ def data_gen_for_qa():
output_transform_fn = lambda x: x
# define loss funciton
loss_fn_for_bert_model = lambda x: torch.nn.functional.mse_loss(x.last_hidden_state, torch.ones_like(x.last_hidden_state
))
loss_fn = lambda x: x.loss
@ -126,7 +127,7 @@ config = transformers.BertConfig(hidden_size=128,
# register the BERT variants
model_zoo.register(name='transformers_bert',
model_fn=lambda: transformers.BertModel(config),
model_fn=lambda: transformers.BertModel(config, add_pooling_layer=False),
data_gen_fn=data_gen,
output_transform_fn=output_transform_fn,
loss_fn=loss_fn_for_bert_model,

View File

@ -57,6 +57,12 @@ def data_gen_for_sequence_classification():
return data
def date_gen_for_double_heads():
data = data_gen_for_lm()
data['mc_labels'] = torch.zeros(data['input_ids'].shape[0], dtype=torch.int64)
return data
# define output transform function
output_transform_fn = lambda x: x
@ -94,8 +100,8 @@ model_zoo.register(name='transformers_gpt_lm',
model_attribute=ModelAttribute(has_control_flow=True))
model_zoo.register(name='transformers_gpt_double_heads',
model_fn=lambda: transformers.GPT2DoubleHeadsModel(config),
data_gen_fn=data_gen_for_lm,
output_transform_fn=output_transform_fn,
data_gen_fn=date_gen_for_double_heads,
output_transform_fn=lambda x: dict(loss=x.loss + x.mc_loss),
loss_fn=loss_fn,
model_attribute=ModelAttribute(has_control_flow=True))
model_zoo.register(name='transformers_gpt_for_question_answering',

View File

@ -12,19 +12,16 @@ from colossalai.lazy.lazy_init import LazyInitContext
from colossalai.nn.optimizer import HybridAdam
from colossalai.tensor.colo_parameter import ColoParameter
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
from colossalai.zero import ColoInitContext
from tests.kit.model_zoo import model_zoo
def run_fn(init_method, model_fn, data_gen_fn, output_transform_fn) -> Optional[str]:
try:
if init_method == 'colo':
ctx = ColoInitContext()
elif init_method == 'lazy':
if init_method == 'lazy':
ctx = LazyInitContext()
else:
ctx = nullcontext()
plugin = GeminiPlugin(placement_policy='cuda', strict_ddp_mode=True, max_norm=1.0, initial_scale=2**5)
plugin = GeminiPlugin(max_norm=1.0, initial_scale=2**5)
booster = Booster(plugin=plugin)
with ctx:
model = model_fn()
@ -50,6 +47,7 @@ def run_fn(init_method, model_fn, data_gen_fn, output_transform_fn) -> Optional[
optimizer.step()
except Exception as e:
# raise e
return repr(e)
@ -57,8 +55,9 @@ def run_fn(init_method, model_fn, data_gen_fn, output_transform_fn) -> Optional[
# @parameterize('init_method', ['lazy', 'none', 'colo'])
@parameterize('subset', ['torchvision', 'transformers', 'diffusers'])
@parameterize('init_method', ['none'])
def check_gemini_plugin(init_method: str = 'none', early_stop: bool = True):
def check_gemini_plugin(subset: str, init_method: str = 'none', early_stop: bool = True):
"""check gemini plugin over model zoo
Args:
@ -71,29 +70,23 @@ def check_gemini_plugin(init_method: str = 'none', early_stop: bool = True):
passed_models = []
failed_info = {} # (model_name, error) pair
for name, (model_fn, data_gen_fn, output_transform_fn, _, _) in model_zoo.items():
for name, (model_fn, data_gen_fn, output_transform_fn, _, _) in model_zoo.get_sub_registry(subset).items():
# These models lead to CUDA error
if name in ('diffusers_auto_encoder_kl', 'diffusers_vq_model', 'diffusers_unet2d_model', 'timm_resmlp',
'timm_gmixer_12_224', 'timm_gmlp_b16_224', 'timm_mixer_b16_224', 'timm_convnext'):
'timm_gmixer_12_224', 'timm_gmlp_b16_224', 'timm_mixer_b16_224', 'timm_convnext',
'torchvision_convnext_base'):
continue
# These models are not compatible with gemini
if name in [
'diffusers_clip_vision_model', 'timm_resnet', 'timm_beit', 'timm_beitv2', 'timm_eca_nfnet',
'timm_efficientformer', 'timm_hrnet_w18_small', 'timm_nf_ecaresnet101', 'timm_nf_regnet_b0',
'timm_skresnet18', 'timm_wide_resnet50_2', 'timm_convit', 'timm_dm_nfnet', 'timm_swin_transformer',
'torchaudio_conformer', 'torchaudio_deepspeech', 'torchaudio_wavernn', 'torchaudio_tacotron',
'deepfm_interactionarch', 'deepfm_simpledeepfmnn', 'dlrm', 'dlrm_interactionarch',
'torchvision_googlenet', 'torchvision_inception_v3', 'torchvision_mobilenet_v3_small',
'torchvision_resnet18', 'torchvision_resnext50_32x4d', 'torchvision_wide_resnet50_2',
'torchvision_vit_b_16', 'torchvision_convnext_base', 'torchvision_swin_s', 'transformers_albert',
'transformers_albert_for_pretraining', 'transformers_bert', 'transformers_bert_for_pretraining',
'transformers_gpt_double_heads', 'torchaudio_hubert_base', 'torchaudio_wav2vec2_base',
'transformers_t5_for_conditional_generation', 'transformers_t5', 'transformers_t5_encoder_model',
'transformers_vit', 'transformers_vit_for_masked_image_modeling',
'transformers_vit_for_image_classification', 'transformers_chatglm',
'transformers_chatglm_for_conditional_generation', 'transformers_blip2',
'transformers_blip2_conditional_gerneration', 'transformers_sam', 'transformers_whisper',
'transformers_whisper_for_conditional_generation', 'transformers_whisper_for_audio_classification'
'timm_convit',
'timm_dm_nfnet',
'torchvision_vit_b_16',
'transformers_t5',
'transformers_t5_for_conditional_generation',
'transformers_t5_encoder_model', # does not support apex rmsnorm
'transformers_chatglm',
'transformers_sam',
'transformers_vit'
]:
continue
@ -105,7 +98,6 @@ def check_gemini_plugin(init_method: str = 'none', early_stop: bool = True):
]:
continue
err = run_fn(init_method, model_fn, data_gen_fn, output_transform_fn)
torch.cuda.empty_cache()
if err is None:
passed_models.append(name)

View File

@ -18,12 +18,45 @@ from colossalai.testing import (
)
from tests.kit.model_zoo import model_zoo
MODEL_PLACEMENT_CONFIGS = [
{
'placement_policy': 'static',
'shard_param_frac': 0.0
}, # zero2
{
'placement_policy': 'static',
'shard_param_frac': 1.0
}, # zero3
{
'placement_policy': 'static',
'shard_param_frac': 0.5
}, # zero3-half
]
OPTIM_PLACEMENT_CONFIGS = [
{
'placement_policy': 'static',
'shard_param_frac': 0.0,
'offload_optim_frac': 0.0
}, # zero2
{
'placement_policy': 'static',
'shard_param_frac': 0.0,
'offload_optim_frac': 1.0
}, # zero2-offload
{
'placement_policy': 'static',
'shard_param_frac': 0.0,
'offload_optim_frac': 0.5
}, # zero2-offload-half
]
@clear_cache_before_run()
@parameterize('placement_policy', ['cuda', 'cpu'])
@parameterize('placement_config', MODEL_PLACEMENT_CONFIGS)
@parameterize('model_name', ['transformers_bert_for_sequence_classification'])
@parameterize('use_safetensors', [False, True])
def exam_state_dict_with_origin(placement_policy, model_name, use_safetensors: bool):
def exam_state_dict_with_origin(placement_config, model_name, use_safetensors: bool):
from transformers import BertForSequenceClassification
(model_fn, data_gen_fn, output_transform_fn, _, _) = next(iter(model_zoo.get_sub_registry(model_name).values()))
bert_model = model_fn()
@ -32,7 +65,7 @@ def exam_state_dict_with_origin(placement_policy, model_name, use_safetensors: b
pretrained_path = os.path.join(tempdir, 'pretrained')
bert_model.config.save_pretrained(save_directory=pretrained_path)
plugin = GeminiPlugin(placement_policy=placement_policy)
plugin = GeminiPlugin(**placement_config)
booster = Booster(plugin=plugin)
bert_model, _, _, _, _ = booster.boost(bert_model)
model_size = sum(p.numel() * p.element_size() for p in bert_model.parameters()) / 1024**2
@ -46,19 +79,19 @@ def exam_state_dict_with_origin(placement_policy, model_name, use_safetensors: b
dist.barrier()
new_bert_model = BertForSequenceClassification.from_pretrained(pretrained_path)
check_state_dict_equal(bert_model.unwrap().state_dict(only_rank_0=False, dtype=torch.float32),
check_state_dict_equal(bert_model.state_dict(only_rank_0=False, dtype=torch.float32),
new_bert_model.state_dict(), False)
@clear_cache_before_run()
@parameterize('placement_policy', ['cuda', 'cpu'])
@parameterize('placement_config', OPTIM_PLACEMENT_CONFIGS)
@parameterize('shard', [False, True])
@parameterize('model_name', ['transformers_gpt'])
@parameterize('size_per_shard', [32])
def exam_state_dict(placement_policy, shard: bool, model_name: str, size_per_shard: int):
def exam_state_dict(placement_config, shard: bool, model_name: str, size_per_shard: int):
(model_fn, data_gen_fn, output_transform_fn, _, _) = next(iter(model_zoo.get_sub_registry(model_name).values()))
criterion = lambda x: x.mean()
plugin = GeminiPlugin(placement_policy=placement_policy, precision="fp16", initial_scale=(2**14))
plugin = GeminiPlugin(**placement_config, precision="fp16", initial_scale=(2**14))
booster = Booster(plugin=plugin)
model = model_fn()
@ -87,12 +120,11 @@ def exam_state_dict(placement_policy, shard: bool, model_name: str, size_per_sha
dist.barrier()
booster.load_model(new_model, model_ckpt_path)
check_state_dict_equal(model.unwrap().state_dict(only_rank_0=False),
new_model.unwrap().state_dict(only_rank_0=False), False)
check_state_dict_equal(model.state_dict(only_rank_0=False), new_model.state_dict(only_rank_0=False), False)
booster.load_optimizer(new_optimizer, optimizer_ckpt_path)
check_state_dict_equal(optimizer.unwrap().state_dict(only_rank_0=False),
new_optimizer.unwrap().state_dict(only_rank_0=False), False)
check_state_dict_equal(optimizer.state_dict(only_rank_0=False), new_optimizer.state_dict(only_rank_0=False),
False)
# Check the new model/optimizer can successfully run.
data = data_gen_fn()

View File

@ -60,12 +60,11 @@ def exam_torch_load_from_gemini(shard: bool, model_name: str):
new_booster.load_model(new_model, model_ckpt_path, strict=True)
# Add prefix to get aligned with pytorch parameter names.
check_state_dict_equal(
model.unwrap().state_dict(only_rank_0=False, prefix='module.module.', dtype=torch.float32),
new_model.state_dict(), False)
check_state_dict_equal(model.state_dict(only_rank_0=False, prefix='module.module.', dtype=torch.float32),
new_model.state_dict(), False)
new_booster.load_optimizer(new_optimizer, optimizer_ckpt_path)
check_state_dict_equal(optimizer.unwrap().state_dict(only_rank_0=False), new_optimizer.state_dict(), False)
check_state_dict_equal(optimizer.state_dict(only_rank_0=False), new_optimizer.state_dict(), False)
# Check the new model/optimizer can successfully run.
data = data_gen_fn()
@ -124,13 +123,12 @@ def exam_gemini_load_from_torch(shard: bool, model_name: str):
new_booster.load_model(new_model, model_ckpt_path, strict=True)
# Add prefix to get aligned with pytorch parameter names.
check_state_dict_equal(
new_model.unwrap().state_dict(only_rank_0=False, prefix='module.module.', dtype=torch.float32),
model.state_dict(), False)
check_state_dict_equal(new_model.state_dict(only_rank_0=False, prefix='module.module.', dtype=torch.float32),
model.state_dict(), False)
new_booster.load_optimizer(new_optimizer, optimizer_ckpt_path)
old_state_dict = optimizer.state_dict()
new_state_dict = new_optimizer.unwrap().state_dict(only_rank_0=False)
new_state_dict = new_optimizer.state_dict(only_rank_0=False)
# Comparison of param_groups needs special care here,
# since not all hyperparameters in Adam are used by HybridAdam
@ -138,7 +136,7 @@ def exam_gemini_load_from_torch(shard: bool, model_name: str):
for old_group, new_group in zip(old_state_dict['param_groups'], new_state_dict['param_groups']):
for k in hyperparameters_to_examine:
assert k in old_group and k in new_group, \
f"Old group's keys: {list(old_group.keys())}, New group's keys: {list(new_group.keys())}"
f"Old group's keys: {list(old_group.keys())}, New group's keys: {list(new_group.keys())}"
assert old_group[k] == new_group[k]
check_state_dict_equal(old_state_dict['state'], new_state_dict['state'], False)

View File

@ -1,104 +0,0 @@
import os
from pathlib import Path
import pytest
import torch
from torchvision import transforms
from torchvision.datasets import CIFAR10
import colossalai
from colossalai.amp import AMP_TYPE
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.engine.schedule._pipeline_schedule_v2 import PipelineScheduleV2
from colossalai.logging import disable_existing_loggers, get_dist_logger
from colossalai.nn import CrossEntropyLoss
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
from colossalai.pipeline.pipelinable import PipelinableContext
from colossalai.testing import rerun_if_address_is_in_use, spawn
from colossalai.trainer import Trainer, hooks
from colossalai.utils import get_dataloader
disable_existing_loggers()
BATCH_SIZE = 4
NUM_EPOCHS = 10
WARMUP_EPOCHS = 5
CONFIG = dict(NUM_MICRO_BATCHES=2,
parallel=dict(pipeline=2, tensor=dict(size=1, mode='1d')),
fp16=dict(mode=AMP_TYPE.NAIVE),
gradient_accumulation=2)
def run_trainer(rank, world_size, port):
disable_existing_loggers()
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
disable_existing_loggers()
# get logger
logger = get_dist_logger()
pipelinable = PipelinableContext()
try:
from titans.model.vit import vit_tiny_patch4_32
except ImportError:
logger.warning('skip the test_cifar_with_data_pipeline_tensor test because titan is not installed')
logger.warning('please install titan from https://github.com/hpcaitech/Titans')
return
with pipelinable:
model = vit_tiny_patch4_32()
pipelinable.to_layer_list()
pipelinable.policy = "uniform"
model = pipelinable.partition(1, gpc.pipeline_parallel_size, gpc.get_local_rank(ParallelMode.PIPELINE))
# create dataloaders
root = Path(os.environ['DATA'])
transform_train = transforms.Compose([
transforms.RandomCrop(32, padding=4, pad_if_needed=True),
transforms.AutoAugment(policy=transforms.AutoAugmentPolicy.CIFAR10),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
train_dataset = CIFAR10(root=root, train=True, download=True, transform=transform_train)
train_dataloader = get_dataloader(dataset=train_dataset, shuffle=True, batch_size=BATCH_SIZE, pin_memory=True)
# create loss function
criterion = CrossEntropyLoss(label_smoothing=0.1)
# create optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=0.001, weight_decay=0)
# create lr scheduler
lr_scheduler = CosineAnnealingWarmupLR(optimizer=optimizer, total_steps=NUM_EPOCHS, warmup_steps=WARMUP_EPOCHS)
# initialize
engine, train_dataloader, *_ = colossalai.initialize(model=model,
optimizer=optimizer,
criterion=criterion,
train_dataloader=train_dataloader)
engine._schedule = PipelineScheduleV2(num_microbatches=gpc.config.NUM_MICRO_BATCHES)
logger = get_dist_logger()
trainer = Trainer(engine=engine, logger=logger)
hook_list = [
hooks.LRSchedulerHook(lr_scheduler=lr_scheduler, by_epoch=False),
]
trainer.fit(train_dataloader=train_dataloader,
max_steps=2,
epochs=NUM_EPOCHS,
hooks=hook_list,
display_progress=True)
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_hybrid_parallel():
spawn(run_trainer, 2)
disable_existing_loggers()
if __name__ == '__main__':
test_hybrid_parallel()

View File

@ -1,92 +0,0 @@
import os
import random
from typing import Callable, Type
import numpy as np
import pytest
import torch
import torch.distributed as dist
import colossalai
from colossalai.nn.parallel import ColoDDP
from colossalai.tensor import ProcessGroup
from colossalai.testing import rerun_if_address_is_in_use, spawn
from colossalai.utils.cuda import get_current_device
from colossalai.zero import ColoInitContext, ZeroDDP
from colossalai.zero.gemini.chunk import ChunkManager, search_chunk_configuration
from colossalai.zero.gemini.gemini_mgr import GeminiManager
def set_seed(seed):
random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True
def init_ddp(module: torch.nn.Module) -> ColoDDP:
pg = ProcessGroup()
return ColoDDP(module, process_group=pg)
def init_ddpv2(module: torch.nn.Module) -> ZeroDDP:
chunk_config, *_ = search_chunk_configuration(module, 4, 1024)
chunk_manager = ChunkManager(chunk_config)
gemini_manager = GeminiManager('cuda', chunk_manager)
return ZeroDDP(module, gemini_manager)
class Net(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.fc1 = torch.nn.Linear(3, 3, bias=False)
self.fc2 = torch.nn.Linear(3, 1, bias=False)
def forward(self, x):
return self.fc2(self.fc1(x))
def run_fwd_bwd(ddp_cls: Type[ColoDDP], init_ddp_func: Callable[[torch.nn.Module], ColoDDP]):
with ColoInitContext(device=get_current_device()):
model = Net().cuda()
w1 = model.fc1.weight
w2 = model.fc2.weight
ddp_cls.set_params_to_ignore([w2])
model = init_ddp_func(model)
x = torch.rand(2, 3, device=get_current_device())
logits = model(x)
loss = torch.sum(logits)
model.backward(loss)
if ddp_cls is ZeroDDP:
w1s_grad = w1
else:
w1s_grad = w1.grad
w1_grads = [torch.empty_like(w1) for _ in range(dist.get_world_size())]
dist.all_gather(w1_grads, w1s_grad)
assert torch.equal(w1_grads[0], w1_grads[1])
w2_grads = [torch.empty_like(w2) for _ in range(dist.get_world_size())]
dist.all_gather(w2_grads, w2.grad)
assert not torch.equal(w2_grads[0], w2_grads[1])
def run_dist(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
set_seed(dist.get_rank())
run_fwd_bwd(ColoDDP, init_ddp)
run_fwd_bwd(ZeroDDP, init_ddpv2)
@pytest.mark.dist
@pytest.mark.parametrize('world_size', [2])
@rerun_if_address_is_in_use()
def test_ddp_ignore_params(world_size):
spawn(run_dist, world_size)
if __name__ == '__main__':
test_ddp_ignore_params(2)

View File

@ -1,67 +0,0 @@
from collections import OrderedDict
import pytest
import torch
import colossalai
from colossalai.nn.parallel import ColoDDP
from colossalai.tensor import ColoParameter, ProcessGroup
from colossalai.testing import rerun_if_address_is_in_use, spawn
from colossalai.utils.cuda import get_current_device
from colossalai.zero import ColoInitContext
from tests.components_to_test.registry import non_distributed_component_funcs
def check_state_dict_equal(state_dict: OrderedDict, other_state_dict: OrderedDict):
for (k1, t1), (k2, t2) in zip(state_dict.items(), other_state_dict.items()):
assert k1 == k2
if t1.device != t2.device:
temp_t2 = t2.to(t1.device)
else:
temp_t2 = t2
assert torch.equal(t1, temp_t2), "\t{}\n\t{}".format(t1, temp_t2)
def init_ddp(module: torch.nn.Module) -> ColoDDP:
pg = ProcessGroup()
return ColoDDP(module, process_group=pg)
def run_ddp_state_dict():
get_components_func = non_distributed_component_funcs.get_callable('gpt2')
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
torch_model = model_builder().cuda()
with ColoInitContext(device=get_current_device()):
model = model_builder()
model = init_ddp(model)
torch_state_dict = torch_model.state_dict()
for param in model.parameters():
if isinstance(param, ColoParameter):
assert param.get_process_group() is not None
model.load_state_dict(torch_state_dict)
for param in model.parameters():
if isinstance(param, ColoParameter):
assert param.get_process_group() is not None
state_dict = model.state_dict()
check_state_dict_equal(torch_state_dict, state_dict)
def run_dist(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
run_ddp_state_dict()
@pytest.mark.dist
@pytest.mark.parametrize('world_size', [1, 2])
@rerun_if_address_is_in_use()
def test_state_dict(world_size):
spawn(run_dist, world_size)
if __name__ == '__main__':
test_state_dict(2)

View File

@ -1,47 +0,0 @@
from functools import partial
import pytest
import torch
import torch.distributed as dist
from torch.distributed.distributed_c10d import _get_default_group
import colossalai
from colossalai.nn.parallel.reducer import Reducer
from colossalai.testing import rerun_if_address_is_in_use, spawn
from colossalai.utils.cuda import get_current_device
REDUCE_CNT = 0
def check_eq(grad, grad_clone):
global REDUCE_CNT
print(f'Rank{dist.get_rank()} check {REDUCE_CNT}')
REDUCE_CNT += 1
assert torch.allclose(grad, grad_clone)
def run_reducer():
grads = [torch.rand(64, i + 1, device=get_current_device()) for i in range(10)]
grads_clone = [g.clone().detach() for g in grads]
for g in grads:
dist.all_reduce(g)
reducer = Reducer(bucket_size_mb=1)
for g, g_clone in zip(grads, grads_clone):
reducer.all_reduce_async(g_clone, _get_default_group(), partial(check_eq, g))
reducer.flush()
def run_dist(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
run_reducer()
@pytest.mark.dist
@pytest.mark.parametrize('world_size', [1, 2])
@rerun_if_address_is_in_use()
def test_reducer(world_size):
spawn(run_dist, world_size)
if __name__ == '__main__':
test_reducer(2)

View File

@ -1,73 +0,0 @@
import pytest
import torch
import torch.nn as nn
import colossalai
from colossalai.tensor import ColoTensor, ColoTensorSpec, ProcessGroup
from colossalai.testing import rerun_if_address_is_in_use, spawn
from tests.test_tensor.common_utils import split_param_col_tp1d, split_param_row_tp1d, tensor_equal, tensor_shard_equal
class Conv1D(nn.Module):
"""
1D-convolutional layer as defined by Radford et al. for OpenAI GPT (and also used in GPT-2).
Basically works like a linear layer but the weights are transposed.
Args:
nf (`int`): The number of output features.
nx (`int`): The number of input features.
"""
def __init__(self, nf, nx):
super().__init__()
self.nf = nf
w = torch.empty(nx, nf)
nn.init.normal_(w, std=0.02)
self.weight = nn.Parameter(w)
self.bias = nn.Parameter(torch.ones(nf))
def forward(self, x):
size_out = x.size()[:-1] + (self.nf,)
x = torch.addmm(self.bias, x.view(-1, x.size(-1)), self.weight)
x = x.view(size_out)
return x
def run_with_spec(spec_init_func, split_bias):
model = Conv1D(4, 16).cuda()
world_size = torch.distributed.get_world_size()
pg = ProcessGroup(tp_degree=world_size)
weight = ColoTensor(torch.nn.Parameter(model.weight.detach()), ColoTensorSpec(pg))
bias = ColoTensor(torch.nn.Parameter(model.bias.detach()), ColoTensorSpec(pg))
spec_init_func(weight, pg)
if split_bias:
spec_init_func(bias, pg)
x = torch.rand(2, 16).cuda()
out = model(x)
colo_out = torch.addmm(bias, x, weight)
colo_out = colo_out.to_replicate()
assert tensor_equal(out, colo_out)
grad = torch.rand_like(out)
out.backward(grad)
colo_out.backward(grad)
tensor_shard_equal(model.weight.grad, weight.grad, pg.tp_local_rank(), pg.tp_world_size())
tensor_shard_equal(model.bias.grad, bias.grad, pg.tp_local_rank(), pg.tp_world_size())
def run_dist(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
run_with_spec(spec_init_func=split_param_row_tp1d, split_bias=False)
run_with_spec(spec_init_func=split_param_col_tp1d, split_bias=True)
@pytest.mark.dist
@pytest.mark.parametrize('world_size', [1, 4])
@rerun_if_address_is_in_use()
def test_addmm_1d(world_size):
spawn(run_dist, world_size)
if __name__ == '__main__':
test_addmm_1d(4)

View File

@ -1,43 +0,0 @@
import pytest
import torch
from torch.nn import functional as F
import colossalai
from colossalai.tensor import ColoParameter, ColoTensorSpec, ProcessGroup
from colossalai.testing import rerun_if_address_is_in_use, spawn
from tests.test_tensor.common_utils import split_param_col_tp1d, tensor_equal, tensor_shard_equal
def run_with_spec(spec_init_func):
pg = ProcessGroup(tp_degree=torch.distributed.get_world_size())
model = torch.nn.EmbeddingBag(10, 4).cuda()
weight = ColoParameter(model.weight.clone(), True, ColoTensorSpec(pg))
spec_init_func(weight, pg)
inputs = torch.tensor([1, 2, 4, 5, 4, 3, 2, 9]).cuda()
offsets = torch.tensor([0, 4]).cuda()
out = model(inputs, offsets=offsets)
colo_out = F.embedding_bag(inputs, weight, offsets=offsets)
assert tensor_equal(out, colo_out)
grad = torch.rand_like(out)
out.backward(grad)
colo_out.backward(grad)
assert tensor_shard_equal(model.weight.grad, weight.grad, pg.tp_local_rank(), pg.tp_world_size())
def run_dist(rank, world_size, port):
config = dict(parallel=dict(tensor=dict(mode="1d", size=world_size),))
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
run_with_spec(split_param_col_tp1d)
@pytest.mark.dist
@pytest.mark.parametrize('world_size', [1, 4])
@rerun_if_address_is_in_use()
def test_embedding_bag_1d(world_size):
spawn(run_dist, world_size)
if __name__ == '__main__':
test_embedding_bag_1d(4)

View File

@ -1,44 +0,0 @@
import pytest
import torch
from torch.nn import functional as F
import colossalai
from colossalai.tensor import ColoTensor, ColoTensorSpec, ProcessGroup
from colossalai.testing import rerun_if_address_is_in_use, spawn
from tests.test_tensor.common_utils import split_param_col_tp1d, split_param_row_tp1d, tensor_equal, tensor_shard_equal
def run_with_spec(spec_init_func, pg: ProcessGroup):
model = torch.nn.Embedding(12, 32).cuda()
weight = ColoTensor(torch.nn.Parameter(model.weight.detach()), ColoTensorSpec(pg))
spec_init_func(weight, pg)
x = torch.tensor((0, 3, 6, 9)).cuda()
out = model(x)
colo_out = F.embedding(x, weight)
assert tensor_equal(out, colo_out)
grad = torch.rand_like(out)
out.backward(grad)
colo_out.backward(grad)
# compare grad inside a TP group
assert tensor_shard_equal(model.weight.grad, weight.grad, pg.tp_local_rank(), pg.tp_world_size())
def run_dist(rank, world_size, port):
# config = dict(parallel=dict(tensor=dict(mode="1d", size=world_size),))
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
pg = ProcessGroup(tp_degree=world_size)
run_with_spec(split_param_row_tp1d, pg)
run_with_spec(split_param_col_tp1d, pg)
@pytest.mark.dist
@pytest.mark.parametrize('world_size', [1, 4])
@rerun_if_address_is_in_use()
def test_embedding_1d(world_size):
spawn(run_dist, world_size)
if __name__ == '__main__':
test_embedding_1d(4)

View File

@ -1,48 +0,0 @@
import pytest
import torch
import torch.nn.functional as F
import colossalai
from colossalai.tensor import ColoTensor, ColoTensorSpec, ProcessGroup
from colossalai.testing import rerun_if_address_is_in_use, spawn
from tests.test_tensor.common_utils import split_param_col_tp1d, split_param_row_tp1d, tensor_equal, tensor_shard_equal
def run_with_spec(spec_init_func, split_bias):
pg = ProcessGroup(tp_degree=torch.distributed.get_world_size())
model = torch.nn.Linear(4, 8).cuda()
weight = ColoTensor(torch.nn.Parameter(model.weight.detach()), ColoTensorSpec(pg))
bias = ColoTensor(torch.nn.Parameter(model.bias.detach()), ColoTensorSpec(pg))
spec_init_func(weight, pg)
if split_bias:
spec_init_func(bias, pg)
x = torch.rand(2, 4).cuda()
out = model(x)
colo_out = F.linear(x, weight, bias)
colo_out = colo_out.to_replicate()
assert tensor_equal(out, colo_out)
grad = torch.rand_like(out)
out.backward(grad)
colo_out.backward(grad)
assert tensor_shard_equal(model.weight.grad, weight.grad, pg.tp_local_rank(), pg.tp_world_size())
assert tensor_shard_equal(model.bias.grad, bias.grad, pg.tp_local_rank(), pg.tp_world_size())
def run_dist(rank, world_size, port):
config = dict(parallel=dict(tensor=dict(mode="1d", size=world_size),))
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
run_with_spec(spec_init_func=split_param_col_tp1d, split_bias=False)
run_with_spec(spec_init_func=split_param_row_tp1d, split_bias=True)
@pytest.mark.dist
@pytest.mark.parametrize('world_size', [1, 4])
@rerun_if_address_is_in_use()
def test_linear_1d(world_size):
spawn(run_dist, world_size)
if __name__ == '__main__':
test_linear_1d(4)

View File

@ -1,48 +0,0 @@
import pytest
import torch
import torch.nn.functional as F
import colossalai
from colossalai.tensor import ColoTensor, ColoTensorSpec, ComputePattern, ComputeSpec, ProcessGroup, ShardSpec
from colossalai.testing import rerun_if_address_is_in_use, spawn
from colossalai.utils import get_current_device
def check_cross_entropy():
input_t = torch.randn(4, 4, device=get_current_device(), requires_grad=True)
input_ct = torch.randn(4, 4, device=get_current_device(), requires_grad=True)
with torch.no_grad():
input_ct.copy_(input_t)
target = torch.randint(4, (4,), dtype=torch.int64, device=get_current_device())
world_size = torch.distributed.get_world_size()
pg = ProcessGroup(tp_degree=world_size)
input_t_colo = ColoTensor.from_torch_tensor(tensor=input_ct, spec=ColoTensorSpec(pg))
input_shard = input_t_colo.redistribute(ShardSpec([-1], [pg.tp_world_size()]))
input_shard.set_tensor_spec(dist_spec=None, compute_spec=ComputeSpec(ComputePattern.TP1D))
output = F.cross_entropy(input_t, target)
output_colo = F.cross_entropy(input_shard, target)
assert torch.allclose(output_colo, output)
output.backward()
output_colo.backward()
assert torch.allclose(input_t.grad, input_ct.grad)
def run_dist(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
check_cross_entropy()
@pytest.mark.dist
@pytest.mark.parametrize('world_size', [1, 2])
@rerun_if_address_is_in_use()
def test_loss_func(world_size):
spawn(run_dist, world_size)
if __name__ == '__main__':
test_loss_func(1)

View File

@ -1,87 +0,0 @@
import pytest
import torch
import torch.nn.functional as F
from torch.nn import Parameter
import colossalai
from colossalai.tensor import ColoTensor, ColoTensorSpec, ProcessGroup, ShardSpec
from colossalai.testing import rerun_if_address_is_in_use, spawn
from colossalai.utils import get_current_device
def _run_layer_norm():
ln_op = torch.nn.LayerNorm(2, 3, device=get_current_device())
input_t = torch.randn(3, 2, device=get_current_device())
pg = ProcessGroup(tp_degree=torch.distributed.get_world_size())
input_t_colo = ColoTensor.from_torch_tensor(input_t.clone().detach(), ColoTensorSpec(pg))
# prepare colossalai LN
weight = ColoTensor(Parameter(ln_op.weight.detach()), ColoTensorSpec(pg))
bias = ColoTensor(Parameter(ln_op.bias.detach()), ColoTensorSpec(pg))
output = ln_op(input_t)
output_colo = F.layer_norm(input_t_colo, ln_op.normalized_shape, weight, bias, ln_op.eps)
assert torch.allclose(output_colo, output)
torch.mean(output).backward()
torch.mean(output_colo).backward()
assert torch.allclose(ln_op.weight.grad, weight.grad)
def check_spec_eq(tensor, other):
assert isinstance(tensor, ColoTensor) and isinstance(other, ColoTensor)
for k in dir(tensor.dist_spec):
if not k.startswith('__'):
assert hasattr(other.dist_spec, k), f"{k}"
assert getattr(tensor.dist_spec, k) == getattr(other.dist_spec, k)
def check_element_wise_ops():
world_size = torch.distributed.get_world_size()
pg = ProcessGroup(tp_degree=world_size)
t = torch.rand(2, 2)
x = ColoTensor(t, spec=ColoTensorSpec(pg, ShardSpec([0], [pg.tp_world_size()])))
check_spec_eq(x, x.cuda())
assert torch.equal(x.cuda(), t.cuda())
check_spec_eq(x, torch.abs(x))
assert torch.equal(torch.abs(x), torch.abs(t))
check_spec_eq(x, F.sigmoid(x))
assert torch.equal(F.sigmoid(x), F.sigmoid(t))
def run_dist(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
check_element_wise_ops()
_run_layer_norm()
@pytest.mark.dist
@pytest.mark.parametrize('world_size', [2])
@rerun_if_address_is_in_use()
def test_element_wise_ops(world_size):
spawn(run_dist, world_size)
def run_dist2(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
_run_layer_norm()
@pytest.mark.dist
@pytest.mark.parametrize('world_size', [1])
@rerun_if_address_is_in_use()
def test_ln(world_size):
spawn(run_dist2, world_size)
def check_all():
test_element_wise_ops(2)
if __name__ == '__main__':
check_all()

View File

@ -1,97 +0,0 @@
import pytest
import torch
import torch.distributed as dist
import colossalai
from colossalai.tensor import ColoTensor, ColoTensorSpec, ProcessGroup, ShardSpec
from colossalai.tensor.distspec import DistPlacementPattern
from colossalai.testing import rerun_if_address_is_in_use, spawn
from colossalai.utils import get_current_device
from tests.test_tensor.common_utils import debug_print, split_param_col_tp1d, split_param_row_tp1d
def exam_view_core(pg):
# the case of replicated ColoTensors
x = torch.randn(4, 4).cuda()
x_colo = ColoTensor(x, ColoTensorSpec(pg))
y = x.view(2, -1, 2)
y_colo = x_colo.view(2, -1, 2)
assert torch.all(y == y_colo)
assert y_colo.dist_spec.placement == DistPlacementPattern.REPLICATE
# the perfect case of col-sliced ColoTensors
split_param_col_tp1d(x_colo, pg)
z = x.view(torch.Size((2, 1, 2, -1)))
z_colo = x_colo.view(torch.Size((2, 1, 2, -1)))
if dist.get_rank() == 0:
z = z[:, :, :, 0:2]
else:
z = z[:, :, :, 2:]
assert torch.all(z == z_colo)
assert z_colo.dist_spec == x_colo.dist_spec
# the perfect case of row-sliced ColoTensors
split_param_row_tp1d(x_colo, pg)
z = x.view(torch.Size((-1, 2, 2)))
z_colo = x_colo.view(torch.Size((-1, 2, 2)))
if dist.get_rank() == 0:
z = z[0:2, :, :]
else:
z = z[2:, :, :]
assert torch.all(z == z_colo)
assert z_colo.dist_spec == x_colo.dist_spec
# the normal case of row-sliced ColoTensors
z = x.view(-1, 2, 2, 2)
z_colo = x_colo.view(-1, 2, 2, 2)
assert torch.all(z == z_colo)
assert y_colo.dist_spec.placement == DistPlacementPattern.REPLICATE
def exam_view_autograd(pg):
x = torch.randn(8, 2, device=get_current_device(), requires_grad=True)
y = torch.randn(8, 2, device=get_current_device(), requires_grad=True)
with torch.no_grad():
y.copy_(x)
y = ColoTensor(y, ColoTensorSpec(pg))
y_slice = y.redistribute(ShardSpec([-1], [pg.tp_world_size()]))
xx = x.view(2, 2, -1)
yy_slice = y_slice.view(2, 2, -1)
yy = yy_slice.to_replicate()
grad = torch.randn(2, 2, 4, device=get_current_device())
xx.backward(grad)
yy.backward(grad)
assert torch.all(x.grad == y.grad)
def exam_view_errors(pg):
x = torch.randn(8, 2, device=get_current_device())
x = ColoTensor(x, ColoTensorSpec(pg))
split_param_row_tp1d(x, pg)
x.view('a', 'b', 'c')
x.view(8, -1)
x.view([-2, -2, -2])
x.view((-1, -1, -1))
def run_dist(rank, world_size, port):
colossalai.launch(config=dict(), rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
pg = ProcessGroup(tp_degree=torch.distributed.get_world_size())
exam_view_core(pg)
exam_view_autograd(pg)
# exam_view_errors(pg)
@pytest.mark.dist
@pytest.mark.parametrize('world_size', [2])
@rerun_if_address_is_in_use()
def test_view(world_size):
spawn(run_dist, world_size)
if __name__ == '__main__':
test_view(2)

View File

@ -1,3 +1,4 @@
import pytest
import torch
from colossalai.pipeline.pipelinable import PipelinableContext
@ -48,6 +49,7 @@ def run_pipelinable(rank, world_size, port):
assert layers_count_in_part_0 + layers_count_in_part_1 == pipelinable.layers_count
@pytest.mark.skip(reason="this is useless")
@rerun_if_address_is_in_use()
def test_pipelinable():
spawn(run_pipelinable, 1)

View File

@ -127,6 +127,10 @@ def check_gpt2(rank, world_size, port):
run_gpt2_test()
# TODO(ver217): fix this
@pytest.mark.skip("this will stuck in CI")
@pytest.mark.dist
@rerun_if_address_is_in_use()
@clear_cache_before_run()

View File

@ -1,153 +0,0 @@
import pytest
import torch
from numpy import allclose
import colossalai
from colossalai.core import global_context as gpc
from colossalai.tensor import ColoTensor, ColoTensorSpec, ProcessGroup, ReplicaSpec, ShardSpec, distspec
from colossalai.testing import rerun_if_address_is_in_use, spawn
def _run_tensor_indexing():
pg = ProcessGroup()
torch_t = torch.randn(2, 3)
colo_t = ColoTensor(torch_t, ColoTensorSpec(pg))
assert allclose(torch_t[:, 1], colo_t[:, 1])
def _run_wrapped_tensor_func():
pg = ProcessGroup()
t_ref = torch.randn(4, 5)
t = ColoTensor.from_torch_tensor(t_ref.clone(), ColoTensorSpec(pg))
# non-func attr
assert t.is_cuda == t_ref.is_cuda
# return 1 torch.Tensor
t_abs = t.abs()
assert isinstance(t_abs, ColoTensor) and torch.equal(t_abs, t_ref.abs())
# return 1 non-torch.Tensor
assert t.dim() == t_ref.dim()
# return >1 torch.Tensor
assert isinstance(t, ColoTensor)
t_split1, t_split2 = t.split(2)
assert isinstance(t_split1, ColoTensor) and isinstance(t_split2, ColoTensor), f"{type(t_split1)} {type(t_split2)}"
def _run_operand(world_size):
pg = ProcessGroup()
t_ref = torch.randn(4, 5)
t = ColoTensor.from_torch_tensor(t_ref.clone(), ColoTensorSpec(pg))
t_ref_res = t_ref + t_ref
t_res = t + t
assert isinstance(t_res, ColoTensor)
assert torch.allclose(t_ref_res, t_res)
pg = ProcessGroup(tp_degree=world_size)
t = ColoTensor.from_torch_tensor(t_ref.clone(), ColoTensorSpec(pg))
t.set_dist_spec(ShardSpec([0], [world_size]))
t_new = torch.zeros_like(t)
assert isinstance(t_new, ColoTensor)
assert t_new.is_sharded()
#### Test Distributed init a Colotensor
def _run_view(world_size):
t_ref = torch.randn(4, 5)
rank = gpc.get_global_rank()
pg = ProcessGroup(rank, list(range(world_size)), tp_degree=world_size)
t = ColoTensor.from_torch_tensor(
t_ref, ColoTensorSpec(pg, dist_attr=ShardSpec(dims=[0], num_partitions=[pg.tp_world_size()])))
assert t.size_global()[0] == 4 * world_size
assert t.size_global(1) == 5
assert t.size_global() == torch.Size([4 * world_size, 5])
t = t.view(4 * 5 * world_size)
assert t.shape == torch.Size([4 * 5 * world_size])
def _run_tensor_shard_init(world_size):
t_ref = torch.randn(4, 5)
pg = ProcessGroup(tp_degree=world_size)
shard_attr = ShardSpec(dims=[0], num_partitions=[pg.tp_world_size()])
tensor_spec = ColoTensorSpec(pg, dist_attr=shard_attr)
t = ColoTensor.from_torch_tensor(t_ref.clone(), tensor_spec)
t.set_dist_spec(ReplicaSpec())
assert t.shape == torch.Size((4 * world_size, 5)), f"{t.shape} vs ({4 * world_size, 5})"
def _run_tensor_replicated_init(world_size):
t_ref = torch.randn(4 * world_size, 5)
pg = ProcessGroup()
spec = ColoTensorSpec(pg)
t = ColoTensor.from_torch_tensor(t_ref.clone(), spec)
assert t.shape == torch.Size((4 * world_size, 5)), f"{t.shape}"
def _run_process_group(world_size):
pg1 = ProcessGroup()
pg2 = ProcessGroup()
assert pg1 == pg2
def _run_redistributed(world_size):
if world_size != 4:
return
pg1 = ProcessGroup(tp_degree=2, dp_degree=2)
pg2 = ProcessGroup(tp_degree=4, dp_degree=1)
spec1 = ColoTensorSpec(pg1)
t1 = ColoTensor.from_torch_tensor(torch.randn(2, 3, 4), spec1)
t1 = t1.redistribute(ShardSpec([0], [pg1.tp_world_size()]))
assert t1.is_sharded()
t1 = t1.redistribute(ShardSpec([-1], [pg2.tp_world_size()]), pg2)
assert t1.is_sharded()
pg3 = ProcessGroup(tp_degree=1, dp_degree=4)
t1 = t1.redistribute(ReplicaSpec(), pg3)
assert t1.is_replicate()
def _run_set_tensor_spec(world_size):
if world_size != 4:
return
pg = ProcessGroup(tp_degree=2, dp_degree=2)
spec1 = ColoTensorSpec(pg)
t1 = ColoTensor.from_torch_tensor(torch.randn(2, 3, 4), spec1)
dist_spec2 = ShardSpec([-1], [pg.tp_world_size()])
assert t1.is_replicate()
t1.set_dist_spec(dist_spec2)
assert t1.is_shard_1dcol()
def run_dist_tests(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
_run_tensor_shard_init(world_size)
_run_tensor_replicated_init(world_size)
_run_view(world_size)
_run_process_group(world_size)
_run_tensor_indexing()
_run_operand(world_size)
_run_wrapped_tensor_func()
_run_redistributed(world_size)
_run_set_tensor_spec(world_size)
@pytest.mark.dist
@pytest.mark.parametrize('world_size', [1, 2])
@rerun_if_address_is_in_use()
def test_dist_cases(world_size):
spawn(run_dist_tests, world_size)
if __name__ == '__main__':
test_dist_cases(4)

View File

@ -1,148 +0,0 @@
import pytest
import torch
from torch.nn.parallel import DistributedDataParallel as DDP
import colossalai
from colossalai.nn.parallel.data_parallel import ColoDDP
from colossalai.tensor import ColoTensor, ColoTensorSpec, ComputePattern, ComputeSpec, ProcessGroup, ShardSpec
from colossalai.testing import rerun_if_address_is_in_use, spawn
from colossalai.utils.cuda import get_current_device
from colossalai.zero import ColoInitContext
from tests.components_to_test.registry import non_distributed_component_funcs
from tests.test_tensor.common_utils import (
debug_print,
set_seed,
split_param_col_tp1d,
split_param_row_tp1d,
tensor_equal,
tensor_shard_equal,
)
def init_1d_row_spec(model, pg: ProcessGroup):
tensor_spec = (ShardSpec([0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
for n, p in model.named_parameters():
p.set_process_group(pg)
if 'weight' in n and 'ln' not in n:
p.set_tensor_spec(*tensor_spec)
def init_1d_col_spec(model, pg: ProcessGroup):
spec = (ShardSpec([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
for n, p in model.named_parameters():
p.set_process_group(pg)
if 'ln' not in n and ('weight' in n or 'bias' in n):
p.set_tensor_spec(*spec)
def init_megatron_spec(model, pg: ProcessGroup):
for mn, module in model.named_modules():
# debug_print([0], mn)
for pn, param in module.named_parameters(recurse=False):
# debug_print([0], '\t', pn, param.compute_spec, param.shape)
param.set_process_group(pg)
if 'mlp.c_fc' in mn:
if 'weight' in pn or 'bias' in pn:
split_param_col_tp1d(param, pg)
param.compute_spec.set_output_replicate(False)
else:
raise RuntimeError
elif 'mlp.c_proj' in mn:
if 'weight' in pn:
split_param_row_tp1d(param, pg)
else:
assert 'bias' in pn
elif 'wte' in mn or 'wpe' in mn:
assert 'weight' in pn
split_param_col_tp1d(param, pg)
elif 'c_attn' in mn or 'c_proj' in mn:
split_param_col_tp1d(param, pg)
# debug_print([0], '\t', param.compute_spec, param.shape)
def check_param_equal(model, torch_model, pg: ProcessGroup):
for p, torch_p in zip(model.parameters(), torch_model.parameters()):
assert pg.tp_local_rank() is not None, f"{pg.rank()} {pg.tp_world_size()} {pg._tp_degree} {pg.tp_local_rank()}1"
assert pg.tp_world_size() is not None
assert tensor_shard_equal(torch_p, p, pg.tp_local_rank(), pg.tp_world_size())
def check_grad_equal(model, torch_model, pg: ProcessGroup):
for p, torch_p in zip(model.parameters(), torch_model.parameters()):
assert tensor_shard_equal(torch_p.grad, p.grad, pg.tp_local_rank(), pg.tp_world_size())
def run_gpt(init_spec_func, use_ddp):
world_size = torch.distributed.get_world_size()
# build a PG with TP and DP hybrid
pg = ProcessGroup(dp_degree=(2 if (use_ddp and world_size >= 2) else 1))
# set seed make processes of the same tp group use the same seed
# set_seed(pg.tp_local_rank())
get_components_func = non_distributed_component_funcs.get_callable('gpt2')
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
# make sure torch_model and model has the same parameter values
with ColoInitContext(device=get_current_device()):
model = model_builder()
model = model.cuda()
torch_model = model_builder().cuda()
if use_ddp:
torch_model = DDP(torch_model, device_ids=[pg.rank()], process_group=pg.dp_process_group())
model = ColoDDP(model, process_group=pg)
for torch_p, p in zip(torch_model.parameters(), model.parameters()):
torch_p.data.copy_(p)
init_spec_func(model, pg)
check_param_equal(model, torch_model, pg)
# close the dropout in eval mode
model.eval()
torch_model.eval()
set_seed(pg.dp_local_rank())
torch.distributed.barrier()
for i, (input_ids, label) in enumerate(train_dataloader):
colo_input = ColoTensor.from_torch_tensor(input_ids, ColoTensorSpec(pg))
logits = model(colo_input)
torch_logits = torch_model(input_ids)
assert tensor_equal(torch_logits, logits), f"{torch_logits - logits}"
loss = criterion(logits, input_ids)
torch_loss = criterion(torch_logits, input_ids)
if use_ddp:
model.backward(loss)
else:
loss.backward()
torch_loss.backward()
check_grad_equal(model, torch_model, pg)
if i > 0:
break
set_seed(313)
def run_dist(rank, world_size, port, use_ddp):
if use_ddp and world_size == 1:
return
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
# Comments below tests for speed concern
# run_gpt(init_1d_row_spec, use_ddp)
# run_gpt(init_1d_col_spec, use_ddp)
run_gpt(init_megatron_spec, use_ddp)
@pytest.mark.dist
@pytest.mark.parametrize('world_size', [1, 4])
@pytest.mark.parametrize('use_ddp', [False, True])
@rerun_if_address_is_in_use()
def test_gpt(world_size, use_ddp):
spawn(run_dist, world_size, use_ddp=use_ddp)
if __name__ == '__main__':
test_gpt(4, use_ddp=False)

View File

@ -1,334 +0,0 @@
import pytest
import torch
import colossalai
from colossalai.nn.optimizer import ColossalaiOptimizer
from colossalai.tensor import ColoTensor, ProcessGroup
from colossalai.tensor.colo_parameter import ColoParameter
from colossalai.testing import free_port, rerun_if_address_is_in_use, spawn
from colossalai.utils.cuda import get_current_device
from colossalai.zero import ColoInitContext
from tests.components_to_test.registry import non_distributed_component_funcs
from tests.test_tensor.common_utils import (
check_equal,
set_seed,
split_param_col_tp1d,
split_param_row_tp1d,
tensor_shard_equal,
)
def run_1d_hybrid_tp(model_name):
# A simple net with two stacked nn.Linear
get_components_func = non_distributed_component_funcs.get_callable(model_name)
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
rank = torch.distributed.get_rank()
world_size = torch.distributed.get_world_size()
set_seed(1)
with ColoInitContext(device=get_current_device()):
model = model_builder(checkpoint=True)
if rank == 0:
model_torch = model_builder(checkpoint=True)
model_torch = model_torch.cuda()
optimizer_torch = ColossalaiOptimizer(torch.optim.SGD(model_torch.parameters(), lr=0.1))
# Make two models have the same init params
for p1, p2 in zip(model.parameters(), model_torch.parameters()):
p2.data.copy_(p1.data)
else:
model_torch = None
optimizer_torch = None
pg = ProcessGroup(tp_degree=world_size)
if 'bert' == model_name:
for name, p in model.named_parameters():
if not isinstance(p, ColoTensor):
continue
# num_class = type_vocab_size = 2 | (8, 2)
if 'classifier' in name and 'weight' in name:
split_param_col_tp1d(p, pg)
# num_class = vocab_size = 30524 | (30524, 8)
elif 'word_embeddings' in name and 'weight' in name:
split_param_row_tp1d(p, pg)
# num_class = seq_len = 512 | (512, 8)
elif 'position_embeddings' in name and 'weight' in name:
split_param_row_tp1d(p, pg)
# num_class = type_vocab_size = 2 | (2, 8)
elif 'token_type_embeddings' in name and 'weight' in name:
split_param_col_tp1d(p, pg)
elif "simple_net" == model_name:
# A naive way to set spec for all weights in Linear
for name, p in model.named_parameters():
if not isinstance(p, ColoTensor):
continue
if 'embed' in name and 'weight' in name:
split_param_col_tp1d(p, pg)
if 'proj1' in name and ('weight' in name or 'bias' in name):
split_param_row_tp1d(p, pg)
if 'proj2' in name and 'weight' in name:
split_param_col_tp1d(p, pg)
if 'classifier' in name and ('weight' in name or 'bias' in name):
split_param_row_tp1d(p, pg)
model = model.cuda()
model.eval()
if rank == 0:
model_torch.eval()
colo_optimizer = ColossalaiOptimizer(torch.optim.SGD(model.parameters(), lr=0.1))
for i, (data, label) in enumerate(train_dataloader):
# Zero grad
colo_optimizer.zero_grad()
if rank == 0:
optimizer_torch.zero_grad()
torch.distributed.barrier()
data = data.to(get_current_device())
label = label.to(get_current_device())
torch.distributed.broadcast(data, 0, group=pg.tp_process_group())
torch.distributed.broadcast(label, 0, group=pg.tp_process_group())
# Bcast rank0 data to all processes
if criterion:
output = model(data)
loss = criterion(output, label)
else:
output = model(data, label)
loss = output
# Test output
if rank == 0:
if criterion:
output_torch = model_torch(data)
loss_torch = criterion(output_torch, label)
else:
output_torch = model_torch(data, label)
loss_torch = output_torch
assert torch.allclose(loss, loss_torch, rtol=1e-2), f"model_name {model_name} failed"
torch.distributed.barrier()
loss.backward()
colo_optimizer.step()
if rank == 0:
loss_torch.backward()
optimizer_torch.step()
with torch.no_grad():
# check param
for p, torch_p in zip(model.parameters(), model_torch.parameters()):
assert tensor_shard_equal(torch_p, p, pg.tp_local_rank(), pg.tp_world_size())
torch.distributed.barrier()
if i > 5:
break
# Test the overrided parameters() and named_parameters() member functions
def test_model_parameters():
colossalai.launch(config={}, rank=0, world_size=1, host='localhost', port=free_port(), backend='nccl')
# build a module with 2 Linear, 4 parameters in total.
class Net(torch.nn.Module):
def __init__(self):
super().__init__()
self.fcs = torch.nn.Sequential(torch.nn.Linear(2, 3), torch.nn.Linear(3, 2))
self.extra_param = torch.nn.Parameter(torch.randn(2))
with ColoInitContext(device=get_current_device()):
model = Net()
param_cnt = 0
for name, p in model.named_parameters():
param_cnt += 1
assert param_cnt == 5
for name, colo_p in model.named_parameters():
assert colo_p.is_model_data()
param_cnt = 0
for name, p in model.named_parameters(recurse=False):
param_cnt += 1
assert param_cnt == 1
param_cnt = 0
for p in model.fcs[0].parameters(recurse=False):
param_cnt += 1
assert param_cnt == 2
def test_colo_optimizer():
get_components_func = non_distributed_component_funcs.get_callable('simple_net')
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
set_seed(1)
with ColoInitContext(device=get_current_device()):
model = model_builder(checkpoint=True)
colo_optimizer = ColossalaiOptimizer(torch.optim.SGD(model.parameters(), lr=0.1))
for i, (data, label) in enumerate(train_dataloader):
colo_optimizer.zero_grad()
data = data.to(get_current_device())
label = label.to(get_current_device())
# Bcast rank0 data to all processes
if criterion:
output = model(data)
loss = criterion(output, label)
else:
output = model(data, label)
loss = output
loss.backward()
colo_optimizer.step()
if i > 5:
break
def run_1d_row_tp(model_name: str):
# A simple net with two stacked nn.Linear
get_components_func = non_distributed_component_funcs.get_callable(model_name)
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
rank = torch.distributed.get_rank()
set_seed(1)
with ColoInitContext(device=get_current_device()):
model = model_builder(checkpoint=True)
world_size = torch.distributed.get_world_size()
pg = ProcessGroup(tp_degree=world_size)
set_seed(1)
if rank == 0:
model_torch = model_builder(checkpoint=True)
model_torch = model_torch.cuda()
# A naive way to set spec for all weights in Linear
for mo_name, module in model.named_modules():
# print(mo_name)
for pa_name, param in module.named_parameters(recurse=False):
# print('\t', pa_name, param.shape)
if not isinstance(param, ColoTensor):
continue
if 'weight' in pa_name:
if 'embed' in mo_name and 'token' not in mo_name and 'LayerNorm' not in mo_name:
split_param_row_tp1d(param, pg)
elif 'LayerNorm' not in mo_name and 'ln' not in mo_name:
split_param_col_tp1d(param, pg)
model = model.cuda()
for i, (data, label) in enumerate(train_dataloader):
data = data.to(get_current_device())
label = label.to(get_current_device())
torch.distributed.broadcast(data, 0, group=pg.tp_process_group())
torch.distributed.broadcast(label, 0, group=pg.tp_process_group())
# Bcast rank0 data to all processes
if criterion:
output = model(data)
loss = criterion(output, label)
else:
output = model(data, label)
loss = output
# For reference
if rank == 0:
if criterion:
output_torch = model_torch(data)
loss_torch = criterion(output_torch, label)
else:
output_torch = model_torch(data, label)
loss_torch = output_torch
assert torch.allclose(loss, loss_torch, rtol=1e-2)
torch.distributed.barrier()
loss.backward()
if rank == 0:
loss_torch.backward()
torch.distributed.barrier()
if i > 5:
break
def _run_pretrain_load():
from transformers import BertForMaskedLM
set_seed(1)
model_pretrained = BertForMaskedLM.from_pretrained('bert-base-uncased')
with ColoInitContext(device=get_current_device()):
model = BertForMaskedLM.from_pretrained('bert-base-uncased')
model_pretrained = model_pretrained.cuda()
model = model.cuda()
dict_pretrained = {}
dict_col = {}
c_ref = 0
for name, param in model_pretrained.named_parameters():
dict_pretrained[name] = param
c_ref += 1
c1 = 0
c2 = 0
for name, param in model.named_parameters():
if isinstance(param, ColoParameter):
c1 += 1
else:
c2 += 1
dict_col[name] = param
assert c_ref == c1
assert c2 == 0
if model_pretrained.cls.predictions.decoder.bias is model_pretrained.cls.predictions.bias:
assert model.cls.predictions.decoder.bias is model.cls.predictions.bias
for name, param in dict_pretrained.items():
check_equal(param, dict_col[name])
def run_model_dist(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
# Comment below test for speed consideration
# for name in ['bert', 'simple_net']:
# run_1d_row_tp(name)
for name in ['bert', 'simple_net']:
run_1d_hybrid_tp(name)
@pytest.mark.dist
@pytest.mark.parametrize('world_size', [1, 4])
@rerun_if_address_is_in_use()
def test_model(world_size):
spawn(run_model_dist, world_size)
def run_pretrain_load_dist(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
_run_pretrain_load()
# The test case has to download huggingface pretrained models from the internet
# So we manually trigger the test.
@pytest.mark.skip
@pytest.mark.dist
@pytest.mark.parametrize('world_size', [1, 4])
@rerun_if_address_is_in_use()
def test_pretrain_load(world_size):
spawn(run_pretrain_load_dist, world_size)
if __name__ == '__main__':
# test_model_parameters()
# test_colo_optimizer()
test_model(4)
# test_pretrain_load(4)

View File

@ -1,227 +0,0 @@
from copy import deepcopy
import pytest
import torch
import colossalai
from colossalai.nn.parallel.layers import check_colo_module, init_colo_module
from colossalai.tensor import (
ColoTensor,
ColoTensorSpec,
ComputePattern,
ComputeSpec,
ProcessGroup,
ReplicaSpec,
ShardSpec,
distspec,
)
from colossalai.testing import rerun_if_address_is_in_use, spawn
from colossalai.utils.cuda import get_current_device
from colossalai.zero import ColoInitContext
from tests.components_to_test.registry import non_distributed_component_funcs
from tests.test_tensor.common_utils import set_seed, tensor_equal, tensor_shard_equal
def run_model_with_spec(mode, model_name):
get_components_func = non_distributed_component_funcs.get_callable(model_name)
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
world_size = torch.distributed.get_world_size()
pg = ProcessGroup(tp_degree=world_size)
rank = pg.rank()
set_seed(1)
with ColoInitContext(device=get_current_device()):
model = model_builder(checkpoint=False)
if rank == 0:
model_seq = model_builder(checkpoint=False)
model_seq = model_seq.cuda()
# Make two models have the same init params
for p1, p2 in zip(model.parameters(), model_seq.parameters()):
p2.data.copy_(p1.data)
compute_spec = ComputeSpec(ComputePattern.TP1D)
# Not all layers in Bert can be mod by 4.
# e.g. row shard for all layers is invalid because the first dim of some layer is the classification type size 2.
if 'bert' == model_name:
if 'col' == mode:
init_colo_module(model.bert.embeddings, compute_spec, pg=pg, recursive=True, mode=mode)
init_colo_module(model.bert.encoder, compute_spec, pg=pg, recursive=True, mode=mode)
init_colo_module(model.classifier, compute_spec, pg=pg, recursive=True, mode='row')
elif 'row' == mode:
init_colo_module(model.bert.embeddings, compute_spec, pg=pg, recursive=True, mode='col')
init_colo_module(model.bert.encoder, compute_spec, pg=pg, recursive=True, mode=mode)
init_colo_module(model.classifier, compute_spec, pg=pg, recursive=True, mode=mode)
elif 'simple_net' == model_name:
init_colo_module(model, compute_spec, pg=pg, recursive=True, mode=mode)
model = model.cuda()
for i, (data, label) in enumerate(train_dataloader):
data = data.to(get_current_device())
label = label.to(get_current_device())
torch.distributed.broadcast(data, 0, group=pg.tp_process_group())
torch.distributed.broadcast(label, 0, group=pg.tp_process_group())
if criterion:
output = model(data)
loss = criterion(output, label)
else:
output = model(data, label)
loss = output
# For reference
if rank == 0:
if criterion:
output_seq = model_seq(data)
loss_seq = criterion(output_seq, label)
else:
output_seq = model_seq(data, label)
loss_seq = output_seq
if rank == 0:
with torch.no_grad():
assert torch.allclose(loss, loss_seq, rtol=1e-2)
loss.backward()
if rank == 0:
loss_seq.backward()
with torch.no_grad():
# check param
for p1, p2 in zip(model.parameters(), model_seq.parameters()):
if p1.size() == p2.size():
assert torch.allclose(p1, p2)
else:
if p1.size(-1) < p2.size(-1): # col
world_size = p2.size(-1) // p1.size(-1)
split_p2 = torch.chunk(p2, world_size, dim=-1)[0]
elif p1.size(0) < p2.size(0): # row
world_size = p2.size(0) // p1.size(0)
split_p2 = torch.chunk(p2, world_size, dim=0)[0]
assert torch.allclose(p1, split_p2)
if i > 3:
break
def run_linear_with_spec(mode):
with ColoInitContext(device=get_current_device()):
model = torch.nn.Linear(4, 8)
model_handy = deepcopy(model)
world_size = torch.distributed.get_world_size()
pg = ProcessGroup(tp_degree=world_size)
compute_spec = ComputeSpec(ComputePattern.TP1D)
init_colo_module(model, compute_spec, pg=pg, recursive=True, mode=mode)
x = torch.rand(2, 4).cuda()
colo_x = ColoTensor.from_torch_tensor(x, ColoTensorSpec(pg))
out = model(x)
colo_out = model_handy(colo_x)
assert tensor_equal(out, colo_out)
grad = torch.rand_like(out)
out.backward(grad)
colo_out.backward(grad)
assert tensor_shard_equal(model_handy.weight.grad, model.weight.grad, pg.tp_local_rank(), pg.tp_world_size())
assert tensor_shard_equal(model_handy.bias.grad, model.bias.grad, pg.tp_local_rank(), pg.tp_world_size())
def run_check_shared_param():
from transformers import BertConfig, BertForMaskedLM
hidden_dim = 8
num_head = 4
sequence_length = 12
num_layer = 2
vocab_size = 24
world_size = torch.distributed.get_world_size()
pg = ProcessGroup(tp_degree=world_size)
rank = pg.rank()
config = BertConfig(vocab_size=vocab_size,
hidden_size=hidden_dim,
intermediate_size=hidden_dim * 4,
num_attention_heads=num_head,
max_position_embeddings=sequence_length,
num_hidden_layers=num_layer,
hidden_dropout_prob=0.,
attention_probs_dropout_prob=0.)
with ColoInitContext(device=get_current_device()):
model = BertForMaskedLM(config)
model = model.cuda()
compute_spec = ComputeSpec(ComputePattern.TP1D)
# model.cls.predictions.decoder and model.cls.predictions share the bias, so they should have the same spec
assert len(model.cls.predictions.decoder.bias.shared_param_modules) == 2
# They are all Linear, so both row is allowed. This should pass check.
init_colo_module(model, compute_spec, pg=pg, recursive=True, mode='row')
# This should be detected by check because you can not set weight as row while set bias as col.
col_spec = (ShardSpec([0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
# TODO(jiaruifang) optimize this line
if not model.cls.predictions.bias.has_initialized:
model.cls.predictions.bias.pg = pg
model.cls.predictions.bias.dist_spec = ReplicaSpec()
model.cls.predictions.bias.has_initialized = True
model.cls.predictions.bias.set_tensor_spec(*col_spec)
try:
check_colo_module(model.cls.predictions.decoder, pg=pg, recursive=False)
except Exception as e:
assert 'incorrectly sharded' in str(e)
def run_dist(rank, world_size, port):
config = dict(parallel=dict(tensor=dict(mode="1d", size=world_size),))
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
run_linear_with_spec('col')
run_linear_with_spec('row')
def run_dist_model(rank, world_size, port):
config = dict(parallel=dict(tensor=dict(mode="1d", size=world_size),))
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
for model_name in ['simple_net', 'bert']:
run_model_with_spec('col', model_name)
run_model_with_spec('row', model_name)
def run_dist_check(rank, world_size, port):
config = dict(parallel=dict(tensor=dict(mode="1d", size=world_size),))
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
run_check_shared_param()
@pytest.mark.dist
@pytest.mark.parametrize('world_size', [1, 4])
@pytest.mark.skip("for higher testing speed")
@rerun_if_address_is_in_use()
def test_module_linear_1d(world_size):
spawn(run_dist, world_size)
@pytest.mark.dist
@pytest.mark.parametrize('world_size', [1, 4])
@pytest.mark.skip("for higher testing speed")
@rerun_if_address_is_in_use()
def test_module_model(world_size):
spawn(run_dist_model, world_size)
@pytest.mark.dist
@pytest.mark.parametrize('world_size', [1, 2])
@pytest.mark.skip("for higher testing speed")
@rerun_if_address_is_in_use()
def test_module_check(world_size):
spawn(run_dist_check, world_size)
if __name__ == '__main__':
test_module_linear_1d(4)

View File

@ -1,41 +0,0 @@
import pytest
import torch
import torch.distributed as dist
import colossalai
from colossalai.tensor import ColoTensor, ColoTensorSpec, ComputePattern, ComputeSpec, ProcessGroup, ShardSpec
from colossalai.testing import rerun_if_address_is_in_use, spawn
from colossalai.utils.checkpoint.utils import gather_tensor, scatter_tensor
from tests.test_tensor.common_utils import tensor_shard_equal
def run_dist(rank, world_size, port, dp_degree, tp_degree):
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
pg = ProcessGroup(dp_degree=dp_degree, tp_degree=tp_degree)
x = torch.randn(4, 4)
param = ColoTensor(torch.nn.Parameter(x), spec=ColoTensorSpec(pg))
spec = ShardSpec([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D)
param.set_tensor_spec(*spec)
gather_tensor(param)
if dist.get_rank() == 0:
assert torch.all(x == param)
else:
assert tensor_shard_equal(x, param.data, pg.tp_local_rank(), pg.tp_world_size())
dist.barrier()
scatter_tensor(param, spec[0])
assert tensor_shard_equal(x, param.data, pg.tp_local_rank(), pg.tp_world_size())
assert param.requires_grad is True
dist.barrier()
@pytest.mark.dist
@pytest.mark.parametrize('world_size', [4])
@rerun_if_address_is_in_use()
def test_checkpoint(world_size):
spawn(run_dist, world_size, dp_degree=2, tp_degree=world_size // 2)
if __name__ == '__main__':
test_checkpoint(world_size=4)

View File

@ -1,64 +0,0 @@
import pytest
import torch
import colossalai
from colossalai.tensor import (
ColoParameter,
ColoTensorSpec,
ComputePattern,
ComputeSpec,
ProcessGroup,
ReplicaSpec,
ShardSpec,
)
from colossalai.testing import rerun_if_address_is_in_use, spawn
from colossalai.utils.cuda import get_current_device
from colossalai.zero import ColoInitContext
from tests.components_to_test.registry import non_distributed_component_funcs
from tests.test_tensor.common_utils import set_seed
def run_colo_init_context(rank: int, world_size: int, port: int):
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
# make sure seed of each process is the same, so the params are consistent among processes and the params are exactly replicated.
set_seed(42)
get_components_func = non_distributed_component_funcs.get_callable('gpt2')
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
# keep parameters replicated during init
with ColoInitContext(device=get_current_device()):
model1 = model_builder()
# shard the parameters during init
set_seed(42)
shard_spec = ReplicaSpec()
# If using ShardSpec, the assertations will failed.
# But it is not a bug, the initialized values are not consist with the original one.
# shard_spec = ShardSpec(dims=[0], num_partitions=[world_size])
default_pg = ProcessGroup(tp_degree=world_size)
with ColoInitContext(device=get_current_device(), default_pg=default_pg, default_dist_spec=shard_spec):
model2 = model_builder()
# reshard both models
new_shard = ShardSpec(dims=[-1], num_partitions=[world_size])
for p1, p2 in zip(model1.parameters(), model2.parameters()):
p1: ColoParameter = p1
p1.set_process_group(ProcessGroup(tp_degree=world_size))
p1.set_dist_spec(new_shard)
p2.set_dist_spec(new_shard)
for p1, p2 in zip(model1.parameters(), model2.parameters()):
assert (torch.allclose(p1, p2))
@pytest.mark.dist
@pytest.mark.parametrize('world_size', [1, 4])
@rerun_if_address_is_in_use()
def test_colo_init_context(world_size):
spawn(run_colo_init_context, world_size)
if __name__ == '__main__':
test_colo_init_context(2)

View File

@ -1,232 +0,0 @@
import pytest
import torch
import torch.nn.functional as F
import colossalai
from colossalai.device.device_mesh import DeviceMesh
from colossalai.nn._ops._utils import gather_forward_split_backward
from colossalai.tensor import ColoParameter, ColoTensor, ProcessGroup
from colossalai.tensor.sharding_spec import ShardingSpec
from colossalai.testing import rerun_if_address_is_in_use, spawn
def run_dist(rank, world_size, port):
config = {}
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
# create mlp vars
x = ColoTensor.from_torch_tensor(torch.rand(4, 4, 8, requires_grad=True)).cuda()
w = ColoParameter.from_torch_tensor(torch.rand(16, 8, requires_grad=True)).cuda()
b = ColoParameter.from_torch_tensor(torch.rand(16, requires_grad=True)).cuda()
# run normal forward
out = F.linear(x, w, b)
# create mesh meta
# the mesh is in the following topo
# [[0, 1],
# [2, 3]]
physical_mesh_id = torch.arange(0, 4)
mesh_shape = (2, 2)
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
row_id = rank // 2
column_id = rank % 2
# create pg
row_process_group = None
col_process_group = None
row_to_ranks = {0: [0, 1], 1: [2, 3]}
col_to_ranks = {0: [0, 2], 1: [1, 3]}
for idx in range(2):
# row ranks
row_ranks = row_to_ranks[idx]
row_pg = ProcessGroup(ranks=row_ranks, tp_degree=2)
# col ranks
col_ranks = col_to_ranks[idx]
col_pg = ProcessGroup(ranks=col_ranks, tp_degree=2)
if rank in row_ranks:
row_process_group = row_pg
if rank in col_ranks:
col_process_group = col_pg
########################
# RRR x RS0 -> RRS0 #
########################
# w will be transposed in F.linear
x_replica = x.detach().clone()
w_shard = torch.chunk(w.detach().clone(), chunks=2, dim=0)[row_id]
b_shard = torch.chunk(b.detach().clone(), chunks=2, dim=0)[row_id]
# adding sharding spec
x_replica.sharding_spec = ShardingSpec(device_mesh, x.shape, dim_partition_dict={})
w_shard.sharding_spec = ShardingSpec(device_mesh, w.shape, dim_partition_dict={0: [0]})
b_shard.sharding_spec = ShardingSpec(device_mesh, b.shape, dim_partition_dict={0: [0]})
# check sharding spec
assert str(x_replica.sharding_spec.sharding_sequence) == "[R, R, R]"
assert str(w_shard.sharding_spec.sharding_sequence) == "[S0, R]"
assert str(b_shard.sharding_spec.sharding_sequence) == "[S0]"
w_shard.pg_axis0 = col_process_group
w_shard.pg_axis1 = row_process_group
out_shard = F.linear(x_replica, w_shard, b_shard)
assert str(out_shard.sharding_spec.sharding_sequence) == "[R, R, S0]"
# each row only has a mini-batch
expected_out_shard = torch.chunk(out, chunks=2, dim=2)[row_id]
assert torch.allclose(out_shard, expected_out_shard)
########################
# S0RR x RS1 -> S0RS1 #
########################
# w will be transposed in F.linear
x_shard = torch.chunk(x.detach().clone(), chunks=2, dim=0)[row_id]
w_shard = torch.chunk(w.detach().clone(), chunks=2, dim=0)[column_id]
b_shard = torch.chunk(b.detach().clone(), chunks=2, dim=0)[column_id]
# adding sharding spec
x_shard.sharding_spec = ShardingSpec(device_mesh, x.shape, dim_partition_dict={0: [0]})
w_shard.sharding_spec = ShardingSpec(device_mesh, w.shape, dim_partition_dict={0: [1]})
b_shard.sharding_spec = ShardingSpec(device_mesh, b.shape, dim_partition_dict={0: [1]})
# check sharding spec
assert str(x_shard.sharding_spec.sharding_sequence) == "[S0, R, R]"
assert str(w_shard.sharding_spec.sharding_sequence) == "[S1, R]"
assert str(b_shard.sharding_spec.sharding_sequence) == "[S1]"
w_shard.pg_axis0 = col_process_group
w_shard.pg_axis1 = row_process_group
out_shard = F.linear(x_shard, w_shard, b_shard)
# each row only has a mini-batch
expected_out_shard = torch.chunk(out, chunks=2, dim=0)[row_id]
expected_out_shard = torch.chunk(expected_out_shard, chunks=2, dim=2)[column_id]
assert torch.allclose(out_shard, expected_out_shard)
########################
# S0RS1 x S1R -> S0RR #
########################
# w will be transposed in F.linear
x_shard = torch.chunk(x.clone(), chunks=2, dim=0)[row_id]
x_shard = torch.chunk(x_shard, chunks=2, dim=2)[column_id]
w_shard = torch.chunk(w.clone(), chunks=2, dim=1)[column_id]
b_replica = b.clone()
# adding sharding spec
x_shard.sharding_spec = ShardingSpec(device_mesh, x.shape, dim_partition_dict={0: [0], 2: [1]})
w_shard.sharding_spec = ShardingSpec(device_mesh, w.shape, dim_partition_dict={1: [1]})
b_replica.sharding_spec = ShardingSpec(device_mesh, b.shape, dim_partition_dict={})
# check sharding spec
assert str(x_shard.sharding_spec.sharding_sequence) == "[S0, R, S1]"
assert str(w_shard.sharding_spec.sharding_sequence) == "[R, S1]"
assert str(b_replica.sharding_spec.sharding_sequence) == "[R]"
w_shard.pg_axis0 = col_process_group
w_shard.pg_axis1 = row_process_group
out_shard = F.linear(x_shard, w_shard, b_replica)
# each row only has a mini-batch
expected_out_shard = torch.chunk(out, chunks=2, dim=0)[row_id]
assert torch.allclose(out_shard, expected_out_shard)
########################
# RRS0 x S0R -> RRR #
########################
# w will be transposed in F.linear
x_shard = torch.chunk(x.clone(), chunks=2, dim=2)[row_id]
w_shard = torch.chunk(w.clone(), chunks=2, dim=1)[row_id]
b_replica = b.clone()
# adding sharding spec
x_shard.sharding_spec = ShardingSpec(device_mesh, x.shape, dim_partition_dict={2: [0]})
w_shard.sharding_spec = ShardingSpec(device_mesh, w.shape, dim_partition_dict={1: [0]})
b_replica.sharding_spec = ShardingSpec(device_mesh, b.shape, dim_partition_dict={})
# check sharding spec
assert str(x_shard.sharding_spec.sharding_sequence) == "[R, R, S0]"
assert str(w_shard.sharding_spec.sharding_sequence) == "[R, S0]"
assert str(b_replica.sharding_spec.sharding_sequence) == "[R]"
w_shard.pg_axis0 = col_process_group
w_shard.pg_axis1 = row_process_group
out_shard = F.linear(x_shard, w_shard, b_replica)
# each row only has a mini-batch
expected_out_shard = out
assert torch.allclose(out_shard, expected_out_shard)
########################
# RS0S1 x S1R -> RS0R #
########################
# w will be transposed in F.linear
x_shard = torch.chunk(x.clone(), chunks=2, dim=1)[row_id]
x_shard = torch.chunk(x_shard, chunks=2, dim=2)[column_id]
w_shard = torch.chunk(w.clone(), chunks=2, dim=1)[column_id]
b_replica = b.clone()
# adding sharding spec
x_shard.sharding_spec = ShardingSpec(device_mesh, x.shape, dim_partition_dict={1: [0], 2: [1]})
w_shard.sharding_spec = ShardingSpec(device_mesh, w.shape, dim_partition_dict={1: [1]})
b_replica.sharding_spec = ShardingSpec(device_mesh, b.shape, dim_partition_dict={})
# check sharding spec
assert str(x_shard.sharding_spec.sharding_sequence) == "[R, S0, S1]"
assert str(w_shard.sharding_spec.sharding_sequence) == "[R, S1]"
assert str(b_replica.sharding_spec.sharding_sequence) == "[R]"
w_shard.pg_axis0 = col_process_group
w_shard.pg_axis1 = row_process_group
out_shard = F.linear(x_shard, w_shard, b_replica)
# each row only has a mini-batch
expected_out_shard = torch.chunk(out, chunks=2, dim=1)[row_id]
assert torch.allclose(out_shard, expected_out_shard)
########################
# RRS0 x S0S1 -> RRS1 #
########################
# w will be transposed in F.linear
x_shard = torch.chunk(x.clone(), chunks=2, dim=2)[row_id]
w_shard = torch.chunk(w.clone(), chunks=2, dim=1)[row_id]
w_shard = torch.chunk(w_shard, chunks=2, dim=0)[column_id]
b_shard = torch.chunk(b.clone(), chunks=2, dim=0)[column_id]
# adding sharding spec
x_shard.sharding_spec = ShardingSpec(device_mesh, x.shape, dim_partition_dict={2: [0]})
w_shard.sharding_spec = ShardingSpec(device_mesh, w.shape, dim_partition_dict={0: [1], 1: [0]})
b_shard.sharding_spec = ShardingSpec(device_mesh, b.shape, dim_partition_dict={0: [1]})
# check sharding spec
assert str(x_shard.sharding_spec.sharding_sequence) == "[R, R, S0]"
assert str(w_shard.sharding_spec.sharding_sequence) == "[S1, S0]"
assert str(b_shard.sharding_spec.sharding_sequence) == "[S1]"
w_shard.pg_axis0 = col_process_group
w_shard.pg_axis1 = row_process_group
out_shard = F.linear(x_shard, w_shard, b_shard)
# each row only has a mini-batch
expected_out_shard = torch.chunk(out, chunks=2, dim=2)[column_id]
assert torch.allclose(out_shard, expected_out_shard)
@pytest.mark.dist
@pytest.mark.parametrize('world_size', [4])
@rerun_if_address_is_in_use()
def test_sharded_mlp(world_size):
spawn(run_dist, world_size)
if __name__ == '__main__':
test_sharded_mlp(4)

View File

@ -1,143 +0,0 @@
import pytest
import torch
from torch.nn.parallel import DistributedDataParallel as DDP
import colossalai
from colossalai.amp import convert_to_apex_amp
from colossalai.tensor import ColoTensor, ColoTensorSpec, ComputePattern, ComputeSpec, ProcessGroup, ShardSpec
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
from colossalai.utils.cuda import get_current_device
from colossalai.zero import ColoInitContext, GeminiAdamOptimizer, GeminiDDP, ZeroDDP
from colossalai.zero.gemini import search_chunk_configuration
from tests.components_to_test.registry import non_distributed_component_funcs
from tests.test_tensor.common_utils import set_seed, tensor_shard_equal
from tests.test_tensor.model.test_gpt2 import init_megatron_spec
def check_param(model: ZeroDDP, torch_model: torch.nn.Module, pg: ProcessGroup):
zero_dict = model.state_dict(only_rank_0=False)
torch_dict = torch_model.state_dict()
for key, value in torch_dict.items():
# key is 'module.model.PARAMETER', so we truncate it
key = key[7:]
assert key in zero_dict, "{} not in ZeRO dictionary.".format(key)
temp_zero_value = zero_dict[key].to(device=value.device, dtype=value.dtype)
# debug_print([0], "max range: ", key, torch.max(torch.abs(value - temp_zero_value)))
assert tensor_shard_equal(value, temp_zero_value, pg.tp_local_rank(), pg.tp_world_size()), \
"parameter '{}' has problem.".format(key)
def run_fwd_bwd(model, criterion, optimizer, input_ids):
optimizer.zero_grad()
logits = model(input_ids)
logits = logits.float()
loss = criterion(logits, input_ids)
optimizer.backward(loss)
return logits
def init_1d_row_spec(model, pg: ProcessGroup):
spec = (ShardSpec([0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
for n, p in model.named_parameters():
p.set_process_group(pg)
if 'weight' in n and 'ln' not in n:
p.set_tensor_spec(*spec)
def init_1d_col_spec(model, pg: ProcessGroup):
spec = (ShardSpec([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
for n, p in model.named_parameters():
p.set_process_group(pg)
if 'ln' not in n and ('weight' in n or 'bias' in n):
p.set_tensor_spec(*spec)
@parameterize('placement_policy', ['cuda', 'cpu'])
def run_gpt(placement_policy, tp_init_spec_func=None):
set_seed(42)
get_components_func = non_distributed_component_funcs.get_callable('gpt2')
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
with ColoInitContext(device=get_current_device()):
model = model_builder()
model = model.cuda()
torch_model = model_builder().cuda()
for torch_p, p in zip(torch_model.parameters(), model.parameters()):
torch_p.data.copy_(p.data)
world_size = torch.distributed.get_world_size()
# world size, dp = 2, tp =2, construct a hybrid parallelism.
if world_size == 4:
pg = ProcessGroup(tp_degree=2)
else:
pg = ProcessGroup(tp_degree=world_size)
if tp_init_spec_func:
tp_init_spec_func(model, pg)
dp_world_size = pg.dp_world_size()
config_dict, *_ = search_chunk_configuration(model, search_range_m=1, search_interval=100)
config_dict[dp_world_size]['chunk_size'] = 5000
config_dict[dp_world_size]['keep_gathered'] = False
if placement_policy != 'cuda':
init_device = torch.device('cpu')
else:
init_device = None
model = GeminiDDP(model, init_device, placement_policy, True, False)
# The same as the following 3 lines
# chunk_manager = ChunkManager(config_dict, init_device=init_device)
# gemini_manager = GeminiManager(placement_policy, chunk_manager)
# model = ZeroDDP(model, gemini_manager, pin_memory=True)
zero_optim = GeminiAdamOptimizer(model, lr=1e-3, initial_scale=1)
# The same as the following 2 lines
# optimizer = HybridAdam(model.parameters(), lr=1e-3)
# zero_optim = ZeroOptimizer(optimizer, model, initial_scale=1)
amp_config = dict(opt_level='O2', keep_batchnorm_fp32=False, loss_scale=1)
torch_optim = torch.optim.Adam(torch_model.parameters(), lr=1e-3)
torch_model, torch_optim = convert_to_apex_amp(torch_model, torch_optim, amp_config)
torch_model = DDP(torch_model, device_ids=[pg.rank()], process_group=pg.dp_process_group())
check_param(model, torch_model, pg)
model.eval()
torch_model.eval()
set_seed(pg.dp_local_rank())
for i, (input_ids, label) in enumerate(train_dataloader):
if i > 2:
break
input_ids_colo = ColoTensor.from_torch_tensor(input_ids, ColoTensorSpec(pg))
zero_logits = run_fwd_bwd(model, criterion, zero_optim, input_ids_colo)
torch_logits = run_fwd_bwd(torch_model, criterion, torch_optim, input_ids)
assert torch.allclose(zero_logits, torch_logits, rtol=1e-3, atol=1e-2)
zero_optim.step()
torch_optim.step()
check_param(model, torch_model, pg)
def run_dist(rank, world_size, port):
config = {}
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
if world_size == 4:
run_gpt(tp_init_spec_func=init_megatron_spec)
else:
run_gpt(tp_init_spec_func=init_1d_col_spec)
run_gpt(tp_init_spec_func=init_1d_row_spec)
@pytest.mark.dist
@pytest.mark.parametrize('world_size', [1, 4])
@rerun_if_address_is_in_use()
def test_gpt(world_size):
spawn(run_dist, world_size)
if __name__ == '__main__':
test_gpt(4)

View File

@ -1,206 +0,0 @@
import os
import shutil
from copy import deepcopy
import pytest
import torch
import torch.distributed as dist
from torch.optim.lr_scheduler import CosineAnnealingLR, MultiplicativeLR
import colossalai
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
from colossalai.nn.optimizer import ColossalaiOptimizer
from colossalai.tensor import ColoTensor, ComputePattern, ComputeSpec, ProcessGroup, ShardSpec
from colossalai.testing import rerun_if_address_is_in_use, spawn
from colossalai.utils.checkpoint import load_checkpoint, save_checkpoint
from colossalai.utils.cuda import get_current_device
from colossalai.zero import ColoInitContext
from tests.components_to_test.registry import non_distributed_component_funcs
def init_1d_row_linear(weight: ColoTensor, pg: ProcessGroup):
spec = (ShardSpec([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
weight.set_process_group(pg)
weight.set_tensor_spec(*spec)
def init_1d_col_linear(weight, pg):
spec = (ShardSpec([0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
weight.set_process_group(pg)
weight.set_tensor_spec(*spec)
def init_1d_row_embedding(weight, pg):
spec = (ShardSpec([0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
weight.set_process_group(pg)
weight.set_tensor_spec(*spec)
def init_1d_col_embedding(weight, pg):
spec = (ShardSpec([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
weight.set_process_group(pg)
weight.set_tensor_spec(*spec)
def init_1d_row_for_linear_weight_spec(model, pg: ProcessGroup):
spec = (ShardSpec([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
for name, p in model.named_parameters():
if not isinstance(p, ColoTensor):
continue
if 'embed' in name and 'weight' in name:
init_1d_col_embedding(p, pg)
if 'proj1' in name and ('weight' in name or 'bias' in name):
init_1d_col_linear(p, pg)
if 'proj2' in name and 'weight' in name:
init_1d_row_linear(p, pg)
if 'classifier' in name and ('weight' in name or 'bias' in name):
init_1d_col_linear(p, pg)
def check_param_equal(model, torch_model):
for (n, p), (tn, tp) in zip(model.named_parameters(), torch_model.named_parameters()):
assert torch.all(p.data == tp.data), "{} went wrong.\n {} vs {}\n{}".format(n, p, tp, p.shape)
def remove(path):
""" param <path> could either be relative or absolute. """
if os.path.isfile(path) or os.path.islink(path):
os.remove(path)
elif os.path.isdir(path):
shutil.rmtree(path)
else:
raise ValueError("file {} is not a file or dir.".format(path))
def compare_optims(optim1, optim2):
state1 = optim1.state_dict()['state']
state2 = optim2.state_dict()['state']
for k, p1 in state1.items():
if k not in state2:
continue
p2 = state2[k]
for n, t1 in p1.items():
if n not in p2:
continue
t2 = p2[n]
if isinstance(t1, ColoTensor):
assert isinstance(t2, ColoTensor)
assert torch.allclose(t1, t2, rtol=0, atol=0)
def _run_checkpoint(model_name, init_spec_func, use_ddp, use_mp_reload, test_scheduler, pg):
get_components_func = non_distributed_component_funcs.get_callable(model_name)
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
rank = torch.distributed.get_rank()
world_size = torch.distributed.get_world_size()
# set_seed(1)
with ColoInitContext(device=get_current_device()):
model = model_builder(checkpoint=True)
if use_mp_reload:
if 'bert' == model_name:
for name, p in model.named_parameters():
if not isinstance(p, ColoTensor):
continue
# num_class = type_vocab_size = 2 | (8, 2)
if 'classifier' in name and 'weight' in name:
init_1d_row_linear(p, pg)
# num_class = vocab_size = 30524 | (30524, 8)
elif 'word_embeddings' in name and 'weight' in name:
init_1d_row_embedding(p, pg)
# num_class = seq_len = 512 | (512, 8)
elif 'position_embeddings' in name and 'weight' in name:
init_1d_row_embedding(p, pg)
# num_class = type_vocab_size = 2 | (2, 8)
elif 'token_type_embeddings' in name and 'weight' in name:
init_1d_col_embedding(p, pg)
elif p.process_group.tp_world_size() == 1:
p.set_process_group(pg)
elif "simple_net" == model_name:
init_spec_func(model, pg)
model_reload = deepcopy(model)
model = model.cuda()
model.eval()
model_reload = model_reload.cuda()
model_reload.eval()
opt_class = torch.optim.Adam
colo_optimizer = ColossalaiOptimizer(opt_class(model.parameters(), lr=0.1))
colo_optimizer_reload = ColossalaiOptimizer(opt_class(model_reload.parameters(), lr=0.1))
for i, (data, label) in enumerate(train_dataloader):
# Zero grad
colo_optimizer.zero_grad()
colo_optimizer_reload.zero_grad()
data = data.to(get_current_device())
label = label.to(get_current_device())
dist.broadcast(data, pg.tp_rank_list()[0], pg.tp_process_group())
dist.broadcast(label, pg.tp_rank_list()[0], pg.tp_process_group())
# Bcast rank0 data to all processes
if criterion:
output = model(data)
output_reload = model_reload(data)
loss = criterion(output, label)
loss_reload = criterion(output_reload, label)
else:
loss = model(data, label)
loss_reload = model_reload(data, label)
loss.backward()
loss_reload.backward()
colo_optimizer.step()
colo_optimizer_reload.step()
if i > 2:
break
if not os.path.isdir('./checkpoint') and rank == 0:
os.mkdir('./checkpoint')
dist.barrier()
save_checkpoint('./checkpoint', 0, model, colo_optimizer, None)
load_checkpoint('./checkpoint', 0, model_reload, colo_optimizer_reload, None)
check_param_equal(model, model_reload)
compare_optims(colo_optimizer, colo_optimizer_reload)
if rank == 0:
remove('./checkpoint')
dist.barrier()
def run_dist(rank, world_size, port, use_ddp, use_mp_reload, test_scheduler):
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
pg = ProcessGroup(tp_degree=world_size)
# the data loader of BERT is in DDP mode, causing the input data is not replicated in the TP context
for model_name in ['bert']:
_run_checkpoint(model_name,
init_1d_row_for_linear_weight_spec,
use_ddp,
use_mp_reload,
test_scheduler=test_scheduler,
pg=pg)
@pytest.mark.dist
@pytest.mark.parametrize('world_size', [1, 2])
@pytest.mark.parametrize('use_ddp', [False])
@pytest.mark.parametrize('use_mp_reload', [True, False])
# @pytest.mark.parametrize('test_scheduler', ['colossalai_cosine_warmup', 'torch_cosine', 'torch_lambda'])
@rerun_if_address_is_in_use()
def test_checkpoint(world_size, use_ddp, use_mp_reload, test_scheduler=None):
spawn(run_dist, world_size, use_ddp=use_ddp, use_mp_reload=use_mp_reload, test_scheduler=test_scheduler)
if __name__ == '__main__':
test_checkpoint(2, use_ddp=False, use_mp_reload=True, test_scheduler="torch_cosine")

View File

@ -66,6 +66,7 @@ def run_dist(rank, world_size, port):
run_grad_clip_norm(world_size=world_size)
@pytest.mark.skip("this need to be updated")
@pytest.mark.dist
@pytest.mark.parametrize('world_size', [1, 2])
@rerun_if_address_is_in_use()

View File

@ -1,8 +1,9 @@
import pytest
import torch
from torch.distributed.distributed_c10d import _get_default_group
import colossalai
from colossalai.tensor import ColoTensor, ColoTensorSpec, ProcessGroup
from colossalai.tensor import ColoTensor
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
from colossalai.zero.gemini.chunk import ChunkManager
from tests.test_tensor.common_utils import debug_print
@ -15,19 +16,18 @@ CPU_MEM = {True: {True: 0, False: 0}, False: {True: 512, False: 0}}
@parameterize('keep_gathered', [True, False])
@parameterize('pin_memory', [True, False])
def exam_chunk_memory(keep_gathered, pin_memory):
pg = ProcessGroup()
debug_print([0], "keep_gathered: {}, pin_memory: {}".format(keep_gathered, pin_memory))
params = [ColoTensor(torch.rand(8, 8), spec=ColoTensorSpec(pg)) for _ in range(3)]
params = [ColoTensor(torch.rand(8, 8)) for _ in range(3)]
config = {2: dict(chunk_size=128, keep_gathered=keep_gathered)}
chunk_manager = ChunkManager(config)
assert chunk_manager.total_mem['cpu'] == 0
assert chunk_manager.total_mem['cuda'] == 0
process_group = _get_default_group()
for p in params:
chunk_manager.register_tensor(p, 'param', 2, pin_memory=pin_memory)
chunk_manager.register_tensor(p, 'param', 2, process_group, pin_memory=pin_memory)
chunk_manager.close_all_groups()
assert chunk_manager.total_mem['cpu'] == CPU_MEM[keep_gathered][pin_memory]
assert chunk_manager.total_mem['cuda'] == CUDA_MEM_0[keep_gathered]

View File

@ -1,10 +1,10 @@
import pytest
import torch
import torch.distributed as dist
from torch.distributed.distributed_c10d import _get_default_group
import colossalai
from colossalai.tensor import ColoParameter
from colossalai.tensor import ProcessGroup as ColoProcessGroup
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
from colossalai.utils import get_current_device
from colossalai.zero.gemini import TensorState
@ -36,7 +36,7 @@ def check_equal(param, param_cp):
@parameterize('pin_memory', [True, False])
def exam_chunk_basic(init_device, keep_gathered, pin_memory):
world_size = torch.distributed.get_world_size()
pg = ColoProcessGroup()
pg = _get_default_group()
my_chunk = Chunk(chunk_size=1024,
process_group=pg,
dtype=torch.float32,

View File

@ -1,23 +1,40 @@
import pytest
import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.testing import assert_close
import colossalai
from colossalai.amp import convert_to_apex_amp
from colossalai.nn.optimizer import HybridAdam
from colossalai.tensor import ProcessGroup
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
from colossalai.utils.cuda import get_current_device
from colossalai.zero import ColoInitContext, ZeroDDP, ZeroOptimizer
from colossalai.zero.gemini.chunk import ChunkManager, search_chunk_configuration
from colossalai.zero.gemini.gemini_mgr import GeminiManager
from tests.components_to_test import run_fwd, run_fwd_bwd
from colossalai.zero import GeminiDDP, GeminiOptimizer
from colossalai.zero.gemini.chunk import search_chunk_configuration
from tests.components_to_test import run_fwd_bwd
from tests.components_to_test.registry import non_distributed_component_funcs
from tests.test_tensor.common_utils import set_seed
PLACEMENT_CONFIGS = [
{
'placement_policy': 'static',
'shard_param_frac': 0.0
}, # zero2
{
'placement_policy': 'static',
'shard_param_frac': 1.0
}, # zero3
{
'placement_policy': 'static',
'shard_param_frac': 0.5
}, # zero3-half
{
'placement_policy': 'auto'
}
]
def check_grad(model: ZeroDDP, torch_model: torch.nn.Module):
def check_grad(model: GeminiDDP, torch_model: torch.nn.Module):
chunk_manager = model.chunk_manager
param_list = [p for p in model.parameters()]
chunk_list = chunk_manager.get_chunks(param_list)
@ -28,12 +45,12 @@ def check_grad(model: ZeroDDP, torch_model: torch.nn.Module):
assert_close(p0, p1.grad, rtol=1e-3, atol=5e-5)
@parameterize('placement_policy', ['cuda', 'cpu', 'auto', 'const'])
@parameterize('placement_config', PLACEMENT_CONFIGS)
@parameterize('keep_gather', [False, True])
@parameterize('model_name', ['gpt2', 'bert', 'albert'])
@parameterize('use_grad_checkpoint', [False, True])
def exam_gpt_fwd_bwd(
placement_policy,
placement_config,
keep_gather,
model_name: str,
use_grad_checkpoint: bool = False,
@ -43,8 +60,7 @@ def exam_gpt_fwd_bwd(
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
set_seed(42)
with ColoInitContext(device=init_device):
model = model_builder(use_grad_checkpoint)
model = model_builder(use_grad_checkpoint)
set_seed(42)
torch_model = model_builder(use_grad_checkpoint).cuda()
@ -55,19 +71,17 @@ def exam_gpt_fwd_bwd(
config_dict, *_ = search_chunk_configuration(model, search_range_m=1, search_interval=100)
config_dict[world_size]['chunk_size'] = 5000
config_dict[world_size]['keep_gathered'] = keep_gather
chunk_manager = ChunkManager(config_dict)
gemini_manager = GeminiManager(placement_policy, chunk_manager)
model = ZeroDDP(model, gemini_manager, pin_memory=True)
model = GeminiDDP(model, config_dict, init_device, pin_memory=True, **placement_config)
optimizer = HybridAdam(model.parameters(), lr=1e-3)
zero_optim = ZeroOptimizer(optimizer, model, initial_scale=1)
zero_optim = GeminiOptimizer(optimizer, model, initial_scale=1)
pg = ProcessGroup()
rank = dist.get_rank()
amp_config = dict(opt_level='O2', keep_batchnorm_fp32=False, loss_scale=1)
torch_optim = torch.optim.Adam(torch_model.parameters(), lr=1e-3)
torch_model, torch_optim = convert_to_apex_amp(torch_model, torch_optim, amp_config)
torch_model = DDP(torch_model, device_ids=[pg.rank()], process_group=pg.dp_process_group())
torch_model = DDP(torch_model, device_ids=[rank])
set_seed(pg.dp_local_rank())
set_seed(rank)
for i, (input_ids, label) in enumerate(train_dataloader):
# you can only test a single fwd + bwd.
# after bwd param is grad for Gemini, due to the chunk reuse optimization.
@ -89,65 +103,10 @@ def exam_gpt_fwd_bwd(
check_grad(model, torch_model)
@parameterize('placement_policy', ['cuda', 'cpu'])
@parameterize('keep_gather', [False, True])
@parameterize('model_name', ['gpt2', 'bert', 'albert'])
@parameterize('scatter_after_inference', [False, True])
def exam_gpt_inference(
placement_policy,
keep_gather,
model_name: str,
scatter_after_inference: bool = False,
):
init_device = get_current_device()
get_components_func = non_distributed_component_funcs.get_callable(model_name)
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
set_seed(42)
with ColoInitContext(device=init_device):
model = model_builder()
set_seed(42)
torch_model = model_builder().cuda()
for torch_p, p in zip(torch_model.parameters(), model.parameters()):
torch_p.data.copy_(p.data)
world_size = torch.distributed.get_world_size()
config_dict, *_ = search_chunk_configuration(model, search_range_m=1, search_interval=100)
config_dict[world_size]['chunk_size'] = 5000
config_dict[world_size]['keep_gathered'] = keep_gather
chunk_manager = ChunkManager(config_dict)
gemini_manager = GeminiManager(placement_policy, chunk_manager)
model = ZeroDDP(model, gemini_manager, pin_memory=True, scatter_after_inference=scatter_after_inference)
pg = ProcessGroup()
amp_config = dict(opt_level='O2', keep_batchnorm_fp32=False, loss_scale=1)
torch_optim = torch.optim.Adam(torch_model.parameters(), lr=1e-3)
torch_model, torch_optim = convert_to_apex_amp(torch_model, torch_optim, amp_config)
torch_model = DDP(torch_model, device_ids=[pg.rank()], process_group=pg.dp_process_group())
set_seed(pg.dp_local_rank())
model.eval()
torch_model.eval()
for i, (input_ids, label) in enumerate(train_dataloader):
# you can only test a single fwd + bwd.
# after bwd param is grad for Gemini, due to the chunk reuse optimization.
if i > 0:
break
with torch.no_grad():
input_ids, label = input_ids.cuda(), label.cuda()
torch_loss = run_fwd(torch_model, input_ids, label, criterion)
loss = run_fwd(model, input_ids, label, criterion)
assert torch.equal(torch_loss, loss)
def run_dist(rank, world_size, port):
config = {}
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
exam_gpt_fwd_bwd()
exam_gpt_inference()
@pytest.mark.dist

View File

@ -1,12 +1,11 @@
import pytest
import torch
import torch.distributed as dist
import colossalai
from colossalai.tensor import ProcessGroup
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
from colossalai.zero import ColoInitContext, ZeroDDP
from colossalai.zero.gemini.chunk import ChunkManager, search_chunk_configuration
from colossalai.zero.gemini.gemini_mgr import GeminiManager
from colossalai.zero import GeminiDDP
from colossalai.zero.gemini.chunk import search_chunk_configuration
from colossalai.zero.gemini.memory_tracer.runtime_mem_tracer import RuntimeMemTracer
from tests.components_to_test import run_fwd_bwd
from tests.components_to_test.registry import non_distributed_component_funcs
@ -24,8 +23,7 @@ def run_gemini_use_rmt(placement_policy, keep_gather, model_name: str, use_grad_
get_components_func = non_distributed_component_funcs.get_callable(model_name)
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
with ColoInitContext(device='cpu'):
model = model_builder(use_grad_checkpoint)
model = model_builder(use_grad_checkpoint).cuda()
print(f'model_name {model_name}')
runtime_mem_tracer = RuntimeMemTracer(model)
@ -59,12 +57,13 @@ def run_gemini_use_rmt(placement_policy, keep_gather, model_name: str, use_grad_
config_dict, *_ = search_chunk_configuration(model, search_range_m=1, search_interval=100)
config_dict[world_size]['chunk_size'] = 5000
config_dict[world_size]['keep_gathered'] = keep_gather
chunk_manager = ChunkManager(config_dict)
gemini_manager = GeminiManager(placement_policy, chunk_manager, memstats)
model = ZeroDDP(model, gemini_manager, pin_memory=True)
model = GeminiDDP(model,
chunk_config_dict=config_dict,
placement_policy=placement_policy,
pin_memory=True,
memstats=memstats)
pg = ProcessGroup()
set_seed(pg.dp_local_rank())
set_seed(dist.get_rank())
for i, (input_ids, label) in enumerate(train_dataloader):
# you can only test a single fwd + bwd.
# after bwd param is grad for Gemini, due to the chunk reuse optimization.
@ -76,7 +75,7 @@ def run_gemini_use_rmt(placement_policy, keep_gather, model_name: str, use_grad_
set_seed(42)
loss = run_fwd_bwd(model, input_ids, label, criterion, model)
gemini_non_model_data = gemini_manager._mem_stats_collector._memstats.non_model_data_list('cuda')
gemini_non_model_data = model.gemini_manager._mem_stats_collector._memstats.non_model_data_list('cuda')
# print('gemini non model data:', gemini_non_model_data)
@ -90,6 +89,7 @@ def run_dist(rank, world_size, port):
run_gemini_use_rmt()
@pytest.mark.skip("this is not used")
@pytest.mark.dist
@pytest.mark.parametrize('world_size', [1, 4])
@rerun_if_address_is_in_use()

View File

@ -1,52 +0,0 @@
import pytest
import torch
import colossalai
from colossalai.tensor import ColoParameter
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
from colossalai.utils.cuda import get_current_device
from colossalai.zero import ColoInitContext, GeminiDDP
from colossalai.zero.gemini.utils import get_static_torch_model
from tests.components_to_test.registry import non_distributed_component_funcs
@parameterize('model_name', ['hanging_param_model', 'resnet18', 'gpt2'])
def run_convert_torch_module(model_name: str):
get_components_func = non_distributed_component_funcs.get_callable(model_name)
model_builder, _, _, _, _ = get_components_func()
with ColoInitContext(device=torch.device("cpu")):
model = model_builder(checkpoint=False)
model = GeminiDDP(model, device=get_current_device(), placement_policy='auto', pin_memory=True)
pytorch_model = get_static_torch_model(model, only_rank_0=False)
for n, p in pytorch_model.named_parameters():
assert type(p) == torch.nn.Parameter, f"type error: {n} is a {type(p)}"
# get the static model should not change the original model
for n, p in model.named_parameters():
assert isinstance(p, ColoParameter)
for (pn, pm), (cn, cm) in zip(pytorch_model.named_modules(), model.named_modules()):
assert pn == cn
assert id(pm) != id(cm)
for pp, cp in zip(pm.parameters(recurse=False), cm.parameters(recurse=False)):
assert id(pp) != id(cp)
assert pp.shape == cp.shape
def run_dist(rank, world_size, port):
config = {}
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
run_convert_torch_module()
@pytest.mark.dist
@pytest.mark.parametrize('world_size', [1, 4])
@rerun_if_address_is_in_use()
def test_convert_torch_module(world_size):
spawn(run_dist, world_size)
if __name__ == '__main__':
test_convert_torch_module(2)

View File

@ -8,16 +8,38 @@ import colossalai
from colossalai.amp import convert_to_apex_amp
from colossalai.nn.optimizer import HybridAdam
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
from colossalai.utils.cuda import get_current_device
from colossalai.zero import ColoInitContext, ZeroDDP, ZeroOptimizer
from colossalai.zero.gemini.chunk import ChunkManager, search_chunk_configuration
from colossalai.zero.gemini.gemini_mgr import GeminiManager
from colossalai.zero import GeminiDDP, GeminiOptimizer
from colossalai.zero.gemini.chunk import search_chunk_configuration
from tests.components_to_test import run_fwd_bwd
from tests.components_to_test.registry import non_distributed_component_funcs
from tests.test_tensor.common_utils import set_seed
PLACEMENT_CONFIGS = [
{
'placement_policy': 'static',
'shard_param_frac': 0.0,
'offload_optim_frac': 0.0,
'offload_param_frac': 0.0
}, # zero2
{
'placement_policy': 'static',
'shard_param_frac': 0.0,
'offload_optim_frac': 1.0,
'offload_param_frac': 0.0
}, # zero2-offload
{
'placement_policy': 'static',
'shard_param_frac': 0.0,
'offload_optim_frac': 0.5,
'offload_param_frac': 0.0
}, # zero2-offload-half
{
'placement_policy': 'auto'
}
]
def check_param(model: ZeroDDP, torch_model: torch.nn.Module):
def check_param(model: GeminiDDP, torch_model: torch.nn.Module):
zero_dict = model.state_dict(only_rank_0=False)
torch_dict = torch_model.state_dict()
@ -30,9 +52,9 @@ def check_param(model: ZeroDDP, torch_model: torch.nn.Module):
assert_close(value, temp_zero_value, rtol=1e-3, atol=4e-3)
@parameterize('placement_policy', ['cuda', 'cpu', 'auto', 'const'])
@parameterize('placement_config', PLACEMENT_CONFIGS)
@parameterize('model_name', ['gpt2'])
def exam_grad_clipping(placement_policy, model_name: str):
def exam_grad_clipping(placement_config, model_name: str):
set_seed(1912)
get_components_func = non_distributed_component_funcs.get_callable(model_name)
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
@ -43,9 +65,7 @@ def exam_grad_clipping(placement_policy, model_name: str):
torch_model, torch_optim = convert_to_apex_amp(torch_model, torch_optim, amp_config)
torch_model = DDP(torch_model, device_ids=[dist.get_rank()])
init_dev = get_current_device()
with ColoInitContext(device=init_dev):
model = model_builder()
model = model_builder()
for torch_p, p in zip(torch_model.parameters(), model.parameters()):
p.data.copy_(torch_p.data)
@ -54,16 +74,19 @@ def exam_grad_clipping(placement_policy, model_name: str):
config_dict, *_ = search_chunk_configuration(model, search_range_m=1, search_interval=100)
config_dict[world_size]['chunk_size'] = 5000
config_dict[world_size]['keep_gathered'] = False
if placement_policy != 'cuda':
if placement_config['placement_policy'] != 'cuda':
init_device = torch.device('cpu')
else:
init_device = None
chunk_manager = ChunkManager(config_dict, init_device=init_device)
gemini_manager = GeminiManager(placement_policy, chunk_manager)
model = ZeroDDP(model, gemini_manager, pin_memory=True)
model = GeminiDDP(model,
chunk_config_dict=config_dict,
chunk_init_device=init_device,
pin_memory=True,
**placement_config)
optimizer = HybridAdam(model.parameters(), lr=1e-3)
zero_optim = ZeroOptimizer(optimizer, model, initial_scale=32, clipping_norm=1.0)
zero_optim = GeminiOptimizer(optimizer, model, initial_scale=32, clipping_norm=1.0)
model.train()
torch_model.train()

View File

@ -11,15 +11,32 @@ from colossalai.amp import convert_to_apex_amp
from colossalai.nn.optimizer import HybridAdam
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
from colossalai.utils.cuda import get_current_device
from colossalai.zero import ColoInitContext, ZeroDDP, ZeroOptimizer, post_process_colo_init_ctx, zero_model_wrapper
from colossalai.zero.gemini.chunk import ChunkManager, init_chunk_manager, search_chunk_configuration
from colossalai.zero.gemini.gemini_mgr import GeminiManager
from colossalai.zero import GeminiDDP, GeminiOptimizer
from colossalai.zero.gemini.chunk import search_chunk_configuration
from tests.components_to_test import run_fwd_bwd
from tests.components_to_test.registry import non_distributed_component_funcs
from tests.test_tensor.common_utils import debug_print, set_seed
from tests.test_tensor.common_utils import set_seed
PLACEMENT_CONFIGS = [
{
'placement_policy': 'static',
'shard_param_frac': 0.0
}, # zero2
{
'placement_policy': 'static',
'shard_param_frac': 1.0
}, # zero3
{
'placement_policy': 'static',
'shard_param_frac': 0.5
}, # zero3-half
{
'placement_policy': 'auto'
}
]
def check_param(model: ZeroDDP, torch_model: torch.nn.Module):
def check_param(model: GeminiDDP, torch_model: torch.nn.Module):
zero_dict = model.state_dict(only_rank_0=False)
torch_dict = torch_model.state_dict()
@ -32,35 +49,24 @@ def check_param(model: ZeroDDP, torch_model: torch.nn.Module):
assert_close(value, temp_zero_value, rtol=1e-3, atol=4e-3)
def multi_chunk_init(model: torch.nn.Module, placement_policy: str):
def multi_chunk_init(model: torch.nn.Module, placement_config: dict):
world_size = dist.get_world_size()
config_dict, *_ = search_chunk_configuration(model, search_range_m=1, search_interval=100)
config_dict[world_size]['chunk_size'] = 5000
config_dict[world_size]['keep_gathered'] = False
if placement_policy != 'cuda':
init_device = torch.device('cpu')
else:
init_device = None
chunk_manager = ChunkManager(config_dict, init_device=init_device)
gemini_manager = GeminiManager(placement_policy, chunk_manager)
model = ZeroDDP(model, gemini_manager, pin_memory=True)
model = GeminiDDP(model, config_dict, pin_memory=True, **placement_config)
return model
def single_chunk_init(model: torch.nn.Module, placement_policy: str):
gemini_config = dict(
device=get_current_device(),
placement_policy=placement_policy,
pin_memory=True,
)
model = zero_model_wrapper(model=model, zero_stage=3, gemini_config=gemini_config)
def single_chunk_init(model: torch.nn.Module, placement_config: dict):
model = GeminiDDP(model, chunk_init_device=get_current_device(), pin_memory=True, **placement_config)
return model
@parameterize('placement_policy', ['cuda', 'cpu', 'auto', 'const'])
@parameterize('placement_config', PLACEMENT_CONFIGS)
@parameterize('model_name', ['gpt2'])
@parameterize('model_init_func', [single_chunk_init, multi_chunk_init])
def exam_inference(placement_policy: str, model_name: str, model_init_func: Callable):
def exam_inference(placement_config: dict, model_name: str, model_init_func: Callable):
set_seed(19360226)
get_components_func = non_distributed_component_funcs.get_callable(model_name)
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
@ -70,17 +76,15 @@ def exam_inference(placement_policy: str, model_name: str, model_init_func: Call
torch_optim = torch.optim.Adam(torch_model.parameters(), lr=1e-3)
torch_model, torch_optim = convert_to_apex_amp(torch_model, torch_optim, amp_config)
torch_model = DDP(torch_model, device_ids=[dist.get_rank()])
init_dev = get_current_device()
with ColoInitContext(device=init_dev):
model = model_builder()
model = model_builder().to(init_dev)
for torch_p, p in zip(torch_model.parameters(), model.parameters()):
p.data.copy_(torch_p.data)
model = model_init_func(model, placement_policy)
model = model_init_func(model, placement_config)
optimizer = HybridAdam(model.parameters(), lr=1e-3)
zero_optim = ZeroOptimizer(optimizer, model, initial_scale=128)
zero_optim = GeminiOptimizer(optimizer, model, initial_scale=128)
model.eval()
torch_model.eval()
@ -95,7 +99,7 @@ def exam_inference(placement_policy: str, model_name: str, model_init_func: Call
torch_optim.zero_grad()
torch_loss = run_fwd_bwd(torch_model, input_ids, label, criterion, torch_optim)
loss = run_fwd_bwd(model, input_ids, label, criterion, zero_optim)
assert_close(torch_loss, loss)
assert_close(torch_loss, loss, rtol=1e-5, atol=1e-5)
zero_optim.step()
torch_optim.step()
check_param(model, torch_model)

View File

@ -9,12 +9,46 @@ from colossalai.amp import convert_to_apex_amp
from colossalai.nn.optimizer import HybridAdam
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
from colossalai.utils.cuda import get_current_device
from colossalai.zero import ColoInitContext, ZeroDDP, ZeroOptimizer, post_process_colo_init_ctx
from colossalai.zero.gemini.chunk import ChunkManager, init_chunk_manager, search_chunk_configuration
from colossalai.zero.gemini.gemini_mgr import GeminiManager
from colossalai.zero import GeminiDDP, GeminiOptimizer
from colossalai.zero.gemini.chunk import search_chunk_configuration
from tests.components_to_test import run_fwd_bwd
from tests.components_to_test.registry import non_distributed_component_funcs
from tests.test_tensor.common_utils import debug_print, set_seed
from tests.test_tensor.common_utils import set_seed
PLACEMENT_CONFIGS = [
{
'placement_policy': 'static',
'shard_param_frac': 0.0,
'offload_optim_frac': 0.0
}, # zero2
{
'placement_policy': 'static',
'shard_param_frac': 0.0,
'offload_optim_frac': 1.0
}, # zero2-offload
{
'placement_policy': 'static',
'shard_param_frac': 0.0,
'offload_optim_frac': 0.5
}, # zero2-offload-half
{
'placement_policy': 'static',
'shard_param_frac': 1.0
}, # zero3
{
'placement_policy': 'static',
'shard_param_frac': 0.5
}, # zero3-half
{
'placement_policy': 'static',
'shard_param_frac': 1.0,
'offload_optim_frac': 1.0,
'offload_param_frac': 1.0
}, # zero3-offload-all
{
'placement_policy': 'auto'
}
]
# this model is large enough to slice to chunks
TEST_MODELS = ['gpt2']
@ -29,7 +63,7 @@ BF16_IGNORED_KEYS = [
]
def check_param(model: ZeroDDP, torch_model: torch.nn.Module, dtype: torch.dtype):
def check_param(model: GeminiDDP, torch_model: torch.nn.Module, dtype: torch.dtype):
zero_dict = model.state_dict(only_rank_0=False, dtype=dtype)
torch_dict = torch_model.state_dict()
@ -51,10 +85,10 @@ def check_param(model: ZeroDDP, torch_model: torch.nn.Module, dtype: torch.dtype
msg=lambda s: s + f'\n{key}\n{temp_zero_value.dtype}')
@parameterize('placement_policy', ['cuda', 'cpu', 'auto', 'const'])
@parameterize('placement_config', PLACEMENT_CONFIGS)
@parameterize('model_name', TEST_MODELS)
@parameterize('mixed_precision', [torch.half, torch.bfloat16])
def exam_model_step(placement_policy, model_name: str, mixed_precision: torch.dtype):
def exam_model_step(placement_config, model_name: str, mixed_precision: torch.dtype):
set_seed(42)
get_components_func = non_distributed_component_funcs.get_callable(model_name)
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
@ -65,9 +99,7 @@ def exam_model_step(placement_policy, model_name: str, mixed_precision: torch.dt
torch_model, torch_optim = convert_to_apex_amp(torch_model, torch_optim, amp_config)
torch_model = DDP(torch_model, device_ids=[dist.get_rank()])
init_dev = get_current_device()
with ColoInitContext(device=init_dev):
model = model_builder()
model = model_builder().cuda()
for torch_p, p in zip(torch_model.parameters(), model.parameters()):
p.data.copy_(torch_p.data)
@ -76,16 +108,10 @@ def exam_model_step(placement_policy, model_name: str, mixed_precision: torch.dt
config_dict, *_ = search_chunk_configuration(model, search_range_m=1, search_interval=100)
config_dict[world_size]['chunk_size'] = 5000
config_dict[world_size]['keep_gathered'] = False
if placement_policy != 'cuda':
init_device = torch.device('cpu')
else:
init_device = None
chunk_manager = ChunkManager(config_dict, init_device=init_device)
gemini_manager = GeminiManager(placement_policy, chunk_manager)
model = ZeroDDP(model, gemini_manager, pin_memory=True, mixed_precision=mixed_precision)
model = GeminiDDP(model, config_dict, **placement_config, mixed_precision=mixed_precision)
optimizer = HybridAdam(model.parameters(), lr=1e-3)
zero_optim = ZeroOptimizer(optimizer, model, initial_scale=128)
zero_optim = GeminiOptimizer(optimizer, model, initial_scale=128)
model.eval()
torch_model.eval()
@ -109,10 +135,10 @@ def exam_model_step(placement_policy, model_name: str, mixed_precision: torch.dt
check_param(model, torch_model, mixed_precision)
@parameterize('placement_policy', ['cuda', 'cpu', 'auto', 'const'])
@parameterize('placement_config', PLACEMENT_CONFIGS)
@parameterize('model_name', EXAMPLE_MODELS)
@parameterize('mixed_precision', [torch.half, torch.bfloat16])
def exam_tiny_example(placement_policy, model_name: str, mixed_precision: torch.dtype):
def exam_tiny_example(placement_config, model_name: str, mixed_precision: torch.dtype):
set_seed(2008)
get_components_func = non_distributed_component_funcs.get_callable(model_name)
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
@ -123,18 +149,19 @@ def exam_tiny_example(placement_policy, model_name: str, mixed_precision: torch.
torch_model, torch_optim = convert_to_apex_amp(torch_model, torch_optim, amp_config)
torch_model = DDP(torch_model, device_ids=[dist.get_rank()])
init_dev = get_current_device()
with ColoInitContext(device=init_dev):
model = model_builder()
model = model_builder().cuda()
for torch_p, p in zip(torch_model.parameters(), model.parameters()):
p.data.copy_(torch_p.data)
chunk_manager = init_chunk_manager(model=model, init_device=get_current_device(), search_range_m=1)
gemini_manager = GeminiManager(placement_policy, chunk_manager)
model = ZeroDDP(model, gemini_manager, pin_memory=True, mixed_precision=mixed_precision)
model = GeminiDDP(model,
chunk_init_device=get_current_device(),
search_range_m=1,
pin_memory=True,
mixed_precision=mixed_precision,
**placement_config)
optimizer = HybridAdam(model.parameters(), lr=1e-3)
zero_optim = ZeroOptimizer(optimizer, model, initial_scale=2)
zero_optim = GeminiOptimizer(optimizer, model, initial_scale=2)
model.eval()
torch_model.eval()

View File

@ -1,15 +1,16 @@
from copy import deepcopy
import numpy as np
import pytest
import torch
from colossalai.testing import clear_cache_before_run
from colossalai.zero import ColoInitContext
from colossalai.zero.gemini.memory_tracer.runtime_mem_tracer import RuntimeMemTracer
from tests.components_to_test import run_fwd_bwd
from tests.components_to_test.registry import non_distributed_component_funcs
@pytest.mark.skip("this is not used")
@clear_cache_before_run()
def test_runtime_mem_tracer():
test_models = ['gpt2', 'bert', 'simple_net', 'repeated_computed_layers', 'nested_model', 'albert']
@ -18,8 +19,7 @@ def test_runtime_mem_tracer():
get_components_func = non_distributed_component_funcs.get_callable(model_name)
model_builder, train_dataloader, _, _, criterion = get_components_func()
with ColoInitContext(device='cpu'):
model = model_builder(checkpoint=False)
model = model_builder(checkpoint=False).cuda()
model_bk = deepcopy(model)
runtime_mem_tracer = RuntimeMemTracer(model)

View File

@ -2,33 +2,20 @@ import pytest
import torch
import colossalai
from colossalai.tensor import ComputePattern, ComputeSpec, ProcessGroup, ShardSpec
from colossalai.testing import rerun_if_address_is_in_use, spawn
from colossalai.utils import get_current_device
from colossalai.zero import ColoInitContext
from colossalai.zero.gemini.chunk import init_chunk_manager, search_chunk_configuration
from tests.components_to_test.registry import non_distributed_component_funcs
def init_1d_row_spec(model, pg: ProcessGroup):
tensor_spec = (ShardSpec([0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
for n, p in model.named_parameters():
if 'weight' in n and 'ln' not in n:
p.set_process_group(pg)
p.set_tensor_spec(*tensor_spec)
def exam_search_chunk_size():
world_size = torch.distributed.get_world_size()
pg_tp = ProcessGroup(tp_degree=world_size)
get_components_func = non_distributed_component_funcs.get_callable('gpt2')
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
# make sure torch_model and model has the same parameter values
with ColoInitContext(device=get_current_device()):
model = model_builder()
init_1d_row_spec(model, pg_tp)
model = model_builder()
config_dict, *_ = search_chunk_configuration(model,
search_range_m=1,
search_interval=16,
@ -37,57 +24,19 @@ def exam_search_chunk_size():
for key in config_dict:
chunk_size = config_dict[key]['chunk_size']
if world_size == 1:
if world_size == 1 or True:
assert chunk_size == 31616
else:
assert chunk_size == 1024
def exam_search_strict_ddp():
world_size = torch.distributed.get_world_size()
default_shard_pg = ProcessGroup(tp_degree=world_size)
default_shard_spec = ShardSpec([-1], [world_size])
get_components_func = non_distributed_component_funcs.get_callable('gpt2')
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
# get the chunk configuration over replicated models
with ColoInitContext(device=get_current_device()):
ddp_model = model_builder()
re_dict, re_total, re_wasted = search_chunk_configuration(ddp_model,
search_range_m=1,
search_interval=16,
min_chunk_size_m=0,
filter_exlarge_params=True,
strict_ddp_flag=False)
# get the chunk configuration over sharded ddp models
with ColoInitContext(device=get_current_device(), default_pg=default_shard_pg,
default_dist_spec=default_shard_spec):
sharded_ddp_model = model_builder()
sh_dict, sh_total, sh_wasted = search_chunk_configuration(sharded_ddp_model,
search_range_m=1,
search_interval=16,
min_chunk_size_m=0,
filter_exlarge_params=True,
strict_ddp_flag=True)
assert re_dict == sh_dict
for key in re_dict:
assert re_dict[key] == sh_dict[key]
assert re_total == sh_total
assert re_wasted == sh_wasted
def exam_chunk_manager():
world_size = torch.distributed.get_world_size()
default_shard_pg = ProcessGroup(tp_degree=world_size)
default_shard_spec = ShardSpec([-1], [world_size])
get_components_func = non_distributed_component_funcs.get_callable('gpt2')
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
with ColoInitContext(device=get_current_device(), default_pg=default_shard_pg,
default_dist_spec=default_shard_spec):
sharded_ddp_model = model_builder()
sharded_ddp_model = model_builder()
chunk_manager = init_chunk_manager(sharded_ddp_model,
get_current_device(),
hidden_dim=16,
@ -103,7 +52,6 @@ def exam_chunk_manager():
def run_dist(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
exam_search_chunk_size()
exam_search_strict_ddp()
exam_chunk_manager()

View File

@ -4,31 +4,46 @@ from torch.testing import assert_close
import colossalai
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
from colossalai.utils.cuda import get_current_device
from colossalai.zero import ColoInitContext, ZeroDDP
from colossalai.zero.gemini.chunk import ChunkManager, search_chunk_configuration
from colossalai.zero.gemini.gemini_mgr import GeminiManager
from colossalai.zero import GeminiDDP
from colossalai.zero.gemini.chunk import search_chunk_configuration
from tests.components_to_test.registry import non_distributed_component_funcs
from tests.test_tensor.common_utils import debug_print, set_seed
from tests.test_tensor.common_utils import set_seed
PLACEMENT_CONFIGS = [
{
'placement_policy': 'static',
'shard_param_frac': 0.0
}, # zero2
{
'placement_policy': 'static',
'shard_param_frac': 1.0
}, # zero3
{
'placement_policy': 'static',
'shard_param_frac': 0.5
}, # zero3-half
{
'placement_policy': 'auto'
}
]
def ignore_the_first_parameter(model: torch.nn.Module):
for name, param in model.named_parameters():
print(f"parameter `{name}` is set ignored")
ZeroDDP.set_params_to_ignore([param])
GeminiDDP.set_params_to_ignore([param])
return
@parameterize('placement_policy', ['cuda', 'cpu', 'auto'])
@parameterize('placement_config', PLACEMENT_CONFIGS)
@parameterize('keep_gathered', [True, False])
@parameterize('model_name', ['gpt2', 'bert'])
def exam_state_dict(placement_policy, keep_gathered, model_name: str):
def exam_state_dict(placement_config, keep_gathered, model_name: str):
set_seed(431)
get_components_func = non_distributed_component_funcs.get_callable(model_name)
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
with ColoInitContext(device=get_current_device()):
model = model_builder()
model = model_builder()
torch_model = model_builder()
for torch_p, p in zip(torch_model.parameters(), model.parameters()):
@ -38,9 +53,7 @@ def exam_state_dict(placement_policy, keep_gathered, model_name: str):
config_dict, *_ = search_chunk_configuration(model, search_range_m=1, search_interval=100)
config_dict[world_size]['chunk_size'] = 5000
config_dict[world_size]['keep_gathered'] = keep_gathered
chunk_manager = ChunkManager(config_dict)
gemini_manager = GeminiManager(placement_policy, chunk_manager)
model = ZeroDDP(model, gemini_manager, pin_memory=True)
model = GeminiDDP(model, config_dict, **placement_config, pin_memory=True)
model.train()
zero_dict = model.state_dict(only_rank_0=False)
@ -52,16 +65,15 @@ def exam_state_dict(placement_policy, keep_gathered, model_name: str):
assert_close(value, temp_zero_value, rtol=1e-3, atol=1e-5)
@parameterize('placement_policy', ['cuda', 'cpu', 'auto'])
@parameterize('placement_config', PLACEMENT_CONFIGS)
@parameterize('keep_gathered', [True, False])
@parameterize('model_name', ['gpt2', 'bert'])
def exam_load_state_dict(placement_policy, keep_gathered, model_name: str):
def exam_load_state_dict(placement_config, keep_gathered, model_name: str):
set_seed(431)
get_components_func = non_distributed_component_funcs.get_callable(model_name)
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
with ColoInitContext(device=get_current_device()):
model = model_builder()
model = model_builder()
set_seed(451)
torch_model = model_builder() # get a different model
@ -71,13 +83,7 @@ def exam_load_state_dict(placement_policy, keep_gathered, model_name: str):
config_dict[world_size]['chunk_size'] = 5000
config_dict[world_size]['keep_gathered'] = keep_gathered
if placement_policy != 'cuda':
init_device = torch.device('cpu')
else:
init_device = None
chunk_manager = ChunkManager(config_dict, init_device=init_device)
gemini_manager = GeminiManager(placement_policy, chunk_manager)
model = ZeroDDP(model, gemini_manager, pin_memory=True)
model = GeminiDDP(model, config_dict, **placement_config, pin_memory=True)
torch_dict = torch_model.state_dict()
model.load_state_dict(torch_dict, strict=False)
@ -89,11 +95,37 @@ def exam_load_state_dict(placement_policy, keep_gathered, model_name: str):
assert_close(value, temp_zero_value, rtol=1e-3, atol=1e-5)
@parameterize('placement_config', PLACEMENT_CONFIGS)
@parameterize('model_name', ['gpt2', 'bert'])
def exam_state_dict_shard(placement_config, model_name: str):
get_components_func = non_distributed_component_funcs.get_callable(model_name)
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
model = model_builder()
model_size = sum(p.numel() * p.element_size() for p in model.parameters()) / 1024**2
config_dict, *_ = search_chunk_configuration(model, search_range_m=1, search_interval=100)
model = GeminiDDP(model, config_dict, **placement_config)
model.train()
zero_dict = model.state_dict(only_rank_0=False)
accumulated_keys = set()
# ensure number of shards > 1
for shard, _ in model.state_dict_shard(max_shard_size=(model_size / 3), only_rank_0=False):
for key, value in shard.items():
assert key not in accumulated_keys, f"key `{key}` is duplicated."
accumulated_keys.add(key)
assert key in zero_dict, f"{key} not in ZeRO dictionary."
assert torch.equal(value, zero_dict[key]), f"{key} not equal."
def run_dist(rank, world_size, port):
config = {}
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
exam_state_dict()
exam_load_state_dict()
exam_state_dict_shard()
@pytest.mark.dist

View File

@ -1,56 +0,0 @@
import pytest
import torch
from torch.testing import assert_close
import colossalai
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
from colossalai.utils.cuda import get_current_device
from colossalai.zero import ColoInitContext, ZeroDDP
from colossalai.zero.gemini.chunk import ChunkManager, search_chunk_configuration
from colossalai.zero.gemini.gemini_mgr import GeminiManager
from tests.components_to_test.registry import non_distributed_component_funcs
@parameterize('placement_policy', ['cuda', 'cpu'])
@parameterize('model_name', ['gpt2', 'bert'])
def exam_state_dict(placement_policy, model_name: str):
get_components_func = non_distributed_component_funcs.get_callable(model_name)
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
with ColoInitContext(device=get_current_device()):
model = model_builder()
model_size = sum(p.numel() * p.element_size() for p in model.parameters()) / 1024**2
config_dict, *_ = search_chunk_configuration(model, search_range_m=1, search_interval=100)
chunk_manager = ChunkManager(config_dict)
gemini_manager = GeminiManager(placement_policy, chunk_manager)
model = ZeroDDP(model, gemini_manager)
model.train()
zero_dict = model.state_dict(only_rank_0=False)
accumulated_keys = set()
# ensure number of shards > 1
for shard, _ in model.state_dict_shard(max_shard_size=(model_size / 3), only_rank_0=False):
for key, value in shard.items():
assert key not in accumulated_keys, f"key `{key}` is duplicated."
accumulated_keys.add(key)
assert key in zero_dict, f"{key} not in ZeRO dictionary."
assert torch.equal(value, zero_dict[key]), f"{key} not equal."
def run_dist(rank, world_size, port):
config = {}
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
exam_state_dict()
@pytest.mark.dist
@pytest.mark.parametrize('world_size', [1, 4])
@rerun_if_address_is_in_use()
def test_zero_ddp_state_dict_shard(world_size):
spawn(run_dist, world_size)
if __name__ == '__main__':
test_zero_ddp_state_dict_shard(1)

View File

@ -5,42 +5,53 @@ import torch.distributed as dist
import colossalai
from colossalai.nn.optimizer import HybridAdam
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
from colossalai.utils.cuda import get_current_device
from colossalai.zero import ColoInitContext, ZeroDDP, ZeroOptimizer
from colossalai.zero.gemini.chunk import ChunkManager, search_chunk_configuration
from colossalai.zero.gemini.gemini_mgr import GeminiManager
from colossalai.zero import GeminiDDP, GeminiOptimizer
from colossalai.zero.gemini.chunk import search_chunk_configuration
from tests.components_to_test.registry import non_distributed_component_funcs
from tests.test_tensor.common_utils import debug_print, set_seed
from tests.test_tensor.common_utils import set_seed
PLACEMENT_CONFIGS = [
{
'placement_policy': 'static',
'shard_param_frac': 0.0,
'offload_optim_frac': 0.0
}, # zero2
{
'placement_policy': 'static',
'shard_param_frac': 0.0,
'offload_optim_frac': 1.0
}, # zero2-offload
{
'placement_policy': 'static',
'shard_param_frac': 0.0,
'offload_optim_frac': 0.5
}, # zero2-offload-half
{
'placement_policy': 'auto'
}
]
@parameterize('placement_policy', ['cuda', 'cpu', 'auto'])
@parameterize('placement_config', PLACEMENT_CONFIGS)
@parameterize('keep_gathered', [True, False])
def exam_zero_optim_state_dict(placement_policy, keep_gathered):
def exam_zero_optim_state_dict(placement_config, keep_gathered):
set_seed(431)
get_components_func = non_distributed_component_funcs.get_callable('gpt2')
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
with ColoInitContext(device=get_current_device()):
model = model_builder()
model = model_builder()
set_seed(451)
torch_model = model_builder() # get a different model
world_size = torch.distributed.get_world_size()
config_dict, *_ = search_chunk_configuration(model, search_range_m=1, search_interval=100)
config_dict[world_size]['chunk_size'] = 5000
config_dict[world_size]['keep_gathered'] = keep_gathered
if placement_policy != 'cuda':
init_device = torch.device('cpu')
else:
init_device = None
chunk_manager = ChunkManager(config_dict, init_device=init_device)
gemini_manager = GeminiManager(placement_policy, chunk_manager)
model = ZeroDDP(model, gemini_manager, pin_memory=True)
model = GeminiDDP(model, config_dict, **placement_config, pin_memory=True)
optimizer = HybridAdam(model.parameters())
optim = ZeroOptimizer(optimizer, model, initial_scale=32) # initialize the link between chunk16 and chunk32
optim = GeminiOptimizer(optimizer, model, initial_scale=32) # initialize the link between chunk16 and chunk32
set_seed(dist.get_rank() * 3 + 128)
model.train()

View File

@ -1,55 +0,0 @@
import pytest
import torch
import torch.distributed as dist
import torch.nn as nn
import colossalai
from colossalai.tensor import ProcessGroup
from colossalai.testing import spawn
from colossalai.utils import get_current_device
from colossalai.zero import ColoInitContext, LowLevelZeroOptimizer
class MlpModel(nn.Module):
def __init__(self):
super(MlpModel, self).__init__()
self.linear1 = nn.Linear(128, 256)
self.linear2 = nn.Linear(256, 512)
def forward(self, x):
x = self.linear1(x)
x = self.linear2(x)
return x
def exam_zero_init():
dp_2_tp_2_pg = ProcessGroup(dp_degree=2, tp_degree=2)
model1 = MlpModel().cuda()
with ColoInitContext(device=get_current_device(), default_pg=dp_2_tp_2_pg):
model2 = MlpModel()
optimizer1 = LowLevelZeroOptimizer(torch.optim.Adam(model1.parameters(), lr=1))
optimizer2 = LowLevelZeroOptimizer(torch.optim.Adam(model2.parameters(), lr=1))
assert optimizer1._local_rank == optimizer2._local_rank
assert optimizer1._world_size == optimizer2._world_size
mp_group1 = optimizer1.tp_pg
mp_group2 = optimizer2.tp_pg
assert dist.get_world_size(mp_group1) == dist.get_world_size(mp_group2)
assert dist.get_rank(mp_group1) == dist.get_rank(mp_group2)
def run_dist(rank, world_size, port):
config_dict = dict(parallel=dict(data=2, tensor=dict(size=2, mode='1d')))
colossalai.launch(config=config_dict, rank=rank, world_size=world_size, port=port, host='localhost')
exam_zero_init()
@pytest.mark.dist
def test_zero_init():
spawn(run_dist, 4)
if __name__ == '__main__':
test_zero_init()

View File

@ -85,6 +85,7 @@ def run_dist(rank, world_size, port):
exam_zero_with_tp()
@pytest.mark.skip('this will be rewritten by shardformer')
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_zero_with_tp():