[NFC] polish type hint for shape consistency (#1801)

* [NFC] polish type hint for shape consistency

* polish code

* polish code
pull/1803/head
Jiarui Fang 2 years ago committed by GitHub
parent c248800359
commit 218c75fd9d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -1,17 +1,12 @@
import math
import operator
from copy import deepcopy
from dataclasses import dataclass
from enum import Enum
from functools import reduce
from typing import Dict, List, Optional, Tuple, Union
from typing import Dict, List, Tuple
import torch
import torch.distributed as dist
from torch.distributed import ReduceOp
from colossalai.context.singleton_meta import SingletonMeta
from colossalai.tensor.sharding_spec import ShardingSpec, ShardingSpecException, _DimSpec
from colossalai.tensor.sharding_spec import ShardingSpec, ShardingSpecException
from colossalai.tensor.utils import all_gather_simulator, all_to_all_simulator, shard_simulator
from .comm_spec import *
@ -28,7 +23,7 @@ class ShapeConsistencyOptions:
pass
def to_global(distributed_tensor: torch.Tensor, sharding_spec: ShardingSpec):
def to_global(distributed_tensor: torch.Tensor, sharding_spec: ShardingSpec) -> torch.Tensor:
shape_consistency_manager = ShapeConsistencyManager()
global_sharding_spec = ShardingSpec(sharding_spec.device_mesh, sharding_spec.entire_shape, {})
with torch.no_grad():
@ -72,7 +67,8 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
assert isinstance(value, bool)
self._forward_only = value
def get_all_all_gather_spec(self, source_spec, orig_cost_dict):
def get_all_all_gather_spec(self, source_spec: ShardingSpec,
orig_cost_dict: Dict[str, float]) -> Dict[ShardingSpec, float]:
'''
Get all valid sharding specs from source_spec with single all-gather operation, and
accumulate commucation cost on origin cost which will finally be used in auto sharding solver.
@ -80,7 +76,7 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
Argument:
source_spec(ShardingSpec): the ShardingSpec of the source_spec.
orig_cost(float): the original communication cost before this operation.
orig_cost(Dict[str, float]): the original communication cost before this operation.
Return:
valid_spec_dict(Dict[ShardingSpec, float]): all valid sharding specs from source_spec with single all-gather operation.
@ -92,7 +88,7 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
# device_mesh_shape: (4, 4)
sharding_spec = ShardingSpec(device_mesh, entire_shape, dim_partition_dict)
shape_consistency_manager = ShapeConsistencyManager()
rst_dict = shape_consistency_manager.get_all_all_gather_spec(sharding_spec, 0)
rst_dict = shape_consistency_manager.get_all_all_gather_spec(sharding_spec, {'forward': 0, 'backward': 0, 'total': 0})
print(rst_dict)
Output:
@ -143,7 +139,8 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
pass
return valid_spec_dict
def get_all_all_to_all_spec(self, source_spec, orig_cost_dict):
def get_all_all_to_all_spec(self, source_spec: ShardingSpec,
orig_cost_dict: Dict[str, float]) -> Dict[ShardingSpec, float]:
'''
Get all valid sharding specs from source_spec with single all-to-all operation, and
accumulate commucation cost on origin cost which will finally be used in auto sharding solver.
@ -151,7 +148,7 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
Argument:
source_spec(ShardingSpec): the ShardingSpec of the source_spec.
orig_cost(float): the original communication cost before this operation.
orig_cost(Dict[str, float]): the original communication cost before this operation.
Return:
valid_spec_dict(Dict[ShardingSpec, float]): all valid sharding specs from source_spec with single all-to-all operation.
@ -163,7 +160,7 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
# device_mesh_shape: (4, 4)
sharding_spec = ShardingSpec(device_mesh, entire_shape, dim_partition_dict)
shape_consistency_manager = ShapeConsistencyManager()
rst_dict = shape_consistency_manager.get_all_all_to_all_spec(sharding_spec, 0)
rst_dict = shape_consistency_manager.get_all_all_to_all_spec(sharding_spec, {'forward': 0, 'backward': 0, 'total': 0})
print(rst_dict)
Output:
@ -250,7 +247,7 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
return valid_spec_dict
def get_all_shard_spec(self, source_spec, orig_cost_dict):
def get_all_shard_spec(self, source_spec: ShardingSpec, orig_cost_dict):
'''
Get all valid sharding specs from source_spec with single shard operation, and
accumulate commucation cost on origin cost which will finally be used in auto sharding solver.
@ -270,7 +267,7 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
# device_mesh_shape: (4, 4)
sharding_spec = ShardingSpec(device_mesh, entire_shape, dim_partition_dict)
shape_consistency_manager = ShapeConsistencyManager()
rst_dict = shape_consistency_manager.get_all_shard_spec(sharding_spec, 0)
rst_dict = shape_consistency_manager.get_all_shard_spec(sharding_spec, {'forward': 0, 'backward': 0, 'total': 0})
print(rst_dict)
Output:
@ -331,7 +328,7 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
pass
return valid_spec_dict
def get_all_one_step_transform_spec(self, source_spec, orig_cost_dict):
def get_all_one_step_transform_spec(self, source_spec: ShardingSpec, orig_cost_dict) -> Dict[ShardingSpec, float]:
'''
Get all valid sharding specs from source_spec with one step transform, and
accumulate commucation cost on origin cost which will finally be used in auto sharding solver.
@ -353,7 +350,8 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
valid_spec_dict.update(self.get_all_shard_spec(source_spec, orig_cost_dict))
return valid_spec_dict
def shape_consistency(self, source_spec, target_spec):
def shape_consistency(self, source_spec: ShardingSpec,
target_spec: ShardingSpec) -> Tuple[List[ShardingSpec], List[CommSpec], float]:
'''
This method will find a path to transform source_spec to target_spec with
a greedy algorithm.
@ -459,7 +457,7 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
raise RuntimeError(f"Could not find a valid transform path with in {MAX_TRANSFORM_STEPS} steps.")
def apply(self, tensor_with_sharding_spec, target_spec):
def apply(self, tensor_with_sharding_spec: torch.Tensor, target_spec: ShardingSpec) -> torch.Tensor:
'''
Apply target_spec to tensor with source sharding spec, the transform path is generated by the
shape_consistency method.

Loading…
Cancel
Save