mirror of https://github.com/hpcaitech/ColossalAI
parent
cbb6436ff0
commit
3a51d909af
|
@ -7,7 +7,7 @@ from colossalai.utils import get_current_device
|
||||||
|
|
||||||
|
|
||||||
def send_tensor_meta(tensor, need_meta=True, next_rank=None):
|
def send_tensor_meta(tensor, need_meta=True, next_rank=None):
|
||||||
"""Sends tensor meta information before sending a specific tensor.
|
"""Sends tensor meta information before sending a specific tensor.
|
||||||
Since the recipient must know the shape of the tensor in p2p communications,
|
Since the recipient must know the shape of the tensor in p2p communications,
|
||||||
meta information of the tensor should be sent before communications. This function
|
meta information of the tensor should be sent before communications. This function
|
||||||
synchronizes with :func:`recv_tensor_meta`.
|
synchronizes with :func:`recv_tensor_meta`.
|
||||||
|
@ -36,7 +36,7 @@ def send_tensor_meta(tensor, need_meta=True, next_rank=None):
|
||||||
|
|
||||||
|
|
||||||
def recv_tensor_meta(tensor_shape, prev_rank=None):
|
def recv_tensor_meta(tensor_shape, prev_rank=None):
|
||||||
"""Recieves tensor meta information before recieving a specific tensor.
|
"""Recieves tensor meta information before recieving a specific tensor.
|
||||||
Since the recipient must know the shape of the tensor in p2p communications,
|
Since the recipient must know the shape of the tensor in p2p communications,
|
||||||
meta information of the tensor should be recieved before communications. This function
|
meta information of the tensor should be recieved before communications. This function
|
||||||
synchronizes with :func:`send_tensor_meta`.
|
synchronizes with :func:`send_tensor_meta`.
|
||||||
|
@ -104,6 +104,6 @@ def gather_split_1d_tensor(tensor):
|
||||||
gathered = torch.empty(numel_gathered, dtype=tensor.dtype,
|
gathered = torch.empty(numel_gathered, dtype=tensor.dtype,
|
||||||
device=torch.cuda.current_device(),
|
device=torch.cuda.current_device(),
|
||||||
requires_grad=False)
|
requires_grad=False)
|
||||||
chunks = [gathered[i*numel:(i+1)*numel] for i in range(world_size)]
|
chunks = [gathered[i * numel:(i + 1) * numel] for i in range(world_size)]
|
||||||
dist.all_gather(chunks, tensor, group=gpc.get_group(ParallelMode.PARALLEL_1D))
|
dist.all_gather(chunks, tensor, group=gpc.get_group(ParallelMode.PARALLEL_1D))
|
||||||
return gathered
|
return gathered
|
||||||
|
|
Loading…
Reference in New Issue