From 86cf6aed5b0822cdb539c54888f00553dfd12209 Mon Sep 17 00:00:00 2001 From: Michelle <97082656+MichelleMa8@users.noreply.github.com> Date: Tue, 18 Jul 2023 10:23:46 +0800 Subject: [PATCH] Fix/format (#4261) * revise shardformer readme (#4246) * [example] add llama pretraining (#4257) * [NFC] polish colossalai/communication/p2p.py code style --------- Co-authored-by: Jianghai <72591262+CjhHa1@users.noreply.github.com> Co-authored-by: binmakeswell Co-authored-by: Qianran Ma --- colossalai/communication/p2p.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/colossalai/communication/p2p.py b/colossalai/communication/p2p.py index 1f20fca4f..d28d14016 100644 --- a/colossalai/communication/p2p.py +++ b/colossalai/communication/p2p.py @@ -1,16 +1,18 @@ #!/usr/bin/env python # -*- encoding: utf-8 -*- +import operator +from functools import reduce from typing import List, Tuple, Union + import torch import torch.distributed as dist from colossalai.context.parallel_mode import ParallelMode from colossalai.core import global_context as gpc from colossalai.utils import get_current_device -from functools import reduce -import operator -from .utils import split_tensor_into_1d_equal_chunks, gather_split_1d_tensor + +from .utils import gather_split_1d_tensor, split_tensor_into_1d_equal_chunks TensorShape = Union[torch.Size, List[int], Tuple[int]] @@ -260,7 +262,7 @@ def send_forward_recv_backward(output_tensor, next_rank=None, dtype=torch.float, scatter_gather_tensors=False) -> Union[torch.Tensor, List[torch.Tensor]]: - """Batched communication operation. Sends the input tensor to the + """Batched communication operation. Sends the input tensor to the next stage in pipeline, while receives the gradient tensor from the next stage in pipeline as the input gradient tensor of this stage. @@ -319,7 +321,7 @@ def send_forward_recv_forward(output_tensor, next_rank=None, dtype=torch.float, scatter_gather_tensors=False) -> Union[torch.Tensor, List[torch.Tensor]]: - """Batched communication operation. Sends the input tensor to the + """Batched communication operation. Sends the input tensor to the next stage in pipeline, while receives the output tensor from the previous stage in pipeline as the input of this stage.