From df9dcbbff65f79d04158bb8361349eaab6cf3032 Mon Sep 17 00:00:00 2001 From: Ziyue Jiang Date: Fri, 3 Jun 2022 12:09:49 +0800 Subject: [PATCH] [Tensor] add hybrid device demo and fix bugs (#1059) --- colossalai/nn/parallel.py | 14 ++-- colossalai/tensor/module_utils.py | 10 ++- colossalai/utils/model/colo_init_context.py | 1 + tests/test_tensor/test_gpt.py | 2 +- tests/test_tensor/test_hybrid_device.py | 75 +++++++++++++++++++++ 5 files changed, 94 insertions(+), 8 deletions(-) create mode 100644 tests/test_tensor/test_hybrid_device.py diff --git a/colossalai/nn/parallel.py b/colossalai/nn/parallel.py index 228868f4e..a257a2377 100644 --- a/colossalai/nn/parallel.py +++ b/colossalai/nn/parallel.py @@ -51,11 +51,17 @@ class ColoDDP(torch.nn.Module): free_storage(empty_grad) if self.dp_world_size > 1: grad = grad / self.dp_world_size - self.comm_stream.wait_stream(torch.cuda.current_stream()) - with torch.cuda.stream(self.comm_stream): - dist.all_reduce(grad, group=gpc.get_group(ParallelMode.DATA)) + if grad.device.type != "cpu": + self.comm_stream.wait_stream(torch.cuda.current_stream()) + with torch.cuda.stream(self.comm_stream): + group = gpc.get_group(ParallelMode.DATA) + dist.all_reduce(grad, group=group) + ColoDDP._save_grad(p, grad) + grad.record_stream(self.comm_stream) + else: + group = gpc.get_cpu_group(ParallelMode.DATA) + dist.all_reduce(grad, group=group) ColoDDP._save_grad(p, grad) - grad.record_stream(self.comm_stream) else: ColoDDP._save_grad(p, grad) return empty_grad diff --git a/colossalai/tensor/module_utils.py b/colossalai/tensor/module_utils.py index 89c285e39..9fa389171 100644 --- a/colossalai/tensor/module_utils.py +++ b/colossalai/tensor/module_utils.py @@ -12,13 +12,17 @@ def register_colo_module(module_type: type, colo_module: ColoModule): def is_colo_module(module: torch.nn.Module): global _COLOSSAL_MODULES - return type(module) in _COLOSSAL_MODULES + for module_type in _COLOSSAL_MODULES.keys(): + if isinstance(type(module), module_type): + return True + return False def get_colo_module(module: torch.nn.Module): global _COLOSSAL_MODULES if is_colo_module(module): - colo_module = _COLOSSAL_MODULES[type(module)] - return colo_module + for module_type, colo_module in _COLOSSAL_MODULES.items(): + if isinstance(type(module), module_type): + return colo_module else: return None diff --git a/colossalai/utils/model/colo_init_context.py b/colossalai/utils/model/colo_init_context.py index 058a5fb65..807f9034a 100644 --- a/colossalai/utils/model/colo_init_context.py +++ b/colossalai/utils/model/colo_init_context.py @@ -92,4 +92,5 @@ class ColoInitContext(InsertPostInitMethodToModuleSubClasses): setattr(submodule, param_name, colo_param) colo_param.shared_param_modules.append(submodule) + module.to(self._device) ColoModulize(module) diff --git a/tests/test_tensor/test_gpt.py b/tests/test_tensor/test_gpt.py index 0ee490f5b..85f8de367 100644 --- a/tests/test_tensor/test_gpt.py +++ b/tests/test_tensor/test_gpt.py @@ -101,4 +101,4 @@ def test_gpt(world_size, use_ddp): if __name__ == '__main__': - test_gpt(4) + test_gpt(4, False) diff --git a/tests/test_tensor/test_hybrid_device.py b/tests/test_tensor/test_hybrid_device.py new file mode 100644 index 000000000..4a7a596a8 --- /dev/null +++ b/tests/test_tensor/test_hybrid_device.py @@ -0,0 +1,75 @@ +from colossalai.utils import free_port, ColoInitContext, get_current_device +from colossalai.testing import rerun_if_address_is_in_use +from colossalai.tensor import TensorSpec, ComputePattern, ParallelAction, init_colo_module +from functools import partial +from colossalai.core import global_context as gpc +from colossalai.context import ParallelMode +from colossalai.nn.parallel import ColoDDP + +import colossalai +import torch +import torch.multiprocessing as mp +import pytest + +class Net(torch.nn.Module): + def __init__(self): + super(Net, self).__init__() + self.embed = torch.nn.Embedding(20, 4) + self.proj = torch.nn.Linear(4, 8) + + def forward(self, x): + # move input to cpu and restore output + current_dev = x.device + x = x.to('cpu') + x = self.embed(x) + x = x.to(current_dev) + + x = self.proj(x) + return x + +def run_hybrid_device(use_ddp): + with ColoInitContext(device=get_current_device()): + model = Net() + + real_model = model + if use_ddp: + model = ColoDDP(model) + real_model = model.module + + + print(f'embedding weight size: {real_model.embed.weight.size()} | device: {real_model.embed.weight.device}') + #print(f'linear weight size: {real_model.proj.weight.size()} | device: {real_model.proj.weight.device}') + parallel_action = ParallelAction(ComputePattern.TP1D) + init_colo_module(model, parallel_action, recursive=True, mode='col') + + # use cpu gloo to handle embedding + real_model.embed.to('cpu') + gloo_group_tp = gpc.get_cpu_group(ParallelMode.PARALLEL_1D) + real_model.embed.weight.spec.dist_spec.process_group = gloo_group_tp + + print(f'embedding weight size: {real_model.embed.weight.size()} | new device: {real_model.embed.weight.device}') + #print(f'linear weight size: {real_model.proj.weight.size()} | new device: {real_model.proj.weight.device}') + + data = torch.randint(low=0, high=20, size=(16,), device=get_current_device()) + out = model(data) + out.sum().backward() + +def run_dist(rank, world_size, port, use_ddp): + if use_ddp and world_size == 1: + return + tp_world_size = world_size // 2 if use_ddp else world_size + config = dict(parallel=dict(tensor=dict(mode="1d", size=tp_world_size),)) + colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + run_hybrid_device(use_ddp) + +@pytest.mark.dist +@pytest.mark.parametrize('world_size', [1, 4]) +@pytest.mark.parametrize('use_ddp', [False, True]) +@rerun_if_address_is_in_use() +# Working for simulate the embedding(CPU DP+TP) -> nn(GPU DP+TP) +def _test_hybrid_device(world_size, use_ddp): + run_func = partial(run_dist, world_size=world_size, port=free_port(), use_ddp=use_ddp) + mp.spawn(run_func, nprocs=world_size) + +if __name__ == '__main__': + _test_hybrid_device(1, False) \ No newline at end of file