[hotfix] add kwargs for colo_addmm (#2171)

pull/2136/head
Tongping Liu 2022-12-22 00:25:30 -05:00 committed by GitHub
parent a110933d65
commit ab54fed292
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 2 additions and 2 deletions

View File

@ -55,7 +55,7 @@ def colo_addmm(input_tensor: GeneralTensor,
mat2: ColoTensor,
beta: Number = 1,
alpha: Number = 1,
*args) -> ColoTensor:
**kargs) -> ColoTensor:
"""Handles ``__torch_function__`` dispatch for ``torch.nn.functional.linear``.
This method computes a linear.
"""
@ -70,7 +70,7 @@ def colo_addmm(input_tensor: GeneralTensor,
assert mat2.is_replicate(), 'Invalid mat2 spec for native addmm op'
assert input_tensor.is_replicate(), 'Invalid input spec for native addmm op'
ret_tensor = ColoTensor.from_torch_tensor(
tensor=torch.addmm(input_tensor, mat1, mat2, beta=beta, alpha=alpha),
tensor=torch.addmm(input_tensor, mat1, mat2, beta=beta, alpha=alpha, **kargs),
spec=ColoTensorSpec(mat2.get_process_group()))
elif mat2.has_compute_pattern(ComputePattern.TP1D): # Single Model Parallel Applied
if mat2.is_shard_1drow() and input_tensor.is_replicate():