diff --git a/colossalai/zero/init_ctx/init_context.py b/colossalai/zero/init_ctx/init_context.py index 53889db6d..321548f98 100644 --- a/colossalai/zero/init_ctx/init_context.py +++ b/colossalai/zero/init_ctx/init_context.py @@ -83,22 +83,26 @@ class InsertPostInitMethodToModuleSubClasses(object): class ZeroInitContext(InsertPostInitMethodToModuleSubClasses): - r""" - A context to initialize model. + """A context to initialize model. + 1. Convert the model to fp16. 2. The paramaters of the module are adapted to type ShardedParameter. 3. Shard the param and grad according to flags. - target_device: the device where param data after exiting the context - shard_strategy: shard strategy instance - shard_param: is param sharded after exiting the context - shard_grad: is param sharded after exiting the context - - rm_torch_payload_on_the_fly: - True: remove tensor payload on param.data after module init finished. - False: remove tensor payload on param.data afther the context exist. + Args: + convert_fp16 (bool): Whether to convert params to fp16. + target_device (torch.device): The device where param data after exiting the context. + shard_strategy (BaseShardStrategy): Shard strategy instance. + shard_param (bool, optional): Is param sharded after exiting the context. Defaults to False. + shard_grad (bool, optional): Is param sharded after exiting the context. Defaults to False. + rm_torch_payload_on_the_fly (bool, optional): If set to `True`, remove tensor payload on `param.data` after module init finished. + This will reduce memory usage when initializing model. + But it's not suitable for all models, especially when there are `weight init` operations in `__init__`. + If set to `False`, remove tensor payload on param.data afther the context exist. This is used when you add some logic to operate tensors in __init__ of module. - See torchvision resnet18. + See torchvision resnet18. Defaults to False. + model_numel_tensor (torch.Tensor, optional): A tensor which will store the number of elements of model. Defaults to torch.zeros(1, dtype=torch.int). + dp_process_group (Optional[ProcessGroup], optional): Data parallel process group. Defaults to None. """ def __init__(self, @@ -110,14 +114,14 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses): rm_torch_payload_on_the_fly: bool = False, model_numel_tensor: torch.Tensor = torch.zeros(1, dtype=torch.int), dp_process_group: Optional[ProcessGroup] = None): + super().__init__() self.convert_fp16 = convert_fp16 self.target_device = target_device self.shard_param = shard_param self.shard_grad = shard_grad self.shard_strategy = shard_strategy - # FIXME(jiaruifang) now setting it to True is invalid. - self.rm_torch_payload_on_the_fly = False + self.rm_torch_payload_on_the_fly = rm_torch_payload_on_the_fly self.initialized_param_list = [] self.model_numel_tensor = model_numel_tensor self.dp_process_group = dp_process_group or gpc.get_group(ParallelMode.DATA)