|
|
|
@ -1,46 +1,45 @@
|
|
|
|
|
import contextlib
|
|
|
|
|
import functools
|
|
|
|
|
from typing import Optional
|
|
|
|
|
from contextlib import AbstractContextManager
|
|
|
|
|
from dataclasses import dataclass
|
|
|
|
|
from typing import Optional
|
|
|
|
|
|
|
|
|
|
import torch
|
|
|
|
|
import torch.nn as nn
|
|
|
|
|
import torch.distributed as dist
|
|
|
|
|
import torch.nn as nn
|
|
|
|
|
|
|
|
|
|
from colossalai.context.parallel_mode import ParallelMode
|
|
|
|
|
from colossalai.core import global_context as gpc
|
|
|
|
|
from colossalai.context.singleton_meta import SingletonMeta
|
|
|
|
|
from colossalai.core import global_context as gpc
|
|
|
|
|
from colossalai.logging import get_dist_logger
|
|
|
|
|
from colossalai.utils.model.utils import InsertPostInitMethodToModuleSubClasses
|
|
|
|
|
from colossalai.zero.shard_utils import BaseShardStrategy
|
|
|
|
|
from colossalai.zero.sharded_model._utils import cast_tensor_to_fp16
|
|
|
|
|
from colossalai.zero.sharded_model.sharded_model_v2 import ShardedModelV2
|
|
|
|
|
from colossalai.zero.sharded_param import ShardedParamV2
|
|
|
|
|
from colossalai.utils.model.utils import InsertPostInitMethodToModuleSubClasses
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ZeroContextConfig(object):
|
|
|
|
|
@dataclass
|
|
|
|
|
class ZeroContextConfig:
|
|
|
|
|
"""The configuration used to control zero context initialization.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
target_device (torch.device): The device where param data are after exiting the context.
|
|
|
|
|
replicated (bool, optional): Whether the param is replicated across data parallel group.
|
|
|
|
|
is_replicated (bool, optional): Whether the param is replicated across data parallel group.
|
|
|
|
|
Some parameters are not replicated, e.g. parameters in MOE experts.
|
|
|
|
|
shard_param (bool, optional): Is param sharded after exiting the context. Defaults to False.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(self, target_device: torch.device, replicated: bool = True, shard_param: bool = False):
|
|
|
|
|
super().__init__()
|
|
|
|
|
target_device: torch.device
|
|
|
|
|
is_replicated: bool = True
|
|
|
|
|
shard_param: bool = False
|
|
|
|
|
|
|
|
|
|
if shard_param:
|
|
|
|
|
assert replicated, "Non-replicated parameters can't be sharded."
|
|
|
|
|
def __post_init__(self):
|
|
|
|
|
if self.shard_param:
|
|
|
|
|
assert self.is_replicated, "Non-replicated parameters can't be sharded."
|
|
|
|
|
|
|
|
|
|
# replicated no-shard parameters should locate in cuda, since we will broadcast them soon
|
|
|
|
|
if replicated and not shard_param:
|
|
|
|
|
assert target_device.type == 'cuda', "Replicated no-shard paramters should locate in cuda."
|
|
|
|
|
|
|
|
|
|
self.target_device = target_device
|
|
|
|
|
self.is_replicated: bool = replicated
|
|
|
|
|
self.shard_param: bool = shard_param
|
|
|
|
|
if self.is_replicated and not self.shard_param:
|
|
|
|
|
assert self.target_device.type == 'cuda', "Replicated no-shard parameters should be located in cuda."
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
|
|
|
|
@ -74,7 +73,7 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
|
|
|
|
|
self.seed = seed
|
|
|
|
|
self.dp_process_group = gpc.get_group(ParallelMode.DATA)
|
|
|
|
|
|
|
|
|
|
self.config = ZeroContextConfig(target_device=target_device, replicated=True, shard_param=shard_param)
|
|
|
|
|
self.config = ZeroContextConfig(target_device=target_device, is_replicated=True, shard_param=shard_param)
|
|
|
|
|
|
|
|
|
|
ZeroContextMgr().current_context = self
|
|
|
|
|
|
|
|
|
@ -124,7 +123,7 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
|
|
|
|
|
return fan_in, fan_out
|
|
|
|
|
|
|
|
|
|
def _pre_context_exec(self):
|
|
|
|
|
"""
|
|
|
|
|
"""
|
|
|
|
|
The Callback function when entering the context
|
|
|
|
|
"""
|
|
|
|
|
self.logger = get_dist_logger("ZeroInitContext")
|
|
|
|
@ -248,7 +247,7 @@ class ZeroContextMgr(metaclass=SingletonMeta):
|
|
|
|
|
|
|
|
|
|
def no_shard_zero_context(is_replicated: bool = True) -> AbstractContextManager:
|
|
|
|
|
return ZeroContextMgr().hijack_context_config(target_device=torch.device('cuda', torch.cuda.current_device()),
|
|
|
|
|
replicated=is_replicated,
|
|
|
|
|
is_replicated=is_replicated,
|
|
|
|
|
shard_param=False)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|