diff --git a/colossalai/zero/init_ctx/init_context.py b/colossalai/zero/init_ctx/init_context.py index 52a166d89..ff65b3191 100644 --- a/colossalai/zero/init_ctx/init_context.py +++ b/colossalai/zero/init_ctx/init_context.py @@ -90,9 +90,11 @@ class ZeroContextConfig(object): Args: target_device (torch.device): The device where param data are after exiting the context. - replicated (bool, optional): Whether the param is replicated across data parallel group. + replicated (bool, optional): Whether the param is replicated across data parallel (DP) group. + We do not need to synchronize (reduce) the grads of the replicated params among DP group. Some parameters are not replicated, e.g. parameters in MOE experts. shard_param (bool, optional): Is param sharded after exiting the context. Defaults to False. + The process group among which tensors are sharded is assigned as an runtime arg. rm_torch_payload_on_the_fly (bool, optional): If set to `True`, remove tensor payload on `param.data` after module init finished. This will reduce memory usage when initializing model. But it's not suitable for all models, especially when there are `weight init` operations in `__init__`. @@ -110,6 +112,9 @@ class ZeroContextConfig(object): self.target_device = target_device self.is_replicated: bool = replicated self.shard_param: bool = shard_param + + if self.is_replicated is False: + assert self.shard_param is True, f"ZeroContextConfig shard_param must be False when is_replicated is False" self.rm_torch_payload_on_the_fly: bool = rm_torch_payload_on_the_fly @@ -117,8 +122,8 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses): """A context to initialize model. 1. Convert the model to fp16. - 2. The paramaters of the module are adapted to type ShardedParameter. - 3. Shard the param and grad according to flags. + 2. The paramaters of the module are adapted to type `ShardedParameter`. + 3. Shard the param and grad according to flag `shard_param`. Args: target_device (torch.device): The device where param data are after exiting the context. @@ -144,7 +149,8 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses): super().__init__() self.shard_strategy = shard_strategy - self.initialized_param_list = [] + # a list contains params that could be sharded. + self.shardable_param_list = [] self.model_numel_tensor = model_numel_tensor self.dp_process_group = dp_process_group or gpc.get_group(ParallelMode.DATA) @@ -181,21 +187,17 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses): """The callback function when exiting context. """ if not self.rm_torch_payload_on_the_fly: - for param in self.initialized_param_list: + for param in self.shardable_param_list: assert hasattr(param, 'colo_attr') param.colo_attr.remove_torch_payload() - del self.initialized_param_list + del self.shardable_param_list def _post_init_method(self, module: torch.nn.Module): """ The function to call at the end of the constructor of each module. NOTE() The module may be passed to this function multiple times. """ - - def half_fn(t: torch.Tensor): - return t.half() if t.is_floating_point() else t - for param in module.parameters(recurse=False): # avoid adapting a param to ShardedParam twice if hasattr(param, 'colo_attr'): @@ -207,10 +209,10 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses): param.is_replicated = self.is_replicated # convert parameters to half - param_half = half_fn(param) + param_half = cast_tensor_to_fp16(param.data) param.data = param_half if param.grad is not None: - grad_half = half_fn(param.grad) + grad_half = cast_tensor_to_fp16(param.grad) param.grad.data = grad_half # move torch parameters to the target device @@ -223,7 +225,7 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses): if self.shard_param: self.shard_strategy.shard([param.colo_attr.sharded_data_tensor], self.dp_process_group) - self.initialized_param_list.append(param) + self.shardable_param_list.append(param) # We must cast buffers # If we use BN, buffers may be on CPU and Float @@ -255,6 +257,16 @@ def no_shard_zero_context(is_replicated: bool = True) -> AbstractContextManager: def no_shard_zero_decrator(is_replicated: bool = True): + """ + A decorator used to wrap an __init__ function of Module. + The parameters initialized by the model will not sharded. + is_replicated indicates the grad of the param won't be reduced among the data parallel process group. + + >>> def MyModule(torch.nn.Module): + >>> @no_shard_zero_decrator(is_replicated = False) + >>> def __init__(self, ...) + >>> .... + """ def _wrapper(init_func):