|
|
|
@ -1,17 +1,12 @@
|
|
|
|
|
import math
|
|
|
|
|
import operator
|
|
|
|
|
from copy import deepcopy
|
|
|
|
|
from dataclasses import dataclass
|
|
|
|
|
from enum import Enum
|
|
|
|
|
from functools import reduce
|
|
|
|
|
from typing import Dict, List, Optional, Tuple, Union
|
|
|
|
|
from typing import Dict, List, Tuple
|
|
|
|
|
|
|
|
|
|
import torch
|
|
|
|
|
import torch.distributed as dist
|
|
|
|
|
from torch.distributed import ReduceOp
|
|
|
|
|
|
|
|
|
|
from colossalai.context.singleton_meta import SingletonMeta
|
|
|
|
|
from colossalai.tensor.sharding_spec import ShardingSpec, ShardingSpecException, _DimSpec
|
|
|
|
|
from colossalai.tensor.sharding_spec import ShardingSpec, ShardingSpecException
|
|
|
|
|
from colossalai.tensor.utils import all_gather_simulator, all_to_all_simulator, shard_simulator
|
|
|
|
|
|
|
|
|
|
from .comm_spec import *
|
|
|
|
@ -28,7 +23,7 @@ class ShapeConsistencyOptions:
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def to_global(distributed_tensor: torch.Tensor, sharding_spec: ShardingSpec):
|
|
|
|
|
def to_global(distributed_tensor: torch.Tensor, sharding_spec: ShardingSpec) -> torch.Tensor:
|
|
|
|
|
shape_consistency_manager = ShapeConsistencyManager()
|
|
|
|
|
global_sharding_spec = ShardingSpec(sharding_spec.device_mesh, sharding_spec.entire_shape, {})
|
|
|
|
|
with torch.no_grad():
|
|
|
|
@ -72,7 +67,8 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
|
|
|
|
|
assert isinstance(value, bool)
|
|
|
|
|
self._forward_only = value
|
|
|
|
|
|
|
|
|
|
def get_all_all_gather_spec(self, source_spec, orig_cost_dict):
|
|
|
|
|
def get_all_all_gather_spec(self, source_spec: ShardingSpec,
|
|
|
|
|
orig_cost_dict: Dict[str, float]) -> Dict[ShardingSpec, float]:
|
|
|
|
|
'''
|
|
|
|
|
Get all valid sharding specs from source_spec with single all-gather operation, and
|
|
|
|
|
accumulate commucation cost on origin cost which will finally be used in auto sharding solver.
|
|
|
|
@ -80,7 +76,7 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
|
|
|
|
|
|
|
|
|
|
Argument:
|
|
|
|
|
source_spec(ShardingSpec): the ShardingSpec of the source_spec.
|
|
|
|
|
orig_cost(float): the original communication cost before this operation.
|
|
|
|
|
orig_cost(Dict[str, float]): the original communication cost before this operation.
|
|
|
|
|
|
|
|
|
|
Return:
|
|
|
|
|
valid_spec_dict(Dict[ShardingSpec, float]): all valid sharding specs from source_spec with single all-gather operation.
|
|
|
|
@ -92,7 +88,7 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
|
|
|
|
|
# device_mesh_shape: (4, 4)
|
|
|
|
|
sharding_spec = ShardingSpec(device_mesh, entire_shape, dim_partition_dict)
|
|
|
|
|
shape_consistency_manager = ShapeConsistencyManager()
|
|
|
|
|
rst_dict = shape_consistency_manager.get_all_all_gather_spec(sharding_spec, 0)
|
|
|
|
|
rst_dict = shape_consistency_manager.get_all_all_gather_spec(sharding_spec, {'forward': 0, 'backward': 0, 'total': 0})
|
|
|
|
|
print(rst_dict)
|
|
|
|
|
|
|
|
|
|
Output:
|
|
|
|
@ -143,7 +139,8 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
|
|
|
|
|
pass
|
|
|
|
|
return valid_spec_dict
|
|
|
|
|
|
|
|
|
|
def get_all_all_to_all_spec(self, source_spec, orig_cost_dict):
|
|
|
|
|
def get_all_all_to_all_spec(self, source_spec: ShardingSpec,
|
|
|
|
|
orig_cost_dict: Dict[str, float]) -> Dict[ShardingSpec, float]:
|
|
|
|
|
'''
|
|
|
|
|
Get all valid sharding specs from source_spec with single all-to-all operation, and
|
|
|
|
|
accumulate commucation cost on origin cost which will finally be used in auto sharding solver.
|
|
|
|
@ -151,7 +148,7 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
|
|
|
|
|
|
|
|
|
|
Argument:
|
|
|
|
|
source_spec(ShardingSpec): the ShardingSpec of the source_spec.
|
|
|
|
|
orig_cost(float): the original communication cost before this operation.
|
|
|
|
|
orig_cost(Dict[str, float]): the original communication cost before this operation.
|
|
|
|
|
|
|
|
|
|
Return:
|
|
|
|
|
valid_spec_dict(Dict[ShardingSpec, float]): all valid sharding specs from source_spec with single all-to-all operation.
|
|
|
|
@ -163,7 +160,7 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
|
|
|
|
|
# device_mesh_shape: (4, 4)
|
|
|
|
|
sharding_spec = ShardingSpec(device_mesh, entire_shape, dim_partition_dict)
|
|
|
|
|
shape_consistency_manager = ShapeConsistencyManager()
|
|
|
|
|
rst_dict = shape_consistency_manager.get_all_all_to_all_spec(sharding_spec, 0)
|
|
|
|
|
rst_dict = shape_consistency_manager.get_all_all_to_all_spec(sharding_spec, {'forward': 0, 'backward': 0, 'total': 0})
|
|
|
|
|
print(rst_dict)
|
|
|
|
|
|
|
|
|
|
Output:
|
|
|
|
@ -250,7 +247,7 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
|
|
|
|
|
|
|
|
|
|
return valid_spec_dict
|
|
|
|
|
|
|
|
|
|
def get_all_shard_spec(self, source_spec, orig_cost_dict):
|
|
|
|
|
def get_all_shard_spec(self, source_spec: ShardingSpec, orig_cost_dict):
|
|
|
|
|
'''
|
|
|
|
|
Get all valid sharding specs from source_spec with single shard operation, and
|
|
|
|
|
accumulate commucation cost on origin cost which will finally be used in auto sharding solver.
|
|
|
|
@ -270,7 +267,7 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
|
|
|
|
|
# device_mesh_shape: (4, 4)
|
|
|
|
|
sharding_spec = ShardingSpec(device_mesh, entire_shape, dim_partition_dict)
|
|
|
|
|
shape_consistency_manager = ShapeConsistencyManager()
|
|
|
|
|
rst_dict = shape_consistency_manager.get_all_shard_spec(sharding_spec, 0)
|
|
|
|
|
rst_dict = shape_consistency_manager.get_all_shard_spec(sharding_spec, {'forward': 0, 'backward': 0, 'total': 0})
|
|
|
|
|
print(rst_dict)
|
|
|
|
|
|
|
|
|
|
Output:
|
|
|
|
@ -331,7 +328,7 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
|
|
|
|
|
pass
|
|
|
|
|
return valid_spec_dict
|
|
|
|
|
|
|
|
|
|
def get_all_one_step_transform_spec(self, source_spec, orig_cost_dict):
|
|
|
|
|
def get_all_one_step_transform_spec(self, source_spec: ShardingSpec, orig_cost_dict) -> Dict[ShardingSpec, float]:
|
|
|
|
|
'''
|
|
|
|
|
Get all valid sharding specs from source_spec with one step transform, and
|
|
|
|
|
accumulate commucation cost on origin cost which will finally be used in auto sharding solver.
|
|
|
|
@ -353,7 +350,8 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
|
|
|
|
|
valid_spec_dict.update(self.get_all_shard_spec(source_spec, orig_cost_dict))
|
|
|
|
|
return valid_spec_dict
|
|
|
|
|
|
|
|
|
|
def shape_consistency(self, source_spec, target_spec):
|
|
|
|
|
def shape_consistency(self, source_spec: ShardingSpec,
|
|
|
|
|
target_spec: ShardingSpec) -> Tuple[List[ShardingSpec], List[CommSpec], float]:
|
|
|
|
|
'''
|
|
|
|
|
This method will find a path to transform source_spec to target_spec with
|
|
|
|
|
a greedy algorithm.
|
|
|
|
@ -459,7 +457,7 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
|
|
|
|
|
|
|
|
|
|
raise RuntimeError(f"Could not find a valid transform path with in {MAX_TRANSFORM_STEPS} steps.")
|
|
|
|
|
|
|
|
|
|
def apply(self, tensor_with_sharding_spec, target_spec):
|
|
|
|
|
def apply(self, tensor_with_sharding_spec: torch.Tensor, target_spec: ShardingSpec) -> torch.Tensor:
|
|
|
|
|
'''
|
|
|
|
|
Apply target_spec to tensor with source sharding spec, the transform path is generated by the
|
|
|
|
|
shape_consistency method.
|
|
|
|
|