From 32291dd73f92843b4f221226a95f2f7048b1766b Mon Sep 17 00:00:00 2001 From: Ziyue Jiang Date: Thu, 26 May 2022 11:50:44 +0800 Subject: [PATCH] [Tensor] add module handler for linear (#1021) * add module spec for linear * polish * polish * polish --- colossalai/tensor/__init__.py | 6 +- colossalai/tensor/module_utils.py | 92 +++++++++++++ colossalai/tensor/modules/__init__.py | 2 + colossalai/tensor/modules/colo_module.py | 51 ++++++++ colossalai/tensor/modules/linear.py | 39 ++++++ colossalai/utils/model/colo_init_context.py | 16 ++- tests/test_tensor/test_module_spec.py | 137 ++++++++++++++++++++ 7 files changed, 341 insertions(+), 2 deletions(-) create mode 100644 colossalai/tensor/module_utils.py create mode 100644 colossalai/tensor/modules/__init__.py create mode 100644 colossalai/tensor/modules/colo_module.py create mode 100644 colossalai/tensor/modules/linear.py create mode 100644 tests/test_tensor/test_module_spec.py diff --git a/colossalai/tensor/__init__.py b/colossalai/tensor/__init__.py index 8b8d18ce7..950f1f3a0 100644 --- a/colossalai/tensor/__init__.py +++ b/colossalai/tensor/__init__.py @@ -8,8 +8,12 @@ from ._ops import * from .optim.colo_optimizer import ColoOptimizer from . import distspec from .dist_spec_mgr import DistSpecManager +from .module_utils import register_colo_module, is_colo_module, get_colo_module, init_colo_module, check_colo_module +from .modules import ColoLinear __all__ = [ 'ColoTensor', 'convert_parameter', 'colo_op_impl', 'ComputePattern', 'TensorSpec', 'ParallelAction', - 'named_params_with_colotensor', 'ColoOptimizer', 'ColoParameter', 'distspec', 'DistSpecManager' + 'named_params_with_colotensor', 'ColoOptimizer', 'ColoParameter', 'distspec', 'DistSpecManager', + 'register_colo_module', 'is_colo_module', 'get_colo_module', 'init_colo_module', 'check_colo_module', + 'ColoLinear' ] diff --git a/colossalai/tensor/module_utils.py b/colossalai/tensor/module_utils.py new file mode 100644 index 000000000..6f449aa95 --- /dev/null +++ b/colossalai/tensor/module_utils.py @@ -0,0 +1,92 @@ +from typing import Dict +from colossalai.tensor import ColoParameter, ParallelAction, TensorSpec +from .modules import ColoModule +import torch + +_COLOSSAL_MODULES: Dict[type, ColoModule] = {} + + +def register_colo_module(module_type: type, colo_module: ColoModule): + global _COLOSSAL_MODULES + _COLOSSAL_MODULES[module_type] = colo_module + +def is_colo_module(module: torch.nn.Module): + global _COLOSSAL_MODULES + return type(module) in _COLOSSAL_MODULES + +def get_colo_module(module: torch.nn.Module): + global _COLOSSAL_MODULES + if is_colo_module(module): + colo_module = _COLOSSAL_MODULES[type(module)] + colo_module.register() + return colo_module + else: + return None + +def check_colo_module(module: torch.nn.Module, recursive=True): + if is_colo_module(module): + colo_module = get_colo_module(module) + param_names = colo_module.get_param_names() + compute_pattern = None + for param_name in param_names: + param = module.get_parameter(param_name) + if not isinstance(param, ColoParameter): + raise Exception(f'Invalid ColoParameter spec: {param} in {module} is not a ColoParameter.') + if param.has_spec(): + cur_compute_pattern = param.spec.parallel_action.compute_pattern + if compute_pattern is None: + compute_pattern = cur_compute_pattern + else: + if cur_compute_pattern != compute_pattern: + raise Exception(f'Invalid ColoParameter spec: Params in {module} have different compute_pattern.') + else: + continue + + if compute_pattern is not None: + if not colo_module.has_compute_pattern(compute_pattern): + raise Exception(f'Invalid ColoParameter spec: ComputePattern {compute_pattern} in {module} is not allowed.') + + match_specs = False + allowed_specs = colo_module.get_dist_specs(compute_pattern) + for _, param_specs in allowed_specs.items(): + cur_match = True + for param_name, dist_spec in param_specs.items(): + param = module.get_parameter(param_name) + if param.has_spec(): + if dist_spec != param.spec.dist_spec: + cur_match = False + break + else: + if dist_spec is not None: + cur_match = False + break + if cur_match == True: + match_specs = True + break + if match_specs == False: + raise Exception(f'Invalid ColoParameter spec: Params in {module} are incorrectly sharded.') + + if recursive == True: + for submodule in module.children(): + check_colo_module(submodule, recursive=True) + +def init_colo_module(module: torch.nn.Module, parallel_action: ParallelAction, recursive=True, label='default'): + compute_pattern = parallel_action.compute_pattern + if is_colo_module(module): + # for each param + # set DistSpec and ParallelAction + colo_module = get_colo_module(module) + if not colo_module.has_compute_pattern_with_label(compute_pattern, label=label): + raise NotImplementedError + for param_name, dist_spec in colo_module.get_dist_specs_with_label(compute_pattern, label=label).items(): + if dist_spec is None: + continue + param = module.get_parameter(param_name) + if isinstance(param, ColoParameter): + spec = TensorSpec(dist_spec, parallel_action) + param.set_spec(spec) + check_colo_module(module, recursive=False) + if recursive == True: + for submodule in module.children(): + init_colo_module(submodule, parallel_action, recursive=True, label=label) + \ No newline at end of file diff --git a/colossalai/tensor/modules/__init__.py b/colossalai/tensor/modules/__init__.py new file mode 100644 index 000000000..15f10534e --- /dev/null +++ b/colossalai/tensor/modules/__init__.py @@ -0,0 +1,2 @@ +from .colo_module import ColoModule +from .linear import ColoLinear \ No newline at end of file diff --git a/colossalai/tensor/modules/colo_module.py b/colossalai/tensor/modules/colo_module.py new file mode 100644 index 000000000..c0aa37e48 --- /dev/null +++ b/colossalai/tensor/modules/colo_module.py @@ -0,0 +1,51 @@ +from colossalai.tensor.distspec import _DistSpec +from colossalai.tensor import ComputePattern +from typing import List, Dict + + +class ColoModule(object): + def __init__(self): + self._shard_params: List[str] = [] + # Example: + # {ComputePattern.TP1D: + # 'default': + # 'weight': + # distspec.shard(xxxxx) + # 'bias': + # distspec.shard(xxxxx) + # 'row': ... + # 'col': ... + # } + self._allowed_patterns: Dict[ComputePattern, Dict[str, Dict[str, _DistSpec]]] = {} + + def _register_shard_params(self, params: List[str]): + self._shard_params = params + + def _register_allowed_patterns(self, compute_pattern: ComputePattern, dist_specs: Dict[str, _DistSpec], label='default'): + assert list(dist_specs.keys()).sort() == self._shard_params.sort(), 'Every registered param should have dist_spec.' + if not compute_pattern in self._allowed_patterns: + self._allowed_patterns[compute_pattern] = {} + self._allowed_patterns[compute_pattern][label] = dist_specs + + def _set_default(self, compute_pattern: ComputePattern, target_label): + self._allowed_patterns[compute_pattern]['default'] = self._allowed_patterns[compute_pattern][target_label] + + def has_compute_pattern(self, compute_pattern: ComputePattern): + return compute_pattern in self._allowed_patterns + + def get_dist_specs(self, compute_pattern: ComputePattern): + assert self.has_compute_pattern(compute_pattern) + return self._allowed_patterns[compute_pattern] + + def has_compute_pattern_with_label(self, compute_pattern: ComputePattern, label='default'): + return compute_pattern in self._allowed_patterns and label in self._allowed_patterns[compute_pattern] + + def get_dist_specs_with_label(self, compute_pattern: ComputePattern, label='default'): + assert self.has_compute_pattern_with_label(compute_pattern, label) + return self._allowed_patterns[compute_pattern][label] + + def get_param_names(self): + return self._shard_params + + def register(self): + raise NotImplementedError \ No newline at end of file diff --git a/colossalai/tensor/modules/linear.py b/colossalai/tensor/modules/linear.py new file mode 100644 index 000000000..1ff22fdd9 --- /dev/null +++ b/colossalai/tensor/modules/linear.py @@ -0,0 +1,39 @@ +from .colo_module import ColoModule +from colossalai.tensor import ComputePattern, distspec +from colossalai.core import global_context as gpc +from colossalai.context.parallel_mode import ParallelMode + +class ColoLinear(ColoModule): + def __init__(self): + super(ColoLinear, self).__init__() + self._register_shard_params(['weight', 'bias']) + self._register = False + + def register(self): + if self._register == False: + self._set_TP1D() + self._register = True + + def _set_TP1D(self): + # TP1D Row Linear + _compute_pattern = ComputePattern.TP1D + self._register_allowed_patterns( + compute_pattern=_compute_pattern, + dist_specs={ + 'weight': distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]), + 'bias': None + }, + label='row', + ) + + # TP1D Col Linear + self._register_allowed_patterns( + compute_pattern=_compute_pattern, + dist_specs={ + 'weight': distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]), + 'bias': distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]) + }, + label='col', + ) + + self._set_default(compute_pattern=_compute_pattern, target_label='row') diff --git a/colossalai/utils/model/colo_init_context.py b/colossalai/utils/model/colo_init_context.py index a39353b16..53aabc3f2 100644 --- a/colossalai/utils/model/colo_init_context.py +++ b/colossalai/utils/model/colo_init_context.py @@ -1,6 +1,7 @@ from .utils import InsertPostInitMethodToModuleSubClasses import torch -from colossalai.tensor import ColoTensor, ColoParameter +from colossalai.tensor import ColoTensor, ColoParameter, register_colo_module, init_colo_module, \ + ColoLinear import types from torch import nn @@ -101,6 +102,17 @@ def _setattr_with_colotensor(self, name: str, value: Union[torch.Tensor, torch.n else: object.__setattr__(self, name, value) +def _get_parameter_with_colotensor(self, target: str) -> Union[torch.nn.Parameter, ColoTensor]: + module_path, _, param_name = target.rpartition(".") + + mod: torch.nn.Module = self.get_submodule(module_path) + + if not hasattr(mod, param_name): + raise AttributeError(mod._get_name() + " has no attribute `" + + param_name + "`") + + param = getattr(mod, param_name) + return param def ColoModulize(module): """ @@ -124,6 +136,8 @@ class ColoInitContext(InsertPostInitMethodToModuleSubClasses): torch.nn.Module.__setattr__ = _setattr_with_colotensor torch.nn.Module.register_parameter = _register_parameter_with_colotensor + torch.nn.Module.get_parameter = _get_parameter_with_colotensor + register_colo_module(torch.nn.Linear, ColoLinear()) def _post_init_method(self, module: torch.nn.Module, *args, **kwargs): """ diff --git a/tests/test_tensor/test_module_spec.py b/tests/test_tensor/test_module_spec.py new file mode 100644 index 000000000..478aa815a --- /dev/null +++ b/tests/test_tensor/test_module_spec.py @@ -0,0 +1,137 @@ +from copy import copy +from colossalai.utils.cuda import get_current_device +from colossalai.utils.model.colo_init_context import ColoInitContext +import torch +from colossalai.context.parallel_mode import ParallelMode +from colossalai.tensor import ColoTensor, distspec + +from functools import partial + +import colossalai +import pytest +import torch +import torch.multiprocessing as mp +import torch.nn.functional as F +from colossalai.testing import rerun_if_address_is_in_use +from colossalai.utils import free_port +from colossalai.core import global_context as gpc +from colossalai.tensor import TensorSpec, ComputePattern, ParallelAction, DistSpecManager, register_colo_module, init_colo_module, ColoLinear +from _utils import tensor_equal, tensor_shard_equal, set_seed +from tests.components_to_test.registry import non_distributed_component_funcs + +def run_simplenet_with_spec(label): + get_components_func = non_distributed_component_funcs.get_callable('simple_net') + model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() + rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D) + + set_seed(1) + with ColoInitContext(device=get_current_device()): + model = model_builder(checkpoint=True) + + if rank == 0: + model_seq = model_builder(checkpoint=True) + model_seq = model_seq.cuda() + + # Make two models have the same init params + for p1, p2 in zip(model.parameters(), model_seq.parameters()): + p2.data.copy_(p1.data) + + parallel_action = ParallelAction(ComputePattern.TP1D) + init_colo_module(model, parallel_action, recursive=True, label=label) + + model = model.cuda() + for i, (data, label) in enumerate(train_dataloader): + data = data.to(get_current_device()) + label = label.to(get_current_device()) + + torch.distributed.broadcast(data, 0, group=gpc.get_group(ParallelMode.PARALLEL_1D)) + torch.distributed.broadcast(label, 0, group=gpc.get_group(ParallelMode.PARALLEL_1D)) + + if criterion: + output = model(data) + loss = criterion(output, label) + else: + output = model(data, label) + loss = output + + # For reference + if rank == 0: + if criterion: + output_seq = model_seq(data) + loss_seq = criterion(output_seq, label) + else: + output_seq = model_seq(data, label) + loss_seq = output_seq + + if rank == 0: + with torch.no_grad(): + assert torch.allclose(loss, loss_seq, rtol=1e-2) + + loss.backward() + + if rank == 0: + loss_seq.backward() + + with torch.no_grad(): + # check param + for p1, p2 in zip(model.parameters(), model_seq.parameters()): + if p1.size() == p2.size(): + assert torch.allclose(p1, p2) + else: + if p1.size(-1) < p2.size(-1): # col + world_size = p2.size(-1) // p1.size(-1) + split_p2 = torch.chunk(p2, world_size, dim=-1)[0] + + elif p1.size(0) < p2.size(0): # row + world_size = p2.size(0) // p1.size(0) + split_p2 = torch.chunk(p2, world_size, dim=0)[0] + + assert torch.allclose(p1, split_p2) + + if i > 3: + break + +def run_linear_with_spec(label): + with ColoInitContext(device=get_current_device()): + model = torch.nn.Linear(4, 8) + + model_handy = copy(model) + + parallel_action = ParallelAction(ComputePattern.TP1D) + init_colo_module(model, parallel_action, recursive=True, label=label) + + x = torch.rand(2, 4).cuda() + out = model(x) + colo_out = model_handy(x) + assert tensor_equal(out, colo_out) + grad = torch.rand_like(out) + out.backward(grad) + colo_out.backward(grad) + assert tensor_shard_equal(model.weight.grad, model_handy.weight.grad) + assert tensor_shard_equal(model.bias.grad, model_handy.bias.grad) + + +def run_dist(rank, world_size, port, func): + config = dict(parallel=dict(tensor=dict(mode="1d", size=world_size),)) + colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + func('col') + func('row') + func('default') + + +@pytest.mark.dist +@pytest.mark.parametrize('world_size', [1, 4]) +@rerun_if_address_is_in_use() +def test_module_linear_1d(world_size): + run_func = partial(run_dist, world_size=world_size, port=free_port(), func=run_linear_with_spec) + mp.spawn(run_func, nprocs=world_size) + +@pytest.mark.dist +@pytest.mark.parametrize('world_size', [1, 4]) +@rerun_if_address_is_in_use() +def test_module_simplenet(world_size): + run_func = partial(run_dist, world_size=world_size, port=free_port(), func=run_simplenet_with_spec) + mp.spawn(run_func, nprocs=world_size) + +if __name__ == '__main__': + test_module_simplenet(4) \ No newline at end of file