|
|
|
@ -90,9 +90,11 @@ class ZeroContextConfig(object):
|
|
|
|
|
|
|
|
|
|
Args: |
|
|
|
|
target_device (torch.device): The device where param data are after exiting the context. |
|
|
|
|
replicated (bool, optional): Whether the param is replicated across data parallel group. |
|
|
|
|
replicated (bool, optional): Whether the param is replicated across data parallel (DP) group. |
|
|
|
|
We do not need to synchronize (reduce) the grads of the replicated params among DP group. |
|
|
|
|
Some parameters are not replicated, e.g. parameters in MOE experts. |
|
|
|
|
shard_param (bool, optional): Is param sharded after exiting the context. Defaults to False. |
|
|
|
|
The process group among which tensors are sharded is assigned as an runtime arg. |
|
|
|
|
rm_torch_payload_on_the_fly (bool, optional): If set to `True`, remove tensor payload on `param.data` after module init finished. |
|
|
|
|
This will reduce memory usage when initializing model. |
|
|
|
|
But it's not suitable for all models, especially when there are `weight init` operations in `__init__`. |
|
|
|
@ -110,6 +112,9 @@ class ZeroContextConfig(object):
|
|
|
|
|
self.target_device = target_device |
|
|
|
|
self.is_replicated: bool = replicated |
|
|
|
|
self.shard_param: bool = shard_param |
|
|
|
|
|
|
|
|
|
if self.is_replicated is False: |
|
|
|
|
assert self.shard_param is True, f"ZeroContextConfig shard_param must be False when is_replicated is False" |
|
|
|
|
self.rm_torch_payload_on_the_fly: bool = rm_torch_payload_on_the_fly |
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -117,8 +122,8 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
|
|
|
|
|
"""A context to initialize model. |
|
|
|
|
|
|
|
|
|
1. Convert the model to fp16. |
|
|
|
|
2. The paramaters of the module are adapted to type ShardedParameter. |
|
|
|
|
3. Shard the param and grad according to flags. |
|
|
|
|
2. The paramaters of the module are adapted to type `ShardedParameter`. |
|
|
|
|
3. Shard the param and grad according to flag `shard_param`. |
|
|
|
|
|
|
|
|
|
Args: |
|
|
|
|
target_device (torch.device): The device where param data are after exiting the context. |
|
|
|
@ -144,7 +149,8 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
|
|
|
|
|
|
|
|
|
|
super().__init__() |
|
|
|
|
self.shard_strategy = shard_strategy |
|
|
|
|
self.initialized_param_list = [] |
|
|
|
|
# a list contains params that could be sharded. |
|
|
|
|
self.shardable_param_list = [] |
|
|
|
|
self.model_numel_tensor = model_numel_tensor |
|
|
|
|
self.dp_process_group = dp_process_group or gpc.get_group(ParallelMode.DATA) |
|
|
|
|
|
|
|
|
@ -181,21 +187,17 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
|
|
|
|
|
"""The callback function when exiting context. |
|
|
|
|
""" |
|
|
|
|
if not self.rm_torch_payload_on_the_fly: |
|
|
|
|
for param in self.initialized_param_list: |
|
|
|
|
for param in self.shardable_param_list: |
|
|
|
|
assert hasattr(param, 'colo_attr') |
|
|
|
|
param.colo_attr.remove_torch_payload() |
|
|
|
|
|
|
|
|
|
del self.initialized_param_list |
|
|
|
|
del self.shardable_param_list |
|
|
|
|
|
|
|
|
|
def _post_init_method(self, module: torch.nn.Module): |
|
|
|
|
""" |
|
|
|
|
The function to call at the end of the constructor of each module. |
|
|
|
|
NOTE() The module may be passed to this function multiple times. |
|
|
|
|
""" |
|
|
|
|
|
|
|
|
|
def half_fn(t: torch.Tensor): |
|
|
|
|
return t.half() if t.is_floating_point() else t |
|
|
|
|
|
|
|
|
|
for param in module.parameters(recurse=False): |
|
|
|
|
# avoid adapting a param to ShardedParam twice |
|
|
|
|
if hasattr(param, 'colo_attr'): |
|
|
|
@ -207,10 +209,10 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
|
|
|
|
|
param.is_replicated = self.is_replicated |
|
|
|
|
|
|
|
|
|
# convert parameters to half |
|
|
|
|
param_half = half_fn(param) |
|
|
|
|
param_half = cast_tensor_to_fp16(param.data) |
|
|
|
|
param.data = param_half |
|
|
|
|
if param.grad is not None: |
|
|
|
|
grad_half = half_fn(param.grad) |
|
|
|
|
grad_half = cast_tensor_to_fp16(param.grad) |
|
|
|
|
param.grad.data = grad_half |
|
|
|
|
|
|
|
|
|
# move torch parameters to the target device |
|
|
|
@ -223,7 +225,7 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
|
|
|
|
|
|
|
|
|
|
if self.shard_param: |
|
|
|
|
self.shard_strategy.shard([param.colo_attr.sharded_data_tensor], self.dp_process_group) |
|
|
|
|
self.initialized_param_list.append(param) |
|
|
|
|
self.shardable_param_list.append(param) |
|
|
|
|
|
|
|
|
|
# We must cast buffers |
|
|
|
|
# If we use BN, buffers may be on CPU and Float |
|
|
|
@ -255,6 +257,16 @@ def no_shard_zero_context(is_replicated: bool = True) -> AbstractContextManager:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def no_shard_zero_decrator(is_replicated: bool = True): |
|
|
|
|
""" |
|
|
|
|
A decorator used to wrap an __init__ function of Module. |
|
|
|
|
The parameters initialized by the model will not sharded. |
|
|
|
|
is_replicated indicates the grad of the param won't be reduced among the data parallel process group. |
|
|
|
|
|
|
|
|
|
>>> def MyModule(torch.nn.Module): |
|
|
|
|
>>> @no_shard_zero_decrator(is_replicated = False) |
|
|
|
|
>>> def __init__(self, ...) |
|
|
|
|
>>> .... |
|
|
|
|
""" |
|
|
|
|
|
|
|
|
|
def _wrapper(init_func): |
|
|
|
|
|
|
|
|
|