Browse Source

[zero] zero init ctx enable rm_torch_payload_on_the_fly (#512)

* enable rm_torch_payload_on_the_fly

* polish docstr
pull/519/head
ver217 3 years ago committed by GitHub
parent
commit
a2e61d61d4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 30
      colossalai/zero/init_ctx/init_context.py

30
colossalai/zero/init_ctx/init_context.py

@ -83,22 +83,26 @@ class InsertPostInitMethodToModuleSubClasses(object):
class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
r"""
A context to initialize model.
"""A context to initialize model.
1. Convert the model to fp16.
2. The paramaters of the module are adapted to type ShardedParameter.
3. Shard the param and grad according to flags.
target_device: the device where param data after exiting the context
shard_strategy: shard strategy instance
shard_param: is param sharded after exiting the context
shard_grad: is param sharded after exiting the context
rm_torch_payload_on_the_fly:
True: remove tensor payload on param.data after module init finished.
False: remove tensor payload on param.data afther the context exist.
Args:
convert_fp16 (bool): Whether to convert params to fp16.
target_device (torch.device): The device where param data after exiting the context.
shard_strategy (BaseShardStrategy): Shard strategy instance.
shard_param (bool, optional): Is param sharded after exiting the context. Defaults to False.
shard_grad (bool, optional): Is param sharded after exiting the context. Defaults to False.
rm_torch_payload_on_the_fly (bool, optional): If set to `True`, remove tensor payload on `param.data` after module init finished.
This will reduce memory usage when initializing model.
But it's not suitable for all models, especially when there are `weight init` operations in `__init__`.
If set to `False`, remove tensor payload on param.data afther the context exist.
This is used when you add some logic to operate tensors in __init__ of module.
See torchvision resnet18.
See torchvision resnet18. Defaults to False.
model_numel_tensor (torch.Tensor, optional): A tensor which will store the number of elements of model. Defaults to torch.zeros(1, dtype=torch.int).
dp_process_group (Optional[ProcessGroup], optional): Data parallel process group. Defaults to None.
"""
def __init__(self,
@ -110,14 +114,14 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
rm_torch_payload_on_the_fly: bool = False,
model_numel_tensor: torch.Tensor = torch.zeros(1, dtype=torch.int),
dp_process_group: Optional[ProcessGroup] = None):
super().__init__()
self.convert_fp16 = convert_fp16
self.target_device = target_device
self.shard_param = shard_param
self.shard_grad = shard_grad
self.shard_strategy = shard_strategy
# FIXME(jiaruifang) now setting it to True is invalid.
self.rm_torch_payload_on_the_fly = False
self.rm_torch_payload_on_the_fly = rm_torch_payload_on_the_fly
self.initialized_param_list = []
self.model_numel_tensor = model_numel_tensor
self.dp_process_group = dp_process_group or gpc.get_group(ParallelMode.DATA)

Loading…
Cancel
Save