[Doc] add more doc for ColoTensor. (#1458)

pull/1459/head
Jiarui Fang 2022-08-16 10:38:41 +08:00 committed by GitHub
parent a1476ea882
commit 36824a304c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 46 additions and 18 deletions

View File

@ -2,7 +2,7 @@ import torch
import torch.nn as nn import torch.nn as nn
import operator import operator
from colossalai.tensor import ProcessGroup from colossalai.tensor import ProcessGroup
from colossalai.tensor.distspec import shard from colossalai.tensor.distspec import ShardSpec
from colossalai.tensor.compute_spec import ComputePattern, ComputeSpec from colossalai.tensor.compute_spec import ComputePattern, ComputeSpec
ELEMENTWISE_MODULE_OP = [torch.nn.Dropout, torch.nn.ReLU] ELEMENTWISE_MODULE_OP = [torch.nn.Dropout, torch.nn.ReLU]
@ -85,13 +85,13 @@ def transformer_mlp_pass(graph_module: torch.fx.GraphModule, process_group: Proc
for shard_type, module in annotation_record.items(): for shard_type, module in annotation_record.items():
# add row sharding spec # add row sharding spec
if shard_type == 'row': if shard_type == 'row':
dist_spec = shard(dims=[-1], num_partitions=[world_size]) dist_spec = ShardSpec(dims=[-1], num_partitions=[world_size])
comp_spec = ComputeSpec(ComputePattern.TP1D) comp_spec = ComputeSpec(ComputePattern.TP1D)
setattr(module.weight, 'pg', process_group) setattr(module.weight, 'pg', process_group)
setattr(module.weight, 'dist_spec', dist_spec) setattr(module.weight, 'dist_spec', dist_spec)
setattr(module.weight, 'comp_spec', comp_spec) setattr(module.weight, 'comp_spec', comp_spec)
elif shard_type == 'col': elif shard_type == 'col':
weight_dist_spec = shard(dims=[0], num_partitions=[world_size]) weight_dist_spec = ShardSpec(dims=[0], num_partitions=[world_size])
weight_comp_spec = ComputeSpec(ComputePattern.TP1D) weight_comp_spec = ComputeSpec(ComputePattern.TP1D)
weight_comp_spec.output_replicate = False weight_comp_spec.output_replicate = False
setattr(module.weight, 'pg', process_group) setattr(module.weight, 'pg', process_group)
@ -99,7 +99,7 @@ def transformer_mlp_pass(graph_module: torch.fx.GraphModule, process_group: Proc
setattr(module.weight, 'comp_spec', weight_comp_spec) setattr(module.weight, 'comp_spec', weight_comp_spec)
if module.bias is not None: if module.bias is not None:
bias_dist_spec = shard(dims=[0], num_partitions=[world_size]) bias_dist_spec = ShardSpec(dims=[0], num_partitions=[world_size])
bias_comp_spec = ComputeSpec(ComputePattern.TP1D) bias_comp_spec = ComputeSpec(ComputePattern.TP1D)
bias_comp_spec.output_replicate = False bias_comp_spec.output_replicate = False
setattr(module.bias, 'pg', process_group) setattr(module.bias, 'pg', process_group)

View File

@ -1,7 +1,7 @@
from .process_group import ProcessGroup from .process_group import ProcessGroup
from .tensor_spec import ColoTensorSpec from .tensor_spec import ColoTensorSpec
from .distspec import shard as ShardSpec from .distspec import ShardSpec
from .distspec import replicate as ReplicaSpec from .distspec import ReplicaSpec
from .compute_spec import ComputeSpec, ComputePattern from .compute_spec import ComputeSpec, ComputePattern
from .colo_tensor import ColoTensor from .colo_tensor import ColoTensor
@ -13,6 +13,6 @@ from . import distspec
__all__ = [ __all__ = [
'ColoTensor', 'convert_parameter', 'ComputePattern', 'ComputeSpec', 'named_params_with_colotensor', 'ColoParameter', 'ColoTensor', 'convert_parameter', 'ComputePattern', 'ComputeSpec', 'named_params_with_colotensor', 'ColoParameter',
'distspec', 'DistSpecManager', 'ParamOpHook', 'ParamOpHookManager', 'ProcessGroup', 'ColoTensorSpec', 'distspec', 'DistSpecManager', 'ParamOpHook', 'ParamOpHookManager', 'ProcessGroup', 'ColoTensorSpec', 'ShardSpec',
'ShardSpec', 'ReplicaSpec' 'ReplicaSpec'
] ]

View File

@ -1,7 +1,7 @@
from enum import Enum from enum import Enum
from typing import List from typing import List
__all__ = ['replicate', 'shard'] __all__ = ['ReplicaSpec', 'ShardSpec']
class DistPlacementPattern(Enum): class DistPlacementPattern(Enum):
@ -10,15 +10,22 @@ class DistPlacementPattern(Enum):
class _DistSpec: class _DistSpec:
"""_DistSpec
def __init__(self, dist_placement_pattern: DistPlacementPattern, **meta_info): A class indicates Distributed Specification.
"""_DistSpec, Distributed Specification The DistSpec is only works for the tensor parallel process groups.
Because the dist spec of data parallel process group can be automatically deduced.
This is an internal data structrue.
The API for users should be `ShardSpec` and `ReplicaSpec`.
Args: Args:
dist_placement_pattern (DistPlacementPattern): the pattern describing how tensors are distributed among processes. dist_placement_pattern (DistPlacementPattern): the pattern describing how tensors are distributed among processes.
The dist_placement_pattern is picked from a limited set, now including two patterns: replicate and shard. The dist_placement_pattern is picked from a limited set, now including two patterns: replicate and shard.
process_group (Optional[ProcessGroup], optional): the process group contains processes. Defaults to None. process_group (Optional[ProcessGroup], optional): the process group contains processes. Defaults to None.
""" """
def __init__(self, dist_placement_pattern: DistPlacementPattern, **meta_info):
self.placement = dist_placement_pattern self.placement = dist_placement_pattern
for k, v in meta_info.items(): for k, v in meta_info.items():
setattr(self, k, v) setattr(self, k, v)
@ -39,11 +46,32 @@ class _DistSpec:
return ''.join(res_list) return ''.join(res_list)
def replicate() -> _DistSpec: def ReplicaSpec() -> _DistSpec:
"""ReplicaSpec
A distributed specification represents the tensor is replicated among the tensor parallel process group.
Returns:
_DistSpec: an replicated dist spec instance.
"""
return _DistSpec(DistPlacementPattern.REPLICATE) return _DistSpec(DistPlacementPattern.REPLICATE)
def shard(dims: List[int], num_partitions: List[int]) -> _DistSpec: def ShardSpec(dims: List[int], num_partitions: List[int]) -> _DistSpec:
"""ShardSpec
A distributed specification represents the tensor is sharded among the tensor parallel process group.
Note:
Currently, only shard on one dimension is valid. In another word, dims should be of size 1.
Args:
dims (List[int]): a list of dimensions
num_partitions (List[int]): a list of partition number of each dimensions.
Returns:
_DistSpec: an shard dist spec instance.
"""
assert isinstance(dims, list) and isinstance(num_partitions, list) assert isinstance(dims, list) and isinstance(num_partitions, list)
assert len(dims) == len(num_partitions) assert len(dims) == len(num_partitions)
return _DistSpec(DistPlacementPattern.SHARD, dims=tuple(dims), num_partitions=tuple(num_partitions)) return _DistSpec(DistPlacementPattern.SHARD, dims=tuple(dims), num_partitions=tuple(num_partitions))

View File

@ -19,7 +19,7 @@ def close(num: float, other: float, rtol: float = 1e-5, atol: float = 1e-8):
def shard_param(p: ColoParameter) -> None: def shard_param(p: ColoParameter) -> None:
pg = p.get_process_group() pg = p.get_process_group()
p._redistribute(distspec.shard([0], [pg.tp_world_size()])) p._redistribute(distspec.ShardSpec([0], [pg.tp_world_size()]))
p.grad = p.grad.chunk(pg.tp_world_size(), 0)[pg.tp_local_rank()].clone().detach() p.grad = p.grad.chunk(pg.tp_world_size(), 0)[pg.tp_local_rank()].clone().detach()