mirror of https://github.com/hpcaitech/ColossalAI
[zero] zero init ctx enable rm_torch_payload_on_the_fly (#512)
* enable rm_torch_payload_on_the_fly * polish docstrpull/519/head
parent
81145208d1
commit
a2e61d61d4
|
@ -83,22 +83,26 @@ class InsertPostInitMethodToModuleSubClasses(object):
|
||||||
|
|
||||||
|
|
||||||
class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
|
class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
|
||||||
r"""
|
"""A context to initialize model.
|
||||||
A context to initialize model.
|
|
||||||
1. Convert the model to fp16.
|
1. Convert the model to fp16.
|
||||||
2. The paramaters of the module are adapted to type ShardedParameter.
|
2. The paramaters of the module are adapted to type ShardedParameter.
|
||||||
3. Shard the param and grad according to flags.
|
3. Shard the param and grad according to flags.
|
||||||
|
|
||||||
target_device: the device where param data after exiting the context
|
Args:
|
||||||
shard_strategy: shard strategy instance
|
convert_fp16 (bool): Whether to convert params to fp16.
|
||||||
shard_param: is param sharded after exiting the context
|
target_device (torch.device): The device where param data after exiting the context.
|
||||||
shard_grad: is param sharded after exiting the context
|
shard_strategy (BaseShardStrategy): Shard strategy instance.
|
||||||
|
shard_param (bool, optional): Is param sharded after exiting the context. Defaults to False.
|
||||||
rm_torch_payload_on_the_fly:
|
shard_grad (bool, optional): Is param sharded after exiting the context. Defaults to False.
|
||||||
True: remove tensor payload on param.data after module init finished.
|
rm_torch_payload_on_the_fly (bool, optional): If set to `True`, remove tensor payload on `param.data` after module init finished.
|
||||||
False: remove tensor payload on param.data afther the context exist.
|
This will reduce memory usage when initializing model.
|
||||||
|
But it's not suitable for all models, especially when there are `weight init` operations in `__init__`.
|
||||||
|
If set to `False`, remove tensor payload on param.data afther the context exist.
|
||||||
This is used when you add some logic to operate tensors in __init__ of module.
|
This is used when you add some logic to operate tensors in __init__ of module.
|
||||||
See torchvision resnet18.
|
See torchvision resnet18. Defaults to False.
|
||||||
|
model_numel_tensor (torch.Tensor, optional): A tensor which will store the number of elements of model. Defaults to torch.zeros(1, dtype=torch.int).
|
||||||
|
dp_process_group (Optional[ProcessGroup], optional): Data parallel process group. Defaults to None.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
|
@ -110,14 +114,14 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
|
||||||
rm_torch_payload_on_the_fly: bool = False,
|
rm_torch_payload_on_the_fly: bool = False,
|
||||||
model_numel_tensor: torch.Tensor = torch.zeros(1, dtype=torch.int),
|
model_numel_tensor: torch.Tensor = torch.zeros(1, dtype=torch.int),
|
||||||
dp_process_group: Optional[ProcessGroup] = None):
|
dp_process_group: Optional[ProcessGroup] = None):
|
||||||
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.convert_fp16 = convert_fp16
|
self.convert_fp16 = convert_fp16
|
||||||
self.target_device = target_device
|
self.target_device = target_device
|
||||||
self.shard_param = shard_param
|
self.shard_param = shard_param
|
||||||
self.shard_grad = shard_grad
|
self.shard_grad = shard_grad
|
||||||
self.shard_strategy = shard_strategy
|
self.shard_strategy = shard_strategy
|
||||||
# FIXME(jiaruifang) now setting it to True is invalid.
|
self.rm_torch_payload_on_the_fly = rm_torch_payload_on_the_fly
|
||||||
self.rm_torch_payload_on_the_fly = False
|
|
||||||
self.initialized_param_list = []
|
self.initialized_param_list = []
|
||||||
self.model_numel_tensor = model_numel_tensor
|
self.model_numel_tensor = model_numel_tensor
|
||||||
self.dp_process_group = dp_process_group or gpc.get_group(ParallelMode.DATA)
|
self.dp_process_group = dp_process_group or gpc.get_group(ParallelMode.DATA)
|
||||||
|
|
Loading…
Reference in New Issue