Revert "[zero] polish init context (#645)" (#657)

pull/655/head^2
Jiarui Fang 2022-04-02 18:30:06 +08:00 committed by GitHub
parent b31daed4cf
commit 036404ca8a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 13 additions and 25 deletions

View File

@ -90,11 +90,9 @@ class ZeroContextConfig(object):
Args:
target_device (torch.device): The device where param data are after exiting the context.
replicated (bool, optional): Whether the param is replicated across data parallel (DP) group.
We do not need to synchronize (reduce) the grads of the replicated params among DP group.
replicated (bool, optional): Whether the param is replicated across data parallel group.
Some parameters are not replicated, e.g. parameters in MOE experts.
shard_param (bool, optional): Is param sharded after exiting the context. Defaults to False.
The process group among which tensors are sharded is assigned as an runtime arg.
rm_torch_payload_on_the_fly (bool, optional): If set to `True`, remove tensor payload on `param.data` after module init finished.
This will reduce memory usage when initializing model.
But it's not suitable for all models, especially when there are `weight init` operations in `__init__`.
@ -112,9 +110,6 @@ class ZeroContextConfig(object):
self.target_device = target_device
self.is_replicated: bool = replicated
self.shard_param: bool = shard_param
if self.is_replicated is False:
assert self.shard_param is True, f"ZeroContextConfig shard_param must be False when is_replicated is False"
self.rm_torch_payload_on_the_fly: bool = rm_torch_payload_on_the_fly
@ -122,8 +117,8 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
"""A context to initialize model.
1. Convert the model to fp16.
2. The paramaters of the module are adapted to type `ShardedParameter`.
3. Shard the param and grad according to flag `shard_param`.
2. The paramaters of the module are adapted to type ShardedParameter.
3. Shard the param and grad according to flags.
Args:
target_device (torch.device): The device where param data are after exiting the context.
@ -149,8 +144,7 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
super().__init__()
self.shard_strategy = shard_strategy
# a list contains params that could be sharded.
self.shardable_param_list = []
self.initialized_param_list = []
self.model_numel_tensor = model_numel_tensor
self.dp_process_group = dp_process_group or gpc.get_group(ParallelMode.DATA)
@ -187,17 +181,21 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
"""The callback function when exiting context.
"""
if not self.rm_torch_payload_on_the_fly:
for param in self.shardable_param_list:
for param in self.initialized_param_list:
assert hasattr(param, 'colo_attr')
param.colo_attr.remove_torch_payload()
del self.shardable_param_list
del self.initialized_param_list
def _post_init_method(self, module: torch.nn.Module):
"""
The function to call at the end of the constructor of each module.
NOTE() The module may be passed to this function multiple times.
"""
def half_fn(t: torch.Tensor):
return t.half() if t.is_floating_point() else t
for param in module.parameters(recurse=False):
# avoid adapting a param to ShardedParam twice
if hasattr(param, 'colo_attr'):
@ -209,10 +207,10 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
param.is_replicated = self.is_replicated
# convert parameters to half
param_half = cast_tensor_to_fp16(param.data)
param_half = half_fn(param)
param.data = param_half
if param.grad is not None:
grad_half = cast_tensor_to_fp16(param.grad)
grad_half = half_fn(param.grad)
param.grad.data = grad_half
# move torch parameters to the target device
@ -225,7 +223,7 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
if self.shard_param:
self.shard_strategy.shard([param.colo_attr.sharded_data_tensor], self.dp_process_group)
self.shardable_param_list.append(param)
self.initialized_param_list.append(param)
# We must cast buffers
# If we use BN, buffers may be on CPU and Float
@ -257,16 +255,6 @@ def no_shard_zero_context(is_replicated: bool = True) -> AbstractContextManager:
def no_shard_zero_decrator(is_replicated: bool = True):
"""
A decorator used to wrap an __init__ function of Module.
The parameters initialized by the model will not sharded.
is_replicated indicates the grad of the param won't be reduced among the data parallel process group.
>>> def MyModule(torch.nn.Module):
>>> @no_shard_zero_decrator(is_replicated = False)
>>> def __init__(self, ...)
>>> ....
"""
def _wrapper(init_func):