fix format (#332)

Co-authored-by: 何晓昕 <cautious@r-205-106-25-172.comp.nus.edu.sg>
pull/394/head
Cautiousss 2022-03-09 10:35:05 +08:00 committed by Frank Lee
parent cbb6436ff0
commit 3a51d909af
1 changed files with 3 additions and 3 deletions

View File

@ -7,7 +7,7 @@ from colossalai.utils import get_current_device
def send_tensor_meta(tensor, need_meta=True, next_rank=None): def send_tensor_meta(tensor, need_meta=True, next_rank=None):
"""Sends tensor meta information before sending a specific tensor. """Sends tensor meta information before sending a specific tensor.
Since the recipient must know the shape of the tensor in p2p communications, Since the recipient must know the shape of the tensor in p2p communications,
meta information of the tensor should be sent before communications. This function meta information of the tensor should be sent before communications. This function
synchronizes with :func:`recv_tensor_meta`. synchronizes with :func:`recv_tensor_meta`.
@ -36,7 +36,7 @@ def send_tensor_meta(tensor, need_meta=True, next_rank=None):
def recv_tensor_meta(tensor_shape, prev_rank=None): def recv_tensor_meta(tensor_shape, prev_rank=None):
"""Recieves tensor meta information before recieving a specific tensor. """Recieves tensor meta information before recieving a specific tensor.
Since the recipient must know the shape of the tensor in p2p communications, Since the recipient must know the shape of the tensor in p2p communications,
meta information of the tensor should be recieved before communications. This function meta information of the tensor should be recieved before communications. This function
synchronizes with :func:`send_tensor_meta`. synchronizes with :func:`send_tensor_meta`.
@ -104,6 +104,6 @@ def gather_split_1d_tensor(tensor):
gathered = torch.empty(numel_gathered, dtype=tensor.dtype, gathered = torch.empty(numel_gathered, dtype=tensor.dtype,
device=torch.cuda.current_device(), device=torch.cuda.current_device(),
requires_grad=False) requires_grad=False)
chunks = [gathered[i*numel:(i+1)*numel] for i in range(world_size)] chunks = [gathered[i * numel:(i + 1) * numel] for i in range(world_size)]
dist.all_gather(chunks, tensor, group=gpc.get_group(ParallelMode.PARALLEL_1D)) dist.all_gather(chunks, tensor, group=gpc.get_group(ParallelMode.PARALLEL_1D))
return gathered return gathered