From 196d4696d0ef5360b1d8b7f99d8124d768e4e075 Mon Sep 17 00:00:00 2001 From: Tong Li Date: Tue, 28 Mar 2023 11:23:38 +0800 Subject: [PATCH] [NFC] polish colossalai/nn/_ops/addmm.py code style (#3274) --- colossalai/nn/_ops/addmm.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/colossalai/nn/_ops/addmm.py b/colossalai/nn/_ops/addmm.py index fe2eb0c99..660b48a71 100644 --- a/colossalai/nn/_ops/addmm.py +++ b/colossalai/nn/_ops/addmm.py @@ -1,9 +1,9 @@ import torch + +from colossalai.tensor import ColoTensor, ColoTensorSpec, ComputePattern, ComputeSpec, ReplicaSpec, ShardSpec, distspec from colossalai.tensor.op_wrapper import colo_op_impl -from colossalai.tensor import ComputePattern, ComputePattern, ComputeSpec, ColoTensor -from colossalai.tensor import distspec, ColoTensorSpec, ShardSpec, ReplicaSpec -from ._utils import GeneralTensor, Number, convert_to_colo_tensor -from ._utils import reduce_input, reduce_grad + +from ._utils import GeneralTensor, Number, convert_to_colo_tensor, reduce_grad, reduce_input def colo_addmm_1Drow(input_tensor: ColoTensor, mat1: ColoTensor, mat2: ColoTensor, beta: Number, @@ -69,9 +69,13 @@ def colo_addmm(input_tensor: GeneralTensor, if not mat2.has_compute_spec(): # No Model Parallel Applied assert mat2.is_replicate(), 'Invalid mat2 spec for native addmm op' assert input_tensor.is_replicate(), 'Invalid input spec for native addmm op' - ret_tensor = ColoTensor.from_torch_tensor( - tensor=torch.addmm(input_tensor, mat1, mat2, beta=beta, alpha=alpha, **kargs), - spec=ColoTensorSpec(mat2.get_process_group())) + ret_tensor = ColoTensor.from_torch_tensor(tensor=torch.addmm(input_tensor, + mat1, + mat2, + beta=beta, + alpha=alpha, + **kargs), + spec=ColoTensorSpec(mat2.get_process_group())) elif mat2.has_compute_pattern(ComputePattern.TP1D): # Single Model Parallel Applied if mat2.is_shard_1drow() and input_tensor.is_replicate(): mode = 'row'