From f071b500b665018de983d2bcff5415b56905594b Mon Sep 17 00:00:00 2001 From: HELSON Date: Fri, 8 Jul 2022 13:25:57 +0800 Subject: [PATCH] [polish] polish __repr__ for ColoTensor, DistSpec, ProcessGroup (#1235) --- colossalai/tensor/colo_tensor.py | 2 +- colossalai/tensor/distspec.py | 6 +++--- colossalai/tensor/process_group.py | 4 ++++ 3 files changed, 8 insertions(+), 4 deletions(-) diff --git a/colossalai/tensor/colo_tensor.py b/colossalai/tensor/colo_tensor.py index 7a70c4447..874612f63 100644 --- a/colossalai/tensor/colo_tensor.py +++ b/colossalai/tensor/colo_tensor.py @@ -158,7 +158,7 @@ class ColoTensor(torch.Tensor): return _convert_output(ret, pg) def __repr__(self): - return f'ColoTensor: {super().__repr__()}\n dist spec: {self.dist_spec}\n process group: {self.process_group}' + return f'ColoTensor:\n{super().__repr__()}\n{self.dist_spec}\n{self.process_group}' def _convert_to_dist_spec(self, dist_spec: _DistSpec) -> None: """_convert_to_dist_spec diff --git a/colossalai/tensor/distspec.py b/colossalai/tensor/distspec.py index 4ca2db4c4..4796d420c 100644 --- a/colossalai/tensor/distspec.py +++ b/colossalai/tensor/distspec.py @@ -32,11 +32,11 @@ class _DistSpec: return True def __repr__(self) -> str: - res = "\nDistSpec:\n\t" + res_list = ["DistSpec:"] for attr in dir(self): if not attr.startswith('__'): - res += f'{attr}: {str(getattr(self, attr))}\n\t' - return res + res_list.append(f'\n\t{attr}: {str(getattr(self, attr))}') + return ''.join(res_list) def replicate() -> _DistSpec: diff --git a/colossalai/tensor/process_group.py b/colossalai/tensor/process_group.py index 90337864f..1482f02db 100644 --- a/colossalai/tensor/process_group.py +++ b/colossalai/tensor/process_group.py @@ -112,6 +112,10 @@ class ProcessGroup: def has_cpu_groups(self): return self._has_cpu_groups + def __repr__(self): + return "ProcessGroup:\n\tRank: {}, World size: {}, DP degree: {}, TP degree: {}\n\tRanks in group: {}".\ + format(self._rank, self._world_size, self._dp_degree, self._tp_degree, self._rank_list) + def __eq__(self, obj: 'ProcessGroup') -> bool: if not isinstance(obj, ProcessGroup): return False