[polish] polish ColoTensor and its submodules (#2537)

pull/2567/head
HELSON 2023-02-03 11:44:10 +08:00 committed by GitHub
parent 51d4d6e718
commit 552183bb74
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 75 additions and 65 deletions

View File

@ -71,7 +71,7 @@ class ColoParameter(ColoTensor, torch.nn.Parameter):
return tensor return tensor
def __repr__(self): def __repr__(self):
return f'ColoParameter: {ColoTensor.__repr__(self)}' return super(ColoParameter, self).__repr__()
@classmethod @classmethod
def __torch_function__(cls, func, types, args=..., kwargs=None): def __torch_function__(cls, func, types, args=..., kwargs=None):

View File

@ -189,7 +189,12 @@ class ColoTensor(torch.Tensor):
return _convert_output(ret, colo_spec) return _convert_output(ret, colo_spec)
def __repr__(self): def __repr__(self):
return f'ColoTensor:\n{super().__repr__()}\n{self.dist_spec}\n{self.process_group}\n{self.compute_spec}' output_list = [super(ColoTensor, self).__repr__()]
output_list.append(str(self.process_group))
output_list.append(str(self.dist_spec))
if self.compute_spec is not None:
output_list.append(str(self.compute_spec))
return "\n".join(output_list)
def _redistribute(self, dist_spec: _DistSpec) -> None: def _redistribute(self, dist_spec: _DistSpec) -> None:
"""_redistribute """_redistribute

View File

@ -23,7 +23,7 @@ class ComputeSpec(object):
self.output_replicate = True self.output_replicate = True
def __repr__(self): def __repr__(self):
return f'Compute pattern: {self.compute_pattern}' return f'ComputeSpec(pattern={self.compute_pattern}, replicate_output={self.output_replicate})'
def set_output_replicate(self, flag: bool = True): def set_output_replicate(self, flag: bool = True):
self.output_replicate = flag self.output_replicate = flag

View File

@ -39,11 +39,12 @@ class _DistSpec:
return True return True
def __repr__(self) -> str: def __repr__(self) -> str:
res_list = ["DistSpec:"] attr_list = []
for attr in dir(self): for attr in dir(self):
if not attr.startswith('__'): if not attr.startswith('__'):
res_list.append(f'\n\t{attr}: {str(getattr(self, attr))}') attr_list.append(f'{attr}={str(getattr(self, attr))}')
return ''.join(res_list) attr_str = ", ".join(attr_list)
return "DistSpec(" + attr_str + ")"
def ReplicaSpec() -> _DistSpec: def ReplicaSpec() -> _DistSpec:

View File

@ -1,29 +1,36 @@
import torch
from typing import List, Optional from typing import List, Optional
from colossalai.logging import get_dist_logger
import torch
from colossalai.context.singleton_meta import SingletonMeta from colossalai.context.singleton_meta import SingletonMeta
from colossalai.logging import get_dist_logger
class PyTorchProcessGroupDict(metaclass=SingletonMeta): class PyTorchProcessGroupDict(metaclass=SingletonMeta):
def __init__(self): def __init__(self):
# distributed settings # distributed settings
# use this dict to record all Pytorch ProcessGroups
self.dict = {} self.dict = {}
# set a distributed logger
self.logger = get_dist_logger('ProcessGroup')
def log_pg_init(self, rank_list: List[int], backend: str):
str_list = ["Pytorch ProcessGroup Init:"]
str_list.append(f"backend: {backend}")
str_list.append(f"ranks: {rank_list}")
self.logger.info("\n\t".join(str_list), ranks=[0])
def get(self, rank_list: List[int], backend: str = 'nccl'): def get(self, rank_list: List[int], backend: str = 'nccl'):
"""Reuse Pytorch ProcessGroup when such a group is initialized """Reuse Pytorch ProcessGroup when such a group is initialized
""" """
rank_tuple = tuple(rank_list)
# we need to convert the passed list to a tuple # we need to convert the passed list to a tuple
# since List is unhashable # since List is unhashable
pg_key = (backend, rank_tuple) processgroup_key = (backend, tuple(rank_list))
if processgroup_key not in self.dict:
if pg_key not in self.dict: self.log_pg_init(rank_list=rank_list, backend=backend)
self.dict[processgroup_key] = torch.distributed.new_group(ranks=rank_list, backend=backend)
self.logger = get_dist_logger('ProcessGroup') return self.dict[processgroup_key]
self.logger.info(f'NCCL initialize ProcessGroup on {rank_list}', ranks=[0])
self.dict[pg_key] = torch.distributed.new_group(ranks=rank_list, backend=backend)
return self.dict[pg_key]
PYTORCHPGDICT_ = PyTorchProcessGroupDict() PYTORCHPGDICT_ = PyTorchProcessGroupDict()
@ -54,10 +61,10 @@ class ProcessGroup:
return return
assert torch.distributed.is_initialized(), f"ProcessGroup must be used after distributed initialized" assert torch.distributed.is_initialized(), f"ProcessGroup must be used after distributed initialized"
if rank is None:
self._rank = torch.distributed.get_rank() self._rank = torch.distributed.get_rank()
else: if rank is not None:
self._rank = rank assert self._rank == rank # make sure that the global rank is correct
if ranks is None: if ranks is None:
self._rank_list = list(range(torch.distributed.get_world_size())) self._rank_list = list(range(torch.distributed.get_world_size()))
@ -132,8 +139,9 @@ class ProcessGroup:
def __repr__(self): def __repr__(self):
if self.is_init: if self.is_init:
return "ProcessGroup:\n\tRank: {}, World size: {}, DP degree: {}, TP degree: {}\n\tRanks in group: {}".\ ranks_str = f"ProcessGroup(ranks={self._rank_list},\n"
format(self._rank, self._world_size, self._dp_degree, self._tp_degree, self._rank_list) personal_str = f" rank={self._rank}, dp={self._dp_degree}, tp={self._tp_degree})"
return ranks_str + personal_str
else: else:
return "ProcessGroup not initialized" return "ProcessGroup not initialized"

View File

@ -43,7 +43,6 @@ def _convert_to_coloparam(param: torch.nn.Parameter,
else: else:
colo_param = ColoParameter(param.to(device=device, dtype=dtype), requires_grad=requires_grad) colo_param = ColoParameter(param.to(device=device, dtype=dtype), requires_grad=requires_grad)
# if default_shard_plan exists, shard the param during initialization. # if default_shard_plan exists, shard the param during initialization.
# This can reduce the model size after initialization. # This can reduce the model size after initialization.
# NOTE() embedding usually can not be correctly sharded. So I use except to handle # NOTE() embedding usually can not be correctly sharded. So I use except to handle
@ -130,30 +129,27 @@ class ColoInitContext(InsertPostInitMethodToModuleSubClasses):
setattr(submodule, param_name, colo_param) setattr(submodule, param_name, colo_param)
colo_param.shared_param_modules.append(submodule) colo_param.shared_param_modules.append(submodule)
meta_param_flag = 0 param_number = 0
meta_buffer_flag = 0 meta_param_number = 0
buffer_number = 0
meta_buffer_number = 0
for param in module.parameters(): for param in module.parameters():
if param.device.type=="meta": param_number += 1
meta_param_flag = 1 meta_param_number += (param.device.type == 'meta')
if meta_param_flag == 1 and param.device.type!="meta":
raise ValueError("Meta parameters and valued parameters can not be in the same model")
for buffer in module.buffers(): for buffer in module.buffers():
if buffer.device.type=="meta": buffer_number += 1
meta_buffer_flag = 1 meta_buffer_number += (buffer.device.type == 'meta')
if meta_buffer_flag == 1 and buffer.device.type!="meta":
raise ValueError("Meta buffers and valued buffers can not be in the same model")
if meta_param_flag==1 and meta_buffer_flag==1: if meta_param_number > 0 and meta_param_number != param_number:
pass raise ValueError("Meta parameters and valued parameters can not be in the same model")
elif meta_buffer_flag==0 and meta_param_flag==1: if meta_buffer_number > 0 and meta_buffer_number != buffer_number:
for name, buf in module.named_buffers(): raise ValueError("Meta buffers and valued buffers can not be in the same model")
module._buffers[name] = module._buffers[name].to(device=self._device)
elif meta_param_flag==0 and meta_buffer_flag==1: if meta_buffer_number == 0:
for name, param in module.named_parameters(): for buffer in module.buffers():
module._parameters[name] = module._parameters[name].to(device=self._device) buffer.data = buffer.data.to(device=self._device)
else:
module.to(self._device)
def post_process_colo_init_ctx(model: torch.nn.Module, def post_process_colo_init_ctx(model: torch.nn.Module,