diff --git a/colossalai/nn/_ops/_utils.py b/colossalai/nn/_ops/_utils.py index 56bb5f465..24877bbb5 100644 --- a/colossalai/nn/_ops/_utils.py +++ b/colossalai/nn/_ops/_utils.py @@ -1,12 +1,11 @@ -import torch -from typing import Union, Optional, List -from colossalai.tensor import ColoTensor +from typing import List, Optional, Union + import torch import torch.distributed as dist -from colossalai.global_variables import tensor_parallel_env as env +from colossalai.global_variables import tensor_parallel_env as env from colossalai.nn.layer.utils import divide -from colossalai.tensor import ProcessGroup, ColoTensorSpec +from colossalai.tensor import ColoTensor, ColoTensorSpec, ProcessGroup GeneralTensor = Union[ColoTensor, torch.Tensor] Number = Union[int, float] @@ -135,7 +134,7 @@ class _ReduceInput(torch.autograd.Function): class _SplitForwardGatherBackward(torch.autograd.Function): """ Split the input and keep only the corresponding chuck to the rank. - + Args: input_: input matrix. process_group: parallel mode.