[Doc] add more doc for ColoTensor. (#1458)

pull/1459/head
Jiarui Fang 2022-08-16 10:38:41 +08:00 committed by GitHub
parent a1476ea882
commit 36824a304c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 46 additions and 18 deletions

View File

@ -2,7 +2,7 @@ import torch
import torch.nn as nn
import operator
from colossalai.tensor import ProcessGroup
from colossalai.tensor.distspec import shard
from colossalai.tensor.distspec import ShardSpec
from colossalai.tensor.compute_spec import ComputePattern, ComputeSpec
ELEMENTWISE_MODULE_OP = [torch.nn.Dropout, torch.nn.ReLU]
@ -85,13 +85,13 @@ def transformer_mlp_pass(graph_module: torch.fx.GraphModule, process_group: Proc
for shard_type, module in annotation_record.items():
# add row sharding spec
if shard_type == 'row':
dist_spec = shard(dims=[-1], num_partitions=[world_size])
dist_spec = ShardSpec(dims=[-1], num_partitions=[world_size])
comp_spec = ComputeSpec(ComputePattern.TP1D)
setattr(module.weight, 'pg', process_group)
setattr(module.weight, 'dist_spec', dist_spec)
setattr(module.weight, 'comp_spec', comp_spec)
elif shard_type == 'col':
weight_dist_spec = shard(dims=[0], num_partitions=[world_size])
weight_dist_spec = ShardSpec(dims=[0], num_partitions=[world_size])
weight_comp_spec = ComputeSpec(ComputePattern.TP1D)
weight_comp_spec.output_replicate = False
setattr(module.weight, 'pg', process_group)
@ -99,7 +99,7 @@ def transformer_mlp_pass(graph_module: torch.fx.GraphModule, process_group: Proc
setattr(module.weight, 'comp_spec', weight_comp_spec)
if module.bias is not None:
bias_dist_spec = shard(dims=[0], num_partitions=[world_size])
bias_dist_spec = ShardSpec(dims=[0], num_partitions=[world_size])
bias_comp_spec = ComputeSpec(ComputePattern.TP1D)
bias_comp_spec.output_replicate = False
setattr(module.bias, 'pg', process_group)

View File

@ -1,7 +1,7 @@
from .process_group import ProcessGroup
from .tensor_spec import ColoTensorSpec
from .distspec import shard as ShardSpec
from .distspec import replicate as ReplicaSpec
from .distspec import ShardSpec
from .distspec import ReplicaSpec
from .compute_spec import ComputeSpec, ComputePattern
from .colo_tensor import ColoTensor
@ -13,6 +13,6 @@ from . import distspec
__all__ = [
'ColoTensor', 'convert_parameter', 'ComputePattern', 'ComputeSpec', 'named_params_with_colotensor', 'ColoParameter',
'distspec', 'DistSpecManager', 'ParamOpHook', 'ParamOpHookManager', 'ProcessGroup', 'ColoTensorSpec',
'ShardSpec', 'ReplicaSpec'
'distspec', 'DistSpecManager', 'ParamOpHook', 'ParamOpHookManager', 'ProcessGroup', 'ColoTensorSpec', 'ShardSpec',
'ReplicaSpec'
]

View File

@ -1,7 +1,7 @@
from enum import Enum
from typing import List
__all__ = ['replicate', 'shard']
__all__ = ['ReplicaSpec', 'ShardSpec']
class DistPlacementPattern(Enum):
@ -10,15 +10,22 @@ class DistPlacementPattern(Enum):
class _DistSpec:
"""_DistSpec
A class indicates Distributed Specification.
The DistSpec is only works for the tensor parallel process groups.
Because the dist spec of data parallel process group can be automatically deduced.
This is an internal data structrue.
The API for users should be `ShardSpec` and `ReplicaSpec`.
Args:
dist_placement_pattern (DistPlacementPattern): the pattern describing how tensors are distributed among processes.
The dist_placement_pattern is picked from a limited set, now including two patterns: replicate and shard.
process_group (Optional[ProcessGroup], optional): the process group contains processes. Defaults to None.
"""
def __init__(self, dist_placement_pattern: DistPlacementPattern, **meta_info):
"""_DistSpec, Distributed Specification
Args:
dist_placement_pattern (DistPlacementPattern): the pattern describing how tensors are distributed among processes.
The dist_placement_pattern is picked from a limited set, now including two patterns: replicate and shard.
process_group (Optional[ProcessGroup], optional): the process group contains processes. Defaults to None.
"""
self.placement = dist_placement_pattern
for k, v in meta_info.items():
setattr(self, k, v)
@ -39,11 +46,32 @@ class _DistSpec:
return ''.join(res_list)
def replicate() -> _DistSpec:
def ReplicaSpec() -> _DistSpec:
"""ReplicaSpec
A distributed specification represents the tensor is replicated among the tensor parallel process group.
Returns:
_DistSpec: an replicated dist spec instance.
"""
return _DistSpec(DistPlacementPattern.REPLICATE)
def shard(dims: List[int], num_partitions: List[int]) -> _DistSpec:
def ShardSpec(dims: List[int], num_partitions: List[int]) -> _DistSpec:
"""ShardSpec
A distributed specification represents the tensor is sharded among the tensor parallel process group.
Note:
Currently, only shard on one dimension is valid. In another word, dims should be of size 1.
Args:
dims (List[int]): a list of dimensions
num_partitions (List[int]): a list of partition number of each dimensions.
Returns:
_DistSpec: an shard dist spec instance.
"""
assert isinstance(dims, list) and isinstance(num_partitions, list)
assert len(dims) == len(num_partitions)
return _DistSpec(DistPlacementPattern.SHARD, dims=tuple(dims), num_partitions=tuple(num_partitions))

View File

@ -19,7 +19,7 @@ def close(num: float, other: float, rtol: float = 1e-5, atol: float = 1e-8):
def shard_param(p: ColoParameter) -> None:
pg = p.get_process_group()
p._redistribute(distspec.shard([0], [pg.tp_world_size()]))
p._redistribute(distspec.ShardSpec([0], [pg.tp_world_size()]))
p.grad = p.grad.chunk(pg.tp_world_size(), 0)[pg.tp_local_rank()].clone().detach()