[NFC] polish colossalai/nn/_ops/addmm.py code style (#3274)

pull/3313/head
Tong Li 2 years ago committed by binmakeswell
parent 4b95464994
commit 196d4696d0

@ -1,9 +1,9 @@
import torch import torch
from colossalai.tensor import ColoTensor, ColoTensorSpec, ComputePattern, ComputeSpec, ReplicaSpec, ShardSpec, distspec
from colossalai.tensor.op_wrapper import colo_op_impl from colossalai.tensor.op_wrapper import colo_op_impl
from colossalai.tensor import ComputePattern, ComputePattern, ComputeSpec, ColoTensor
from colossalai.tensor import distspec, ColoTensorSpec, ShardSpec, ReplicaSpec from ._utils import GeneralTensor, Number, convert_to_colo_tensor, reduce_grad, reduce_input
from ._utils import GeneralTensor, Number, convert_to_colo_tensor
from ._utils import reduce_input, reduce_grad
def colo_addmm_1Drow(input_tensor: ColoTensor, mat1: ColoTensor, mat2: ColoTensor, beta: Number, def colo_addmm_1Drow(input_tensor: ColoTensor, mat1: ColoTensor, mat2: ColoTensor, beta: Number,
@ -69,8 +69,12 @@ def colo_addmm(input_tensor: GeneralTensor,
if not mat2.has_compute_spec(): # No Model Parallel Applied if not mat2.has_compute_spec(): # No Model Parallel Applied
assert mat2.is_replicate(), 'Invalid mat2 spec for native addmm op' assert mat2.is_replicate(), 'Invalid mat2 spec for native addmm op'
assert input_tensor.is_replicate(), 'Invalid input spec for native addmm op' assert input_tensor.is_replicate(), 'Invalid input spec for native addmm op'
ret_tensor = ColoTensor.from_torch_tensor( ret_tensor = ColoTensor.from_torch_tensor(tensor=torch.addmm(input_tensor,
tensor=torch.addmm(input_tensor, mat1, mat2, beta=beta, alpha=alpha, **kargs), mat1,
mat2,
beta=beta,
alpha=alpha,
**kargs),
spec=ColoTensorSpec(mat2.get_process_group())) spec=ColoTensorSpec(mat2.get_process_group()))
elif mat2.has_compute_pattern(ComputePattern.TP1D): # Single Model Parallel Applied elif mat2.has_compute_pattern(ComputePattern.TP1D): # Single Model Parallel Applied
if mat2.is_shard_1drow() and input_tensor.is_replicate(): if mat2.is_shard_1drow() and input_tensor.is_replicate():

Loading…
Cancel
Save