mirror of https://github.com/hpcaitech/ColossalAI
[NFC] polish colossalai/communication/collective.py (#1262)
parent
f1cafcc73a
commit
197a2c89e2
|
@ -10,10 +10,7 @@ from colossalai.context import ParallelMode
|
||||||
from colossalai.core import global_context as gpc
|
from colossalai.core import global_context as gpc
|
||||||
|
|
||||||
|
|
||||||
def all_gather(tensor: Tensor,
|
def all_gather(tensor: Tensor, dim: int, parallel_mode: ParallelMode, async_op: bool = False) -> Tensor:
|
||||||
dim: int,
|
|
||||||
parallel_mode: ParallelMode,
|
|
||||||
async_op: bool = False) -> Tensor:
|
|
||||||
r"""Gathers all tensors from the parallel group and concatenates them in a
|
r"""Gathers all tensors from the parallel group and concatenates them in a
|
||||||
specific dimension.
|
specific dimension.
|
||||||
|
|
||||||
|
@ -163,11 +160,7 @@ def broadcast(tensor: Tensor, src: int, parallel_mode: ParallelMode, async_op: b
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
def reduce(tensor: Tensor,
|
def reduce(tensor: Tensor, dst: int, parallel_mode: ParallelMode, op: ReduceOp = ReduceOp.SUM, async_op: bool = False):
|
||||||
dst: int,
|
|
||||||
parallel_mode: ParallelMode,
|
|
||||||
op: ReduceOp = ReduceOp.SUM,
|
|
||||||
async_op: bool = False):
|
|
||||||
r"""Reduce tensors across whole parallel group. Only the process with
|
r"""Reduce tensors across whole parallel group. Only the process with
|
||||||
rank ``dst`` is going to receive the final result.
|
rank ``dst`` is going to receive the final result.
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue