[zero] Refactor ZeroContextConfig class using dataclass (#3186)

pull/3191/head
YH 2 years ago committed by GitHub
parent 9d644ff09f
commit 80aed29cd3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -1,46 +1,45 @@
import contextlib
import functools
from typing import Optional
from contextlib import AbstractContextManager
from dataclasses import dataclass
from typing import Optional
import torch
import torch.nn as nn
import torch.distributed as dist
import torch.nn as nn
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.context.singleton_meta import SingletonMeta
from colossalai.core import global_context as gpc
from colossalai.logging import get_dist_logger
from colossalai.utils.model.utils import InsertPostInitMethodToModuleSubClasses
from colossalai.zero.shard_utils import BaseShardStrategy
from colossalai.zero.sharded_model._utils import cast_tensor_to_fp16
from colossalai.zero.sharded_model.sharded_model_v2 import ShardedModelV2
from colossalai.zero.sharded_param import ShardedParamV2
from colossalai.utils.model.utils import InsertPostInitMethodToModuleSubClasses
class ZeroContextConfig(object):
@dataclass
class ZeroContextConfig:
"""The configuration used to control zero context initialization.
Args:
target_device (torch.device): The device where param data are after exiting the context.
replicated (bool, optional): Whether the param is replicated across data parallel group.
is_replicated (bool, optional): Whether the param is replicated across data parallel group.
Some parameters are not replicated, e.g. parameters in MOE experts.
shard_param (bool, optional): Is param sharded after exiting the context. Defaults to False.
"""
def __init__(self, target_device: torch.device, replicated: bool = True, shard_param: bool = False):
super().__init__()
target_device: torch.device
is_replicated: bool = True
shard_param: bool = False
if shard_param:
assert replicated, "Non-replicated parameters can't be sharded."
def __post_init__(self):
if self.shard_param:
assert self.is_replicated, "Non-replicated parameters can't be sharded."
# replicated no-shard parameters should locate in cuda, since we will broadcast them soon
if replicated and not shard_param:
assert target_device.type == 'cuda', "Replicated no-shard paramters should locate in cuda."
self.target_device = target_device
self.is_replicated: bool = replicated
self.shard_param: bool = shard_param
if self.is_replicated and not self.shard_param:
assert self.target_device.type == 'cuda', "Replicated no-shard parameters should be located in cuda."
class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
@ -74,7 +73,7 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
self.seed = seed
self.dp_process_group = gpc.get_group(ParallelMode.DATA)
self.config = ZeroContextConfig(target_device=target_device, replicated=True, shard_param=shard_param)
self.config = ZeroContextConfig(target_device=target_device, is_replicated=True, shard_param=shard_param)
ZeroContextMgr().current_context = self
@ -124,7 +123,7 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
return fan_in, fan_out
def _pre_context_exec(self):
"""
"""
The Callback function when entering the context
"""
self.logger = get_dist_logger("ZeroInitContext")
@ -248,7 +247,7 @@ class ZeroContextMgr(metaclass=SingletonMeta):
def no_shard_zero_context(is_replicated: bool = True) -> AbstractContextManager:
return ZeroContextMgr().hijack_context_config(target_device=torch.device('cuda', torch.cuda.current_device()),
replicated=is_replicated,
is_replicated=is_replicated,
shard_param=False)

Loading…
Cancel
Save