mirror of https://github.com/hpcaitech/ColossalAI
[zero] zero init ctx enable rm_torch_payload_on_the_fly (#512)
* enable rm_torch_payload_on_the_fly * polish docstrpull/519/head
parent
81145208d1
commit
a2e61d61d4
|
@ -83,22 +83,26 @@ class InsertPostInitMethodToModuleSubClasses(object):
|
|||
|
||||
|
||||
class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
|
||||
r"""
|
||||
A context to initialize model.
|
||||
"""A context to initialize model.
|
||||
|
||||
1. Convert the model to fp16.
|
||||
2. The paramaters of the module are adapted to type ShardedParameter.
|
||||
3. Shard the param and grad according to flags.
|
||||
|
||||
target_device: the device where param data after exiting the context
|
||||
shard_strategy: shard strategy instance
|
||||
shard_param: is param sharded after exiting the context
|
||||
shard_grad: is param sharded after exiting the context
|
||||
|
||||
rm_torch_payload_on_the_fly:
|
||||
True: remove tensor payload on param.data after module init finished.
|
||||
False: remove tensor payload on param.data afther the context exist.
|
||||
Args:
|
||||
convert_fp16 (bool): Whether to convert params to fp16.
|
||||
target_device (torch.device): The device where param data after exiting the context.
|
||||
shard_strategy (BaseShardStrategy): Shard strategy instance.
|
||||
shard_param (bool, optional): Is param sharded after exiting the context. Defaults to False.
|
||||
shard_grad (bool, optional): Is param sharded after exiting the context. Defaults to False.
|
||||
rm_torch_payload_on_the_fly (bool, optional): If set to `True`, remove tensor payload on `param.data` after module init finished.
|
||||
This will reduce memory usage when initializing model.
|
||||
But it's not suitable for all models, especially when there are `weight init` operations in `__init__`.
|
||||
If set to `False`, remove tensor payload on param.data afther the context exist.
|
||||
This is used when you add some logic to operate tensors in __init__ of module.
|
||||
See torchvision resnet18.
|
||||
See torchvision resnet18. Defaults to False.
|
||||
model_numel_tensor (torch.Tensor, optional): A tensor which will store the number of elements of model. Defaults to torch.zeros(1, dtype=torch.int).
|
||||
dp_process_group (Optional[ProcessGroup], optional): Data parallel process group. Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
|
@ -110,14 +114,14 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
|
|||
rm_torch_payload_on_the_fly: bool = False,
|
||||
model_numel_tensor: torch.Tensor = torch.zeros(1, dtype=torch.int),
|
||||
dp_process_group: Optional[ProcessGroup] = None):
|
||||
|
||||
super().__init__()
|
||||
self.convert_fp16 = convert_fp16
|
||||
self.target_device = target_device
|
||||
self.shard_param = shard_param
|
||||
self.shard_grad = shard_grad
|
||||
self.shard_strategy = shard_strategy
|
||||
# FIXME(jiaruifang) now setting it to True is invalid.
|
||||
self.rm_torch_payload_on_the_fly = False
|
||||
self.rm_torch_payload_on_the_fly = rm_torch_payload_on_the_fly
|
||||
self.initialized_param_list = []
|
||||
self.model_numel_tensor = model_numel_tensor
|
||||
self.dp_process_group = dp_process_group or gpc.get_group(ParallelMode.DATA)
|
||||
|
|
Loading…
Reference in New Issue