[polish] polish __repr__ for ColoTensor, DistSpec, ProcessGroup (#1235)

pull/1226/head^2
HELSON 2 years ago committed by GitHub
parent 0453776def
commit f071b500b6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -158,7 +158,7 @@ class ColoTensor(torch.Tensor):
return _convert_output(ret, pg)
def __repr__(self):
return f'ColoTensor: {super().__repr__()}\n dist spec: {self.dist_spec}\n process group: {self.process_group}'
return f'ColoTensor:\n{super().__repr__()}\n{self.dist_spec}\n{self.process_group}'
def _convert_to_dist_spec(self, dist_spec: _DistSpec) -> None:
"""_convert_to_dist_spec

@ -32,11 +32,11 @@ class _DistSpec:
return True
def __repr__(self) -> str:
res = "\nDistSpec:\n\t"
res_list = ["DistSpec:"]
for attr in dir(self):
if not attr.startswith('__'):
res += f'{attr}: {str(getattr(self, attr))}\n\t'
return res
res_list.append(f'\n\t{attr}: {str(getattr(self, attr))}')
return ''.join(res_list)
def replicate() -> _DistSpec:

@ -112,6 +112,10 @@ class ProcessGroup:
def has_cpu_groups(self):
return self._has_cpu_groups
def __repr__(self):
return "ProcessGroup:\n\tRank: {}, World size: {}, DP degree: {}, TP degree: {}\n\tRanks in group: {}".\
format(self._rank, self._world_size, self._dp_degree, self._tp_degree, self._rank_list)
def __eq__(self, obj: 'ProcessGroup') -> bool:
if not isinstance(obj, ProcessGroup):
return False

Loading…
Cancel
Save