mirror of https://github.com/hpcaitech/ColossalAI
[NFC] polish doc style for ColoTensor (#1457)
parent
0dbd61c29b
commit
a1476ea882
|
@ -9,6 +9,7 @@ def register_colo_graph(input_pos: List[int], param_pos: List[int]) -> Callable:
|
|||
"""register_colo_graph
|
||||
Register a Op (Layer) to ColoGraph.
|
||||
Recoders the input args in types of ColoTensor to the Graph.
|
||||
|
||||
Args:
|
||||
func (Callable): a function implements the Op.
|
||||
|
||||
|
|
|
@ -99,9 +99,11 @@ class CachedParamMgr(torch.nn.Module):
|
|||
|
||||
@torch.no_grad()
|
||||
def reorder(self, ids_freq_mapping: Optional[List[int]] = None, warmup_ratio=0.7):
|
||||
"""reorder the weight according to ids' frequency in dataset before training.
|
||||
"""reorder
|
||||
reorder the weight according to ids' frequency in dataset before training.
|
||||
Also Build the IndexMappingTable, aka index_mapping_table.
|
||||
Execute only once before training.
|
||||
|
||||
Args:
|
||||
ids_freq_mapping (List[int]): a list, idx is id number, value is freq. if None no reorder
|
||||
warmup_ratio (float): the amount of chunks preloaded in cuda cache
|
||||
|
|
|
@ -16,8 +16,9 @@ class LimitBuffIndexCopyer(object):
|
|||
@torch.no_grad()
|
||||
def index_copy(self, dim: int, src_index: LongTensor, tgt_index: LongTensor, src: torch.Tensor, tgt: torch.Tensor):
|
||||
"""copy
|
||||
src tensor[src_index] -(index_select)-> tmp -()-> tgt tensor [tgt_index]
|
||||
The valid part in src is continous, while in tgt is scatter.
|
||||
src tensor[src_index] -(index_select)-> tmp -(index_copy_)-> tgt tensor [tgt_index]
|
||||
The valid rows in the src tensor are continous, while rows in tgt tensor is scattered.
|
||||
|
||||
Args:
|
||||
dim (int): dimension along which to index
|
||||
src_index (int): indices of src tensor to select from
|
||||
|
|
|
@ -57,6 +57,7 @@ class FreqAwareEmbeddingBag(BaseEmbeddingBag):
|
|||
Called after initialized.
|
||||
Reorder the weight rows according to the ids_freq_mapping.
|
||||
Then, let the weights of the Module be managed by a CachedParamMgr.
|
||||
|
||||
Args:
|
||||
cuda_row_num (int): number of rows can be hosted in CUDA memory
|
||||
ids_freq_mapping (List[int]): a list, idx is id number, value is freq
|
||||
|
|
|
@ -18,7 +18,7 @@ def _get_my_nowrap_functions() -> Set[Callable]:
|
|||
Tensor._base.__get__,
|
||||
Tensor.grad.__get__,
|
||||
Tensor._grad.__get__,
|
||||
Tensor.data.__get__, # make .data returns torch.Tensor rather than ColoTensor
|
||||
Tensor.data.__get__, # make .data returns torch.Tensor rather than ColoTensor
|
||||
}
|
||||
|
||||
|
||||
|
@ -51,31 +51,31 @@ def _get_spec_from_args(args, kwargs) -> ColoTensorSpec:
|
|||
|
||||
class ColoTensor(torch.Tensor):
|
||||
""" Data Structure for Tensor in Colossal-AI. It is a subclass of torch.Tensor.
|
||||
|
||||
The Colotensor can be initialized with a PyTorch tensor in the following ways.
|
||||
|
||||
>>> pg = ProcessGroup()
|
||||
>>> colo_t1 = ColoTensor(torch.randn(2,3), spec = ColoTensorSpec(pg, ReplicaSpec())
|
||||
>>> # The tensor passed in is a tensor after sharding but not a global tensor.
|
||||
>>> shard_spec = ShardSpec(process_group=ProcessGroup(tp=world_size),
|
||||
>>> dims=[0],
|
||||
>>> num_partitions=[world_size])
|
||||
>>> tensor_spec = ColoTensorSpec(pg, shard_spec)
|
||||
>>> colo_t2 = ColoTensor.from_torch_tensor(t_ref.clone(), tensor_spec)
|
||||
|
||||
Args:
|
||||
data (torch.Tensor): a torch tensor used as the payload the colotensor.
|
||||
spec (ColoTensorSpec, optional): the tensor spec of initialization. Defaults to ColoTensorSpec(ReplicaSpec()).
|
||||
|
||||
The signature of the function has to be consistent with the __new__ except for the 1st arg.
|
||||
The class should be initialized with a torch tensor in the following ways.
|
||||
1. directly init.
|
||||
>>> pg = ProcessGroup()
|
||||
>>> colo_t1 = ColoTensor(torch.randn(2,3), spec = ColoTensorSpec(pg, ReplicaSpec())
|
||||
>>> # If initializaed in a shard model, the tensor passed in is one shard of the global tensor.
|
||||
>>> shard_spec = ShardSpec(process_group=ProcessGroup(tp=world_size),
|
||||
>>> dims=[0],
|
||||
>>> num_partitions=[world_size])
|
||||
>>> tensor_spec = ColoTensorSpec(pg, shard_spec)
|
||||
>>> colo_t2 = ColoTensor.from_torch_tensor(t_ref.clone(), tensor_spec)
|
||||
2. use static method from_torch_tensor
|
||||
>>> colo_t = ColoTensor.from_torch_tensor(torch.randn(2,3), spec = ColoTensorSpec(pg, ReplicaSpec())
|
||||
"""
|
||||
|
||||
def __new__(cls, data: torch.Tensor, spec: ColoTensorSpec) -> 'ColoTensor':
|
||||
"""__new__
|
||||
"""
|
||||
The signature of the __new__ has to be consistent with the torch.Tensor.
|
||||
|
||||
Args:
|
||||
data (torch.Tensor): a torch tensor used as the payload the colotensor.
|
||||
spec (TensorSpec, optional): the tensor spec of initialization.
|
||||
|
||||
Returns:
|
||||
ColoTensor: a ColoTensor wrappers the data.
|
||||
"""
|
||||
|
@ -115,12 +115,10 @@ class ColoTensor(torch.Tensor):
|
|||
"""set_process_group
|
||||
change the pg of the ColoTensor. Note that the valid use cases is limited.
|
||||
Only existing pg is DP and dist spec is REPLICaTE is valid.
|
||||
|
||||
Args:
|
||||
pg (ProcessGroup): target pg
|
||||
|
||||
Raises:
|
||||
RuntimeError:
|
||||
RuntimeError:
|
||||
"""
|
||||
assert isinstance(pg, ProcessGroup), f"pg as type {type(pg)} is invalid"
|
||||
# if the new pg is the same as the old pg, just returns
|
||||
|
@ -139,6 +137,7 @@ class ColoTensor(torch.Tensor):
|
|||
def set_dist_spec(self, dist_spec: _DistSpec):
|
||||
"""set_dist_spec
|
||||
set dist spec and change the payloads.
|
||||
|
||||
Args:
|
||||
dist_spec (_DistSpec): target dist spec.
|
||||
"""
|
||||
|
@ -182,6 +181,7 @@ class ColoTensor(torch.Tensor):
|
|||
"""_redistribute
|
||||
Note the function will not handle the logic of backward propagation!
|
||||
It is used during model tensor initializations as an internal function.
|
||||
|
||||
Args:
|
||||
dist_spec (_DistSpec): the target dist. spec.
|
||||
"""
|
||||
|
@ -193,12 +193,14 @@ class ColoTensor(torch.Tensor):
|
|||
def redistribute(self, dist_spec: _DistSpec, pg: Optional[ProcessGroup] = None) -> 'ColoTensor':
|
||||
"""redistribute
|
||||
Redistribute the tensor among processes. The rule is like this:
|
||||
1. If the pg is None, then redistributed tensor payload among TP process group. Keep the
|
||||
DP process group still as replicated.
|
||||
2. If the pg is not not None and not equal to the cureent process group.
|
||||
First, convert the tensor as replicated among TP process group.
|
||||
Second, reset the process group.
|
||||
Third, conver the tensor (new replicated both among tp and dp process group) to the new dist_spec.
|
||||
|
||||
1. If the pg is None, then redistribute the tensor payload among the TP process group. Keep the
|
||||
DP process group not changed.
|
||||
|
||||
2. If the pg is not not None and not equal to the current process group.
|
||||
First, convert the tensor as replicated among the TP process group.
|
||||
Second, reset the process group to the new pg.
|
||||
Third, conver the tensor (new replicated both among the tp process group) to the new dist_spec.
|
||||
|
||||
Args:
|
||||
dist_spec (_DistSpec): the new dist spec.
|
||||
|
@ -219,18 +221,31 @@ class ColoTensor(torch.Tensor):
|
|||
|
||||
def to_replicate_(self):
|
||||
"""to_replicate_
|
||||
|
||||
an inline member function, converting dist spec of the tensor to REPLICATE
|
||||
"""
|
||||
self._redistribute(dist_spec=ReplicaSpec())
|
||||
|
||||
def to_replicate(self) -> 'ColoTensor':
|
||||
"""to_replicate
|
||||
converting dist spec of the tensor to REPLICATE
|
||||
|
||||
converting dist spec of the tensor to ReplicaSpec()
|
||||
"""
|
||||
return self.redistribute(ReplicaSpec())
|
||||
|
||||
@staticmethod
|
||||
def from_torch_tensor(tensor: torch.Tensor, spec: Optional[ColoTensorSpec] = None) -> 'ColoTensor':
|
||||
"""from_torch_tensor
|
||||
|
||||
A static method builds a `ColoTensor` from a PyTorch Tensor.
|
||||
|
||||
Args:
|
||||
tensor (torch.Tensor): the pytorch tensor, which is a local tensor for this rank not a global tensor.
|
||||
spec (Optional[ColoTensorSpec], optional): tensor spec. Defaults to None.
|
||||
|
||||
Returns:
|
||||
ColoTensor: a ColoTensor
|
||||
"""
|
||||
tensor = tensor.as_subclass(ColoTensor)
|
||||
tensor.__init__(tensor, spec=spec)
|
||||
return tensor
|
||||
|
@ -252,10 +267,13 @@ class ColoTensor(torch.Tensor):
|
|||
return super().size(*args)
|
||||
|
||||
def size_global(self, *args) -> torch.Size:
|
||||
"""override the torch buildin size()
|
||||
"""size_global
|
||||
|
||||
override the torch buildin size()
|
||||
the shape passed in must be in a replicate placement.
|
||||
|
||||
Returns:
|
||||
ColoTensor: a tensor after viewed.
|
||||
torch.Size: the global tensor shape
|
||||
"""
|
||||
if self.is_replicate():
|
||||
return self.size_local(*args)
|
||||
|
|
|
@ -11,6 +11,7 @@ class ComputePattern(Enum):
|
|||
class ComputeSpec(object):
|
||||
"""ComputeSpec
|
||||
The Specification for compuattion pattern
|
||||
|
||||
Args:
|
||||
compute_pattern (ComputePattern): an Enum instance for compute pattern.
|
||||
"""
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
from enum import Enum
|
||||
from typing import List, Optional
|
||||
from typing import List
|
||||
|
||||
__all__ = ['replicate', 'shard']
|
||||
|
||||
|
|
|
@ -30,10 +30,13 @@ PYTORCHPGDICT_ = PyTorchProcessGroupDict()
|
|||
|
||||
|
||||
class ProcessGroup:
|
||||
"""
|
||||
"""ProcessGroup
|
||||
Process Group contains group partition for Tensor Parallel and Data Parallel.
|
||||
NOTE, the ProcessGroup must be used after torch.distributed.initialize()
|
||||
args:
|
||||
|
||||
NOTE, the ProcessGroup must be used after `torch.distributed.initialize()`
|
||||
|
||||
|
||||
Args:
|
||||
rank: the global rank of the current process.
|
||||
ranks: List[int], a list of rank id belongings to this process group.
|
||||
backend: str, the backend of the process group.
|
||||
|
@ -101,6 +104,9 @@ class ProcessGroup:
|
|||
self.is_init = True
|
||||
|
||||
def set_cpu_groups(self):
|
||||
"""set_cpu_groups
|
||||
Initialize Pytorch process groups for cpu communications.
|
||||
"""
|
||||
if self.has_cpu_groups:
|
||||
return
|
||||
|
||||
|
@ -115,7 +121,13 @@ class ProcessGroup:
|
|||
self._has_cpu_groups = True
|
||||
|
||||
@property
|
||||
def has_cpu_groups(self):
|
||||
def has_cpu_groups(self) -> bool:
|
||||
"""has_cpu_groups
|
||||
If cpu groups have been initailized.
|
||||
|
||||
Returns:
|
||||
bool: cpu process groups have been initialized or not.
|
||||
"""
|
||||
return self._has_cpu_groups
|
||||
|
||||
def __repr__(self):
|
||||
|
@ -142,51 +154,158 @@ class ProcessGroup:
|
|||
return False
|
||||
return True
|
||||
|
||||
def rank(self):
|
||||
def rank(self) -> int:
|
||||
"""rank
|
||||
|
||||
The current rank in the global process group.
|
||||
|
||||
Returns:
|
||||
int: the rank number
|
||||
"""
|
||||
return self._rank
|
||||
|
||||
def ranks_in_group(self):
|
||||
def ranks_in_group(self) -> List[int]:
|
||||
"""ranks_in_group
|
||||
|
||||
a list of rank number in in the global process group.
|
||||
|
||||
Returns:
|
||||
List[int]: a list of rank number.
|
||||
"""
|
||||
return self._rank_list
|
||||
|
||||
def world_size(self):
|
||||
def world_size(self) -> int:
|
||||
"""world_size
|
||||
|
||||
The world size of the global process group.
|
||||
|
||||
Returns:
|
||||
int: world size
|
||||
"""
|
||||
return self._world_size
|
||||
|
||||
def tp_rank_list(self):
|
||||
def tp_rank_list(self) -> List[int]:
|
||||
"""tp_rank_list
|
||||
|
||||
the rank list in the TP process group containing the current rank.
|
||||
|
||||
Returns:
|
||||
List[int]: the list of rank number.
|
||||
"""
|
||||
return self._tp_rank_list
|
||||
|
||||
def dp_rank_list(self):
|
||||
def dp_rank_list(self) -> List[int]:
|
||||
"""dp_rank_list
|
||||
|
||||
the rank list in the DP process group containing the current rank.
|
||||
|
||||
Returns:
|
||||
List[int]: the list of rank number.
|
||||
"""
|
||||
return self._dp_rank_list
|
||||
|
||||
def tp_local_rank(self):
|
||||
def tp_local_rank(self) -> int:
|
||||
"""tp_local_rank
|
||||
|
||||
The local rank number in the current TP process group.
|
||||
|
||||
Returns:
|
||||
int: tp rank number.
|
||||
"""
|
||||
return self._rank % self._tp_degree
|
||||
|
||||
def dp_local_rank(self):
|
||||
def dp_local_rank(self) -> int:
|
||||
"""dp_local_rank
|
||||
|
||||
The local rank number in the current DP process group.
|
||||
|
||||
Returns:
|
||||
int: dp rank number.
|
||||
"""
|
||||
return self._rank // self._tp_degree
|
||||
|
||||
def dp_world_size(self):
|
||||
def dp_world_size(self) -> int:
|
||||
"""dp_world_size
|
||||
|
||||
The world size of the current DP process group.
|
||||
|
||||
Returns:
|
||||
int: dp world size
|
||||
"""
|
||||
return len(self._dp_rank_list)
|
||||
|
||||
def tp_world_size(self):
|
||||
def tp_world_size(self) -> int:
|
||||
"""tp_world_size
|
||||
|
||||
The world size of the current TP process group.
|
||||
|
||||
Returns:
|
||||
int: tp world size
|
||||
"""
|
||||
return len(self._tp_rank_list)
|
||||
|
||||
def dp_process_group(self):
|
||||
# return self._dp_process_group
|
||||
"""dp_process_group
|
||||
|
||||
the pytorch DP process group containing the current rank.
|
||||
|
||||
Returns:
|
||||
`torch._C._distributed_c10d.ProcessGroup`: the pytorch DP process group.
|
||||
"""
|
||||
return PYTORCHPGDICT_.get(self._dp_rank_list, 'nccl')
|
||||
|
||||
def tp_process_group(self):
|
||||
# return self._tp_process_group
|
||||
"""tp_process_group
|
||||
|
||||
the pytorch TP process group containing the current rank.
|
||||
|
||||
Returns:
|
||||
`torch._C._distributed_c10d.ProcessGroup`: the pytorch TP process group.
|
||||
"""
|
||||
return PYTORCHPGDICT_.get(self._tp_rank_list, 'nccl')
|
||||
|
||||
def cpu_dp_process_group(self):
|
||||
"""cpu_dp_process_group
|
||||
|
||||
the pytorch CPU DP process group containing the current rank.
|
||||
|
||||
assert failed if cpu process group is not initialized.
|
||||
|
||||
Returns:
|
||||
`torch._C._distributed_c10d.ProcessGroup`: the pytorch DP process group.
|
||||
"""
|
||||
assert self._has_cpu_groups
|
||||
return PYTORCHPGDICT_.get(self._dp_rank_list, 'gloo')
|
||||
|
||||
def cpu_tp_process_group(self):
|
||||
"""cpu_tp_process_group
|
||||
|
||||
the pytorch CPU TP process group containing the current rank.
|
||||
|
||||
assert failed if cpu process group is not initialized.
|
||||
|
||||
Returns:
|
||||
`torch._C._distributed_c10d.ProcessGroup`: the pytorch TP process group.
|
||||
"""
|
||||
assert self._has_cpu_groups
|
||||
return PYTORCHPGDICT_.get(self._tp_rank_list, 'gloo')
|
||||
|
||||
def get_ranks_in_dp(self):
|
||||
def get_ranks_in_dp(self) -> List[int]:
|
||||
"""get_ranks_in_dp
|
||||
|
||||
ranks in current dp process group.
|
||||
|
||||
Returns:
|
||||
List[int]: a list of rank number.
|
||||
"""
|
||||
return self._dp_rank_list
|
||||
|
||||
def get_ranks_in_tp(self):
|
||||
"""get_ranks_in_tp
|
||||
|
||||
ranks in current tp process group.
|
||||
|
||||
Returns:
|
||||
List[int]: a list of rank number.
|
||||
"""
|
||||
return self._tp_rank_list
|
||||
|
|
|
@ -7,6 +7,12 @@ from dataclasses import dataclass
|
|||
|
||||
@dataclass
|
||||
class ColoTensorSpec:
|
||||
""" ColoTensorSpec
|
||||
|
||||
A data class for specifications of the `ColoTensor`.
|
||||
It contains attributes of `ProcessGroup`, `_DistSpec`, `ComputeSpec`.
|
||||
The latter two attributes are optional. If not set, they are default value is `Replicate()` and `None`.
|
||||
"""
|
||||
pg: ProcessGroup
|
||||
dist_attr: Optional[_DistSpec] = _DistSpec(DistPlacementPattern.REPLICATE)
|
||||
compute_attr: Optional[ComputeSpec] = None
|
||||
|
|
Loading…
Reference in New Issue