[zero] Refactor ZeroContextConfig class using dataclass (#3186)

pull/3191/head
YH 2 years ago committed by GitHub
parent 9d644ff09f
commit 80aed29cd3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -1,46 +1,45 @@
import contextlib import contextlib
import functools import functools
from typing import Optional
from contextlib import AbstractContextManager from contextlib import AbstractContextManager
from dataclasses import dataclass
from typing import Optional
import torch import torch
import torch.nn as nn
import torch.distributed as dist import torch.distributed as dist
import torch.nn as nn
from colossalai.context.parallel_mode import ParallelMode from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.context.singleton_meta import SingletonMeta from colossalai.context.singleton_meta import SingletonMeta
from colossalai.core import global_context as gpc
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
from colossalai.utils.model.utils import InsertPostInitMethodToModuleSubClasses
from colossalai.zero.shard_utils import BaseShardStrategy from colossalai.zero.shard_utils import BaseShardStrategy
from colossalai.zero.sharded_model._utils import cast_tensor_to_fp16 from colossalai.zero.sharded_model._utils import cast_tensor_to_fp16
from colossalai.zero.sharded_model.sharded_model_v2 import ShardedModelV2 from colossalai.zero.sharded_model.sharded_model_v2 import ShardedModelV2
from colossalai.zero.sharded_param import ShardedParamV2 from colossalai.zero.sharded_param import ShardedParamV2
from colossalai.utils.model.utils import InsertPostInitMethodToModuleSubClasses
class ZeroContextConfig(object): @dataclass
class ZeroContextConfig:
"""The configuration used to control zero context initialization. """The configuration used to control zero context initialization.
Args: Args:
target_device (torch.device): The device where param data are after exiting the context. target_device (torch.device): The device where param data are after exiting the context.
replicated (bool, optional): Whether the param is replicated across data parallel group. is_replicated (bool, optional): Whether the param is replicated across data parallel group.
Some parameters are not replicated, e.g. parameters in MOE experts. Some parameters are not replicated, e.g. parameters in MOE experts.
shard_param (bool, optional): Is param sharded after exiting the context. Defaults to False. shard_param (bool, optional): Is param sharded after exiting the context. Defaults to False.
""" """
def __init__(self, target_device: torch.device, replicated: bool = True, shard_param: bool = False): target_device: torch.device
super().__init__() is_replicated: bool = True
shard_param: bool = False
if shard_param: def __post_init__(self):
assert replicated, "Non-replicated parameters can't be sharded." if self.shard_param:
assert self.is_replicated, "Non-replicated parameters can't be sharded."
# replicated no-shard parameters should locate in cuda, since we will broadcast them soon if self.is_replicated and not self.shard_param:
if replicated and not shard_param: assert self.target_device.type == 'cuda', "Replicated no-shard parameters should be located in cuda."
assert target_device.type == 'cuda', "Replicated no-shard paramters should locate in cuda."
self.target_device = target_device
self.is_replicated: bool = replicated
self.shard_param: bool = shard_param
class ZeroInitContext(InsertPostInitMethodToModuleSubClasses): class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
@ -74,7 +73,7 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
self.seed = seed self.seed = seed
self.dp_process_group = gpc.get_group(ParallelMode.DATA) self.dp_process_group = gpc.get_group(ParallelMode.DATA)
self.config = ZeroContextConfig(target_device=target_device, replicated=True, shard_param=shard_param) self.config = ZeroContextConfig(target_device=target_device, is_replicated=True, shard_param=shard_param)
ZeroContextMgr().current_context = self ZeroContextMgr().current_context = self
@ -124,7 +123,7 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
return fan_in, fan_out return fan_in, fan_out
def _pre_context_exec(self): def _pre_context_exec(self):
""" """
The Callback function when entering the context The Callback function when entering the context
""" """
self.logger = get_dist_logger("ZeroInitContext") self.logger = get_dist_logger("ZeroInitContext")
@ -248,7 +247,7 @@ class ZeroContextMgr(metaclass=SingletonMeta):
def no_shard_zero_context(is_replicated: bool = True) -> AbstractContextManager: def no_shard_zero_context(is_replicated: bool = True) -> AbstractContextManager:
return ZeroContextMgr().hijack_context_config(target_device=torch.device('cuda', torch.cuda.current_device()), return ZeroContextMgr().hijack_context_config(target_device=torch.device('cuda', torch.cuda.current_device()),
replicated=is_replicated, is_replicated=is_replicated,
shard_param=False) shard_param=False)

Loading…
Cancel
Save