mirror of https://github.com/hpcaitech/ColossalAI
272 lines
11 KiB
Python
272 lines
11 KiB
Python
import contextlib
|
|
import functools
|
|
from contextlib import AbstractContextManager
|
|
from dataclasses import dataclass
|
|
from typing import Optional
|
|
|
|
import torch
|
|
import torch.distributed as dist
|
|
import torch.nn as nn
|
|
|
|
from colossalai.context.singleton_meta import SingletonMeta
|
|
from colossalai.legacy.context.parallel_mode import ParallelMode
|
|
from colossalai.legacy.core import global_context as gpc
|
|
from colossalai.legacy.zero.shard_utils import BaseShardStrategy
|
|
from colossalai.legacy.zero.sharded_model._utils import cast_tensor_to_bf16, cast_tensor_to_fp16
|
|
from colossalai.legacy.zero.sharded_model.sharded_model_v2 import ShardedModelV2
|
|
from colossalai.legacy.zero.sharded_param import ShardedParamV2
|
|
from colossalai.logging import get_dist_logger
|
|
from colossalai.utils.model.utils import InsertPostInitMethodToModuleSubClasses
|
|
|
|
|
|
@dataclass
|
|
class ZeroContextConfig:
|
|
"""The configuration used to control zero context initialization.
|
|
|
|
Args:
|
|
target_device (torch.device): The device where param data are after exiting the context.
|
|
is_replicated (bool, optional): Whether the param is replicated across data parallel group.
|
|
Some parameters are not replicated, e.g. parameters in MOE experts.
|
|
shard_param (bool, optional): Is param sharded after exiting the context. Defaults to False.
|
|
"""
|
|
|
|
target_device: torch.device
|
|
is_replicated: bool = True
|
|
shard_param: bool = False
|
|
|
|
def __post_init__(self):
|
|
if self.shard_param:
|
|
assert self.is_replicated, "Non-replicated parameters can't be sharded."
|
|
|
|
if self.is_replicated and not self.shard_param:
|
|
assert self.target_device.type == "cuda", "Replicated no-shard parameters should be located in cuda."
|
|
|
|
|
|
class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
|
|
"""A context to initialize model.
|
|
|
|
1. Convert the model to fp16.
|
|
2. The parameters of the module are adapted to type ShardedParameter.
|
|
3. Shard the param and grad according to flags.
|
|
|
|
Args:
|
|
target_device (torch.device): The device where param data are after exiting the context.
|
|
shard_strategy (BaseShardStrategy): Shard strategy instance.
|
|
seed (int, optional): Random seed for weight initialization
|
|
shard_param (bool, optional): Is param sharded after exiting the context. Defaults to False.
|
|
default_dtype (torch.dtype, optional): If it's not None, parameters will be initialized as ``default_dtype`` then converted to fp16.
|
|
bf16 (bool, optional): If it's True, parameters will be initialized as ``torch.bfloat16``. Otherwise, parameters will be initialized as ``torch.float16``. Defaults to False.
|
|
model_numel_tensor (torch.Tensor, optional): A tensor which will store the number of elements of model. Defaults to torch.zeros(1, dtype=torch.int).
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
target_device: torch.device,
|
|
shard_strategy: BaseShardStrategy,
|
|
seed: int = 2**10 - 1,
|
|
shard_param: bool = False,
|
|
default_dtype: Optional[torch.dtype] = None,
|
|
bf16: bool = False,
|
|
model_numel_tensor: torch.Tensor = torch.zeros(1, dtype=torch.long),
|
|
):
|
|
super().__init__(default_dtype=default_dtype)
|
|
self.shard_strategy = shard_strategy
|
|
self.param_list = []
|
|
self.model_numel_tensor = model_numel_tensor
|
|
self.seed = seed
|
|
self.bf16 = bf16
|
|
self.dp_process_group = gpc.get_group(ParallelMode.DATA)
|
|
|
|
self.config = ZeroContextConfig(target_device=target_device, is_replicated=True, shard_param=shard_param)
|
|
|
|
ZeroContextMgr().current_context = self
|
|
|
|
self.param_numel = {}
|
|
self.top_module = None
|
|
|
|
@property
|
|
def target_device(self):
|
|
return self.config.target_device
|
|
|
|
@property
|
|
def is_replicated(self):
|
|
return self.config.is_replicated
|
|
|
|
@property
|
|
def shard_param(self):
|
|
return self.config.shard_param
|
|
|
|
@staticmethod
|
|
def calc_fanin_fanout(tensor: torch.Tensor):
|
|
"""We use this function to substitute fan-in and fan-out calculation in torch.nn.init.
|
|
This can help us get correct fan-in and fan-out for sharded tensor.
|
|
"""
|
|
assert isinstance(tensor, nn.Parameter), "Sharded tensor initialization is only allowed for parameters"
|
|
|
|
# get correct shape of input tensor
|
|
if not hasattr(tensor, "colo_attr") or not tensor.colo_attr.param_is_sharded:
|
|
tensor_shape = tensor.shape
|
|
else:
|
|
tensor_shape = tensor.colo_attr.sharded_data_tensor.origin_shape
|
|
|
|
dimensions = len(tensor_shape)
|
|
if dimensions < 2:
|
|
raise ValueError("Fan in and fan out can not be computed for tensor with fewer than 2 dimensions")
|
|
|
|
num_input_fmaps = tensor_shape[1]
|
|
num_output_fmaps = tensor_shape[0]
|
|
receptive_field_size = 1
|
|
if dimensions > 2:
|
|
# math.prod is not always available, accumulate the product manually
|
|
# we could use functools.reduce but that is not supported by TorchScript
|
|
for s in tensor_shape[2:]:
|
|
receptive_field_size *= s
|
|
fan_in = num_input_fmaps * receptive_field_size
|
|
fan_out = num_output_fmaps * receptive_field_size
|
|
|
|
return fan_in, fan_out
|
|
|
|
def _pre_context_exec(self):
|
|
"""
|
|
The Callback function when entering the context
|
|
"""
|
|
self.logger = get_dist_logger("ZeroInitContext")
|
|
|
|
# substitute fan-in and fan-out calculation
|
|
self.nn_fanin_fanout = nn.init._calculate_fan_in_and_fan_out
|
|
nn.init._calculate_fan_in_and_fan_out = self.calc_fanin_fanout
|
|
|
|
self.module_load_from_state_dict = nn.Module._load_from_state_dict
|
|
shard_strategy = self.shard_strategy if self.config.shard_param else None
|
|
nn.Module._load_from_state_dict = functools.partialmethod(
|
|
ShardedModelV2._colo_load_from_state_dict, shard_strategy=shard_strategy
|
|
)
|
|
self.module_state_dict = nn.Module.state_dict
|
|
nn.Module.state_dict = functools.partialmethod(
|
|
ShardedModelV2._colo_state_dict,
|
|
shard_strategy=shard_strategy,
|
|
state_dict_func=self.module_state_dict,
|
|
process_group=self.dp_process_group,
|
|
)
|
|
|
|
# reserve rng states
|
|
self.cpu_rng_state = torch.get_rng_state()
|
|
self.cuda_rng_state = torch.cuda.get_rng_state()
|
|
|
|
# set new seed for initialization, since we initialize sharded tensor separately
|
|
# we don't want all processes have the same seed
|
|
# otherwise all sharded tensors are same after init
|
|
offset = self.seed + 1 # we want to have more 1 in binary format seed
|
|
torch.manual_seed(self.seed + offset * dist.get_rank())
|
|
|
|
def _post_context_exec(self):
|
|
"""The callback function when exiting context."""
|
|
# broadcast replicated no-shard parameters
|
|
src_rank = gpc.get_ranks_in_group(ParallelMode.DATA)[0]
|
|
for param in self.param_list:
|
|
assert hasattr(param, "colo_attr")
|
|
if not param.colo_attr.param_is_sharded and param.colo_attr.is_replicated:
|
|
dist.broadcast(tensor=param.data, src=src_rank, group=self.dp_process_group)
|
|
param.colo_attr.set_data_none()
|
|
|
|
del self.param_list
|
|
|
|
nn.init._calculate_fan_in_and_fan_out = self.nn_fanin_fanout
|
|
nn.Module.load_state_dict = self.module_load_from_state_dict
|
|
nn.Module.state_dict = self.module_state_dict
|
|
torch.set_rng_state(self.cpu_rng_state)
|
|
torch.cuda.set_rng_state(self.cuda_rng_state)
|
|
|
|
params = frozenset(self.top_module.parameters())
|
|
for param in self.param_numel.keys():
|
|
if param not in params:
|
|
self.param_numel[param] = 0
|
|
self.model_numel_tensor.fill_(sum(self.param_numel.values()))
|
|
|
|
def _post_init_method(self, module: torch.nn.Module, *args, **kwargs):
|
|
"""
|
|
The function to call at the end of the constructor of each module.
|
|
NOTE() The module may be passed to this function multiple times.
|
|
"""
|
|
self.top_module = module
|
|
half_dtype = torch.float16 if not self.bf16 else torch.bfloat16
|
|
|
|
def half_fn(t: torch.Tensor):
|
|
return t.to(half_dtype) if t.is_floating_point() else t
|
|
|
|
for param in module.parameters(recurse=False):
|
|
# avoid adapting a param to ShardedParam twice
|
|
if hasattr(param, "colo_attr"):
|
|
continue
|
|
|
|
self.param_numel[param] = param.numel()
|
|
|
|
# convert parameters to half
|
|
param_half = half_fn(param)
|
|
param.data = param_half
|
|
if param.grad is not None:
|
|
grad_half = half_fn(param.grad)
|
|
param.grad.data = grad_half
|
|
|
|
# move torch parameters to the target device
|
|
target_device = self.target_device
|
|
param.data = param.data.to(target_device)
|
|
if param.grad is not None:
|
|
param.grad = param.grad.to(target_device)
|
|
|
|
param.colo_attr = ShardedParamV2(param, set_data_none=True)
|
|
|
|
if self.shard_param:
|
|
self.shard_strategy.shard([param.colo_attr.sharded_data_tensor], self.dp_process_group)
|
|
|
|
param.data = param.colo_attr.data_payload # set param.data to payload
|
|
|
|
# mark whether the param is replicated
|
|
param.colo_attr.is_replicated = self.is_replicated
|
|
|
|
# mark whether the param should keep not sharded
|
|
# if True, the param is used as Zero stage 2
|
|
param.colo_attr.keep_not_shard = not self.shard_param
|
|
|
|
self.param_list.append(param)
|
|
|
|
# We must cast buffers
|
|
# If we use BN, buffers may be on CPU and Float
|
|
# We must cast them
|
|
cast_fn = cast_tensor_to_fp16 if not self.bf16 else cast_tensor_to_bf16
|
|
for buffer in module.buffers(recurse=False):
|
|
buffer.data = buffer.data.to(device=torch.cuda.current_device())
|
|
buffer.data = cast_fn(buffer.data)
|
|
|
|
|
|
class ZeroContextMgr(metaclass=SingletonMeta):
|
|
current_context: Optional[ZeroInitContext] = None
|
|
|
|
@contextlib.contextmanager
|
|
def hijack_context_config(self, **kwargs):
|
|
if self.current_context is None:
|
|
yield
|
|
else:
|
|
old_config = self.current_context.config
|
|
self.current_context.config = ZeroContextConfig(**kwargs)
|
|
yield
|
|
self.current_context.config = old_config
|
|
|
|
|
|
def no_shard_zero_context(is_replicated: bool = True) -> AbstractContextManager:
|
|
return ZeroContextMgr().hijack_context_config(
|
|
target_device=torch.device("cuda", torch.cuda.current_device()), is_replicated=is_replicated, shard_param=False
|
|
)
|
|
|
|
|
|
def no_shard_zero_decrator(is_replicated: bool = True):
|
|
def _wrapper(init_func):
|
|
def _no_shard(*args, **kwargs):
|
|
with no_shard_zero_context(is_replicated):
|
|
ret = init_func(*args, **kwargs)
|
|
return ret
|
|
|
|
return _no_shard
|
|
|
|
return _wrapper
|