mirror of https://github.com/hpcaitech/ColossalAI
[zero] polish init context (#645)
parent
f5d3a9c2b0
commit
67b4928244
|
@ -90,9 +90,11 @@ class ZeroContextConfig(object):
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
target_device (torch.device): The device where param data are after exiting the context.
|
target_device (torch.device): The device where param data are after exiting the context.
|
||||||
replicated (bool, optional): Whether the param is replicated across data parallel group.
|
replicated (bool, optional): Whether the param is replicated across data parallel (DP) group.
|
||||||
|
We do not need to synchronize (reduce) the grads of the replicated params among DP group.
|
||||||
Some parameters are not replicated, e.g. parameters in MOE experts.
|
Some parameters are not replicated, e.g. parameters in MOE experts.
|
||||||
shard_param (bool, optional): Is param sharded after exiting the context. Defaults to False.
|
shard_param (bool, optional): Is param sharded after exiting the context. Defaults to False.
|
||||||
|
The process group among which tensors are sharded is assigned as an runtime arg.
|
||||||
rm_torch_payload_on_the_fly (bool, optional): If set to `True`, remove tensor payload on `param.data` after module init finished.
|
rm_torch_payload_on_the_fly (bool, optional): If set to `True`, remove tensor payload on `param.data` after module init finished.
|
||||||
This will reduce memory usage when initializing model.
|
This will reduce memory usage when initializing model.
|
||||||
But it's not suitable for all models, especially when there are `weight init` operations in `__init__`.
|
But it's not suitable for all models, especially when there are `weight init` operations in `__init__`.
|
||||||
|
@ -110,6 +112,9 @@ class ZeroContextConfig(object):
|
||||||
self.target_device = target_device
|
self.target_device = target_device
|
||||||
self.is_replicated: bool = replicated
|
self.is_replicated: bool = replicated
|
||||||
self.shard_param: bool = shard_param
|
self.shard_param: bool = shard_param
|
||||||
|
|
||||||
|
if self.is_replicated is False:
|
||||||
|
assert self.shard_param is True, f"ZeroContextConfig shard_param must be False when is_replicated is False"
|
||||||
self.rm_torch_payload_on_the_fly: bool = rm_torch_payload_on_the_fly
|
self.rm_torch_payload_on_the_fly: bool = rm_torch_payload_on_the_fly
|
||||||
|
|
||||||
|
|
||||||
|
@ -117,8 +122,8 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
|
||||||
"""A context to initialize model.
|
"""A context to initialize model.
|
||||||
|
|
||||||
1. Convert the model to fp16.
|
1. Convert the model to fp16.
|
||||||
2. The paramaters of the module are adapted to type ShardedParameter.
|
2. The paramaters of the module are adapted to type `ShardedParameter`.
|
||||||
3. Shard the param and grad according to flags.
|
3. Shard the param and grad according to flag `shard_param`.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
target_device (torch.device): The device where param data are after exiting the context.
|
target_device (torch.device): The device where param data are after exiting the context.
|
||||||
|
@ -144,7 +149,8 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
|
||||||
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.shard_strategy = shard_strategy
|
self.shard_strategy = shard_strategy
|
||||||
self.initialized_param_list = []
|
# a list contains params that could be sharded.
|
||||||
|
self.shardable_param_list = []
|
||||||
self.model_numel_tensor = model_numel_tensor
|
self.model_numel_tensor = model_numel_tensor
|
||||||
self.dp_process_group = dp_process_group or gpc.get_group(ParallelMode.DATA)
|
self.dp_process_group = dp_process_group or gpc.get_group(ParallelMode.DATA)
|
||||||
|
|
||||||
|
@ -181,21 +187,17 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
|
||||||
"""The callback function when exiting context.
|
"""The callback function when exiting context.
|
||||||
"""
|
"""
|
||||||
if not self.rm_torch_payload_on_the_fly:
|
if not self.rm_torch_payload_on_the_fly:
|
||||||
for param in self.initialized_param_list:
|
for param in self.shardable_param_list:
|
||||||
assert hasattr(param, 'colo_attr')
|
assert hasattr(param, 'colo_attr')
|
||||||
param.colo_attr.remove_torch_payload()
|
param.colo_attr.remove_torch_payload()
|
||||||
|
|
||||||
del self.initialized_param_list
|
del self.shardable_param_list
|
||||||
|
|
||||||
def _post_init_method(self, module: torch.nn.Module):
|
def _post_init_method(self, module: torch.nn.Module):
|
||||||
"""
|
"""
|
||||||
The function to call at the end of the constructor of each module.
|
The function to call at the end of the constructor of each module.
|
||||||
NOTE() The module may be passed to this function multiple times.
|
NOTE() The module may be passed to this function multiple times.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def half_fn(t: torch.Tensor):
|
|
||||||
return t.half() if t.is_floating_point() else t
|
|
||||||
|
|
||||||
for param in module.parameters(recurse=False):
|
for param in module.parameters(recurse=False):
|
||||||
# avoid adapting a param to ShardedParam twice
|
# avoid adapting a param to ShardedParam twice
|
||||||
if hasattr(param, 'colo_attr'):
|
if hasattr(param, 'colo_attr'):
|
||||||
|
@ -207,10 +209,10 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
|
||||||
param.is_replicated = self.is_replicated
|
param.is_replicated = self.is_replicated
|
||||||
|
|
||||||
# convert parameters to half
|
# convert parameters to half
|
||||||
param_half = half_fn(param)
|
param_half = cast_tensor_to_fp16(param.data)
|
||||||
param.data = param_half
|
param.data = param_half
|
||||||
if param.grad is not None:
|
if param.grad is not None:
|
||||||
grad_half = half_fn(param.grad)
|
grad_half = cast_tensor_to_fp16(param.grad)
|
||||||
param.grad.data = grad_half
|
param.grad.data = grad_half
|
||||||
|
|
||||||
# move torch parameters to the target device
|
# move torch parameters to the target device
|
||||||
|
@ -223,7 +225,7 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
|
||||||
|
|
||||||
if self.shard_param:
|
if self.shard_param:
|
||||||
self.shard_strategy.shard([param.colo_attr.sharded_data_tensor], self.dp_process_group)
|
self.shard_strategy.shard([param.colo_attr.sharded_data_tensor], self.dp_process_group)
|
||||||
self.initialized_param_list.append(param)
|
self.shardable_param_list.append(param)
|
||||||
|
|
||||||
# We must cast buffers
|
# We must cast buffers
|
||||||
# If we use BN, buffers may be on CPU and Float
|
# If we use BN, buffers may be on CPU and Float
|
||||||
|
@ -255,6 +257,16 @@ def no_shard_zero_context(is_replicated: bool = True) -> AbstractContextManager:
|
||||||
|
|
||||||
|
|
||||||
def no_shard_zero_decrator(is_replicated: bool = True):
|
def no_shard_zero_decrator(is_replicated: bool = True):
|
||||||
|
"""
|
||||||
|
A decorator used to wrap an __init__ function of Module.
|
||||||
|
The parameters initialized by the model will not sharded.
|
||||||
|
is_replicated indicates the grad of the param won't be reduced among the data parallel process group.
|
||||||
|
|
||||||
|
>>> def MyModule(torch.nn.Module):
|
||||||
|
>>> @no_shard_zero_decrator(is_replicated = False)
|
||||||
|
>>> def __init__(self, ...)
|
||||||
|
>>> ....
|
||||||
|
"""
|
||||||
|
|
||||||
def _wrapper(init_func):
|
def _wrapper(init_func):
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue