From 0653c63eaacd1504f5d66f2e11f80defdb155832 Mon Sep 17 00:00:00 2001 From: Ziyue Jiang Date: Wed, 8 Jun 2022 12:04:59 +0800 Subject: [PATCH] [Tensor] 1d row embedding (#1075) * Add CPU 1d row embedding * polish --- colossalai/nn/layer/parallel_1d/_utils.py | 6 ++++-- tests/test_tensor/test_hybrid_device.py | 16 ++++++++-------- 2 files changed, 12 insertions(+), 10 deletions(-) diff --git a/colossalai/nn/layer/parallel_1d/_utils.py b/colossalai/nn/layer/parallel_1d/_utils.py index 0655db4df..1212d5956 100644 --- a/colossalai/nn/layer/parallel_1d/_utils.py +++ b/colossalai/nn/layer/parallel_1d/_utils.py @@ -32,7 +32,8 @@ def _reduce(input_, parallel_mode): # skip if only one rank involved if gpc.get_world_size(parallel_mode) == 1: return input_ - dist.all_reduce(input_, group=gpc.get_group(parallel_mode)) + group = gpc.get_cpu_group(parallel_mode) if input_.device.type == "cpu" else gpc.get_group(parallel_mode) + dist.all_reduce(input_, group=group) return input_ @@ -66,7 +67,8 @@ def _gather(input_, parallel_mode, dim=-1): rank = gpc.get_local_rank(parallel_mode) tensor_list = [torch.empty_like(input_) for _ in range(world_size)] tensor_list[rank] = input_ - torch.distributed.all_gather(tensor_list, input_, group=gpc.get_group(parallel_mode)) + group = gpc.get_cpu_group(parallel_mode) if input_.device.type == "cpu" else gpc.get_group(parallel_mode) + torch.distributed.all_gather(tensor_list, input_, group=group) # concat output = torch.cat(tensor_list, dim=dim).contiguous() diff --git a/tests/test_tensor/test_hybrid_device.py b/tests/test_tensor/test_hybrid_device.py index cb63b2152..290f965f7 100644 --- a/tests/test_tensor/test_hybrid_device.py +++ b/tests/test_tensor/test_hybrid_device.py @@ -35,7 +35,7 @@ class Net(torch.nn.Module): return x -def run_hybrid_device(use_ddp): +def run_hybrid_device(use_ddp, mode): with ColoInitContext(device=get_current_device()): model = Net() @@ -47,7 +47,7 @@ def run_hybrid_device(use_ddp): print(f'embedding weight size: {real_model.embed.weight.size()} | device: {real_model.embed.weight.device}') #print(f'linear weight size: {real_model.proj.weight.size()} | device: {real_model.proj.weight.device}') parallel_action = ParallelAction(ComputePattern.TP1D) - init_colo_module(model, parallel_action, recursive=True, mode='col') + init_colo_module(model, parallel_action, recursive=True, mode=mode) # use cpu gloo to handle embedding real_model.embed.to('cpu') @@ -63,24 +63,24 @@ def run_hybrid_device(use_ddp): out.sum().backward() optimizer.step() -def run_dist(rank, world_size, port, use_ddp): +def run_dist(rank, world_size, port, use_ddp, mode): if use_ddp and world_size == 1: return tp_world_size = world_size // 2 if use_ddp else world_size config = dict(parallel=dict(tensor=dict(mode="1d", size=tp_world_size),)) colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - run_hybrid_device(use_ddp) - + run_hybrid_device(use_ddp, mode) @pytest.mark.dist @pytest.mark.parametrize('world_size', [1, 4]) @pytest.mark.parametrize('use_ddp', [False, True]) +@pytest.mark.parametrize('mode', ['col', 'row']) @rerun_if_address_is_in_use() # Working for simulate the embedding(CPU DP+TP) -> nn(GPU DP+TP) -def _test_hybrid_device(world_size, use_ddp): - run_func = partial(run_dist, world_size=world_size, port=free_port(), use_ddp=use_ddp) +def _test_hybrid_device(world_size, use_ddp, mode): + run_func = partial(run_dist, world_size=world_size, port=free_port(), use_ddp=use_ddp ,mode=mode) mp.spawn(run_func, nprocs=world_size) if __name__ == '__main__': - _test_hybrid_device(4, True) + _test_hybrid_device(4, True, 'row')