mirror of https://github.com/hpcaitech/ColossalAI
hotfix colotensor _scan_for_pg_from_args (#1276)
parent
0cf8e8e91c
commit
7aadcbd070
|
@ -40,7 +40,7 @@ def _scan_for_pg_from_args(args, kwargs) -> ProcessGroup:
|
|||
pg = _scan_for_pg_from_args(elem, {})
|
||||
if pg is not None:
|
||||
return pg
|
||||
for k, v in kwargs:
|
||||
for k, v in kwargs.items():
|
||||
if isinstance(v, ColoTensor):
|
||||
pg = v.get_process_group()
|
||||
return pg
|
||||
|
@ -52,7 +52,7 @@ class ColoTensor(torch.Tensor):
|
|||
Args:
|
||||
data (torch.Tensor): a torch tensor used as the payload the colotensor.
|
||||
spec (ColoTensorSpec, optional): the tensor spec of initialization. Defaults to ColoTensorSpec(ReplicaSpec()).
|
||||
|
||||
|
||||
The signature of the function has to be consistent with the __new__ except for the 1st arg.
|
||||
The class should be initialized with a torch tensor in the following ways.
|
||||
1. directly init.
|
||||
|
|
Loading…
Reference in New Issue