[Tensor] 1d row embedding (#1075)

* Add CPU 1d row embedding

* polish
pull/1081/head
Ziyue Jiang 2022-06-08 12:04:59 +08:00 committed by GitHub
parent d66ffb4df4
commit 0653c63eaa
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 12 additions and 10 deletions

View File

@ -32,7 +32,8 @@ def _reduce(input_, parallel_mode):
# skip if only one rank involved # skip if only one rank involved
if gpc.get_world_size(parallel_mode) == 1: if gpc.get_world_size(parallel_mode) == 1:
return input_ return input_
dist.all_reduce(input_, group=gpc.get_group(parallel_mode)) group = gpc.get_cpu_group(parallel_mode) if input_.device.type == "cpu" else gpc.get_group(parallel_mode)
dist.all_reduce(input_, group=group)
return input_ return input_
@ -66,7 +67,8 @@ def _gather(input_, parallel_mode, dim=-1):
rank = gpc.get_local_rank(parallel_mode) rank = gpc.get_local_rank(parallel_mode)
tensor_list = [torch.empty_like(input_) for _ in range(world_size)] tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
tensor_list[rank] = input_ tensor_list[rank] = input_
torch.distributed.all_gather(tensor_list, input_, group=gpc.get_group(parallel_mode)) group = gpc.get_cpu_group(parallel_mode) if input_.device.type == "cpu" else gpc.get_group(parallel_mode)
torch.distributed.all_gather(tensor_list, input_, group=group)
# concat # concat
output = torch.cat(tensor_list, dim=dim).contiguous() output = torch.cat(tensor_list, dim=dim).contiguous()

View File

@ -35,7 +35,7 @@ class Net(torch.nn.Module):
return x return x
def run_hybrid_device(use_ddp): def run_hybrid_device(use_ddp, mode):
with ColoInitContext(device=get_current_device()): with ColoInitContext(device=get_current_device()):
model = Net() model = Net()
@ -47,7 +47,7 @@ def run_hybrid_device(use_ddp):
print(f'embedding weight size: {real_model.embed.weight.size()} | device: {real_model.embed.weight.device}') print(f'embedding weight size: {real_model.embed.weight.size()} | device: {real_model.embed.weight.device}')
#print(f'linear weight size: {real_model.proj.weight.size()} | device: {real_model.proj.weight.device}') #print(f'linear weight size: {real_model.proj.weight.size()} | device: {real_model.proj.weight.device}')
parallel_action = ParallelAction(ComputePattern.TP1D) parallel_action = ParallelAction(ComputePattern.TP1D)
init_colo_module(model, parallel_action, recursive=True, mode='col') init_colo_module(model, parallel_action, recursive=True, mode=mode)
# use cpu gloo to handle embedding # use cpu gloo to handle embedding
real_model.embed.to('cpu') real_model.embed.to('cpu')
@ -63,24 +63,24 @@ def run_hybrid_device(use_ddp):
out.sum().backward() out.sum().backward()
optimizer.step() optimizer.step()
def run_dist(rank, world_size, port, use_ddp): def run_dist(rank, world_size, port, use_ddp, mode):
if use_ddp and world_size == 1: if use_ddp and world_size == 1:
return return
tp_world_size = world_size // 2 if use_ddp else world_size tp_world_size = world_size // 2 if use_ddp else world_size
config = dict(parallel=dict(tensor=dict(mode="1d", size=tp_world_size),)) config = dict(parallel=dict(tensor=dict(mode="1d", size=tp_world_size),))
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
run_hybrid_device(use_ddp) run_hybrid_device(use_ddp, mode)
@pytest.mark.dist @pytest.mark.dist
@pytest.mark.parametrize('world_size', [1, 4]) @pytest.mark.parametrize('world_size', [1, 4])
@pytest.mark.parametrize('use_ddp', [False, True]) @pytest.mark.parametrize('use_ddp', [False, True])
@pytest.mark.parametrize('mode', ['col', 'row'])
@rerun_if_address_is_in_use() @rerun_if_address_is_in_use()
# Working for simulate the embedding(CPU DP+TP) -> nn(GPU DP+TP) # Working for simulate the embedding(CPU DP+TP) -> nn(GPU DP+TP)
def _test_hybrid_device(world_size, use_ddp): def _test_hybrid_device(world_size, use_ddp, mode):
run_func = partial(run_dist, world_size=world_size, port=free_port(), use_ddp=use_ddp) run_func = partial(run_dist, world_size=world_size, port=free_port(), use_ddp=use_ddp ,mode=mode)
mp.spawn(run_func, nprocs=world_size) mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__': if __name__ == '__main__':
_test_hybrid_device(4, True) _test_hybrid_device(4, True, 'row')