mirror of https://github.com/hpcaitech/ColossalAI
[hotfix] add kwargs for colo_addmm (#2171)
parent
a110933d65
commit
ab54fed292
|
@ -55,7 +55,7 @@ def colo_addmm(input_tensor: GeneralTensor,
|
|||
mat2: ColoTensor,
|
||||
beta: Number = 1,
|
||||
alpha: Number = 1,
|
||||
*args) -> ColoTensor:
|
||||
**kargs) -> ColoTensor:
|
||||
"""Handles ``__torch_function__`` dispatch for ``torch.nn.functional.linear``.
|
||||
This method computes a linear.
|
||||
"""
|
||||
|
@ -70,7 +70,7 @@ def colo_addmm(input_tensor: GeneralTensor,
|
|||
assert mat2.is_replicate(), 'Invalid mat2 spec for native addmm op'
|
||||
assert input_tensor.is_replicate(), 'Invalid input spec for native addmm op'
|
||||
ret_tensor = ColoTensor.from_torch_tensor(
|
||||
tensor=torch.addmm(input_tensor, mat1, mat2, beta=beta, alpha=alpha),
|
||||
tensor=torch.addmm(input_tensor, mat1, mat2, beta=beta, alpha=alpha, **kargs),
|
||||
spec=ColoTensorSpec(mat2.get_process_group()))
|
||||
elif mat2.has_compute_pattern(ComputePattern.TP1D): # Single Model Parallel Applied
|
||||
if mat2.is_shard_1drow() and input_tensor.is_replicate():
|
||||
|
|
Loading…
Reference in New Issue