From 197a2c89e2719b045e1e39207785173ec9220ecc Mon Sep 17 00:00:00 2001 From: Zangwei Zheng Date: Tue, 12 Jul 2022 18:07:15 +0800 Subject: [PATCH] [NFC] polish colossalai/communication/collective.py (#1262) --- colossalai/communication/collective.py | 11 ++--------- 1 file changed, 2 insertions(+), 9 deletions(-) diff --git a/colossalai/communication/collective.py b/colossalai/communication/collective.py index 50fd7dcc2..2c9e9927c 100644 --- a/colossalai/communication/collective.py +++ b/colossalai/communication/collective.py @@ -10,10 +10,7 @@ from colossalai.context import ParallelMode from colossalai.core import global_context as gpc -def all_gather(tensor: Tensor, - dim: int, - parallel_mode: ParallelMode, - async_op: bool = False) -> Tensor: +def all_gather(tensor: Tensor, dim: int, parallel_mode: ParallelMode, async_op: bool = False) -> Tensor: r"""Gathers all tensors from the parallel group and concatenates them in a specific dimension. @@ -163,11 +160,7 @@ def broadcast(tensor: Tensor, src: int, parallel_mode: ParallelMode, async_op: b return out -def reduce(tensor: Tensor, - dst: int, - parallel_mode: ParallelMode, - op: ReduceOp = ReduceOp.SUM, - async_op: bool = False): +def reduce(tensor: Tensor, dst: int, parallel_mode: ParallelMode, op: ReduceOp = ReduceOp.SUM, async_op: bool = False): r"""Reduce tensors across whole parallel group. Only the process with rank ``dst`` is going to receive the final result.