mirror of https://github.com/hpcaitech/ColossalAI
parent
d66ffb4df4
commit
0653c63eaa
|
@ -32,7 +32,8 @@ def _reduce(input_, parallel_mode):
|
||||||
# skip if only one rank involved
|
# skip if only one rank involved
|
||||||
if gpc.get_world_size(parallel_mode) == 1:
|
if gpc.get_world_size(parallel_mode) == 1:
|
||||||
return input_
|
return input_
|
||||||
dist.all_reduce(input_, group=gpc.get_group(parallel_mode))
|
group = gpc.get_cpu_group(parallel_mode) if input_.device.type == "cpu" else gpc.get_group(parallel_mode)
|
||||||
|
dist.all_reduce(input_, group=group)
|
||||||
|
|
||||||
return input_
|
return input_
|
||||||
|
|
||||||
|
@ -66,7 +67,8 @@ def _gather(input_, parallel_mode, dim=-1):
|
||||||
rank = gpc.get_local_rank(parallel_mode)
|
rank = gpc.get_local_rank(parallel_mode)
|
||||||
tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
|
tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
|
||||||
tensor_list[rank] = input_
|
tensor_list[rank] = input_
|
||||||
torch.distributed.all_gather(tensor_list, input_, group=gpc.get_group(parallel_mode))
|
group = gpc.get_cpu_group(parallel_mode) if input_.device.type == "cpu" else gpc.get_group(parallel_mode)
|
||||||
|
torch.distributed.all_gather(tensor_list, input_, group=group)
|
||||||
|
|
||||||
# concat
|
# concat
|
||||||
output = torch.cat(tensor_list, dim=dim).contiguous()
|
output = torch.cat(tensor_list, dim=dim).contiguous()
|
||||||
|
|
|
@ -35,7 +35,7 @@ class Net(torch.nn.Module):
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
def run_hybrid_device(use_ddp):
|
def run_hybrid_device(use_ddp, mode):
|
||||||
with ColoInitContext(device=get_current_device()):
|
with ColoInitContext(device=get_current_device()):
|
||||||
model = Net()
|
model = Net()
|
||||||
|
|
||||||
|
@ -47,7 +47,7 @@ def run_hybrid_device(use_ddp):
|
||||||
print(f'embedding weight size: {real_model.embed.weight.size()} | device: {real_model.embed.weight.device}')
|
print(f'embedding weight size: {real_model.embed.weight.size()} | device: {real_model.embed.weight.device}')
|
||||||
#print(f'linear weight size: {real_model.proj.weight.size()} | device: {real_model.proj.weight.device}')
|
#print(f'linear weight size: {real_model.proj.weight.size()} | device: {real_model.proj.weight.device}')
|
||||||
parallel_action = ParallelAction(ComputePattern.TP1D)
|
parallel_action = ParallelAction(ComputePattern.TP1D)
|
||||||
init_colo_module(model, parallel_action, recursive=True, mode='col')
|
init_colo_module(model, parallel_action, recursive=True, mode=mode)
|
||||||
|
|
||||||
# use cpu gloo to handle embedding
|
# use cpu gloo to handle embedding
|
||||||
real_model.embed.to('cpu')
|
real_model.embed.to('cpu')
|
||||||
|
@ -63,24 +63,24 @@ def run_hybrid_device(use_ddp):
|
||||||
out.sum().backward()
|
out.sum().backward()
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
|
|
||||||
def run_dist(rank, world_size, port, use_ddp):
|
def run_dist(rank, world_size, port, use_ddp, mode):
|
||||||
if use_ddp and world_size == 1:
|
if use_ddp and world_size == 1:
|
||||||
return
|
return
|
||||||
tp_world_size = world_size // 2 if use_ddp else world_size
|
tp_world_size = world_size // 2 if use_ddp else world_size
|
||||||
config = dict(parallel=dict(tensor=dict(mode="1d", size=tp_world_size),))
|
config = dict(parallel=dict(tensor=dict(mode="1d", size=tp_world_size),))
|
||||||
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||||
run_hybrid_device(use_ddp)
|
run_hybrid_device(use_ddp, mode)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.dist
|
@pytest.mark.dist
|
||||||
@pytest.mark.parametrize('world_size', [1, 4])
|
@pytest.mark.parametrize('world_size', [1, 4])
|
||||||
@pytest.mark.parametrize('use_ddp', [False, True])
|
@pytest.mark.parametrize('use_ddp', [False, True])
|
||||||
|
@pytest.mark.parametrize('mode', ['col', 'row'])
|
||||||
@rerun_if_address_is_in_use()
|
@rerun_if_address_is_in_use()
|
||||||
# Working for simulate the embedding(CPU DP+TP) -> nn(GPU DP+TP)
|
# Working for simulate the embedding(CPU DP+TP) -> nn(GPU DP+TP)
|
||||||
def _test_hybrid_device(world_size, use_ddp):
|
def _test_hybrid_device(world_size, use_ddp, mode):
|
||||||
run_func = partial(run_dist, world_size=world_size, port=free_port(), use_ddp=use_ddp)
|
run_func = partial(run_dist, world_size=world_size, port=free_port(), use_ddp=use_ddp ,mode=mode)
|
||||||
mp.spawn(run_func, nprocs=world_size)
|
mp.spawn(run_func, nprocs=world_size)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
_test_hybrid_device(4, True)
|
_test_hybrid_device(4, True, 'row')
|
||||||
|
|
Loading…
Reference in New Issue