mirror of https://github.com/hpcaitech/ColossalAI
[NFC] polish doc style for ColoTensor (#1457)
parent
0dbd61c29b
commit
a1476ea882
|
@ -9,6 +9,7 @@ def register_colo_graph(input_pos: List[int], param_pos: List[int]) -> Callable:
|
||||||
"""register_colo_graph
|
"""register_colo_graph
|
||||||
Register a Op (Layer) to ColoGraph.
|
Register a Op (Layer) to ColoGraph.
|
||||||
Recoders the input args in types of ColoTensor to the Graph.
|
Recoders the input args in types of ColoTensor to the Graph.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
func (Callable): a function implements the Op.
|
func (Callable): a function implements the Op.
|
||||||
|
|
||||||
|
|
|
@ -99,9 +99,11 @@ class CachedParamMgr(torch.nn.Module):
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def reorder(self, ids_freq_mapping: Optional[List[int]] = None, warmup_ratio=0.7):
|
def reorder(self, ids_freq_mapping: Optional[List[int]] = None, warmup_ratio=0.7):
|
||||||
"""reorder the weight according to ids' frequency in dataset before training.
|
"""reorder
|
||||||
|
reorder the weight according to ids' frequency in dataset before training.
|
||||||
Also Build the IndexMappingTable, aka index_mapping_table.
|
Also Build the IndexMappingTable, aka index_mapping_table.
|
||||||
Execute only once before training.
|
Execute only once before training.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
ids_freq_mapping (List[int]): a list, idx is id number, value is freq. if None no reorder
|
ids_freq_mapping (List[int]): a list, idx is id number, value is freq. if None no reorder
|
||||||
warmup_ratio (float): the amount of chunks preloaded in cuda cache
|
warmup_ratio (float): the amount of chunks preloaded in cuda cache
|
||||||
|
|
|
@ -16,8 +16,9 @@ class LimitBuffIndexCopyer(object):
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def index_copy(self, dim: int, src_index: LongTensor, tgt_index: LongTensor, src: torch.Tensor, tgt: torch.Tensor):
|
def index_copy(self, dim: int, src_index: LongTensor, tgt_index: LongTensor, src: torch.Tensor, tgt: torch.Tensor):
|
||||||
"""copy
|
"""copy
|
||||||
src tensor[src_index] -(index_select)-> tmp -()-> tgt tensor [tgt_index]
|
src tensor[src_index] -(index_select)-> tmp -(index_copy_)-> tgt tensor [tgt_index]
|
||||||
The valid part in src is continous, while in tgt is scatter.
|
The valid rows in the src tensor are continous, while rows in tgt tensor is scattered.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
dim (int): dimension along which to index
|
dim (int): dimension along which to index
|
||||||
src_index (int): indices of src tensor to select from
|
src_index (int): indices of src tensor to select from
|
||||||
|
|
|
@ -57,6 +57,7 @@ class FreqAwareEmbeddingBag(BaseEmbeddingBag):
|
||||||
Called after initialized.
|
Called after initialized.
|
||||||
Reorder the weight rows according to the ids_freq_mapping.
|
Reorder the weight rows according to the ids_freq_mapping.
|
||||||
Then, let the weights of the Module be managed by a CachedParamMgr.
|
Then, let the weights of the Module be managed by a CachedParamMgr.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
cuda_row_num (int): number of rows can be hosted in CUDA memory
|
cuda_row_num (int): number of rows can be hosted in CUDA memory
|
||||||
ids_freq_mapping (List[int]): a list, idx is id number, value is freq
|
ids_freq_mapping (List[int]): a list, idx is id number, value is freq
|
||||||
|
|
|
@ -51,31 +51,31 @@ def _get_spec_from_args(args, kwargs) -> ColoTensorSpec:
|
||||||
|
|
||||||
class ColoTensor(torch.Tensor):
|
class ColoTensor(torch.Tensor):
|
||||||
""" Data Structure for Tensor in Colossal-AI. It is a subclass of torch.Tensor.
|
""" Data Structure for Tensor in Colossal-AI. It is a subclass of torch.Tensor.
|
||||||
Args:
|
|
||||||
data (torch.Tensor): a torch tensor used as the payload the colotensor.
|
|
||||||
spec (ColoTensorSpec, optional): the tensor spec of initialization. Defaults to ColoTensorSpec(ReplicaSpec()).
|
|
||||||
|
|
||||||
The signature of the function has to be consistent with the __new__ except for the 1st arg.
|
The Colotensor can be initialized with a PyTorch tensor in the following ways.
|
||||||
The class should be initialized with a torch tensor in the following ways.
|
|
||||||
1. directly init.
|
|
||||||
>>> pg = ProcessGroup()
|
>>> pg = ProcessGroup()
|
||||||
>>> colo_t1 = ColoTensor(torch.randn(2,3), spec = ColoTensorSpec(pg, ReplicaSpec())
|
>>> colo_t1 = ColoTensor(torch.randn(2,3), spec = ColoTensorSpec(pg, ReplicaSpec())
|
||||||
>>> # If initializaed in a shard model, the tensor passed in is one shard of the global tensor.
|
>>> # The tensor passed in is a tensor after sharding but not a global tensor.
|
||||||
>>> shard_spec = ShardSpec(process_group=ProcessGroup(tp=world_size),
|
>>> shard_spec = ShardSpec(process_group=ProcessGroup(tp=world_size),
|
||||||
>>> dims=[0],
|
>>> dims=[0],
|
||||||
>>> num_partitions=[world_size])
|
>>> num_partitions=[world_size])
|
||||||
>>> tensor_spec = ColoTensorSpec(pg, shard_spec)
|
>>> tensor_spec = ColoTensorSpec(pg, shard_spec)
|
||||||
>>> colo_t2 = ColoTensor.from_torch_tensor(t_ref.clone(), tensor_spec)
|
>>> colo_t2 = ColoTensor.from_torch_tensor(t_ref.clone(), tensor_spec)
|
||||||
2. use static method from_torch_tensor
|
|
||||||
>>> colo_t = ColoTensor.from_torch_tensor(torch.randn(2,3), spec = ColoTensorSpec(pg, ReplicaSpec())
|
Args:
|
||||||
|
data (torch.Tensor): a torch tensor used as the payload the colotensor.
|
||||||
|
spec (ColoTensorSpec, optional): the tensor spec of initialization. Defaults to ColoTensorSpec(ReplicaSpec()).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __new__(cls, data: torch.Tensor, spec: ColoTensorSpec) -> 'ColoTensor':
|
def __new__(cls, data: torch.Tensor, spec: ColoTensorSpec) -> 'ColoTensor':
|
||||||
"""__new__
|
"""
|
||||||
The signature of the __new__ has to be consistent with the torch.Tensor.
|
The signature of the __new__ has to be consistent with the torch.Tensor.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
data (torch.Tensor): a torch tensor used as the payload the colotensor.
|
data (torch.Tensor): a torch tensor used as the payload the colotensor.
|
||||||
spec (TensorSpec, optional): the tensor spec of initialization.
|
spec (TensorSpec, optional): the tensor spec of initialization.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
ColoTensor: a ColoTensor wrappers the data.
|
ColoTensor: a ColoTensor wrappers the data.
|
||||||
"""
|
"""
|
||||||
|
@ -115,12 +115,10 @@ class ColoTensor(torch.Tensor):
|
||||||
"""set_process_group
|
"""set_process_group
|
||||||
change the pg of the ColoTensor. Note that the valid use cases is limited.
|
change the pg of the ColoTensor. Note that the valid use cases is limited.
|
||||||
Only existing pg is DP and dist spec is REPLICaTE is valid.
|
Only existing pg is DP and dist spec is REPLICaTE is valid.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
pg (ProcessGroup): target pg
|
pg (ProcessGroup): target pg
|
||||||
|
|
||||||
Raises:
|
|
||||||
RuntimeError:
|
|
||||||
RuntimeError:
|
|
||||||
"""
|
"""
|
||||||
assert isinstance(pg, ProcessGroup), f"pg as type {type(pg)} is invalid"
|
assert isinstance(pg, ProcessGroup), f"pg as type {type(pg)} is invalid"
|
||||||
# if the new pg is the same as the old pg, just returns
|
# if the new pg is the same as the old pg, just returns
|
||||||
|
@ -139,6 +137,7 @@ class ColoTensor(torch.Tensor):
|
||||||
def set_dist_spec(self, dist_spec: _DistSpec):
|
def set_dist_spec(self, dist_spec: _DistSpec):
|
||||||
"""set_dist_spec
|
"""set_dist_spec
|
||||||
set dist spec and change the payloads.
|
set dist spec and change the payloads.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
dist_spec (_DistSpec): target dist spec.
|
dist_spec (_DistSpec): target dist spec.
|
||||||
"""
|
"""
|
||||||
|
@ -182,6 +181,7 @@ class ColoTensor(torch.Tensor):
|
||||||
"""_redistribute
|
"""_redistribute
|
||||||
Note the function will not handle the logic of backward propagation!
|
Note the function will not handle the logic of backward propagation!
|
||||||
It is used during model tensor initializations as an internal function.
|
It is used during model tensor initializations as an internal function.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
dist_spec (_DistSpec): the target dist. spec.
|
dist_spec (_DistSpec): the target dist. spec.
|
||||||
"""
|
"""
|
||||||
|
@ -193,12 +193,14 @@ class ColoTensor(torch.Tensor):
|
||||||
def redistribute(self, dist_spec: _DistSpec, pg: Optional[ProcessGroup] = None) -> 'ColoTensor':
|
def redistribute(self, dist_spec: _DistSpec, pg: Optional[ProcessGroup] = None) -> 'ColoTensor':
|
||||||
"""redistribute
|
"""redistribute
|
||||||
Redistribute the tensor among processes. The rule is like this:
|
Redistribute the tensor among processes. The rule is like this:
|
||||||
1. If the pg is None, then redistributed tensor payload among TP process group. Keep the
|
|
||||||
DP process group still as replicated.
|
1. If the pg is None, then redistribute the tensor payload among the TP process group. Keep the
|
||||||
2. If the pg is not not None and not equal to the cureent process group.
|
DP process group not changed.
|
||||||
First, convert the tensor as replicated among TP process group.
|
|
||||||
Second, reset the process group.
|
2. If the pg is not not None and not equal to the current process group.
|
||||||
Third, conver the tensor (new replicated both among tp and dp process group) to the new dist_spec.
|
First, convert the tensor as replicated among the TP process group.
|
||||||
|
Second, reset the process group to the new pg.
|
||||||
|
Third, conver the tensor (new replicated both among the tp process group) to the new dist_spec.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
dist_spec (_DistSpec): the new dist spec.
|
dist_spec (_DistSpec): the new dist spec.
|
||||||
|
@ -219,18 +221,31 @@ class ColoTensor(torch.Tensor):
|
||||||
|
|
||||||
def to_replicate_(self):
|
def to_replicate_(self):
|
||||||
"""to_replicate_
|
"""to_replicate_
|
||||||
|
|
||||||
an inline member function, converting dist spec of the tensor to REPLICATE
|
an inline member function, converting dist spec of the tensor to REPLICATE
|
||||||
"""
|
"""
|
||||||
self._redistribute(dist_spec=ReplicaSpec())
|
self._redistribute(dist_spec=ReplicaSpec())
|
||||||
|
|
||||||
def to_replicate(self) -> 'ColoTensor':
|
def to_replicate(self) -> 'ColoTensor':
|
||||||
"""to_replicate
|
"""to_replicate
|
||||||
converting dist spec of the tensor to REPLICATE
|
|
||||||
|
converting dist spec of the tensor to ReplicaSpec()
|
||||||
"""
|
"""
|
||||||
return self.redistribute(ReplicaSpec())
|
return self.redistribute(ReplicaSpec())
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_torch_tensor(tensor: torch.Tensor, spec: Optional[ColoTensorSpec] = None) -> 'ColoTensor':
|
def from_torch_tensor(tensor: torch.Tensor, spec: Optional[ColoTensorSpec] = None) -> 'ColoTensor':
|
||||||
|
"""from_torch_tensor
|
||||||
|
|
||||||
|
A static method builds a `ColoTensor` from a PyTorch Tensor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tensor (torch.Tensor): the pytorch tensor, which is a local tensor for this rank not a global tensor.
|
||||||
|
spec (Optional[ColoTensorSpec], optional): tensor spec. Defaults to None.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ColoTensor: a ColoTensor
|
||||||
|
"""
|
||||||
tensor = tensor.as_subclass(ColoTensor)
|
tensor = tensor.as_subclass(ColoTensor)
|
||||||
tensor.__init__(tensor, spec=spec)
|
tensor.__init__(tensor, spec=spec)
|
||||||
return tensor
|
return tensor
|
||||||
|
@ -252,10 +267,13 @@ class ColoTensor(torch.Tensor):
|
||||||
return super().size(*args)
|
return super().size(*args)
|
||||||
|
|
||||||
def size_global(self, *args) -> torch.Size:
|
def size_global(self, *args) -> torch.Size:
|
||||||
"""override the torch buildin size()
|
"""size_global
|
||||||
|
|
||||||
|
override the torch buildin size()
|
||||||
the shape passed in must be in a replicate placement.
|
the shape passed in must be in a replicate placement.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
ColoTensor: a tensor after viewed.
|
torch.Size: the global tensor shape
|
||||||
"""
|
"""
|
||||||
if self.is_replicate():
|
if self.is_replicate():
|
||||||
return self.size_local(*args)
|
return self.size_local(*args)
|
||||||
|
|
|
@ -11,6 +11,7 @@ class ComputePattern(Enum):
|
||||||
class ComputeSpec(object):
|
class ComputeSpec(object):
|
||||||
"""ComputeSpec
|
"""ComputeSpec
|
||||||
The Specification for compuattion pattern
|
The Specification for compuattion pattern
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
compute_pattern (ComputePattern): an Enum instance for compute pattern.
|
compute_pattern (ComputePattern): an Enum instance for compute pattern.
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import List, Optional
|
from typing import List
|
||||||
|
|
||||||
__all__ = ['replicate', 'shard']
|
__all__ = ['replicate', 'shard']
|
||||||
|
|
||||||
|
|
|
@ -30,10 +30,13 @@ PYTORCHPGDICT_ = PyTorchProcessGroupDict()
|
||||||
|
|
||||||
|
|
||||||
class ProcessGroup:
|
class ProcessGroup:
|
||||||
"""
|
"""ProcessGroup
|
||||||
Process Group contains group partition for Tensor Parallel and Data Parallel.
|
Process Group contains group partition for Tensor Parallel and Data Parallel.
|
||||||
NOTE, the ProcessGroup must be used after torch.distributed.initialize()
|
|
||||||
args:
|
NOTE, the ProcessGroup must be used after `torch.distributed.initialize()`
|
||||||
|
|
||||||
|
|
||||||
|
Args:
|
||||||
rank: the global rank of the current process.
|
rank: the global rank of the current process.
|
||||||
ranks: List[int], a list of rank id belongings to this process group.
|
ranks: List[int], a list of rank id belongings to this process group.
|
||||||
backend: str, the backend of the process group.
|
backend: str, the backend of the process group.
|
||||||
|
@ -101,6 +104,9 @@ class ProcessGroup:
|
||||||
self.is_init = True
|
self.is_init = True
|
||||||
|
|
||||||
def set_cpu_groups(self):
|
def set_cpu_groups(self):
|
||||||
|
"""set_cpu_groups
|
||||||
|
Initialize Pytorch process groups for cpu communications.
|
||||||
|
"""
|
||||||
if self.has_cpu_groups:
|
if self.has_cpu_groups:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
@ -115,7 +121,13 @@ class ProcessGroup:
|
||||||
self._has_cpu_groups = True
|
self._has_cpu_groups = True
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def has_cpu_groups(self):
|
def has_cpu_groups(self) -> bool:
|
||||||
|
"""has_cpu_groups
|
||||||
|
If cpu groups have been initailized.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: cpu process groups have been initialized or not.
|
||||||
|
"""
|
||||||
return self._has_cpu_groups
|
return self._has_cpu_groups
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
|
@ -142,51 +154,158 @@ class ProcessGroup:
|
||||||
return False
|
return False
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def rank(self):
|
def rank(self) -> int:
|
||||||
|
"""rank
|
||||||
|
|
||||||
|
The current rank in the global process group.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
int: the rank number
|
||||||
|
"""
|
||||||
return self._rank
|
return self._rank
|
||||||
|
|
||||||
def ranks_in_group(self):
|
def ranks_in_group(self) -> List[int]:
|
||||||
|
"""ranks_in_group
|
||||||
|
|
||||||
|
a list of rank number in in the global process group.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[int]: a list of rank number.
|
||||||
|
"""
|
||||||
return self._rank_list
|
return self._rank_list
|
||||||
|
|
||||||
def world_size(self):
|
def world_size(self) -> int:
|
||||||
|
"""world_size
|
||||||
|
|
||||||
|
The world size of the global process group.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
int: world size
|
||||||
|
"""
|
||||||
return self._world_size
|
return self._world_size
|
||||||
|
|
||||||
def tp_rank_list(self):
|
def tp_rank_list(self) -> List[int]:
|
||||||
|
"""tp_rank_list
|
||||||
|
|
||||||
|
the rank list in the TP process group containing the current rank.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[int]: the list of rank number.
|
||||||
|
"""
|
||||||
return self._tp_rank_list
|
return self._tp_rank_list
|
||||||
|
|
||||||
def dp_rank_list(self):
|
def dp_rank_list(self) -> List[int]:
|
||||||
|
"""dp_rank_list
|
||||||
|
|
||||||
|
the rank list in the DP process group containing the current rank.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[int]: the list of rank number.
|
||||||
|
"""
|
||||||
return self._dp_rank_list
|
return self._dp_rank_list
|
||||||
|
|
||||||
def tp_local_rank(self):
|
def tp_local_rank(self) -> int:
|
||||||
|
"""tp_local_rank
|
||||||
|
|
||||||
|
The local rank number in the current TP process group.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
int: tp rank number.
|
||||||
|
"""
|
||||||
return self._rank % self._tp_degree
|
return self._rank % self._tp_degree
|
||||||
|
|
||||||
def dp_local_rank(self):
|
def dp_local_rank(self) -> int:
|
||||||
|
"""dp_local_rank
|
||||||
|
|
||||||
|
The local rank number in the current DP process group.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
int: dp rank number.
|
||||||
|
"""
|
||||||
return self._rank // self._tp_degree
|
return self._rank // self._tp_degree
|
||||||
|
|
||||||
def dp_world_size(self):
|
def dp_world_size(self) -> int:
|
||||||
|
"""dp_world_size
|
||||||
|
|
||||||
|
The world size of the current DP process group.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
int: dp world size
|
||||||
|
"""
|
||||||
return len(self._dp_rank_list)
|
return len(self._dp_rank_list)
|
||||||
|
|
||||||
def tp_world_size(self):
|
def tp_world_size(self) -> int:
|
||||||
|
"""tp_world_size
|
||||||
|
|
||||||
|
The world size of the current TP process group.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
int: tp world size
|
||||||
|
"""
|
||||||
return len(self._tp_rank_list)
|
return len(self._tp_rank_list)
|
||||||
|
|
||||||
def dp_process_group(self):
|
def dp_process_group(self):
|
||||||
# return self._dp_process_group
|
"""dp_process_group
|
||||||
|
|
||||||
|
the pytorch DP process group containing the current rank.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`torch._C._distributed_c10d.ProcessGroup`: the pytorch DP process group.
|
||||||
|
"""
|
||||||
return PYTORCHPGDICT_.get(self._dp_rank_list, 'nccl')
|
return PYTORCHPGDICT_.get(self._dp_rank_list, 'nccl')
|
||||||
|
|
||||||
def tp_process_group(self):
|
def tp_process_group(self):
|
||||||
# return self._tp_process_group
|
"""tp_process_group
|
||||||
|
|
||||||
|
the pytorch TP process group containing the current rank.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`torch._C._distributed_c10d.ProcessGroup`: the pytorch TP process group.
|
||||||
|
"""
|
||||||
return PYTORCHPGDICT_.get(self._tp_rank_list, 'nccl')
|
return PYTORCHPGDICT_.get(self._tp_rank_list, 'nccl')
|
||||||
|
|
||||||
def cpu_dp_process_group(self):
|
def cpu_dp_process_group(self):
|
||||||
|
"""cpu_dp_process_group
|
||||||
|
|
||||||
|
the pytorch CPU DP process group containing the current rank.
|
||||||
|
|
||||||
|
assert failed if cpu process group is not initialized.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`torch._C._distributed_c10d.ProcessGroup`: the pytorch DP process group.
|
||||||
|
"""
|
||||||
assert self._has_cpu_groups
|
assert self._has_cpu_groups
|
||||||
return PYTORCHPGDICT_.get(self._dp_rank_list, 'gloo')
|
return PYTORCHPGDICT_.get(self._dp_rank_list, 'gloo')
|
||||||
|
|
||||||
def cpu_tp_process_group(self):
|
def cpu_tp_process_group(self):
|
||||||
|
"""cpu_tp_process_group
|
||||||
|
|
||||||
|
the pytorch CPU TP process group containing the current rank.
|
||||||
|
|
||||||
|
assert failed if cpu process group is not initialized.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`torch._C._distributed_c10d.ProcessGroup`: the pytorch TP process group.
|
||||||
|
"""
|
||||||
assert self._has_cpu_groups
|
assert self._has_cpu_groups
|
||||||
return PYTORCHPGDICT_.get(self._tp_rank_list, 'gloo')
|
return PYTORCHPGDICT_.get(self._tp_rank_list, 'gloo')
|
||||||
|
|
||||||
def get_ranks_in_dp(self):
|
def get_ranks_in_dp(self) -> List[int]:
|
||||||
|
"""get_ranks_in_dp
|
||||||
|
|
||||||
|
ranks in current dp process group.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[int]: a list of rank number.
|
||||||
|
"""
|
||||||
return self._dp_rank_list
|
return self._dp_rank_list
|
||||||
|
|
||||||
def get_ranks_in_tp(self):
|
def get_ranks_in_tp(self):
|
||||||
|
"""get_ranks_in_tp
|
||||||
|
|
||||||
|
ranks in current tp process group.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[int]: a list of rank number.
|
||||||
|
"""
|
||||||
return self._tp_rank_list
|
return self._tp_rank_list
|
||||||
|
|
|
@ -7,6 +7,12 @@ from dataclasses import dataclass
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ColoTensorSpec:
|
class ColoTensorSpec:
|
||||||
|
""" ColoTensorSpec
|
||||||
|
|
||||||
|
A data class for specifications of the `ColoTensor`.
|
||||||
|
It contains attributes of `ProcessGroup`, `_DistSpec`, `ComputeSpec`.
|
||||||
|
The latter two attributes are optional. If not set, they are default value is `Replicate()` and `None`.
|
||||||
|
"""
|
||||||
pg: ProcessGroup
|
pg: ProcessGroup
|
||||||
dist_attr: Optional[_DistSpec] = _DistSpec(DistPlacementPattern.REPLICATE)
|
dist_attr: Optional[_DistSpec] = _DistSpec(DistPlacementPattern.REPLICATE)
|
||||||
compute_attr: Optional[ComputeSpec] = None
|
compute_attr: Optional[ComputeSpec] = None
|
||||||
|
|
Loading…
Reference in New Issue