You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
ColossalAI/colossalai/legacy/zero/init_ctx/init_context.py

272 lines
11 KiB

import contextlib
import functools
from contextlib import AbstractContextManager
from dataclasses import dataclass
from typing import Optional
import torch
import torch.distributed as dist
import torch.nn as nn
from colossalai.context.singleton_meta import SingletonMeta
from colossalai.legacy.context.parallel_mode import ParallelMode
from colossalai.legacy.core import global_context as gpc
from colossalai.legacy.zero.shard_utils import BaseShardStrategy
from colossalai.legacy.zero.sharded_model._utils import cast_tensor_to_bf16, cast_tensor_to_fp16
from colossalai.legacy.zero.sharded_model.sharded_model_v2 import ShardedModelV2
from colossalai.legacy.zero.sharded_param import ShardedParamV2
from colossalai.logging import get_dist_logger
from colossalai.utils.model.utils import InsertPostInitMethodToModuleSubClasses
@dataclass
class ZeroContextConfig:
"""The configuration used to control zero context initialization.
Args:
target_device (torch.device): The device where param data are after exiting the context.
is_replicated (bool, optional): Whether the param is replicated across data parallel group.
Some parameters are not replicated, e.g. parameters in MOE experts.
shard_param (bool, optional): Is param sharded after exiting the context. Defaults to False.
"""
target_device: torch.device
is_replicated: bool = True
shard_param: bool = False
def __post_init__(self):
if self.shard_param:
assert self.is_replicated, "Non-replicated parameters can't be sharded."
if self.is_replicated and not self.shard_param:
assert self.target_device.type == "cuda", "Replicated no-shard parameters should be located in cuda."
class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
"""A context to initialize model.
1. Convert the model to fp16.
2. The parameters of the module are adapted to type ShardedParameter.
3. Shard the param and grad according to flags.
Args:
target_device (torch.device): The device where param data are after exiting the context.
shard_strategy (BaseShardStrategy): Shard strategy instance.
seed (int, optional): Random seed for weight initialization
shard_param (bool, optional): Is param sharded after exiting the context. Defaults to False.
default_dtype (torch.dtype, optional): If it's not None, parameters will be initialized as ``default_dtype`` then converted to fp16.
bf16 (bool, optional): If it's True, parameters will be initialized as ``torch.bfloat16``. Otherwise, parameters will be initialized as ``torch.float16``. Defaults to False.
model_numel_tensor (torch.Tensor, optional): A tensor which will store the number of elements of model. Defaults to torch.zeros(1, dtype=torch.int).
"""
def __init__(
self,
target_device: torch.device,
shard_strategy: BaseShardStrategy,
seed: int = 2**10 - 1,
shard_param: bool = False,
default_dtype: Optional[torch.dtype] = None,
bf16: bool = False,
model_numel_tensor: torch.Tensor = torch.zeros(1, dtype=torch.long),
):
super().__init__(default_dtype=default_dtype)
self.shard_strategy = shard_strategy
self.param_list = []
self.model_numel_tensor = model_numel_tensor
self.seed = seed
self.bf16 = bf16
self.dp_process_group = gpc.get_group(ParallelMode.DATA)
self.config = ZeroContextConfig(target_device=target_device, is_replicated=True, shard_param=shard_param)
ZeroContextMgr().current_context = self
self.param_numel = {}
self.top_module = None
@property
def target_device(self):
return self.config.target_device
@property
def is_replicated(self):
return self.config.is_replicated
@property
def shard_param(self):
return self.config.shard_param
@staticmethod
def calc_fanin_fanout(tensor: torch.Tensor):
"""We use this function to substitute fan-in and fan-out calculation in torch.nn.init.
This can help us get correct fan-in and fan-out for sharded tensor.
"""
assert isinstance(tensor, nn.Parameter), "Sharded tensor initialization is only allowed for parameters"
# get correct shape of input tensor
if not hasattr(tensor, "colo_attr") or not tensor.colo_attr.param_is_sharded:
tensor_shape = tensor.shape
else:
tensor_shape = tensor.colo_attr.sharded_data_tensor.origin_shape
dimensions = len(tensor_shape)
if dimensions < 2:
raise ValueError("Fan in and fan out can not be computed for tensor with fewer than 2 dimensions")
num_input_fmaps = tensor_shape[1]
num_output_fmaps = tensor_shape[0]
receptive_field_size = 1
if dimensions > 2:
# math.prod is not always available, accumulate the product manually
# we could use functools.reduce but that is not supported by TorchScript
for s in tensor_shape[2:]:
receptive_field_size *= s
fan_in = num_input_fmaps * receptive_field_size
fan_out = num_output_fmaps * receptive_field_size
return fan_in, fan_out
def _pre_context_exec(self):
"""
The Callback function when entering the context
"""
self.logger = get_dist_logger("ZeroInitContext")
# substitute fan-in and fan-out calculation
self.nn_fanin_fanout = nn.init._calculate_fan_in_and_fan_out
nn.init._calculate_fan_in_and_fan_out = self.calc_fanin_fanout
self.module_load_from_state_dict = nn.Module._load_from_state_dict
shard_strategy = self.shard_strategy if self.config.shard_param else None
nn.Module._load_from_state_dict = functools.partialmethod(
ShardedModelV2._colo_load_from_state_dict, shard_strategy=shard_strategy
)
self.module_state_dict = nn.Module.state_dict
nn.Module.state_dict = functools.partialmethod(
ShardedModelV2._colo_state_dict,
shard_strategy=shard_strategy,
state_dict_func=self.module_state_dict,
process_group=self.dp_process_group,
)
# reserve rng states
self.cpu_rng_state = torch.get_rng_state()
self.cuda_rng_state = torch.cuda.get_rng_state()
# set new seed for initialization, since we initialize sharded tensor separately
# we don't want all processes have the same seed
# otherwise all sharded tensors are same after init
offset = self.seed + 1 # we want to have more 1 in binary format seed
torch.manual_seed(self.seed + offset * dist.get_rank())
def _post_context_exec(self):
"""The callback function when exiting context."""
# broadcast replicated no-shard parameters
src_rank = gpc.get_ranks_in_group(ParallelMode.DATA)[0]
for param in self.param_list:
assert hasattr(param, "colo_attr")
if not param.colo_attr.param_is_sharded and param.colo_attr.is_replicated:
dist.broadcast(tensor=param.data, src=src_rank, group=self.dp_process_group)
param.colo_attr.set_data_none()
del self.param_list
nn.init._calculate_fan_in_and_fan_out = self.nn_fanin_fanout
nn.Module.load_state_dict = self.module_load_from_state_dict
nn.Module.state_dict = self.module_state_dict
torch.set_rng_state(self.cpu_rng_state)
torch.cuda.set_rng_state(self.cuda_rng_state)
params = frozenset(self.top_module.parameters())
for param in self.param_numel.keys():
if param not in params:
self.param_numel[param] = 0
self.model_numel_tensor.fill_(sum(self.param_numel.values()))
def _post_init_method(self, module: torch.nn.Module, *args, **kwargs):
"""
The function to call at the end of the constructor of each module.
NOTE() The module may be passed to this function multiple times.
"""
self.top_module = module
half_dtype = torch.float16 if not self.bf16 else torch.bfloat16
def half_fn(t: torch.Tensor):
return t.to(half_dtype) if t.is_floating_point() else t
for param in module.parameters(recurse=False):
# avoid adapting a param to ShardedParam twice
if hasattr(param, "colo_attr"):
continue
self.param_numel[param] = param.numel()
# convert parameters to half
param_half = half_fn(param)
param.data = param_half
if param.grad is not None:
grad_half = half_fn(param.grad)
param.grad.data = grad_half
# move torch parameters to the target device
target_device = self.target_device
param.data = param.data.to(target_device)
if param.grad is not None:
param.grad = param.grad.to(target_device)
param.colo_attr = ShardedParamV2(param, set_data_none=True)
if self.shard_param:
self.shard_strategy.shard([param.colo_attr.sharded_data_tensor], self.dp_process_group)
param.data = param.colo_attr.data_payload # set param.data to payload
# mark whether the param is replicated
param.colo_attr.is_replicated = self.is_replicated
# mark whether the param should keep not sharded
# if True, the param is used as Zero stage 2
param.colo_attr.keep_not_shard = not self.shard_param
self.param_list.append(param)
# We must cast buffers
# If we use BN, buffers may be on CPU and Float
# We must cast them
cast_fn = cast_tensor_to_fp16 if not self.bf16 else cast_tensor_to_bf16
for buffer in module.buffers(recurse=False):
buffer.data = buffer.data.to(device=torch.cuda.current_device())
buffer.data = cast_fn(buffer.data)
class ZeroContextMgr(metaclass=SingletonMeta):
current_context: Optional[ZeroInitContext] = None
@contextlib.contextmanager
def hijack_context_config(self, **kwargs):
if self.current_context is None:
yield
else:
old_config = self.current_context.config
self.current_context.config = ZeroContextConfig(**kwargs)
yield
self.current_context.config = old_config
def no_shard_zero_context(is_replicated: bool = True) -> AbstractContextManager:
return ZeroContextMgr().hijack_context_config(
target_device=torch.device("cuda", torch.cuda.current_device()), is_replicated=is_replicated, shard_param=False
)
def no_shard_zero_decrator(is_replicated: bool = True):
def _wrapper(init_func):
def _no_shard(*args, **kwargs):
with no_shard_zero_context(is_replicated):
ret = init_func(*args, **kwargs)
return ret
return _no_shard
return _wrapper