Browse Source

[zero] polish init context (#645)

pull/657/head
Jiarui Fang 3 years ago committed by GitHub
parent
commit
67b4928244
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 38
      colossalai/zero/init_ctx/init_context.py

38
colossalai/zero/init_ctx/init_context.py

@ -90,9 +90,11 @@ class ZeroContextConfig(object):
Args:
target_device (torch.device): The device where param data are after exiting the context.
replicated (bool, optional): Whether the param is replicated across data parallel group.
replicated (bool, optional): Whether the param is replicated across data parallel (DP) group.
We do not need to synchronize (reduce) the grads of the replicated params among DP group.
Some parameters are not replicated, e.g. parameters in MOE experts.
shard_param (bool, optional): Is param sharded after exiting the context. Defaults to False.
The process group among which tensors are sharded is assigned as an runtime arg.
rm_torch_payload_on_the_fly (bool, optional): If set to `True`, remove tensor payload on `param.data` after module init finished.
This will reduce memory usage when initializing model.
But it's not suitable for all models, especially when there are `weight init` operations in `__init__`.
@ -110,6 +112,9 @@ class ZeroContextConfig(object):
self.target_device = target_device
self.is_replicated: bool = replicated
self.shard_param: bool = shard_param
if self.is_replicated is False:
assert self.shard_param is True, f"ZeroContextConfig shard_param must be False when is_replicated is False"
self.rm_torch_payload_on_the_fly: bool = rm_torch_payload_on_the_fly
@ -117,8 +122,8 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
"""A context to initialize model.
1. Convert the model to fp16.
2. The paramaters of the module are adapted to type ShardedParameter.
3. Shard the param and grad according to flags.
2. The paramaters of the module are adapted to type `ShardedParameter`.
3. Shard the param and grad according to flag `shard_param`.
Args:
target_device (torch.device): The device where param data are after exiting the context.
@ -144,7 +149,8 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
super().__init__()
self.shard_strategy = shard_strategy
self.initialized_param_list = []
# a list contains params that could be sharded.
self.shardable_param_list = []
self.model_numel_tensor = model_numel_tensor
self.dp_process_group = dp_process_group or gpc.get_group(ParallelMode.DATA)
@ -181,21 +187,17 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
"""The callback function when exiting context.
"""
if not self.rm_torch_payload_on_the_fly:
for param in self.initialized_param_list:
for param in self.shardable_param_list:
assert hasattr(param, 'colo_attr')
param.colo_attr.remove_torch_payload()
del self.initialized_param_list
del self.shardable_param_list
def _post_init_method(self, module: torch.nn.Module):
"""
The function to call at the end of the constructor of each module.
NOTE() The module may be passed to this function multiple times.
"""
def half_fn(t: torch.Tensor):
return t.half() if t.is_floating_point() else t
for param in module.parameters(recurse=False):
# avoid adapting a param to ShardedParam twice
if hasattr(param, 'colo_attr'):
@ -207,10 +209,10 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
param.is_replicated = self.is_replicated
# convert parameters to half
param_half = half_fn(param)
param_half = cast_tensor_to_fp16(param.data)
param.data = param_half
if param.grad is not None:
grad_half = half_fn(param.grad)
grad_half = cast_tensor_to_fp16(param.grad)
param.grad.data = grad_half
# move torch parameters to the target device
@ -223,7 +225,7 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
if self.shard_param:
self.shard_strategy.shard([param.colo_attr.sharded_data_tensor], self.dp_process_group)
self.initialized_param_list.append(param)
self.shardable_param_list.append(param)
# We must cast buffers
# If we use BN, buffers may be on CPU and Float
@ -255,6 +257,16 @@ def no_shard_zero_context(is_replicated: bool = True) -> AbstractContextManager:
def no_shard_zero_decrator(is_replicated: bool = True):
"""
A decorator used to wrap an __init__ function of Module.
The parameters initialized by the model will not sharded.
is_replicated indicates the grad of the param won't be reduced among the data parallel process group.
>>> def MyModule(torch.nn.Module):
>>> @no_shard_zero_decrator(is_replicated = False)
>>> def __init__(self, ...)
>>> ....
"""
def _wrapper(init_func):

Loading…
Cancel
Save