diff --git a/colossalai/utils/model/colo_init_context.py b/colossalai/utils/model/colo_init_context.py index 95e9d4090..8d140a1dc 100644 --- a/colossalai/utils/model/colo_init_context.py +++ b/colossalai/utils/model/colo_init_context.py @@ -1,4 +1,4 @@ -from typing import Iterator, Tuple, Union +from typing import Dict, Iterator, Optional, Tuple, Union import torch from torch import nn @@ -36,7 +36,10 @@ def ColoModulize(module): class ColoInitContext(InsertPostInitMethodToModuleSubClasses): - def __init__(self, device: torch.device = torch.device('cpu'), dtype: torch.dtype = torch.float): + def __init__(self, + device: torch.device = torch.device('cpu'), + dtype: torch.dtype = torch.float, + default_shard_plan: Optional[Dict] = None): """ Args: device (torch.device): the device where parameters initialized are resident. Defaults to torch.device('cpu'). @@ -47,6 +50,7 @@ class ColoInitContext(InsertPostInitMethodToModuleSubClasses): self._dtype = dtype self._register_colo_modules() + self._default_shard_plan = default_shard_plan def _register_colo_modules(self): register_colo_module(torch.nn.Linear, ColoLinear()) @@ -64,6 +68,10 @@ class ColoInitContext(InsertPostInitMethodToModuleSubClasses): if hasattr(module, '_colo_visited'): return + if self._default_shard_plan is not None: + default_pg = self._default_shard_plan.get('pg', None) + default_shard_spec = self._default_shard_plan.get('shard_spec', None) + name_list = [] for name, param in _named_params_with_replica(module): if isinstance(param, ColoTensor): @@ -91,7 +99,18 @@ class ColoInitContext(InsertPostInitMethodToModuleSubClasses): # TODO(jiaruifang) we initialize a Default PG memory colo_param = ColoParameter(param.to(device=self._device, dtype=self._dtype), requires_grad=requires_grad) - # add mapping record + + # if default_shard_plan exists, shard the param during initialization. + # This can reduce the model size after initialization. + # NOTE() embedding usually can not be correctly sharded. So I use except to handle + # the param that can not be sharded by the default plan + if self._default_shard_plan is not None: + colo_param.set_process_group(default_pg) + try: + colo_param.set_dist_spec(default_shard_spec) + except: + pass + replaced_tensors[param] = colo_param delattr(submodule, param_name) setattr(submodule, param_name, colo_param) diff --git a/tests/test_tensor/test_context.py b/tests/test_tensor/test_context.py index 0dc9b8c49..3e7f5b475 100644 --- a/tests/test_tensor/test_context.py +++ b/tests/test_tensor/test_context.py @@ -1,5 +1,66 @@ +from functools import partial + import pytest import torch +import torch.multiprocessing as mp +import colossalai +from colossalai.tensor import ( + ColoParameter, + ColoTensorSpec, + ComputePattern, + ComputeSpec, + ProcessGroup, + ReplicaSpec, + ShardSpec, +) +from colossalai.testing import parameterize, rerun_if_address_is_in_use +from colossalai.utils import free_port from colossalai.utils.cuda import get_current_device from colossalai.utils.model.colo_init_context import ColoInitContext +from tests.components_to_test.registry import non_distributed_component_funcs +from tests.test_tensor.common_utils import set_seed + + +def run_colo_init_context(rank: int, world_size: int, port: int): + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + + # make sure seed of each process is the same, so the params are consistent among processes and the params are exactly replicated. + set_seed(42) + get_components_func = non_distributed_component_funcs.get_callable('gpt2') + model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() + + # keep parameters replicated during init + with ColoInitContext(device=get_current_device()): + model1 = model_builder() + + # shard the parameters during init + set_seed(42) + shard_spec = ReplicaSpec() + # ShardSpec(dims=[0], num_partitions=[world_size]) + default_shard_plan = {'pg': ProcessGroup(tp_degree=world_size), 'shard_spec': shard_spec} + with ColoInitContext(device=get_current_device(), default_shard_plan=default_shard_plan): + model2 = model_builder() + + # reshard both models + new_shard = ShardSpec(dims=[-1], num_partitions=[world_size]) + for p1, p2 in zip(model1.parameters(), model2.parameters()): + p1: ColoParameter = p1 + p1.set_process_group(ProcessGroup(tp_degree=world_size)) + p1.set_dist_spec(new_shard) + p2.set_dist_spec(new_shard) + + for p1, p2 in zip(model1.parameters(), model2.parameters()): + assert (torch.allclose(p1, p2)) + + +@pytest.mark.dist +@pytest.mark.parametrize('world_size', [1, 4]) +@rerun_if_address_is_in_use() +def test_colo_init_context(world_size): + run_func = partial(run_colo_init_context, world_size=world_size, port=free_port()) + mp.spawn(run_func, nprocs=world_size) + + +if __name__ == '__main__': + test_colo_init_context(2) diff --git a/tests/test_tensor/test_sharded_linear.py b/tests/test_tensor/test_sharded_linear.py index 7aedb0d5e..85008c67a 100644 --- a/tests/test_tensor/test_sharded_linear.py +++ b/tests/test_tensor/test_sharded_linear.py @@ -1,5 +1,4 @@ from functools import partial -from lib2to3 import pgen2 import pytest import torch diff --git a/tests/test_tensor/test_tp_with_zero.py b/tests/test_tensor/test_tp_with_zero.py index ad5a83e57..9ea274fd1 100644 --- a/tests/test_tensor/test_tp_with_zero.py +++ b/tests/test_tensor/test_tp_with_zero.py @@ -18,7 +18,7 @@ from colossalai.utils.cuda import get_current_device from colossalai.utils.model.colo_init_context import ColoInitContext from colossalai.zero import ZeroOptimizer from tests.components_to_test.registry import non_distributed_component_funcs -from tests.test_tensor.common_utils import set_seed, tensor_equal, tensor_shard_equal +from tests.test_tensor.common_utils import set_seed, tensor_shard_equal from tests.test_tensor.model.test_gpt2 import init_megatron_spec