mirror of https://github.com/hpcaitech/ColossalAI
[Doc] add more doc for ColoTensor. (#1458)
parent
a1476ea882
commit
36824a304c
|
@ -2,7 +2,7 @@ import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import operator
|
import operator
|
||||||
from colossalai.tensor import ProcessGroup
|
from colossalai.tensor import ProcessGroup
|
||||||
from colossalai.tensor.distspec import shard
|
from colossalai.tensor.distspec import ShardSpec
|
||||||
from colossalai.tensor.compute_spec import ComputePattern, ComputeSpec
|
from colossalai.tensor.compute_spec import ComputePattern, ComputeSpec
|
||||||
|
|
||||||
ELEMENTWISE_MODULE_OP = [torch.nn.Dropout, torch.nn.ReLU]
|
ELEMENTWISE_MODULE_OP = [torch.nn.Dropout, torch.nn.ReLU]
|
||||||
|
@ -85,13 +85,13 @@ def transformer_mlp_pass(graph_module: torch.fx.GraphModule, process_group: Proc
|
||||||
for shard_type, module in annotation_record.items():
|
for shard_type, module in annotation_record.items():
|
||||||
# add row sharding spec
|
# add row sharding spec
|
||||||
if shard_type == 'row':
|
if shard_type == 'row':
|
||||||
dist_spec = shard(dims=[-1], num_partitions=[world_size])
|
dist_spec = ShardSpec(dims=[-1], num_partitions=[world_size])
|
||||||
comp_spec = ComputeSpec(ComputePattern.TP1D)
|
comp_spec = ComputeSpec(ComputePattern.TP1D)
|
||||||
setattr(module.weight, 'pg', process_group)
|
setattr(module.weight, 'pg', process_group)
|
||||||
setattr(module.weight, 'dist_spec', dist_spec)
|
setattr(module.weight, 'dist_spec', dist_spec)
|
||||||
setattr(module.weight, 'comp_spec', comp_spec)
|
setattr(module.weight, 'comp_spec', comp_spec)
|
||||||
elif shard_type == 'col':
|
elif shard_type == 'col':
|
||||||
weight_dist_spec = shard(dims=[0], num_partitions=[world_size])
|
weight_dist_spec = ShardSpec(dims=[0], num_partitions=[world_size])
|
||||||
weight_comp_spec = ComputeSpec(ComputePattern.TP1D)
|
weight_comp_spec = ComputeSpec(ComputePattern.TP1D)
|
||||||
weight_comp_spec.output_replicate = False
|
weight_comp_spec.output_replicate = False
|
||||||
setattr(module.weight, 'pg', process_group)
|
setattr(module.weight, 'pg', process_group)
|
||||||
|
@ -99,7 +99,7 @@ def transformer_mlp_pass(graph_module: torch.fx.GraphModule, process_group: Proc
|
||||||
setattr(module.weight, 'comp_spec', weight_comp_spec)
|
setattr(module.weight, 'comp_spec', weight_comp_spec)
|
||||||
|
|
||||||
if module.bias is not None:
|
if module.bias is not None:
|
||||||
bias_dist_spec = shard(dims=[0], num_partitions=[world_size])
|
bias_dist_spec = ShardSpec(dims=[0], num_partitions=[world_size])
|
||||||
bias_comp_spec = ComputeSpec(ComputePattern.TP1D)
|
bias_comp_spec = ComputeSpec(ComputePattern.TP1D)
|
||||||
bias_comp_spec.output_replicate = False
|
bias_comp_spec.output_replicate = False
|
||||||
setattr(module.bias, 'pg', process_group)
|
setattr(module.bias, 'pg', process_group)
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
from .process_group import ProcessGroup
|
from .process_group import ProcessGroup
|
||||||
from .tensor_spec import ColoTensorSpec
|
from .tensor_spec import ColoTensorSpec
|
||||||
from .distspec import shard as ShardSpec
|
from .distspec import ShardSpec
|
||||||
from .distspec import replicate as ReplicaSpec
|
from .distspec import ReplicaSpec
|
||||||
|
|
||||||
from .compute_spec import ComputeSpec, ComputePattern
|
from .compute_spec import ComputeSpec, ComputePattern
|
||||||
from .colo_tensor import ColoTensor
|
from .colo_tensor import ColoTensor
|
||||||
|
@ -13,6 +13,6 @@ from . import distspec
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'ColoTensor', 'convert_parameter', 'ComputePattern', 'ComputeSpec', 'named_params_with_colotensor', 'ColoParameter',
|
'ColoTensor', 'convert_parameter', 'ComputePattern', 'ComputeSpec', 'named_params_with_colotensor', 'ColoParameter',
|
||||||
'distspec', 'DistSpecManager', 'ParamOpHook', 'ParamOpHookManager', 'ProcessGroup', 'ColoTensorSpec',
|
'distspec', 'DistSpecManager', 'ParamOpHook', 'ParamOpHookManager', 'ProcessGroup', 'ColoTensorSpec', 'ShardSpec',
|
||||||
'ShardSpec', 'ReplicaSpec'
|
'ReplicaSpec'
|
||||||
]
|
]
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
__all__ = ['replicate', 'shard']
|
__all__ = ['ReplicaSpec', 'ShardSpec']
|
||||||
|
|
||||||
|
|
||||||
class DistPlacementPattern(Enum):
|
class DistPlacementPattern(Enum):
|
||||||
|
@ -10,15 +10,22 @@ class DistPlacementPattern(Enum):
|
||||||
|
|
||||||
|
|
||||||
class _DistSpec:
|
class _DistSpec:
|
||||||
|
"""_DistSpec
|
||||||
|
|
||||||
def __init__(self, dist_placement_pattern: DistPlacementPattern, **meta_info):
|
A class indicates Distributed Specification.
|
||||||
"""_DistSpec, Distributed Specification
|
The DistSpec is only works for the tensor parallel process groups.
|
||||||
|
Because the dist spec of data parallel process group can be automatically deduced.
|
||||||
|
This is an internal data structrue.
|
||||||
|
The API for users should be `ShardSpec` and `ReplicaSpec`.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
dist_placement_pattern (DistPlacementPattern): the pattern describing how tensors are distributed among processes.
|
dist_placement_pattern (DistPlacementPattern): the pattern describing how tensors are distributed among processes.
|
||||||
The dist_placement_pattern is picked from a limited set, now including two patterns: replicate and shard.
|
The dist_placement_pattern is picked from a limited set, now including two patterns: replicate and shard.
|
||||||
process_group (Optional[ProcessGroup], optional): the process group contains processes. Defaults to None.
|
process_group (Optional[ProcessGroup], optional): the process group contains processes. Defaults to None.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
def __init__(self, dist_placement_pattern: DistPlacementPattern, **meta_info):
|
||||||
|
|
||||||
self.placement = dist_placement_pattern
|
self.placement = dist_placement_pattern
|
||||||
for k, v in meta_info.items():
|
for k, v in meta_info.items():
|
||||||
setattr(self, k, v)
|
setattr(self, k, v)
|
||||||
|
@ -39,11 +46,32 @@ class _DistSpec:
|
||||||
return ''.join(res_list)
|
return ''.join(res_list)
|
||||||
|
|
||||||
|
|
||||||
def replicate() -> _DistSpec:
|
def ReplicaSpec() -> _DistSpec:
|
||||||
|
"""ReplicaSpec
|
||||||
|
|
||||||
|
A distributed specification represents the tensor is replicated among the tensor parallel process group.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
_DistSpec: an replicated dist spec instance.
|
||||||
|
"""
|
||||||
return _DistSpec(DistPlacementPattern.REPLICATE)
|
return _DistSpec(DistPlacementPattern.REPLICATE)
|
||||||
|
|
||||||
|
|
||||||
def shard(dims: List[int], num_partitions: List[int]) -> _DistSpec:
|
def ShardSpec(dims: List[int], num_partitions: List[int]) -> _DistSpec:
|
||||||
|
"""ShardSpec
|
||||||
|
|
||||||
|
A distributed specification represents the tensor is sharded among the tensor parallel process group.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
Currently, only shard on one dimension is valid. In another word, dims should be of size 1.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dims (List[int]): a list of dimensions
|
||||||
|
num_partitions (List[int]): a list of partition number of each dimensions.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
_DistSpec: an shard dist spec instance.
|
||||||
|
"""
|
||||||
assert isinstance(dims, list) and isinstance(num_partitions, list)
|
assert isinstance(dims, list) and isinstance(num_partitions, list)
|
||||||
assert len(dims) == len(num_partitions)
|
assert len(dims) == len(num_partitions)
|
||||||
return _DistSpec(DistPlacementPattern.SHARD, dims=tuple(dims), num_partitions=tuple(num_partitions))
|
return _DistSpec(DistPlacementPattern.SHARD, dims=tuple(dims), num_partitions=tuple(num_partitions))
|
||||||
|
|
|
@ -19,7 +19,7 @@ def close(num: float, other: float, rtol: float = 1e-5, atol: float = 1e-8):
|
||||||
|
|
||||||
def shard_param(p: ColoParameter) -> None:
|
def shard_param(p: ColoParameter) -> None:
|
||||||
pg = p.get_process_group()
|
pg = p.get_process_group()
|
||||||
p._redistribute(distspec.shard([0], [pg.tp_world_size()]))
|
p._redistribute(distspec.ShardSpec([0], [pg.tp_world_size()]))
|
||||||
p.grad = p.grad.chunk(pg.tp_world_size(), 0)[pg.tp_local_rank()].clone().detach()
|
p.grad = p.grad.chunk(pg.tp_world_size(), 0)[pg.tp_local_rank()].clone().detach()
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue