diff --git a/colossalai/tensor/shape_consistency.py b/colossalai/tensor/shape_consistency.py index 4ec5ad9e9..d5d28db0f 100644 --- a/colossalai/tensor/shape_consistency.py +++ b/colossalai/tensor/shape_consistency.py @@ -1,17 +1,12 @@ import math -import operator from copy import deepcopy from dataclasses import dataclass -from enum import Enum -from functools import reduce -from typing import Dict, List, Optional, Tuple, Union +from typing import Dict, List, Tuple import torch -import torch.distributed as dist -from torch.distributed import ReduceOp from colossalai.context.singleton_meta import SingletonMeta -from colossalai.tensor.sharding_spec import ShardingSpec, ShardingSpecException, _DimSpec +from colossalai.tensor.sharding_spec import ShardingSpec, ShardingSpecException from colossalai.tensor.utils import all_gather_simulator, all_to_all_simulator, shard_simulator from .comm_spec import * @@ -28,7 +23,7 @@ class ShapeConsistencyOptions: pass -def to_global(distributed_tensor: torch.Tensor, sharding_spec: ShardingSpec): +def to_global(distributed_tensor: torch.Tensor, sharding_spec: ShardingSpec) -> torch.Tensor: shape_consistency_manager = ShapeConsistencyManager() global_sharding_spec = ShardingSpec(sharding_spec.device_mesh, sharding_spec.entire_shape, {}) with torch.no_grad(): @@ -72,7 +67,8 @@ class ShapeConsistencyManager(metaclass=SingletonMeta): assert isinstance(value, bool) self._forward_only = value - def get_all_all_gather_spec(self, source_spec, orig_cost_dict): + def get_all_all_gather_spec(self, source_spec: ShardingSpec, + orig_cost_dict: Dict[str, float]) -> Dict[ShardingSpec, float]: ''' Get all valid sharding specs from source_spec with single all-gather operation, and accumulate commucation cost on origin cost which will finally be used in auto sharding solver. @@ -80,7 +76,7 @@ class ShapeConsistencyManager(metaclass=SingletonMeta): Argument: source_spec(ShardingSpec): the ShardingSpec of the source_spec. - orig_cost(float): the original communication cost before this operation. + orig_cost(Dict[str, float]): the original communication cost before this operation. Return: valid_spec_dict(Dict[ShardingSpec, float]): all valid sharding specs from source_spec with single all-gather operation. @@ -92,7 +88,7 @@ class ShapeConsistencyManager(metaclass=SingletonMeta): # device_mesh_shape: (4, 4) sharding_spec = ShardingSpec(device_mesh, entire_shape, dim_partition_dict) shape_consistency_manager = ShapeConsistencyManager() - rst_dict = shape_consistency_manager.get_all_all_gather_spec(sharding_spec, 0) + rst_dict = shape_consistency_manager.get_all_all_gather_spec(sharding_spec, {'forward': 0, 'backward': 0, 'total': 0}) print(rst_dict) Output: @@ -143,7 +139,8 @@ class ShapeConsistencyManager(metaclass=SingletonMeta): pass return valid_spec_dict - def get_all_all_to_all_spec(self, source_spec, orig_cost_dict): + def get_all_all_to_all_spec(self, source_spec: ShardingSpec, + orig_cost_dict: Dict[str, float]) -> Dict[ShardingSpec, float]: ''' Get all valid sharding specs from source_spec with single all-to-all operation, and accumulate commucation cost on origin cost which will finally be used in auto sharding solver. @@ -151,7 +148,7 @@ class ShapeConsistencyManager(metaclass=SingletonMeta): Argument: source_spec(ShardingSpec): the ShardingSpec of the source_spec. - orig_cost(float): the original communication cost before this operation. + orig_cost(Dict[str, float]): the original communication cost before this operation. Return: valid_spec_dict(Dict[ShardingSpec, float]): all valid sharding specs from source_spec with single all-to-all operation. @@ -163,7 +160,7 @@ class ShapeConsistencyManager(metaclass=SingletonMeta): # device_mesh_shape: (4, 4) sharding_spec = ShardingSpec(device_mesh, entire_shape, dim_partition_dict) shape_consistency_manager = ShapeConsistencyManager() - rst_dict = shape_consistency_manager.get_all_all_to_all_spec(sharding_spec, 0) + rst_dict = shape_consistency_manager.get_all_all_to_all_spec(sharding_spec, {'forward': 0, 'backward': 0, 'total': 0}) print(rst_dict) Output: @@ -250,7 +247,7 @@ class ShapeConsistencyManager(metaclass=SingletonMeta): return valid_spec_dict - def get_all_shard_spec(self, source_spec, orig_cost_dict): + def get_all_shard_spec(self, source_spec: ShardingSpec, orig_cost_dict): ''' Get all valid sharding specs from source_spec with single shard operation, and accumulate commucation cost on origin cost which will finally be used in auto sharding solver. @@ -270,7 +267,7 @@ class ShapeConsistencyManager(metaclass=SingletonMeta): # device_mesh_shape: (4, 4) sharding_spec = ShardingSpec(device_mesh, entire_shape, dim_partition_dict) shape_consistency_manager = ShapeConsistencyManager() - rst_dict = shape_consistency_manager.get_all_shard_spec(sharding_spec, 0) + rst_dict = shape_consistency_manager.get_all_shard_spec(sharding_spec, {'forward': 0, 'backward': 0, 'total': 0}) print(rst_dict) Output: @@ -331,7 +328,7 @@ class ShapeConsistencyManager(metaclass=SingletonMeta): pass return valid_spec_dict - def get_all_one_step_transform_spec(self, source_spec, orig_cost_dict): + def get_all_one_step_transform_spec(self, source_spec: ShardingSpec, orig_cost_dict) -> Dict[ShardingSpec, float]: ''' Get all valid sharding specs from source_spec with one step transform, and accumulate commucation cost on origin cost which will finally be used in auto sharding solver. @@ -353,7 +350,8 @@ class ShapeConsistencyManager(metaclass=SingletonMeta): valid_spec_dict.update(self.get_all_shard_spec(source_spec, orig_cost_dict)) return valid_spec_dict - def shape_consistency(self, source_spec, target_spec): + def shape_consistency(self, source_spec: ShardingSpec, + target_spec: ShardingSpec) -> Tuple[List[ShardingSpec], List[CommSpec], float]: ''' This method will find a path to transform source_spec to target_spec with a greedy algorithm. @@ -459,7 +457,7 @@ class ShapeConsistencyManager(metaclass=SingletonMeta): raise RuntimeError(f"Could not find a valid transform path with in {MAX_TRANSFORM_STEPS} steps.") - def apply(self, tensor_with_sharding_spec, target_spec): + def apply(self, tensor_with_sharding_spec: torch.Tensor, target_spec: ShardingSpec) -> torch.Tensor: ''' Apply target_spec to tensor with source sharding spec, the transform path is generated by the shape_consistency method.