mirror of https://github.com/hpcaitech/ColossalAI
[zero] Refactor ZeroContextConfig class using dataclass (#3186)
parent
9d644ff09f
commit
80aed29cd3
|
@ -1,46 +1,45 @@
|
||||||
import contextlib
|
import contextlib
|
||||||
import functools
|
import functools
|
||||||
from typing import Optional
|
|
||||||
from contextlib import AbstractContextManager
|
from contextlib import AbstractContextManager
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
from colossalai.context.parallel_mode import ParallelMode
|
from colossalai.context.parallel_mode import ParallelMode
|
||||||
from colossalai.core import global_context as gpc
|
|
||||||
from colossalai.context.singleton_meta import SingletonMeta
|
from colossalai.context.singleton_meta import SingletonMeta
|
||||||
|
from colossalai.core import global_context as gpc
|
||||||
from colossalai.logging import get_dist_logger
|
from colossalai.logging import get_dist_logger
|
||||||
|
from colossalai.utils.model.utils import InsertPostInitMethodToModuleSubClasses
|
||||||
from colossalai.zero.shard_utils import BaseShardStrategy
|
from colossalai.zero.shard_utils import BaseShardStrategy
|
||||||
from colossalai.zero.sharded_model._utils import cast_tensor_to_fp16
|
from colossalai.zero.sharded_model._utils import cast_tensor_to_fp16
|
||||||
from colossalai.zero.sharded_model.sharded_model_v2 import ShardedModelV2
|
from colossalai.zero.sharded_model.sharded_model_v2 import ShardedModelV2
|
||||||
from colossalai.zero.sharded_param import ShardedParamV2
|
from colossalai.zero.sharded_param import ShardedParamV2
|
||||||
from colossalai.utils.model.utils import InsertPostInitMethodToModuleSubClasses
|
|
||||||
|
|
||||||
|
|
||||||
class ZeroContextConfig(object):
|
@dataclass
|
||||||
|
class ZeroContextConfig:
|
||||||
"""The configuration used to control zero context initialization.
|
"""The configuration used to control zero context initialization.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
target_device (torch.device): The device where param data are after exiting the context.
|
target_device (torch.device): The device where param data are after exiting the context.
|
||||||
replicated (bool, optional): Whether the param is replicated across data parallel group.
|
is_replicated (bool, optional): Whether the param is replicated across data parallel group.
|
||||||
Some parameters are not replicated, e.g. parameters in MOE experts.
|
Some parameters are not replicated, e.g. parameters in MOE experts.
|
||||||
shard_param (bool, optional): Is param sharded after exiting the context. Defaults to False.
|
shard_param (bool, optional): Is param sharded after exiting the context. Defaults to False.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, target_device: torch.device, replicated: bool = True, shard_param: bool = False):
|
target_device: torch.device
|
||||||
super().__init__()
|
is_replicated: bool = True
|
||||||
|
shard_param: bool = False
|
||||||
|
|
||||||
if shard_param:
|
def __post_init__(self):
|
||||||
assert replicated, "Non-replicated parameters can't be sharded."
|
if self.shard_param:
|
||||||
|
assert self.is_replicated, "Non-replicated parameters can't be sharded."
|
||||||
|
|
||||||
# replicated no-shard parameters should locate in cuda, since we will broadcast them soon
|
if self.is_replicated and not self.shard_param:
|
||||||
if replicated and not shard_param:
|
assert self.target_device.type == 'cuda', "Replicated no-shard parameters should be located in cuda."
|
||||||
assert target_device.type == 'cuda', "Replicated no-shard paramters should locate in cuda."
|
|
||||||
|
|
||||||
self.target_device = target_device
|
|
||||||
self.is_replicated: bool = replicated
|
|
||||||
self.shard_param: bool = shard_param
|
|
||||||
|
|
||||||
|
|
||||||
class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
|
class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
|
||||||
|
@ -74,7 +73,7 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
|
||||||
self.seed = seed
|
self.seed = seed
|
||||||
self.dp_process_group = gpc.get_group(ParallelMode.DATA)
|
self.dp_process_group = gpc.get_group(ParallelMode.DATA)
|
||||||
|
|
||||||
self.config = ZeroContextConfig(target_device=target_device, replicated=True, shard_param=shard_param)
|
self.config = ZeroContextConfig(target_device=target_device, is_replicated=True, shard_param=shard_param)
|
||||||
|
|
||||||
ZeroContextMgr().current_context = self
|
ZeroContextMgr().current_context = self
|
||||||
|
|
||||||
|
@ -248,7 +247,7 @@ class ZeroContextMgr(metaclass=SingletonMeta):
|
||||||
|
|
||||||
def no_shard_zero_context(is_replicated: bool = True) -> AbstractContextManager:
|
def no_shard_zero_context(is_replicated: bool = True) -> AbstractContextManager:
|
||||||
return ZeroContextMgr().hijack_context_config(target_device=torch.device('cuda', torch.cuda.current_device()),
|
return ZeroContextMgr().hijack_context_config(target_device=torch.device('cuda', torch.cuda.current_device()),
|
||||||
replicated=is_replicated,
|
is_replicated=is_replicated,
|
||||||
shard_param=False)
|
shard_param=False)
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue