From d58fa705b2519a4e2c908aa9893d5111652d2480 Mon Sep 17 00:00:00 2001 From: Yuanchen <70520919+chengeharrison@users.noreply.github.com> Date: Tue, 28 Mar 2023 10:30:30 +0800 Subject: [PATCH] [NFC] polish code style (#3268) Co-authored-by: Yuanchen Xu --- colossalai/nn/_ops/_utils.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/colossalai/nn/_ops/_utils.py b/colossalai/nn/_ops/_utils.py index 56bb5f465..24877bbb5 100644 --- a/colossalai/nn/_ops/_utils.py +++ b/colossalai/nn/_ops/_utils.py @@ -1,12 +1,11 @@ -import torch -from typing import Union, Optional, List -from colossalai.tensor import ColoTensor +from typing import List, Optional, Union + import torch import torch.distributed as dist -from colossalai.global_variables import tensor_parallel_env as env +from colossalai.global_variables import tensor_parallel_env as env from colossalai.nn.layer.utils import divide -from colossalai.tensor import ProcessGroup, ColoTensorSpec +from colossalai.tensor import ColoTensor, ColoTensorSpec, ProcessGroup GeneralTensor = Union[ColoTensor, torch.Tensor] Number = Union[int, float] @@ -135,7 +134,7 @@ class _ReduceInput(torch.autograd.Function): class _SplitForwardGatherBackward(torch.autograd.Function): """ Split the input and keep only the corresponding chuck to the rank. - + Args: input_: input matrix. process_group: parallel mode.