[zero] zero init ctx enable rm_torch_payload_on_the_fly (#512)

* enable rm_torch_payload_on_the_fly

* polish docstr
pull/519/head
ver217 2022-03-24 23:44:00 +08:00 committed by GitHub
parent 81145208d1
commit a2e61d61d4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 17 additions and 13 deletions

View File

@ -83,22 +83,26 @@ class InsertPostInitMethodToModuleSubClasses(object):
class ZeroInitContext(InsertPostInitMethodToModuleSubClasses): class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
r""" """A context to initialize model.
A context to initialize model.
1. Convert the model to fp16. 1. Convert the model to fp16.
2. The paramaters of the module are adapted to type ShardedParameter. 2. The paramaters of the module are adapted to type ShardedParameter.
3. Shard the param and grad according to flags. 3. Shard the param and grad according to flags.
target_device: the device where param data after exiting the context Args:
shard_strategy: shard strategy instance convert_fp16 (bool): Whether to convert params to fp16.
shard_param: is param sharded after exiting the context target_device (torch.device): The device where param data after exiting the context.
shard_grad: is param sharded after exiting the context shard_strategy (BaseShardStrategy): Shard strategy instance.
shard_param (bool, optional): Is param sharded after exiting the context. Defaults to False.
rm_torch_payload_on_the_fly: shard_grad (bool, optional): Is param sharded after exiting the context. Defaults to False.
True: remove tensor payload on param.data after module init finished. rm_torch_payload_on_the_fly (bool, optional): If set to `True`, remove tensor payload on `param.data` after module init finished.
False: remove tensor payload on param.data afther the context exist. This will reduce memory usage when initializing model.
But it's not suitable for all models, especially when there are `weight init` operations in `__init__`.
If set to `False`, remove tensor payload on param.data afther the context exist.
This is used when you add some logic to operate tensors in __init__ of module. This is used when you add some logic to operate tensors in __init__ of module.
See torchvision resnet18. See torchvision resnet18. Defaults to False.
model_numel_tensor (torch.Tensor, optional): A tensor which will store the number of elements of model. Defaults to torch.zeros(1, dtype=torch.int).
dp_process_group (Optional[ProcessGroup], optional): Data parallel process group. Defaults to None.
""" """
def __init__(self, def __init__(self,
@ -110,14 +114,14 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
rm_torch_payload_on_the_fly: bool = False, rm_torch_payload_on_the_fly: bool = False,
model_numel_tensor: torch.Tensor = torch.zeros(1, dtype=torch.int), model_numel_tensor: torch.Tensor = torch.zeros(1, dtype=torch.int),
dp_process_group: Optional[ProcessGroup] = None): dp_process_group: Optional[ProcessGroup] = None):
super().__init__() super().__init__()
self.convert_fp16 = convert_fp16 self.convert_fp16 = convert_fp16
self.target_device = target_device self.target_device = target_device
self.shard_param = shard_param self.shard_param = shard_param
self.shard_grad = shard_grad self.shard_grad = shard_grad
self.shard_strategy = shard_strategy self.shard_strategy = shard_strategy
# FIXME(jiaruifang) now setting it to True is invalid. self.rm_torch_payload_on_the_fly = rm_torch_payload_on_the_fly
self.rm_torch_payload_on_the_fly = False
self.initialized_param_list = [] self.initialized_param_list = []
self.model_numel_tensor = model_numel_tensor self.model_numel_tensor = model_numel_tensor
self.dp_process_group = dp_process_group or gpc.get_group(ParallelMode.DATA) self.dp_process_group = dp_process_group or gpc.get_group(ParallelMode.DATA)