hotfix colotensor _scan_for_pg_from_args (#1276)

pull/1277/head
ver217 2022-07-12 20:46:31 +08:00 committed by GitHub
parent 0cf8e8e91c
commit 7aadcbd070
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 2 additions and 2 deletions

View File

@ -40,7 +40,7 @@ def _scan_for_pg_from_args(args, kwargs) -> ProcessGroup:
pg = _scan_for_pg_from_args(elem, {})
if pg is not None:
return pg
for k, v in kwargs:
for k, v in kwargs.items():
if isinstance(v, ColoTensor):
pg = v.get_process_group()
return pg
@ -52,7 +52,7 @@ class ColoTensor(torch.Tensor):
Args:
data (torch.Tensor): a torch tensor used as the payload the colotensor.
spec (ColoTensorSpec, optional): the tensor spec of initialization. Defaults to ColoTensorSpec(ReplicaSpec()).
The signature of the function has to be consistent with the __new__ except for the 1st arg.
The class should be initialized with a torch tensor in the following ways.
1. directly init.