|
|
|
@ -83,22 +83,26 @@ class InsertPostInitMethodToModuleSubClasses(object):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ZeroInitContext(InsertPostInitMethodToModuleSubClasses): |
|
|
|
|
r""" |
|
|
|
|
A context to initialize model. |
|
|
|
|
"""A context to initialize model. |
|
|
|
|
|
|
|
|
|
1. Convert the model to fp16. |
|
|
|
|
2. The paramaters of the module are adapted to type ShardedParameter. |
|
|
|
|
3. Shard the param and grad according to flags. |
|
|
|
|
|
|
|
|
|
target_device: the device where param data after exiting the context |
|
|
|
|
shard_strategy: shard strategy instance |
|
|
|
|
shard_param: is param sharded after exiting the context |
|
|
|
|
shard_grad: is param sharded after exiting the context |
|
|
|
|
|
|
|
|
|
rm_torch_payload_on_the_fly: |
|
|
|
|
True: remove tensor payload on param.data after module init finished. |
|
|
|
|
False: remove tensor payload on param.data afther the context exist. |
|
|
|
|
Args: |
|
|
|
|
convert_fp16 (bool): Whether to convert params to fp16. |
|
|
|
|
target_device (torch.device): The device where param data after exiting the context. |
|
|
|
|
shard_strategy (BaseShardStrategy): Shard strategy instance. |
|
|
|
|
shard_param (bool, optional): Is param sharded after exiting the context. Defaults to False. |
|
|
|
|
shard_grad (bool, optional): Is param sharded after exiting the context. Defaults to False. |
|
|
|
|
rm_torch_payload_on_the_fly (bool, optional): If set to `True`, remove tensor payload on `param.data` after module init finished. |
|
|
|
|
This will reduce memory usage when initializing model. |
|
|
|
|
But it's not suitable for all models, especially when there are `weight init` operations in `__init__`. |
|
|
|
|
If set to `False`, remove tensor payload on param.data afther the context exist. |
|
|
|
|
This is used when you add some logic to operate tensors in __init__ of module. |
|
|
|
|
See torchvision resnet18. |
|
|
|
|
See torchvision resnet18. Defaults to False. |
|
|
|
|
model_numel_tensor (torch.Tensor, optional): A tensor which will store the number of elements of model. Defaults to torch.zeros(1, dtype=torch.int). |
|
|
|
|
dp_process_group (Optional[ProcessGroup], optional): Data parallel process group. Defaults to None. |
|
|
|
|
""" |
|
|
|
|
|
|
|
|
|
def __init__(self, |
|
|
|
@ -110,14 +114,14 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
|
|
|
|
|
rm_torch_payload_on_the_fly: bool = False, |
|
|
|
|
model_numel_tensor: torch.Tensor = torch.zeros(1, dtype=torch.int), |
|
|
|
|
dp_process_group: Optional[ProcessGroup] = None): |
|
|
|
|
|
|
|
|
|
super().__init__() |
|
|
|
|
self.convert_fp16 = convert_fp16 |
|
|
|
|
self.target_device = target_device |
|
|
|
|
self.shard_param = shard_param |
|
|
|
|
self.shard_grad = shard_grad |
|
|
|
|
self.shard_strategy = shard_strategy |
|
|
|
|
# FIXME(jiaruifang) now setting it to True is invalid. |
|
|
|
|
self.rm_torch_payload_on_the_fly = False |
|
|
|
|
self.rm_torch_payload_on_the_fly = rm_torch_payload_on_the_fly |
|
|
|
|
self.initialized_param_list = [] |
|
|
|
|
self.model_numel_tensor = model_numel_tensor |
|
|
|
|
self.dp_process_group = dp_process_group or gpc.get_group(ParallelMode.DATA) |
|
|
|
|