[zero] polish init context (#645)

pull/657/head
Jiarui Fang 2022-04-02 15:52:04 +08:00 committed by GitHub
parent f5d3a9c2b0
commit 67b4928244
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 25 additions and 13 deletions

View File

@ -90,9 +90,11 @@ class ZeroContextConfig(object):
Args: Args:
target_device (torch.device): The device where param data are after exiting the context. target_device (torch.device): The device where param data are after exiting the context.
replicated (bool, optional): Whether the param is replicated across data parallel group. replicated (bool, optional): Whether the param is replicated across data parallel (DP) group.
We do not need to synchronize (reduce) the grads of the replicated params among DP group.
Some parameters are not replicated, e.g. parameters in MOE experts. Some parameters are not replicated, e.g. parameters in MOE experts.
shard_param (bool, optional): Is param sharded after exiting the context. Defaults to False. shard_param (bool, optional): Is param sharded after exiting the context. Defaults to False.
The process group among which tensors are sharded is assigned as an runtime arg.
rm_torch_payload_on_the_fly (bool, optional): If set to `True`, remove tensor payload on `param.data` after module init finished. rm_torch_payload_on_the_fly (bool, optional): If set to `True`, remove tensor payload on `param.data` after module init finished.
This will reduce memory usage when initializing model. This will reduce memory usage when initializing model.
But it's not suitable for all models, especially when there are `weight init` operations in `__init__`. But it's not suitable for all models, especially when there are `weight init` operations in `__init__`.
@ -110,6 +112,9 @@ class ZeroContextConfig(object):
self.target_device = target_device self.target_device = target_device
self.is_replicated: bool = replicated self.is_replicated: bool = replicated
self.shard_param: bool = shard_param self.shard_param: bool = shard_param
if self.is_replicated is False:
assert self.shard_param is True, f"ZeroContextConfig shard_param must be False when is_replicated is False"
self.rm_torch_payload_on_the_fly: bool = rm_torch_payload_on_the_fly self.rm_torch_payload_on_the_fly: bool = rm_torch_payload_on_the_fly
@ -117,8 +122,8 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
"""A context to initialize model. """A context to initialize model.
1. Convert the model to fp16. 1. Convert the model to fp16.
2. The paramaters of the module are adapted to type ShardedParameter. 2. The paramaters of the module are adapted to type `ShardedParameter`.
3. Shard the param and grad according to flags. 3. Shard the param and grad according to flag `shard_param`.
Args: Args:
target_device (torch.device): The device where param data are after exiting the context. target_device (torch.device): The device where param data are after exiting the context.
@ -144,7 +149,8 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
super().__init__() super().__init__()
self.shard_strategy = shard_strategy self.shard_strategy = shard_strategy
self.initialized_param_list = [] # a list contains params that could be sharded.
self.shardable_param_list = []
self.model_numel_tensor = model_numel_tensor self.model_numel_tensor = model_numel_tensor
self.dp_process_group = dp_process_group or gpc.get_group(ParallelMode.DATA) self.dp_process_group = dp_process_group or gpc.get_group(ParallelMode.DATA)
@ -181,21 +187,17 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
"""The callback function when exiting context. """The callback function when exiting context.
""" """
if not self.rm_torch_payload_on_the_fly: if not self.rm_torch_payload_on_the_fly:
for param in self.initialized_param_list: for param in self.shardable_param_list:
assert hasattr(param, 'colo_attr') assert hasattr(param, 'colo_attr')
param.colo_attr.remove_torch_payload() param.colo_attr.remove_torch_payload()
del self.initialized_param_list del self.shardable_param_list
def _post_init_method(self, module: torch.nn.Module): def _post_init_method(self, module: torch.nn.Module):
""" """
The function to call at the end of the constructor of each module. The function to call at the end of the constructor of each module.
NOTE() The module may be passed to this function multiple times. NOTE() The module may be passed to this function multiple times.
""" """
def half_fn(t: torch.Tensor):
return t.half() if t.is_floating_point() else t
for param in module.parameters(recurse=False): for param in module.parameters(recurse=False):
# avoid adapting a param to ShardedParam twice # avoid adapting a param to ShardedParam twice
if hasattr(param, 'colo_attr'): if hasattr(param, 'colo_attr'):
@ -207,10 +209,10 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
param.is_replicated = self.is_replicated param.is_replicated = self.is_replicated
# convert parameters to half # convert parameters to half
param_half = half_fn(param) param_half = cast_tensor_to_fp16(param.data)
param.data = param_half param.data = param_half
if param.grad is not None: if param.grad is not None:
grad_half = half_fn(param.grad) grad_half = cast_tensor_to_fp16(param.grad)
param.grad.data = grad_half param.grad.data = grad_half
# move torch parameters to the target device # move torch parameters to the target device
@ -223,7 +225,7 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
if self.shard_param: if self.shard_param:
self.shard_strategy.shard([param.colo_attr.sharded_data_tensor], self.dp_process_group) self.shard_strategy.shard([param.colo_attr.sharded_data_tensor], self.dp_process_group)
self.initialized_param_list.append(param) self.shardable_param_list.append(param)
# We must cast buffers # We must cast buffers
# If we use BN, buffers may be on CPU and Float # If we use BN, buffers may be on CPU and Float
@ -255,6 +257,16 @@ def no_shard_zero_context(is_replicated: bool = True) -> AbstractContextManager:
def no_shard_zero_decrator(is_replicated: bool = True): def no_shard_zero_decrator(is_replicated: bool = True):
"""
A decorator used to wrap an __init__ function of Module.
The parameters initialized by the model will not sharded.
is_replicated indicates the grad of the param won't be reduced among the data parallel process group.
>>> def MyModule(torch.nn.Module):
>>> @no_shard_zero_decrator(is_replicated = False)
>>> def __init__(self, ...)
>>> ....
"""
def _wrapper(init_func): def _wrapper(init_func):