diff --git a/colossalai/utils/model/colo_init_context.py b/colossalai/utils/model/colo_init_context.py index 8d140a1dc..e3861c84f 100644 --- a/colossalai/utils/model/colo_init_context.py +++ b/colossalai/utils/model/colo_init_context.py @@ -4,7 +4,7 @@ import torch from torch import nn from colossalai.nn.parallel.layers import ColoEmbedding, ColoLinear, register_colo_module -from colossalai.tensor import ColoParameter, ColoTensor +from colossalai.tensor import ColoParameter, ColoTensor, ProcessGroup, ShardSpec from .utils import InsertPostInitMethodToModuleSubClasses @@ -39,18 +39,22 @@ class ColoInitContext(InsertPostInitMethodToModuleSubClasses): def __init__(self, device: torch.device = torch.device('cpu'), dtype: torch.dtype = torch.float, - default_shard_plan: Optional[Dict] = None): + default_pg: Optional[ProcessGroup] = None, + default_dist_spec=None): """ Args: device (torch.device): the device where parameters initialized are resident. Defaults to torch.device('cpu'). dtype (torch.dtype): the dtype of parameters initialized. Defults to torch.float. + default_pg (ProcessGroup): the default process group for all initialized parameters. + default_dist_spec: the default distributed specifications. """ super().__init__() self._device = device self._dtype = dtype self._register_colo_modules() - self._default_shard_plan = default_shard_plan + self._default_pg = default_pg + self._default_dist_spec = default_dist_spec def _register_colo_modules(self): register_colo_module(torch.nn.Linear, ColoLinear()) @@ -68,10 +72,6 @@ class ColoInitContext(InsertPostInitMethodToModuleSubClasses): if hasattr(module, '_colo_visited'): return - if self._default_shard_plan is not None: - default_pg = self._default_shard_plan.get('pg', None) - default_shard_spec = self._default_shard_plan.get('shard_spec', None) - name_list = [] for name, param in _named_params_with_replica(module): if isinstance(param, ColoTensor): @@ -96,7 +96,8 @@ class ColoInitContext(InsertPostInitMethodToModuleSubClasses): else: # detaching tensor is necessary for optimizers. requires_grad = param.requires_grad - # TODO(jiaruifang) we initialize a Default PG memory + + # param is the global tensor. colo_param = ColoParameter(param.to(device=self._device, dtype=self._dtype), requires_grad=requires_grad) @@ -104,10 +105,12 @@ class ColoInitContext(InsertPostInitMethodToModuleSubClasses): # This can reduce the model size after initialization. # NOTE() embedding usually can not be correctly sharded. So I use except to handle # the param that can not be sharded by the default plan - if self._default_shard_plan is not None: - colo_param.set_process_group(default_pg) + if self._default_pg is not None: + colo_param.set_process_group(self._default_pg) + + if self._default_dist_spec is not None: try: - colo_param.set_dist_spec(default_shard_spec) + colo_param.set_dist_spec(self._default_dist_spec) except: pass diff --git a/tests/test_tensor/test_context.py b/tests/test_tensor/test_context.py index 3e7f5b475..2f7aebed5 100644 --- a/tests/test_tensor/test_context.py +++ b/tests/test_tensor/test_context.py @@ -37,9 +37,12 @@ def run_colo_init_context(rank: int, world_size: int, port: int): # shard the parameters during init set_seed(42) shard_spec = ReplicaSpec() - # ShardSpec(dims=[0], num_partitions=[world_size]) - default_shard_plan = {'pg': ProcessGroup(tp_degree=world_size), 'shard_spec': shard_spec} - with ColoInitContext(device=get_current_device(), default_shard_plan=default_shard_plan): + + # If using ShardSpec, the assertations will failed. + # But it is not a bug, the initialized values are not consist with the original one. + # shard_spec = ShardSpec(dims=[0], num_partitions=[world_size]) + default_pg = ProcessGroup(tp_degree=world_size) + with ColoInitContext(device=get_current_device(), default_pg=default_pg, default_dist_spec=shard_spec): model2 = model_builder() # reshard both models