|
|
|
@ -1,17 +1,12 @@
|
|
|
|
|
import math |
|
|
|
|
import operator |
|
|
|
|
from copy import deepcopy |
|
|
|
|
from dataclasses import dataclass |
|
|
|
|
from enum import Enum |
|
|
|
|
from functools import reduce |
|
|
|
|
from typing import Dict, List, Optional, Tuple, Union |
|
|
|
|
from typing import Dict, List, Tuple |
|
|
|
|
|
|
|
|
|
import torch |
|
|
|
|
import torch.distributed as dist |
|
|
|
|
from torch.distributed import ReduceOp |
|
|
|
|
|
|
|
|
|
from colossalai.context.singleton_meta import SingletonMeta |
|
|
|
|
from colossalai.tensor.sharding_spec import ShardingSpec, ShardingSpecException, _DimSpec |
|
|
|
|
from colossalai.tensor.sharding_spec import ShardingSpec, ShardingSpecException |
|
|
|
|
from colossalai.tensor.utils import all_gather_simulator, all_to_all_simulator, shard_simulator |
|
|
|
|
|
|
|
|
|
from .comm_spec import * |
|
|
|
@ -28,7 +23,7 @@ class ShapeConsistencyOptions:
|
|
|
|
|
pass |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def to_global(distributed_tensor: torch.Tensor, sharding_spec: ShardingSpec): |
|
|
|
|
def to_global(distributed_tensor: torch.Tensor, sharding_spec: ShardingSpec) -> torch.Tensor: |
|
|
|
|
shape_consistency_manager = ShapeConsistencyManager() |
|
|
|
|
global_sharding_spec = ShardingSpec(sharding_spec.device_mesh, sharding_spec.entire_shape, {}) |
|
|
|
|
with torch.no_grad(): |
|
|
|
@ -72,7 +67,8 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
|
|
|
|
|
assert isinstance(value, bool) |
|
|
|
|
self._forward_only = value |
|
|
|
|
|
|
|
|
|
def get_all_all_gather_spec(self, source_spec, orig_cost_dict): |
|
|
|
|
def get_all_all_gather_spec(self, source_spec: ShardingSpec, |
|
|
|
|
orig_cost_dict: Dict[str, float]) -> Dict[ShardingSpec, float]: |
|
|
|
|
''' |
|
|
|
|
Get all valid sharding specs from source_spec with single all-gather operation, and |
|
|
|
|
accumulate commucation cost on origin cost which will finally be used in auto sharding solver. |
|
|
|
@ -80,7 +76,7 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
|
|
|
|
|
|
|
|
|
|
Argument: |
|
|
|
|
source_spec(ShardingSpec): the ShardingSpec of the source_spec. |
|
|
|
|
orig_cost(float): the original communication cost before this operation. |
|
|
|
|
orig_cost(Dict[str, float]): the original communication cost before this operation. |
|
|
|
|
|
|
|
|
|
Return: |
|
|
|
|
valid_spec_dict(Dict[ShardingSpec, float]): all valid sharding specs from source_spec with single all-gather operation. |
|
|
|
@ -92,7 +88,7 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
|
|
|
|
|
# device_mesh_shape: (4, 4) |
|
|
|
|
sharding_spec = ShardingSpec(device_mesh, entire_shape, dim_partition_dict) |
|
|
|
|
shape_consistency_manager = ShapeConsistencyManager() |
|
|
|
|
rst_dict = shape_consistency_manager.get_all_all_gather_spec(sharding_spec, 0) |
|
|
|
|
rst_dict = shape_consistency_manager.get_all_all_gather_spec(sharding_spec, {'forward': 0, 'backward': 0, 'total': 0}) |
|
|
|
|
print(rst_dict) |
|
|
|
|
|
|
|
|
|
Output: |
|
|
|
@ -143,7 +139,8 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
|
|
|
|
|
pass |
|
|
|
|
return valid_spec_dict |
|
|
|
|
|
|
|
|
|
def get_all_all_to_all_spec(self, source_spec, orig_cost_dict): |
|
|
|
|
def get_all_all_to_all_spec(self, source_spec: ShardingSpec, |
|
|
|
|
orig_cost_dict: Dict[str, float]) -> Dict[ShardingSpec, float]: |
|
|
|
|
''' |
|
|
|
|
Get all valid sharding specs from source_spec with single all-to-all operation, and |
|
|
|
|
accumulate commucation cost on origin cost which will finally be used in auto sharding solver. |
|
|
|
@ -151,7 +148,7 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
|
|
|
|
|
|
|
|
|
|
Argument: |
|
|
|
|
source_spec(ShardingSpec): the ShardingSpec of the source_spec. |
|
|
|
|
orig_cost(float): the original communication cost before this operation. |
|
|
|
|
orig_cost(Dict[str, float]): the original communication cost before this operation. |
|
|
|
|
|
|
|
|
|
Return: |
|
|
|
|
valid_spec_dict(Dict[ShardingSpec, float]): all valid sharding specs from source_spec with single all-to-all operation. |
|
|
|
@ -163,7 +160,7 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
|
|
|
|
|
# device_mesh_shape: (4, 4) |
|
|
|
|
sharding_spec = ShardingSpec(device_mesh, entire_shape, dim_partition_dict) |
|
|
|
|
shape_consistency_manager = ShapeConsistencyManager() |
|
|
|
|
rst_dict = shape_consistency_manager.get_all_all_to_all_spec(sharding_spec, 0) |
|
|
|
|
rst_dict = shape_consistency_manager.get_all_all_to_all_spec(sharding_spec, {'forward': 0, 'backward': 0, 'total': 0}) |
|
|
|
|
print(rst_dict) |
|
|
|
|
|
|
|
|
|
Output: |
|
|
|
@ -250,7 +247,7 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
|
|
|
|
|
|
|
|
|
|
return valid_spec_dict |
|
|
|
|
|
|
|
|
|
def get_all_shard_spec(self, source_spec, orig_cost_dict): |
|
|
|
|
def get_all_shard_spec(self, source_spec: ShardingSpec, orig_cost_dict): |
|
|
|
|
''' |
|
|
|
|
Get all valid sharding specs from source_spec with single shard operation, and |
|
|
|
|
accumulate commucation cost on origin cost which will finally be used in auto sharding solver. |
|
|
|
@ -270,7 +267,7 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
|
|
|
|
|
# device_mesh_shape: (4, 4) |
|
|
|
|
sharding_spec = ShardingSpec(device_mesh, entire_shape, dim_partition_dict) |
|
|
|
|
shape_consistency_manager = ShapeConsistencyManager() |
|
|
|
|
rst_dict = shape_consistency_manager.get_all_shard_spec(sharding_spec, 0) |
|
|
|
|
rst_dict = shape_consistency_manager.get_all_shard_spec(sharding_spec, {'forward': 0, 'backward': 0, 'total': 0}) |
|
|
|
|
print(rst_dict) |
|
|
|
|
|
|
|
|
|
Output: |
|
|
|
@ -331,7 +328,7 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
|
|
|
|
|
pass |
|
|
|
|
return valid_spec_dict |
|
|
|
|
|
|
|
|
|
def get_all_one_step_transform_spec(self, source_spec, orig_cost_dict): |
|
|
|
|
def get_all_one_step_transform_spec(self, source_spec: ShardingSpec, orig_cost_dict) -> Dict[ShardingSpec, float]: |
|
|
|
|
''' |
|
|
|
|
Get all valid sharding specs from source_spec with one step transform, and |
|
|
|
|
accumulate commucation cost on origin cost which will finally be used in auto sharding solver. |
|
|
|
@ -353,7 +350,8 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
|
|
|
|
|
valid_spec_dict.update(self.get_all_shard_spec(source_spec, orig_cost_dict)) |
|
|
|
|
return valid_spec_dict |
|
|
|
|
|
|
|
|
|
def shape_consistency(self, source_spec, target_spec): |
|
|
|
|
def shape_consistency(self, source_spec: ShardingSpec, |
|
|
|
|
target_spec: ShardingSpec) -> Tuple[List[ShardingSpec], List[CommSpec], float]: |
|
|
|
|
''' |
|
|
|
|
This method will find a path to transform source_spec to target_spec with |
|
|
|
|
a greedy algorithm. |
|
|
|
@ -459,7 +457,7 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
|
|
|
|
|
|
|
|
|
|
raise RuntimeError(f"Could not find a valid transform path with in {MAX_TRANSFORM_STEPS} steps.") |
|
|
|
|
|
|
|
|
|
def apply(self, tensor_with_sharding_spec, target_spec): |
|
|
|
|
def apply(self, tensor_with_sharding_spec: torch.Tensor, target_spec: ShardingSpec) -> torch.Tensor: |
|
|
|
|
''' |
|
|
|
|
Apply target_spec to tensor with source sharding spec, the transform path is generated by the |
|
|
|
|
shape_consistency method. |
|
|
|
|