Browse Source

[polish] polish __repr__ for ColoTensor, DistSpec, ProcessGroup (#1235)

pull/1226/head^2
HELSON 2 years ago committed by GitHub
parent
commit
f071b500b6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 2
      colossalai/tensor/colo_tensor.py
  2. 6
      colossalai/tensor/distspec.py
  3. 4
      colossalai/tensor/process_group.py

2
colossalai/tensor/colo_tensor.py

@ -158,7 +158,7 @@ class ColoTensor(torch.Tensor):
return _convert_output(ret, pg)
def __repr__(self):
return f'ColoTensor: {super().__repr__()}\n dist spec: {self.dist_spec}\n process group: {self.process_group}'
return f'ColoTensor:\n{super().__repr__()}\n{self.dist_spec}\n{self.process_group}'
def _convert_to_dist_spec(self, dist_spec: _DistSpec) -> None:
"""_convert_to_dist_spec

6
colossalai/tensor/distspec.py

@ -32,11 +32,11 @@ class _DistSpec:
return True
def __repr__(self) -> str:
res = "\nDistSpec:\n\t"
res_list = ["DistSpec:"]
for attr in dir(self):
if not attr.startswith('__'):
res += f'{attr}: {str(getattr(self, attr))}\n\t'
return res
res_list.append(f'\n\t{attr}: {str(getattr(self, attr))}')
return ''.join(res_list)
def replicate() -> _DistSpec:

4
colossalai/tensor/process_group.py

@ -112,6 +112,10 @@ class ProcessGroup:
def has_cpu_groups(self):
return self._has_cpu_groups
def __repr__(self):
return "ProcessGroup:\n\tRank: {}, World size: {}, DP degree: {}, TP degree: {}\n\tRanks in group: {}".\
format(self._rank, self._world_size, self._dp_degree, self._tp_degree, self._rank_list)
def __eq__(self, obj: 'ProcessGroup') -> bool:
if not isinstance(obj, ProcessGroup):
return False

Loading…
Cancel
Save