diff --git a/colossalai/tensor/colo_tensor.py b/colossalai/tensor/colo_tensor.py index 25adb212f..0c92ac0c7 100644 --- a/colossalai/tensor/colo_tensor.py +++ b/colossalai/tensor/colo_tensor.py @@ -40,7 +40,7 @@ def _scan_for_pg_from_args(args, kwargs) -> ProcessGroup: pg = _scan_for_pg_from_args(elem, {}) if pg is not None: return pg - for k, v in kwargs: + for k, v in kwargs.items(): if isinstance(v, ColoTensor): pg = v.get_process_group() return pg @@ -52,7 +52,7 @@ class ColoTensor(torch.Tensor): Args: data (torch.Tensor): a torch tensor used as the payload the colotensor. spec (ColoTensorSpec, optional): the tensor spec of initialization. Defaults to ColoTensorSpec(ReplicaSpec()). - + The signature of the function has to be consistent with the __new__ except for the 1st arg. The class should be initialized with a torch tensor in the following ways. 1. directly init.