diff --git a/colossalai/zero/init_ctx/init_context.py b/colossalai/zero/init_ctx/init_context.py index 572ddd9e4..b40b69962 100644 --- a/colossalai/zero/init_ctx/init_context.py +++ b/colossalai/zero/init_ctx/init_context.py @@ -1,46 +1,45 @@ import contextlib import functools -from typing import Optional from contextlib import AbstractContextManager +from dataclasses import dataclass +from typing import Optional import torch -import torch.nn as nn import torch.distributed as dist +import torch.nn as nn from colossalai.context.parallel_mode import ParallelMode -from colossalai.core import global_context as gpc from colossalai.context.singleton_meta import SingletonMeta +from colossalai.core import global_context as gpc from colossalai.logging import get_dist_logger +from colossalai.utils.model.utils import InsertPostInitMethodToModuleSubClasses from colossalai.zero.shard_utils import BaseShardStrategy from colossalai.zero.sharded_model._utils import cast_tensor_to_fp16 from colossalai.zero.sharded_model.sharded_model_v2 import ShardedModelV2 from colossalai.zero.sharded_param import ShardedParamV2 -from colossalai.utils.model.utils import InsertPostInitMethodToModuleSubClasses -class ZeroContextConfig(object): +@dataclass +class ZeroContextConfig: """The configuration used to control zero context initialization. Args: target_device (torch.device): The device where param data are after exiting the context. - replicated (bool, optional): Whether the param is replicated across data parallel group. + is_replicated (bool, optional): Whether the param is replicated across data parallel group. Some parameters are not replicated, e.g. parameters in MOE experts. shard_param (bool, optional): Is param sharded after exiting the context. Defaults to False. """ - def __init__(self, target_device: torch.device, replicated: bool = True, shard_param: bool = False): - super().__init__() + target_device: torch.device + is_replicated: bool = True + shard_param: bool = False - if shard_param: - assert replicated, "Non-replicated parameters can't be sharded." + def __post_init__(self): + if self.shard_param: + assert self.is_replicated, "Non-replicated parameters can't be sharded." - # replicated no-shard parameters should locate in cuda, since we will broadcast them soon - if replicated and not shard_param: - assert target_device.type == 'cuda', "Replicated no-shard paramters should locate in cuda." - - self.target_device = target_device - self.is_replicated: bool = replicated - self.shard_param: bool = shard_param + if self.is_replicated and not self.shard_param: + assert self.target_device.type == 'cuda', "Replicated no-shard parameters should be located in cuda." class ZeroInitContext(InsertPostInitMethodToModuleSubClasses): @@ -74,7 +73,7 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses): self.seed = seed self.dp_process_group = gpc.get_group(ParallelMode.DATA) - self.config = ZeroContextConfig(target_device=target_device, replicated=True, shard_param=shard_param) + self.config = ZeroContextConfig(target_device=target_device, is_replicated=True, shard_param=shard_param) ZeroContextMgr().current_context = self @@ -124,7 +123,7 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses): return fan_in, fan_out def _pre_context_exec(self): - """ + """ The Callback function when entering the context """ self.logger = get_dist_logger("ZeroInitContext") @@ -248,7 +247,7 @@ class ZeroContextMgr(metaclass=SingletonMeta): def no_shard_zero_context(is_replicated: bool = True) -> AbstractContextManager: return ZeroContextMgr().hijack_context_config(target_device=torch.device('cuda', torch.cuda.current_device()), - replicated=is_replicated, + is_replicated=is_replicated, shard_param=False)