mirror of https://github.com/hpcaitech/ColossalAI
[ColoTensor] reconfig ColoInitContext, decouple default_pg and default_dist_spec. (#1953)
parent
598d456d0e
commit
52c6ad26e0
|
@ -4,7 +4,7 @@ import torch
|
|||
from torch import nn
|
||||
|
||||
from colossalai.nn.parallel.layers import ColoEmbedding, ColoLinear, register_colo_module
|
||||
from colossalai.tensor import ColoParameter, ColoTensor
|
||||
from colossalai.tensor import ColoParameter, ColoTensor, ProcessGroup, ShardSpec
|
||||
|
||||
from .utils import InsertPostInitMethodToModuleSubClasses
|
||||
|
||||
|
@ -39,18 +39,22 @@ class ColoInitContext(InsertPostInitMethodToModuleSubClasses):
|
|||
def __init__(self,
|
||||
device: torch.device = torch.device('cpu'),
|
||||
dtype: torch.dtype = torch.float,
|
||||
default_shard_plan: Optional[Dict] = None):
|
||||
default_pg: Optional[ProcessGroup] = None,
|
||||
default_dist_spec=None):
|
||||
"""
|
||||
Args:
|
||||
device (torch.device): the device where parameters initialized are resident. Defaults to torch.device('cpu').
|
||||
dtype (torch.dtype): the dtype of parameters initialized. Defults to torch.float.
|
||||
default_pg (ProcessGroup): the default process group for all initialized parameters.
|
||||
default_dist_spec: the default distributed specifications.
|
||||
"""
|
||||
super().__init__()
|
||||
self._device = device
|
||||
self._dtype = dtype
|
||||
|
||||
self._register_colo_modules()
|
||||
self._default_shard_plan = default_shard_plan
|
||||
self._default_pg = default_pg
|
||||
self._default_dist_spec = default_dist_spec
|
||||
|
||||
def _register_colo_modules(self):
|
||||
register_colo_module(torch.nn.Linear, ColoLinear())
|
||||
|
@ -68,10 +72,6 @@ class ColoInitContext(InsertPostInitMethodToModuleSubClasses):
|
|||
if hasattr(module, '_colo_visited'):
|
||||
return
|
||||
|
||||
if self._default_shard_plan is not None:
|
||||
default_pg = self._default_shard_plan.get('pg', None)
|
||||
default_shard_spec = self._default_shard_plan.get('shard_spec', None)
|
||||
|
||||
name_list = []
|
||||
for name, param in _named_params_with_replica(module):
|
||||
if isinstance(param, ColoTensor):
|
||||
|
@ -96,7 +96,8 @@ class ColoInitContext(InsertPostInitMethodToModuleSubClasses):
|
|||
else:
|
||||
# detaching tensor is necessary for optimizers.
|
||||
requires_grad = param.requires_grad
|
||||
# TODO(jiaruifang) we initialize a Default PG memory
|
||||
|
||||
# param is the global tensor.
|
||||
colo_param = ColoParameter(param.to(device=self._device, dtype=self._dtype),
|
||||
requires_grad=requires_grad)
|
||||
|
||||
|
@ -104,10 +105,12 @@ class ColoInitContext(InsertPostInitMethodToModuleSubClasses):
|
|||
# This can reduce the model size after initialization.
|
||||
# NOTE() embedding usually can not be correctly sharded. So I use except to handle
|
||||
# the param that can not be sharded by the default plan
|
||||
if self._default_shard_plan is not None:
|
||||
colo_param.set_process_group(default_pg)
|
||||
if self._default_pg is not None:
|
||||
colo_param.set_process_group(self._default_pg)
|
||||
|
||||
if self._default_dist_spec is not None:
|
||||
try:
|
||||
colo_param.set_dist_spec(default_shard_spec)
|
||||
colo_param.set_dist_spec(self._default_dist_spec)
|
||||
except:
|
||||
pass
|
||||
|
||||
|
|
|
@ -37,9 +37,12 @@ def run_colo_init_context(rank: int, world_size: int, port: int):
|
|||
# shard the parameters during init
|
||||
set_seed(42)
|
||||
shard_spec = ReplicaSpec()
|
||||
# ShardSpec(dims=[0], num_partitions=[world_size])
|
||||
default_shard_plan = {'pg': ProcessGroup(tp_degree=world_size), 'shard_spec': shard_spec}
|
||||
with ColoInitContext(device=get_current_device(), default_shard_plan=default_shard_plan):
|
||||
|
||||
# If using ShardSpec, the assertations will failed.
|
||||
# But it is not a bug, the initialized values are not consist with the original one.
|
||||
# shard_spec = ShardSpec(dims=[0], num_partitions=[world_size])
|
||||
default_pg = ProcessGroup(tp_degree=world_size)
|
||||
with ColoInitContext(device=get_current_device(), default_pg=default_pg, default_dist_spec=shard_spec):
|
||||
model2 = model_builder()
|
||||
|
||||
# reshard both models
|
||||
|
|
Loading…
Reference in New Issue