mirror of https://github.com/hpcaitech/ColossalAI
[zero] Refactor ZeroContextConfig class using dataclass (#3186)
parent
9d644ff09f
commit
80aed29cd3
|
@ -1,46 +1,45 @@
|
|||
import contextlib
|
||||
import functools
|
||||
from typing import Optional
|
||||
from contextlib import AbstractContextManager
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
|
||||
from colossalai.context.parallel_mode import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.context.singleton_meta import SingletonMeta
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.utils.model.utils import InsertPostInitMethodToModuleSubClasses
|
||||
from colossalai.zero.shard_utils import BaseShardStrategy
|
||||
from colossalai.zero.sharded_model._utils import cast_tensor_to_fp16
|
||||
from colossalai.zero.sharded_model.sharded_model_v2 import ShardedModelV2
|
||||
from colossalai.zero.sharded_param import ShardedParamV2
|
||||
from colossalai.utils.model.utils import InsertPostInitMethodToModuleSubClasses
|
||||
|
||||
|
||||
class ZeroContextConfig(object):
|
||||
@dataclass
|
||||
class ZeroContextConfig:
|
||||
"""The configuration used to control zero context initialization.
|
||||
|
||||
Args:
|
||||
target_device (torch.device): The device where param data are after exiting the context.
|
||||
replicated (bool, optional): Whether the param is replicated across data parallel group.
|
||||
is_replicated (bool, optional): Whether the param is replicated across data parallel group.
|
||||
Some parameters are not replicated, e.g. parameters in MOE experts.
|
||||
shard_param (bool, optional): Is param sharded after exiting the context. Defaults to False.
|
||||
"""
|
||||
|
||||
def __init__(self, target_device: torch.device, replicated: bool = True, shard_param: bool = False):
|
||||
super().__init__()
|
||||
target_device: torch.device
|
||||
is_replicated: bool = True
|
||||
shard_param: bool = False
|
||||
|
||||
if shard_param:
|
||||
assert replicated, "Non-replicated parameters can't be sharded."
|
||||
def __post_init__(self):
|
||||
if self.shard_param:
|
||||
assert self.is_replicated, "Non-replicated parameters can't be sharded."
|
||||
|
||||
# replicated no-shard parameters should locate in cuda, since we will broadcast them soon
|
||||
if replicated and not shard_param:
|
||||
assert target_device.type == 'cuda', "Replicated no-shard paramters should locate in cuda."
|
||||
|
||||
self.target_device = target_device
|
||||
self.is_replicated: bool = replicated
|
||||
self.shard_param: bool = shard_param
|
||||
if self.is_replicated and not self.shard_param:
|
||||
assert self.target_device.type == 'cuda', "Replicated no-shard parameters should be located in cuda."
|
||||
|
||||
|
||||
class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
|
||||
|
@ -74,7 +73,7 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
|
|||
self.seed = seed
|
||||
self.dp_process_group = gpc.get_group(ParallelMode.DATA)
|
||||
|
||||
self.config = ZeroContextConfig(target_device=target_device, replicated=True, shard_param=shard_param)
|
||||
self.config = ZeroContextConfig(target_device=target_device, is_replicated=True, shard_param=shard_param)
|
||||
|
||||
ZeroContextMgr().current_context = self
|
||||
|
||||
|
@ -248,7 +247,7 @@ class ZeroContextMgr(metaclass=SingletonMeta):
|
|||
|
||||
def no_shard_zero_context(is_replicated: bool = True) -> AbstractContextManager:
|
||||
return ZeroContextMgr().hijack_context_config(target_device=torch.device('cuda', torch.cuda.current_device()),
|
||||
replicated=is_replicated,
|
||||
is_replicated=is_replicated,
|
||||
shard_param=False)
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue