[ColoTensor] reconfig ColoInitContext, decouple default_pg and default_dist_spec. (#1953)

pull/1958/head
Jiarui Fang 2022-11-15 16:24:16 +08:00 committed by GitHub
parent 598d456d0e
commit 52c6ad26e0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 20 additions and 14 deletions

View File

@ -4,7 +4,7 @@ import torch
from torch import nn
from colossalai.nn.parallel.layers import ColoEmbedding, ColoLinear, register_colo_module
from colossalai.tensor import ColoParameter, ColoTensor
from colossalai.tensor import ColoParameter, ColoTensor, ProcessGroup, ShardSpec
from .utils import InsertPostInitMethodToModuleSubClasses
@ -39,18 +39,22 @@ class ColoInitContext(InsertPostInitMethodToModuleSubClasses):
def __init__(self,
device: torch.device = torch.device('cpu'),
dtype: torch.dtype = torch.float,
default_shard_plan: Optional[Dict] = None):
default_pg: Optional[ProcessGroup] = None,
default_dist_spec=None):
"""
Args:
device (torch.device): the device where parameters initialized are resident. Defaults to torch.device('cpu').
dtype (torch.dtype): the dtype of parameters initialized. Defults to torch.float.
default_pg (ProcessGroup): the default process group for all initialized parameters.
default_dist_spec: the default distributed specifications.
"""
super().__init__()
self._device = device
self._dtype = dtype
self._register_colo_modules()
self._default_shard_plan = default_shard_plan
self._default_pg = default_pg
self._default_dist_spec = default_dist_spec
def _register_colo_modules(self):
register_colo_module(torch.nn.Linear, ColoLinear())
@ -68,10 +72,6 @@ class ColoInitContext(InsertPostInitMethodToModuleSubClasses):
if hasattr(module, '_colo_visited'):
return
if self._default_shard_plan is not None:
default_pg = self._default_shard_plan.get('pg', None)
default_shard_spec = self._default_shard_plan.get('shard_spec', None)
name_list = []
for name, param in _named_params_with_replica(module):
if isinstance(param, ColoTensor):
@ -96,7 +96,8 @@ class ColoInitContext(InsertPostInitMethodToModuleSubClasses):
else:
# detaching tensor is necessary for optimizers.
requires_grad = param.requires_grad
# TODO(jiaruifang) we initialize a Default PG memory
# param is the global tensor.
colo_param = ColoParameter(param.to(device=self._device, dtype=self._dtype),
requires_grad=requires_grad)
@ -104,10 +105,12 @@ class ColoInitContext(InsertPostInitMethodToModuleSubClasses):
# This can reduce the model size after initialization.
# NOTE() embedding usually can not be correctly sharded. So I use except to handle
# the param that can not be sharded by the default plan
if self._default_shard_plan is not None:
colo_param.set_process_group(default_pg)
if self._default_pg is not None:
colo_param.set_process_group(self._default_pg)
if self._default_dist_spec is not None:
try:
colo_param.set_dist_spec(default_shard_spec)
colo_param.set_dist_spec(self._default_dist_spec)
except:
pass

View File

@ -37,9 +37,12 @@ def run_colo_init_context(rank: int, world_size: int, port: int):
# shard the parameters during init
set_seed(42)
shard_spec = ReplicaSpec()
# ShardSpec(dims=[0], num_partitions=[world_size])
default_shard_plan = {'pg': ProcessGroup(tp_degree=world_size), 'shard_spec': shard_spec}
with ColoInitContext(device=get_current_device(), default_shard_plan=default_shard_plan):
# If using ShardSpec, the assertations will failed.
# But it is not a bug, the initialized values are not consist with the original one.
# shard_spec = ShardSpec(dims=[0], num_partitions=[world_size])
default_pg = ProcessGroup(tp_degree=world_size)
with ColoInitContext(device=get_current_device(), default_pg=default_pg, default_dist_spec=shard_spec):
model2 = model_builder()
# reshard both models