[polish] polish ColoTensor and its submodules (#2537)

pull/2567/head
HELSON 2023-02-03 11:44:10 +08:00 committed by GitHub
parent 51d4d6e718
commit 552183bb74
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 75 additions and 65 deletions

View File

@ -71,7 +71,7 @@ class ColoParameter(ColoTensor, torch.nn.Parameter):
return tensor
def __repr__(self):
return f'ColoParameter: {ColoTensor.__repr__(self)}'
return super(ColoParameter, self).__repr__()
@classmethod
def __torch_function__(cls, func, types, args=..., kwargs=None):

View File

@ -189,7 +189,12 @@ class ColoTensor(torch.Tensor):
return _convert_output(ret, colo_spec)
def __repr__(self):
return f'ColoTensor:\n{super().__repr__()}\n{self.dist_spec}\n{self.process_group}\n{self.compute_spec}'
output_list = [super(ColoTensor, self).__repr__()]
output_list.append(str(self.process_group))
output_list.append(str(self.dist_spec))
if self.compute_spec is not None:
output_list.append(str(self.compute_spec))
return "\n".join(output_list)
def _redistribute(self, dist_spec: _DistSpec) -> None:
"""_redistribute

View File

@ -9,9 +9,9 @@ class ComputePattern(Enum):
class ComputeSpec(object):
"""ComputeSpec
"""ComputeSpec
The Specification for compuattion pattern
Args:
compute_pattern (ComputePattern): an Enum instance for compute pattern.
"""
@ -23,7 +23,7 @@ class ComputeSpec(object):
self.output_replicate = True
def __repr__(self):
return f'Compute pattern: {self.compute_pattern}'
return f'ComputeSpec(pattern={self.compute_pattern}, replicate_output={self.output_replicate})'
def set_output_replicate(self, flag: bool = True):
self.output_replicate = flag

View File

@ -11,7 +11,7 @@ class DistPlacementPattern(Enum):
class _DistSpec:
"""_DistSpec
A class indicates Distributed Specification.
The DistSpec is only works for the tensor parallel process groups.
Because the dist spec of data parallel process group can be automatically deduced.
@ -39,11 +39,12 @@ class _DistSpec:
return True
def __repr__(self) -> str:
res_list = ["DistSpec:"]
attr_list = []
for attr in dir(self):
if not attr.startswith('__'):
res_list.append(f'\n\t{attr}: {str(getattr(self, attr))}')
return ''.join(res_list)
attr_list.append(f'{attr}={str(getattr(self, attr))}')
attr_str = ", ".join(attr_list)
return "DistSpec(" + attr_str + ")"
def ReplicaSpec() -> _DistSpec:

View File

@ -1,29 +1,36 @@
import torch
from typing import List, Optional
from colossalai.logging import get_dist_logger
import torch
from colossalai.context.singleton_meta import SingletonMeta
from colossalai.logging import get_dist_logger
class PyTorchProcessGroupDict(metaclass=SingletonMeta):
def __init__(self):
# distributed settings
# use this dict to record all Pytorch ProcessGroups
self.dict = {}
# set a distributed logger
self.logger = get_dist_logger('ProcessGroup')
def log_pg_init(self, rank_list: List[int], backend: str):
str_list = ["Pytorch ProcessGroup Init:"]
str_list.append(f"backend: {backend}")
str_list.append(f"ranks: {rank_list}")
self.logger.info("\n\t".join(str_list), ranks=[0])
def get(self, rank_list: List[int], backend: str = 'nccl'):
"""Reuse Pytorch ProcessGroup when such a group is initialized
"""
rank_tuple = tuple(rank_list)
# we need to convert the passed list to a tuple
# since List is unhashable
pg_key = (backend, rank_tuple)
if pg_key not in self.dict:
self.logger = get_dist_logger('ProcessGroup')
self.logger.info(f'NCCL initialize ProcessGroup on {rank_list}', ranks=[0])
self.dict[pg_key] = torch.distributed.new_group(ranks=rank_list, backend=backend)
return self.dict[pg_key]
processgroup_key = (backend, tuple(rank_list))
if processgroup_key not in self.dict:
self.log_pg_init(rank_list=rank_list, backend=backend)
self.dict[processgroup_key] = torch.distributed.new_group(ranks=rank_list, backend=backend)
return self.dict[processgroup_key]
PYTORCHPGDICT_ = PyTorchProcessGroupDict()
@ -40,7 +47,7 @@ class ProcessGroup:
rank: the global rank of the current process.
ranks: List[int], a list of rank id belongings to this process group.
backend: str, the backend of the process group.
tp_degree: Optional[int], tensor parallelism degree. How many processes are inside a tp process group. default None means 1.
tp_degree: Optional[int], tensor parallelism degree. How many processes are inside a tp process group. default None means 1.
dp_degree: Optional[int], data parallelism degree. How many processes are inside a dp process group. . default None means len(ranks).
"""
@ -54,10 +61,10 @@ class ProcessGroup:
return
assert torch.distributed.is_initialized(), f"ProcessGroup must be used after distributed initialized"
if rank is None:
self._rank = torch.distributed.get_rank()
else:
self._rank = rank
self._rank = torch.distributed.get_rank()
if rank is not None:
assert self._rank == rank # make sure that the global rank is correct
if ranks is None:
self._rank_list = list(range(torch.distributed.get_world_size()))
@ -104,7 +111,7 @@ class ProcessGroup:
self.is_init = True
def set_cpu_groups(self):
"""set_cpu_groups
"""set_cpu_groups
Initialize Pytorch process groups for cpu communications.
"""
if self.has_cpu_groups:
@ -122,7 +129,7 @@ class ProcessGroup:
@property
def has_cpu_groups(self) -> bool:
"""has_cpu_groups
"""has_cpu_groups
If cpu groups have been initailized.
Returns:
@ -132,8 +139,9 @@ class ProcessGroup:
def __repr__(self):
if self.is_init:
return "ProcessGroup:\n\tRank: {}, World size: {}, DP degree: {}, TP degree: {}\n\tRanks in group: {}".\
format(self._rank, self._world_size, self._dp_degree, self._tp_degree, self._rank_list)
ranks_str = f"ProcessGroup(ranks={self._rank_list},\n"
personal_str = f" rank={self._rank}, dp={self._dp_degree}, tp={self._tp_degree})"
return ranks_str + personal_str
else:
return "ProcessGroup not initialized"
@ -155,7 +163,7 @@ class ProcessGroup:
return True
def rank(self) -> int:
"""rank
"""rank
The current rank in the global process group.
@ -165,9 +173,9 @@ class ProcessGroup:
return self._rank
def ranks_in_group(self) -> List[int]:
"""ranks_in_group
"""ranks_in_group
a list of rank number in in the global process group.
a list of rank number in in the global process group.
Returns:
List[int]: a list of rank number.
@ -177,7 +185,7 @@ class ProcessGroup:
def world_size(self) -> int:
"""world_size
The world size of the global process group.
The world size of the global process group.
Returns:
int: world size
@ -185,7 +193,7 @@ class ProcessGroup:
return self._world_size
def tp_rank_list(self) -> List[int]:
"""tp_rank_list
"""tp_rank_list
the rank list in the TP process group containing the current rank.
@ -195,7 +203,7 @@ class ProcessGroup:
return self._tp_rank_list
def dp_rank_list(self) -> List[int]:
"""dp_rank_list
"""dp_rank_list
the rank list in the DP process group containing the current rank.
@ -205,7 +213,7 @@ class ProcessGroup:
return self._dp_rank_list
def tp_local_rank(self) -> int:
"""tp_local_rank
"""tp_local_rank
The local rank number in the current TP process group.
@ -268,7 +276,7 @@ class ProcessGroup:
"""cpu_dp_process_group
the pytorch CPU DP process group containing the current rank.
assert failed if cpu process group is not initialized.
Returns:
@ -281,7 +289,7 @@ class ProcessGroup:
"""cpu_tp_process_group
the pytorch CPU TP process group containing the current rank.
assert failed if cpu process group is not initialized.
Returns:

View File

@ -37,12 +37,11 @@ def _convert_to_coloparam(param: torch.nn.Parameter,
# detaching tensor is necessary for optimizers.
requires_grad = param.requires_grad
# param is the global tensor.
if param.device.type == "meta":
colo_param = ColoParameter(param, requires_grad=requires_grad)
else:
else:
colo_param = ColoParameter(param.to(device=device, dtype=dtype), requires_grad=requires_grad)
# if default_shard_plan exists, shard the param during initialization.
# This can reduce the model size after initialization.
@ -129,32 +128,29 @@ class ColoInitContext(InsertPostInitMethodToModuleSubClasses):
delattr(submodule, param_name)
setattr(submodule, param_name, colo_param)
colo_param.shared_param_modules.append(submodule)
meta_param_flag = 0
meta_buffer_flag = 0
param_number = 0
meta_param_number = 0
buffer_number = 0
meta_buffer_number = 0
for param in module.parameters():
if param.device.type=="meta":
meta_param_flag = 1
if meta_param_flag == 1 and param.device.type!="meta":
raise ValueError("Meta parameters and valued parameters can not be in the same model")
param_number += 1
meta_param_number += (param.device.type == 'meta')
for buffer in module.buffers():
if buffer.device.type=="meta":
meta_buffer_flag = 1
if meta_buffer_flag == 1 and buffer.device.type!="meta":
raise ValueError("Meta buffers and valued buffers can not be in the same model")
if meta_param_flag==1 and meta_buffer_flag==1:
pass
elif meta_buffer_flag==0 and meta_param_flag==1:
for name, buf in module.named_buffers():
module._buffers[name] = module._buffers[name].to(device=self._device)
elif meta_param_flag==0 and meta_buffer_flag==1:
for name, param in module.named_parameters():
module._parameters[name] = module._parameters[name].to(device=self._device)
else:
module.to(self._device)
buffer_number += 1
meta_buffer_number += (buffer.device.type == 'meta')
if meta_param_number > 0 and meta_param_number != param_number:
raise ValueError("Meta parameters and valued parameters can not be in the same model")
if meta_buffer_number > 0 and meta_buffer_number != buffer_number:
raise ValueError("Meta buffers and valued buffers can not be in the same model")
if meta_buffer_number == 0:
for buffer in module.buffers():
buffer.data = buffer.data.to(device=self._device)
def post_process_colo_init_ctx(model: torch.nn.Module,
device: torch.device = torch.device('cpu'),