mirror of https://github.com/hpcaitech/ColossalAI
124 lines
4.5 KiB
Python
124 lines
4.5 KiB
Python
from typing import Dict, Iterator, Optional, Tuple, Union
|
|
|
|
import torch
|
|
from torch import nn
|
|
|
|
from colossalai.nn.parallel.layers import ColoEmbedding, ColoLinear, register_colo_module
|
|
from colossalai.tensor import ColoParameter, ColoTensor, ProcessGroup, ShardSpec
|
|
|
|
from .utils import InsertPostInitMethodToModuleSubClasses
|
|
|
|
# find named_params includes replica
|
|
|
|
|
|
def _named_params_with_replica(
|
|
module: nn.Module,
|
|
prefix: str = '',
|
|
recurse: bool = True,
|
|
) -> Iterator[Tuple[str, Union[nn.Parameter, ColoTensor]]]:
|
|
modules = module.named_modules(prefix=prefix) if recurse else [(prefix, module)]
|
|
|
|
for mod_prefix, mod in modules:
|
|
for name, val in mod._parameters.items():
|
|
if val is None:
|
|
continue
|
|
name = mod_prefix + ('.' if mod_prefix else '') + name
|
|
yield name, val
|
|
|
|
|
|
def ColoModulize(module):
|
|
"""
|
|
Replacing the parameters() and named_parameters() with our customized ones
|
|
"""
|
|
|
|
module._colo_visited = True
|
|
|
|
|
|
class ColoInitContext(InsertPostInitMethodToModuleSubClasses):
|
|
|
|
def __init__(self,
|
|
device: torch.device = torch.device('cpu'),
|
|
dtype: torch.dtype = torch.float,
|
|
default_pg: Optional[ProcessGroup] = None,
|
|
default_dist_spec=None):
|
|
"""
|
|
Args:
|
|
device (torch.device): the device where parameters initialized are resident. Defaults to torch.device('cpu').
|
|
dtype (torch.dtype): the dtype of parameters initialized. Defults to torch.float.
|
|
default_pg (ProcessGroup): the default process group for all initialized parameters.
|
|
default_dist_spec: the default distributed specifications.
|
|
"""
|
|
super().__init__()
|
|
self._device = device
|
|
self._dtype = dtype
|
|
|
|
self._register_colo_modules()
|
|
self._default_pg = default_pg
|
|
self._default_dist_spec = default_dist_spec
|
|
|
|
def _register_colo_modules(self):
|
|
register_colo_module(torch.nn.Linear, ColoLinear())
|
|
register_colo_module(torch.nn.Embedding, ColoEmbedding())
|
|
|
|
def _pre_context_exec(self):
|
|
pass
|
|
|
|
def _post_init_method(self, module: torch.nn.Module, *args, **kwargs):
|
|
"""
|
|
The function to call at the end of the constructor of each module.
|
|
FIXME(fjr) The module may be passed to this function multiple times?
|
|
"""
|
|
|
|
if hasattr(module, '_colo_visited'):
|
|
return
|
|
|
|
name_list = []
|
|
for name, param in _named_params_with_replica(module):
|
|
if isinstance(param, ColoTensor):
|
|
continue
|
|
|
|
split = name.rfind('.')
|
|
if split >= 0: # param in submodule
|
|
module_name = name[:split]
|
|
param_name = name[split + 1:]
|
|
else:
|
|
module_name = '' # param in current module
|
|
param_name = name
|
|
name_list.append((module_name, param_name))
|
|
|
|
replaced_tensors = dict(
|
|
) # record mapping between (torch.Tensor, ColoTensor) to distinguish the same reference
|
|
for module_name, param_name in name_list:
|
|
submodule = module.get_submodule(module_name)
|
|
param = submodule.get_parameter(param_name)
|
|
if param in replaced_tensors:
|
|
colo_param = replaced_tensors[param]
|
|
else:
|
|
# detaching tensor is necessary for optimizers.
|
|
requires_grad = param.requires_grad
|
|
|
|
# param is the global tensor.
|
|
colo_param = ColoParameter(param.to(device=self._device, dtype=self._dtype),
|
|
requires_grad=requires_grad)
|
|
|
|
# if default_shard_plan exists, shard the param during initialization.
|
|
# This can reduce the model size after initialization.
|
|
# NOTE() embedding usually can not be correctly sharded. So I use except to handle
|
|
# the param that can not be sharded by the default plan
|
|
if self._default_pg is not None:
|
|
colo_param.set_process_group(self._default_pg)
|
|
|
|
if self._default_dist_spec is not None:
|
|
try:
|
|
colo_param.set_dist_spec(self._default_dist_spec)
|
|
except:
|
|
pass
|
|
|
|
replaced_tensors[param] = colo_param
|
|
delattr(submodule, param_name)
|
|
setattr(submodule, param_name, colo_param)
|
|
colo_param.shared_param_modules.append(submodule)
|
|
|
|
module.to(self._device)
|
|
ColoModulize(module)
|