|
|
|
@ -1,9 +1,9 @@
|
|
|
|
|
import torch |
|
|
|
|
|
|
|
|
|
from colossalai.tensor import ColoTensor, ColoTensorSpec, ComputePattern, ComputeSpec, ReplicaSpec, ShardSpec, distspec |
|
|
|
|
from colossalai.tensor.op_wrapper import colo_op_impl |
|
|
|
|
from colossalai.tensor import ComputePattern, ComputePattern, ComputeSpec, ColoTensor |
|
|
|
|
from colossalai.tensor import distspec, ColoTensorSpec, ShardSpec, ReplicaSpec |
|
|
|
|
from ._utils import GeneralTensor, Number, convert_to_colo_tensor |
|
|
|
|
from ._utils import reduce_input, reduce_grad |
|
|
|
|
|
|
|
|
|
from ._utils import GeneralTensor, Number, convert_to_colo_tensor, reduce_grad, reduce_input |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def colo_addmm_1Drow(input_tensor: ColoTensor, mat1: ColoTensor, mat2: ColoTensor, beta: Number, |
|
|
|
@ -69,9 +69,13 @@ def colo_addmm(input_tensor: GeneralTensor,
|
|
|
|
|
if not mat2.has_compute_spec(): # No Model Parallel Applied |
|
|
|
|
assert mat2.is_replicate(), 'Invalid mat2 spec for native addmm op' |
|
|
|
|
assert input_tensor.is_replicate(), 'Invalid input spec for native addmm op' |
|
|
|
|
ret_tensor = ColoTensor.from_torch_tensor( |
|
|
|
|
tensor=torch.addmm(input_tensor, mat1, mat2, beta=beta, alpha=alpha, **kargs), |
|
|
|
|
spec=ColoTensorSpec(mat2.get_process_group())) |
|
|
|
|
ret_tensor = ColoTensor.from_torch_tensor(tensor=torch.addmm(input_tensor, |
|
|
|
|
mat1, |
|
|
|
|
mat2, |
|
|
|
|
beta=beta, |
|
|
|
|
alpha=alpha, |
|
|
|
|
**kargs), |
|
|
|
|
spec=ColoTensorSpec(mat2.get_process_group())) |
|
|
|
|
elif mat2.has_compute_pattern(ComputePattern.TP1D): # Single Model Parallel Applied |
|
|
|
|
if mat2.is_shard_1drow() and input_tensor.is_replicate(): |
|
|
|
|
mode = 'row' |
|
|
|
|