mirror of https://github.com/hpcaitech/ColossalAI
[polish] polish ColoTensor and its submodules (#2537)
parent
51d4d6e718
commit
552183bb74
|
@ -71,7 +71,7 @@ class ColoParameter(ColoTensor, torch.nn.Parameter):
|
|||
return tensor
|
||||
|
||||
def __repr__(self):
|
||||
return f'ColoParameter: {ColoTensor.__repr__(self)}'
|
||||
return super(ColoParameter, self).__repr__()
|
||||
|
||||
@classmethod
|
||||
def __torch_function__(cls, func, types, args=..., kwargs=None):
|
||||
|
|
|
@ -189,7 +189,12 @@ class ColoTensor(torch.Tensor):
|
|||
return _convert_output(ret, colo_spec)
|
||||
|
||||
def __repr__(self):
|
||||
return f'ColoTensor:\n{super().__repr__()}\n{self.dist_spec}\n{self.process_group}\n{self.compute_spec}'
|
||||
output_list = [super(ColoTensor, self).__repr__()]
|
||||
output_list.append(str(self.process_group))
|
||||
output_list.append(str(self.dist_spec))
|
||||
if self.compute_spec is not None:
|
||||
output_list.append(str(self.compute_spec))
|
||||
return "\n".join(output_list)
|
||||
|
||||
def _redistribute(self, dist_spec: _DistSpec) -> None:
|
||||
"""_redistribute
|
||||
|
|
|
@ -9,9 +9,9 @@ class ComputePattern(Enum):
|
|||
|
||||
|
||||
class ComputeSpec(object):
|
||||
"""ComputeSpec
|
||||
"""ComputeSpec
|
||||
The Specification for compuattion pattern
|
||||
|
||||
|
||||
Args:
|
||||
compute_pattern (ComputePattern): an Enum instance for compute pattern.
|
||||
"""
|
||||
|
@ -23,7 +23,7 @@ class ComputeSpec(object):
|
|||
self.output_replicate = True
|
||||
|
||||
def __repr__(self):
|
||||
return f'Compute pattern: {self.compute_pattern}'
|
||||
return f'ComputeSpec(pattern={self.compute_pattern}, replicate_output={self.output_replicate})'
|
||||
|
||||
def set_output_replicate(self, flag: bool = True):
|
||||
self.output_replicate = flag
|
||||
|
|
|
@ -11,7 +11,7 @@ class DistPlacementPattern(Enum):
|
|||
|
||||
class _DistSpec:
|
||||
"""_DistSpec
|
||||
|
||||
|
||||
A class indicates Distributed Specification.
|
||||
The DistSpec is only works for the tensor parallel process groups.
|
||||
Because the dist spec of data parallel process group can be automatically deduced.
|
||||
|
@ -39,11 +39,12 @@ class _DistSpec:
|
|||
return True
|
||||
|
||||
def __repr__(self) -> str:
|
||||
res_list = ["DistSpec:"]
|
||||
attr_list = []
|
||||
for attr in dir(self):
|
||||
if not attr.startswith('__'):
|
||||
res_list.append(f'\n\t{attr}: {str(getattr(self, attr))}')
|
||||
return ''.join(res_list)
|
||||
attr_list.append(f'{attr}={str(getattr(self, attr))}')
|
||||
attr_str = ", ".join(attr_list)
|
||||
return "DistSpec(" + attr_str + ")"
|
||||
|
||||
|
||||
def ReplicaSpec() -> _DistSpec:
|
||||
|
|
|
@ -1,29 +1,36 @@
|
|||
import torch
|
||||
from typing import List, Optional
|
||||
from colossalai.logging import get_dist_logger
|
||||
|
||||
import torch
|
||||
|
||||
from colossalai.context.singleton_meta import SingletonMeta
|
||||
from colossalai.logging import get_dist_logger
|
||||
|
||||
|
||||
class PyTorchProcessGroupDict(metaclass=SingletonMeta):
|
||||
|
||||
def __init__(self):
|
||||
# distributed settings
|
||||
# use this dict to record all Pytorch ProcessGroups
|
||||
self.dict = {}
|
||||
# set a distributed logger
|
||||
self.logger = get_dist_logger('ProcessGroup')
|
||||
|
||||
def log_pg_init(self, rank_list: List[int], backend: str):
|
||||
str_list = ["Pytorch ProcessGroup Init:"]
|
||||
str_list.append(f"backend: {backend}")
|
||||
str_list.append(f"ranks: {rank_list}")
|
||||
self.logger.info("\n\t".join(str_list), ranks=[0])
|
||||
|
||||
def get(self, rank_list: List[int], backend: str = 'nccl'):
|
||||
"""Reuse Pytorch ProcessGroup when such a group is initialized
|
||||
"""
|
||||
rank_tuple = tuple(rank_list)
|
||||
# we need to convert the passed list to a tuple
|
||||
# since List is unhashable
|
||||
pg_key = (backend, rank_tuple)
|
||||
|
||||
if pg_key not in self.dict:
|
||||
|
||||
self.logger = get_dist_logger('ProcessGroup')
|
||||
self.logger.info(f'NCCL initialize ProcessGroup on {rank_list}', ranks=[0])
|
||||
self.dict[pg_key] = torch.distributed.new_group(ranks=rank_list, backend=backend)
|
||||
return self.dict[pg_key]
|
||||
processgroup_key = (backend, tuple(rank_list))
|
||||
if processgroup_key not in self.dict:
|
||||
self.log_pg_init(rank_list=rank_list, backend=backend)
|
||||
self.dict[processgroup_key] = torch.distributed.new_group(ranks=rank_list, backend=backend)
|
||||
return self.dict[processgroup_key]
|
||||
|
||||
|
||||
PYTORCHPGDICT_ = PyTorchProcessGroupDict()
|
||||
|
@ -40,7 +47,7 @@ class ProcessGroup:
|
|||
rank: the global rank of the current process.
|
||||
ranks: List[int], a list of rank id belongings to this process group.
|
||||
backend: str, the backend of the process group.
|
||||
tp_degree: Optional[int], tensor parallelism degree. How many processes are inside a tp process group. default None means 1.
|
||||
tp_degree: Optional[int], tensor parallelism degree. How many processes are inside a tp process group. default None means 1.
|
||||
dp_degree: Optional[int], data parallelism degree. How many processes are inside a dp process group. . default None means len(ranks).
|
||||
"""
|
||||
|
||||
|
@ -54,10 +61,10 @@ class ProcessGroup:
|
|||
return
|
||||
|
||||
assert torch.distributed.is_initialized(), f"ProcessGroup must be used after distributed initialized"
|
||||
if rank is None:
|
||||
self._rank = torch.distributed.get_rank()
|
||||
else:
|
||||
self._rank = rank
|
||||
|
||||
self._rank = torch.distributed.get_rank()
|
||||
if rank is not None:
|
||||
assert self._rank == rank # make sure that the global rank is correct
|
||||
|
||||
if ranks is None:
|
||||
self._rank_list = list(range(torch.distributed.get_world_size()))
|
||||
|
@ -104,7 +111,7 @@ class ProcessGroup:
|
|||
self.is_init = True
|
||||
|
||||
def set_cpu_groups(self):
|
||||
"""set_cpu_groups
|
||||
"""set_cpu_groups
|
||||
Initialize Pytorch process groups for cpu communications.
|
||||
"""
|
||||
if self.has_cpu_groups:
|
||||
|
@ -122,7 +129,7 @@ class ProcessGroup:
|
|||
|
||||
@property
|
||||
def has_cpu_groups(self) -> bool:
|
||||
"""has_cpu_groups
|
||||
"""has_cpu_groups
|
||||
If cpu groups have been initailized.
|
||||
|
||||
Returns:
|
||||
|
@ -132,8 +139,9 @@ class ProcessGroup:
|
|||
|
||||
def __repr__(self):
|
||||
if self.is_init:
|
||||
return "ProcessGroup:\n\tRank: {}, World size: {}, DP degree: {}, TP degree: {}\n\tRanks in group: {}".\
|
||||
format(self._rank, self._world_size, self._dp_degree, self._tp_degree, self._rank_list)
|
||||
ranks_str = f"ProcessGroup(ranks={self._rank_list},\n"
|
||||
personal_str = f" rank={self._rank}, dp={self._dp_degree}, tp={self._tp_degree})"
|
||||
return ranks_str + personal_str
|
||||
else:
|
||||
return "ProcessGroup not initialized"
|
||||
|
||||
|
@ -155,7 +163,7 @@ class ProcessGroup:
|
|||
return True
|
||||
|
||||
def rank(self) -> int:
|
||||
"""rank
|
||||
"""rank
|
||||
|
||||
The current rank in the global process group.
|
||||
|
||||
|
@ -165,9 +173,9 @@ class ProcessGroup:
|
|||
return self._rank
|
||||
|
||||
def ranks_in_group(self) -> List[int]:
|
||||
"""ranks_in_group
|
||||
"""ranks_in_group
|
||||
|
||||
a list of rank number in in the global process group.
|
||||
a list of rank number in in the global process group.
|
||||
|
||||
Returns:
|
||||
List[int]: a list of rank number.
|
||||
|
@ -177,7 +185,7 @@ class ProcessGroup:
|
|||
def world_size(self) -> int:
|
||||
"""world_size
|
||||
|
||||
The world size of the global process group.
|
||||
The world size of the global process group.
|
||||
|
||||
Returns:
|
||||
int: world size
|
||||
|
@ -185,7 +193,7 @@ class ProcessGroup:
|
|||
return self._world_size
|
||||
|
||||
def tp_rank_list(self) -> List[int]:
|
||||
"""tp_rank_list
|
||||
"""tp_rank_list
|
||||
|
||||
the rank list in the TP process group containing the current rank.
|
||||
|
||||
|
@ -195,7 +203,7 @@ class ProcessGroup:
|
|||
return self._tp_rank_list
|
||||
|
||||
def dp_rank_list(self) -> List[int]:
|
||||
"""dp_rank_list
|
||||
"""dp_rank_list
|
||||
|
||||
the rank list in the DP process group containing the current rank.
|
||||
|
||||
|
@ -205,7 +213,7 @@ class ProcessGroup:
|
|||
return self._dp_rank_list
|
||||
|
||||
def tp_local_rank(self) -> int:
|
||||
"""tp_local_rank
|
||||
"""tp_local_rank
|
||||
|
||||
The local rank number in the current TP process group.
|
||||
|
||||
|
@ -268,7 +276,7 @@ class ProcessGroup:
|
|||
"""cpu_dp_process_group
|
||||
|
||||
the pytorch CPU DP process group containing the current rank.
|
||||
|
||||
|
||||
assert failed if cpu process group is not initialized.
|
||||
|
||||
Returns:
|
||||
|
@ -281,7 +289,7 @@ class ProcessGroup:
|
|||
"""cpu_tp_process_group
|
||||
|
||||
the pytorch CPU TP process group containing the current rank.
|
||||
|
||||
|
||||
assert failed if cpu process group is not initialized.
|
||||
|
||||
Returns:
|
||||
|
|
|
@ -37,12 +37,11 @@ def _convert_to_coloparam(param: torch.nn.Parameter,
|
|||
# detaching tensor is necessary for optimizers.
|
||||
requires_grad = param.requires_grad
|
||||
# param is the global tensor.
|
||||
|
||||
|
||||
if param.device.type == "meta":
|
||||
colo_param = ColoParameter(param, requires_grad=requires_grad)
|
||||
else:
|
||||
else:
|
||||
colo_param = ColoParameter(param.to(device=device, dtype=dtype), requires_grad=requires_grad)
|
||||
|
||||
|
||||
# if default_shard_plan exists, shard the param during initialization.
|
||||
# This can reduce the model size after initialization.
|
||||
|
@ -129,32 +128,29 @@ class ColoInitContext(InsertPostInitMethodToModuleSubClasses):
|
|||
delattr(submodule, param_name)
|
||||
setattr(submodule, param_name, colo_param)
|
||||
colo_param.shared_param_modules.append(submodule)
|
||||
|
||||
meta_param_flag = 0
|
||||
meta_buffer_flag = 0
|
||||
|
||||
param_number = 0
|
||||
meta_param_number = 0
|
||||
buffer_number = 0
|
||||
meta_buffer_number = 0
|
||||
|
||||
for param in module.parameters():
|
||||
if param.device.type=="meta":
|
||||
meta_param_flag = 1
|
||||
if meta_param_flag == 1 and param.device.type!="meta":
|
||||
raise ValueError("Meta parameters and valued parameters can not be in the same model")
|
||||
|
||||
param_number += 1
|
||||
meta_param_number += (param.device.type == 'meta')
|
||||
|
||||
for buffer in module.buffers():
|
||||
if buffer.device.type=="meta":
|
||||
meta_buffer_flag = 1
|
||||
if meta_buffer_flag == 1 and buffer.device.type!="meta":
|
||||
raise ValueError("Meta buffers and valued buffers can not be in the same model")
|
||||
|
||||
if meta_param_flag==1 and meta_buffer_flag==1:
|
||||
pass
|
||||
elif meta_buffer_flag==0 and meta_param_flag==1:
|
||||
for name, buf in module.named_buffers():
|
||||
module._buffers[name] = module._buffers[name].to(device=self._device)
|
||||
elif meta_param_flag==0 and meta_buffer_flag==1:
|
||||
for name, param in module.named_parameters():
|
||||
module._parameters[name] = module._parameters[name].to(device=self._device)
|
||||
else:
|
||||
module.to(self._device)
|
||||
|
||||
buffer_number += 1
|
||||
meta_buffer_number += (buffer.device.type == 'meta')
|
||||
|
||||
if meta_param_number > 0 and meta_param_number != param_number:
|
||||
raise ValueError("Meta parameters and valued parameters can not be in the same model")
|
||||
if meta_buffer_number > 0 and meta_buffer_number != buffer_number:
|
||||
raise ValueError("Meta buffers and valued buffers can not be in the same model")
|
||||
|
||||
if meta_buffer_number == 0:
|
||||
for buffer in module.buffers():
|
||||
buffer.data = buffer.data.to(device=self._device)
|
||||
|
||||
|
||||
def post_process_colo_init_ctx(model: torch.nn.Module,
|
||||
device: torch.device = torch.device('cpu'),
|
||||
|
|
Loading…
Reference in New Issue