[NFC] polish doc style for ColoTensor (#1457)

pull/1458/head^2
Jiarui Fang 2022-08-16 09:21:05 +08:00 committed by GitHub
parent 0dbd61c29b
commit a1476ea882
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 197 additions and 48 deletions

View File

@ -9,6 +9,7 @@ def register_colo_graph(input_pos: List[int], param_pos: List[int]) -> Callable:
"""register_colo_graph
Register a Op (Layer) to ColoGraph.
Recoders the input args in types of ColoTensor to the Graph.
Args:
func (Callable): a function implements the Op.

View File

@ -99,9 +99,11 @@ class CachedParamMgr(torch.nn.Module):
@torch.no_grad()
def reorder(self, ids_freq_mapping: Optional[List[int]] = None, warmup_ratio=0.7):
"""reorder the weight according to ids' frequency in dataset before training.
"""reorder
reorder the weight according to ids' frequency in dataset before training.
Also Build the IndexMappingTable, aka index_mapping_table.
Execute only once before training.
Args:
ids_freq_mapping (List[int]): a list, idx is id number, value is freq. if None no reorder
warmup_ratio (float): the amount of chunks preloaded in cuda cache

View File

@ -16,8 +16,9 @@ class LimitBuffIndexCopyer(object):
@torch.no_grad()
def index_copy(self, dim: int, src_index: LongTensor, tgt_index: LongTensor, src: torch.Tensor, tgt: torch.Tensor):
"""copy
src tensor[src_index] -(index_select)-> tmp -()-> tgt tensor [tgt_index]
The valid part in src is continous, while in tgt is scatter.
src tensor[src_index] -(index_select)-> tmp -(index_copy_)-> tgt tensor [tgt_index]
The valid rows in the src tensor are continous, while rows in tgt tensor is scattered.
Args:
dim (int): dimension along which to index
src_index (int): indices of src tensor to select from

View File

@ -57,6 +57,7 @@ class FreqAwareEmbeddingBag(BaseEmbeddingBag):
Called after initialized.
Reorder the weight rows according to the ids_freq_mapping.
Then, let the weights of the Module be managed by a CachedParamMgr.
Args:
cuda_row_num (int): number of rows can be hosted in CUDA memory
ids_freq_mapping (List[int]): a list, idx is id number, value is freq

View File

@ -51,31 +51,31 @@ def _get_spec_from_args(args, kwargs) -> ColoTensorSpec:
class ColoTensor(torch.Tensor):
""" Data Structure for Tensor in Colossal-AI. It is a subclass of torch.Tensor.
Args:
data (torch.Tensor): a torch tensor used as the payload the colotensor.
spec (ColoTensorSpec, optional): the tensor spec of initialization. Defaults to ColoTensorSpec(ReplicaSpec()).
The signature of the function has to be consistent with the __new__ except for the 1st arg.
The class should be initialized with a torch tensor in the following ways.
1. directly init.
The Colotensor can be initialized with a PyTorch tensor in the following ways.
>>> pg = ProcessGroup()
>>> colo_t1 = ColoTensor(torch.randn(2,3), spec = ColoTensorSpec(pg, ReplicaSpec())
>>> # If initializaed in a shard model, the tensor passed in is one shard of the global tensor.
>>> # The tensor passed in is a tensor after sharding but not a global tensor.
>>> shard_spec = ShardSpec(process_group=ProcessGroup(tp=world_size),
>>> dims=[0],
>>> num_partitions=[world_size])
>>> tensor_spec = ColoTensorSpec(pg, shard_spec)
>>> colo_t2 = ColoTensor.from_torch_tensor(t_ref.clone(), tensor_spec)
2. use static method from_torch_tensor
>>> colo_t = ColoTensor.from_torch_tensor(torch.randn(2,3), spec = ColoTensorSpec(pg, ReplicaSpec())
Args:
data (torch.Tensor): a torch tensor used as the payload the colotensor.
spec (ColoTensorSpec, optional): the tensor spec of initialization. Defaults to ColoTensorSpec(ReplicaSpec()).
"""
def __new__(cls, data: torch.Tensor, spec: ColoTensorSpec) -> 'ColoTensor':
"""__new__
"""
The signature of the __new__ has to be consistent with the torch.Tensor.
Args:
data (torch.Tensor): a torch tensor used as the payload the colotensor.
spec (TensorSpec, optional): the tensor spec of initialization.
Returns:
ColoTensor: a ColoTensor wrappers the data.
"""
@ -115,12 +115,10 @@ class ColoTensor(torch.Tensor):
"""set_process_group
change the pg of the ColoTensor. Note that the valid use cases is limited.
Only existing pg is DP and dist spec is REPLICaTE is valid.
Args:
pg (ProcessGroup): target pg
Raises:
RuntimeError:
RuntimeError:
"""
assert isinstance(pg, ProcessGroup), f"pg as type {type(pg)} is invalid"
# if the new pg is the same as the old pg, just returns
@ -139,6 +137,7 @@ class ColoTensor(torch.Tensor):
def set_dist_spec(self, dist_spec: _DistSpec):
"""set_dist_spec
set dist spec and change the payloads.
Args:
dist_spec (_DistSpec): target dist spec.
"""
@ -182,6 +181,7 @@ class ColoTensor(torch.Tensor):
"""_redistribute
Note the function will not handle the logic of backward propagation!
It is used during model tensor initializations as an internal function.
Args:
dist_spec (_DistSpec): the target dist. spec.
"""
@ -193,12 +193,14 @@ class ColoTensor(torch.Tensor):
def redistribute(self, dist_spec: _DistSpec, pg: Optional[ProcessGroup] = None) -> 'ColoTensor':
"""redistribute
Redistribute the tensor among processes. The rule is like this:
1. If the pg is None, then redistributed tensor payload among TP process group. Keep the
DP process group still as replicated.
2. If the pg is not not None and not equal to the cureent process group.
First, convert the tensor as replicated among TP process group.
Second, reset the process group.
Third, conver the tensor (new replicated both among tp and dp process group) to the new dist_spec.
1. If the pg is None, then redistribute the tensor payload among the TP process group. Keep the
DP process group not changed.
2. If the pg is not not None and not equal to the current process group.
First, convert the tensor as replicated among the TP process group.
Second, reset the process group to the new pg.
Third, conver the tensor (new replicated both among the tp process group) to the new dist_spec.
Args:
dist_spec (_DistSpec): the new dist spec.
@ -219,18 +221,31 @@ class ColoTensor(torch.Tensor):
def to_replicate_(self):
"""to_replicate_
an inline member function, converting dist spec of the tensor to REPLICATE
"""
self._redistribute(dist_spec=ReplicaSpec())
def to_replicate(self) -> 'ColoTensor':
"""to_replicate
converting dist spec of the tensor to REPLICATE
converting dist spec of the tensor to ReplicaSpec()
"""
return self.redistribute(ReplicaSpec())
@staticmethod
def from_torch_tensor(tensor: torch.Tensor, spec: Optional[ColoTensorSpec] = None) -> 'ColoTensor':
"""from_torch_tensor
A static method builds a `ColoTensor` from a PyTorch Tensor.
Args:
tensor (torch.Tensor): the pytorch tensor, which is a local tensor for this rank not a global tensor.
spec (Optional[ColoTensorSpec], optional): tensor spec. Defaults to None.
Returns:
ColoTensor: a ColoTensor
"""
tensor = tensor.as_subclass(ColoTensor)
tensor.__init__(tensor, spec=spec)
return tensor
@ -252,10 +267,13 @@ class ColoTensor(torch.Tensor):
return super().size(*args)
def size_global(self, *args) -> torch.Size:
"""override the torch buildin size()
"""size_global
override the torch buildin size()
the shape passed in must be in a replicate placement.
Returns:
ColoTensor: a tensor after viewed.
torch.Size: the global tensor shape
"""
if self.is_replicate():
return self.size_local(*args)

View File

@ -11,6 +11,7 @@ class ComputePattern(Enum):
class ComputeSpec(object):
"""ComputeSpec
The Specification for compuattion pattern
Args:
compute_pattern (ComputePattern): an Enum instance for compute pattern.
"""

View File

@ -1,5 +1,5 @@
from enum import Enum
from typing import List, Optional
from typing import List
__all__ = ['replicate', 'shard']

View File

@ -30,10 +30,13 @@ PYTORCHPGDICT_ = PyTorchProcessGroupDict()
class ProcessGroup:
"""
"""ProcessGroup
Process Group contains group partition for Tensor Parallel and Data Parallel.
NOTE, the ProcessGroup must be used after torch.distributed.initialize()
args:
NOTE, the ProcessGroup must be used after `torch.distributed.initialize()`
Args:
rank: the global rank of the current process.
ranks: List[int], a list of rank id belongings to this process group.
backend: str, the backend of the process group.
@ -101,6 +104,9 @@ class ProcessGroup:
self.is_init = True
def set_cpu_groups(self):
"""set_cpu_groups
Initialize Pytorch process groups for cpu communications.
"""
if self.has_cpu_groups:
return
@ -115,7 +121,13 @@ class ProcessGroup:
self._has_cpu_groups = True
@property
def has_cpu_groups(self):
def has_cpu_groups(self) -> bool:
"""has_cpu_groups
If cpu groups have been initailized.
Returns:
bool: cpu process groups have been initialized or not.
"""
return self._has_cpu_groups
def __repr__(self):
@ -142,51 +154,158 @@ class ProcessGroup:
return False
return True
def rank(self):
def rank(self) -> int:
"""rank
The current rank in the global process group.
Returns:
int: the rank number
"""
return self._rank
def ranks_in_group(self):
def ranks_in_group(self) -> List[int]:
"""ranks_in_group
a list of rank number in in the global process group.
Returns:
List[int]: a list of rank number.
"""
return self._rank_list
def world_size(self):
def world_size(self) -> int:
"""world_size
The world size of the global process group.
Returns:
int: world size
"""
return self._world_size
def tp_rank_list(self):
def tp_rank_list(self) -> List[int]:
"""tp_rank_list
the rank list in the TP process group containing the current rank.
Returns:
List[int]: the list of rank number.
"""
return self._tp_rank_list
def dp_rank_list(self):
def dp_rank_list(self) -> List[int]:
"""dp_rank_list
the rank list in the DP process group containing the current rank.
Returns:
List[int]: the list of rank number.
"""
return self._dp_rank_list
def tp_local_rank(self):
def tp_local_rank(self) -> int:
"""tp_local_rank
The local rank number in the current TP process group.
Returns:
int: tp rank number.
"""
return self._rank % self._tp_degree
def dp_local_rank(self):
def dp_local_rank(self) -> int:
"""dp_local_rank
The local rank number in the current DP process group.
Returns:
int: dp rank number.
"""
return self._rank // self._tp_degree
def dp_world_size(self):
def dp_world_size(self) -> int:
"""dp_world_size
The world size of the current DP process group.
Returns:
int: dp world size
"""
return len(self._dp_rank_list)
def tp_world_size(self):
def tp_world_size(self) -> int:
"""tp_world_size
The world size of the current TP process group.
Returns:
int: tp world size
"""
return len(self._tp_rank_list)
def dp_process_group(self):
# return self._dp_process_group
"""dp_process_group
the pytorch DP process group containing the current rank.
Returns:
`torch._C._distributed_c10d.ProcessGroup`: the pytorch DP process group.
"""
return PYTORCHPGDICT_.get(self._dp_rank_list, 'nccl')
def tp_process_group(self):
# return self._tp_process_group
"""tp_process_group
the pytorch TP process group containing the current rank.
Returns:
`torch._C._distributed_c10d.ProcessGroup`: the pytorch TP process group.
"""
return PYTORCHPGDICT_.get(self._tp_rank_list, 'nccl')
def cpu_dp_process_group(self):
"""cpu_dp_process_group
the pytorch CPU DP process group containing the current rank.
assert failed if cpu process group is not initialized.
Returns:
`torch._C._distributed_c10d.ProcessGroup`: the pytorch DP process group.
"""
assert self._has_cpu_groups
return PYTORCHPGDICT_.get(self._dp_rank_list, 'gloo')
def cpu_tp_process_group(self):
"""cpu_tp_process_group
the pytorch CPU TP process group containing the current rank.
assert failed if cpu process group is not initialized.
Returns:
`torch._C._distributed_c10d.ProcessGroup`: the pytorch TP process group.
"""
assert self._has_cpu_groups
return PYTORCHPGDICT_.get(self._tp_rank_list, 'gloo')
def get_ranks_in_dp(self):
def get_ranks_in_dp(self) -> List[int]:
"""get_ranks_in_dp
ranks in current dp process group.
Returns:
List[int]: a list of rank number.
"""
return self._dp_rank_list
def get_ranks_in_tp(self):
"""get_ranks_in_tp
ranks in current tp process group.
Returns:
List[int]: a list of rank number.
"""
return self._tp_rank_list

View File

@ -7,6 +7,12 @@ from dataclasses import dataclass
@dataclass
class ColoTensorSpec:
""" ColoTensorSpec
A data class for specifications of the `ColoTensor`.
It contains attributes of `ProcessGroup`, `_DistSpec`, `ComputeSpec`.
The latter two attributes are optional. If not set, they are default value is `Replicate()` and `None`.
"""
pg: ProcessGroup
dist_attr: Optional[_DistSpec] = _DistSpec(DistPlacementPattern.REPLICATE)
compute_attr: Optional[ComputeSpec] = None