mirror of https://github.com/hpcaitech/ColossalAI
parent
b31daed4cf
commit
036404ca8a
|
@ -90,11 +90,9 @@ class ZeroContextConfig(object):
|
|||
|
||||
Args:
|
||||
target_device (torch.device): The device where param data are after exiting the context.
|
||||
replicated (bool, optional): Whether the param is replicated across data parallel (DP) group.
|
||||
We do not need to synchronize (reduce) the grads of the replicated params among DP group.
|
||||
replicated (bool, optional): Whether the param is replicated across data parallel group.
|
||||
Some parameters are not replicated, e.g. parameters in MOE experts.
|
||||
shard_param (bool, optional): Is param sharded after exiting the context. Defaults to False.
|
||||
The process group among which tensors are sharded is assigned as an runtime arg.
|
||||
rm_torch_payload_on_the_fly (bool, optional): If set to `True`, remove tensor payload on `param.data` after module init finished.
|
||||
This will reduce memory usage when initializing model.
|
||||
But it's not suitable for all models, especially when there are `weight init` operations in `__init__`.
|
||||
|
@ -112,9 +110,6 @@ class ZeroContextConfig(object):
|
|||
self.target_device = target_device
|
||||
self.is_replicated: bool = replicated
|
||||
self.shard_param: bool = shard_param
|
||||
|
||||
if self.is_replicated is False:
|
||||
assert self.shard_param is True, f"ZeroContextConfig shard_param must be False when is_replicated is False"
|
||||
self.rm_torch_payload_on_the_fly: bool = rm_torch_payload_on_the_fly
|
||||
|
||||
|
||||
|
@ -122,8 +117,8 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
|
|||
"""A context to initialize model.
|
||||
|
||||
1. Convert the model to fp16.
|
||||
2. The paramaters of the module are adapted to type `ShardedParameter`.
|
||||
3. Shard the param and grad according to flag `shard_param`.
|
||||
2. The paramaters of the module are adapted to type ShardedParameter.
|
||||
3. Shard the param and grad according to flags.
|
||||
|
||||
Args:
|
||||
target_device (torch.device): The device where param data are after exiting the context.
|
||||
|
@ -149,8 +144,7 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
|
|||
|
||||
super().__init__()
|
||||
self.shard_strategy = shard_strategy
|
||||
# a list contains params that could be sharded.
|
||||
self.shardable_param_list = []
|
||||
self.initialized_param_list = []
|
||||
self.model_numel_tensor = model_numel_tensor
|
||||
self.dp_process_group = dp_process_group or gpc.get_group(ParallelMode.DATA)
|
||||
|
||||
|
@ -187,17 +181,21 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
|
|||
"""The callback function when exiting context.
|
||||
"""
|
||||
if not self.rm_torch_payload_on_the_fly:
|
||||
for param in self.shardable_param_list:
|
||||
for param in self.initialized_param_list:
|
||||
assert hasattr(param, 'colo_attr')
|
||||
param.colo_attr.remove_torch_payload()
|
||||
|
||||
del self.shardable_param_list
|
||||
del self.initialized_param_list
|
||||
|
||||
def _post_init_method(self, module: torch.nn.Module):
|
||||
"""
|
||||
The function to call at the end of the constructor of each module.
|
||||
NOTE() The module may be passed to this function multiple times.
|
||||
"""
|
||||
|
||||
def half_fn(t: torch.Tensor):
|
||||
return t.half() if t.is_floating_point() else t
|
||||
|
||||
for param in module.parameters(recurse=False):
|
||||
# avoid adapting a param to ShardedParam twice
|
||||
if hasattr(param, 'colo_attr'):
|
||||
|
@ -209,10 +207,10 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
|
|||
param.is_replicated = self.is_replicated
|
||||
|
||||
# convert parameters to half
|
||||
param_half = cast_tensor_to_fp16(param.data)
|
||||
param_half = half_fn(param)
|
||||
param.data = param_half
|
||||
if param.grad is not None:
|
||||
grad_half = cast_tensor_to_fp16(param.grad)
|
||||
grad_half = half_fn(param.grad)
|
||||
param.grad.data = grad_half
|
||||
|
||||
# move torch parameters to the target device
|
||||
|
@ -225,7 +223,7 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
|
|||
|
||||
if self.shard_param:
|
||||
self.shard_strategy.shard([param.colo_attr.sharded_data_tensor], self.dp_process_group)
|
||||
self.shardable_param_list.append(param)
|
||||
self.initialized_param_list.append(param)
|
||||
|
||||
# We must cast buffers
|
||||
# If we use BN, buffers may be on CPU and Float
|
||||
|
@ -257,16 +255,6 @@ def no_shard_zero_context(is_replicated: bool = True) -> AbstractContextManager:
|
|||
|
||||
|
||||
def no_shard_zero_decrator(is_replicated: bool = True):
|
||||
"""
|
||||
A decorator used to wrap an __init__ function of Module.
|
||||
The parameters initialized by the model will not sharded.
|
||||
is_replicated indicates the grad of the param won't be reduced among the data parallel process group.
|
||||
|
||||
>>> def MyModule(torch.nn.Module):
|
||||
>>> @no_shard_zero_decrator(is_replicated = False)
|
||||
>>> def __init__(self, ...)
|
||||
>>> ....
|
||||
"""
|
||||
|
||||
def _wrapper(init_func):
|
||||
|
||||
|
|
Loading…
Reference in New Issue