mirror of https://github.com/hpcaitech/ColossalAI
[NFC] polish colossalai/communication/utils.py code style (#656)
parent
5ab9a71299
commit
c336cd3066
|
@ -77,9 +77,7 @@ def split_tensor_into_1d_equal_chunks(tensor, new_buffer=False):
|
||||||
start_index = partition_size * gpc.get_local_rank(ParallelMode.PARALLEL_1D)
|
start_index = partition_size * gpc.get_local_rank(ParallelMode.PARALLEL_1D)
|
||||||
end_index = start_index + partition_size
|
end_index = start_index + partition_size
|
||||||
if new_buffer:
|
if new_buffer:
|
||||||
data = torch.empty(partition_size, dtype=tensor.dtype,
|
data = torch.empty(partition_size, dtype=tensor.dtype, device=torch.cuda.current_device(), requires_grad=False)
|
||||||
device=torch.cuda.current_device(),
|
|
||||||
requires_grad=False)
|
|
||||||
data.copy_(tensor.view(-1)[start_index:end_index])
|
data.copy_(tensor.view(-1)[start_index:end_index])
|
||||||
else:
|
else:
|
||||||
data = tensor.view(-1)[start_index:end_index]
|
data = tensor.view(-1)[start_index:end_index]
|
||||||
|
@ -97,9 +95,7 @@ def gather_split_1d_tensor(tensor):
|
||||||
world_size = gpc.get_world_size(ParallelMode.PARALLEL_1D)
|
world_size = gpc.get_world_size(ParallelMode.PARALLEL_1D)
|
||||||
numel = torch.numel(tensor)
|
numel = torch.numel(tensor)
|
||||||
numel_gathered = world_size * numel
|
numel_gathered = world_size * numel
|
||||||
gathered = torch.empty(numel_gathered, dtype=tensor.dtype,
|
gathered = torch.empty(numel_gathered, dtype=tensor.dtype, device=torch.cuda.current_device(), requires_grad=False)
|
||||||
device=torch.cuda.current_device(),
|
|
||||||
requires_grad=False)
|
|
||||||
chunks = [gathered[i * numel:(i + 1) * numel] for i in range(world_size)]
|
chunks = [gathered[i * numel:(i + 1) * numel] for i in range(world_size)]
|
||||||
dist.all_gather(chunks, tensor, group=gpc.get_group(ParallelMode.PARALLEL_1D))
|
dist.all_gather(chunks, tensor, group=gpc.get_group(ParallelMode.PARALLEL_1D))
|
||||||
return gathered
|
return gathered
|
||||||
|
|
Loading…
Reference in New Issue