mirror of https://github.com/hpcaitech/ColossalAI
[NFC] polish type hint for shape consistency (#1801)
* [NFC] polish type hint for shape consistency * polish code * polish codepull/1803/head
parent
c248800359
commit
218c75fd9d
|
@ -1,17 +1,12 @@
|
|||
import math
|
||||
import operator
|
||||
from copy import deepcopy
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from functools import reduce
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.distributed import ReduceOp
|
||||
|
||||
from colossalai.context.singleton_meta import SingletonMeta
|
||||
from colossalai.tensor.sharding_spec import ShardingSpec, ShardingSpecException, _DimSpec
|
||||
from colossalai.tensor.sharding_spec import ShardingSpec, ShardingSpecException
|
||||
from colossalai.tensor.utils import all_gather_simulator, all_to_all_simulator, shard_simulator
|
||||
|
||||
from .comm_spec import *
|
||||
|
@ -28,7 +23,7 @@ class ShapeConsistencyOptions:
|
|||
pass
|
||||
|
||||
|
||||
def to_global(distributed_tensor: torch.Tensor, sharding_spec: ShardingSpec):
|
||||
def to_global(distributed_tensor: torch.Tensor, sharding_spec: ShardingSpec) -> torch.Tensor:
|
||||
shape_consistency_manager = ShapeConsistencyManager()
|
||||
global_sharding_spec = ShardingSpec(sharding_spec.device_mesh, sharding_spec.entire_shape, {})
|
||||
with torch.no_grad():
|
||||
|
@ -72,7 +67,8 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
|
|||
assert isinstance(value, bool)
|
||||
self._forward_only = value
|
||||
|
||||
def get_all_all_gather_spec(self, source_spec, orig_cost_dict):
|
||||
def get_all_all_gather_spec(self, source_spec: ShardingSpec,
|
||||
orig_cost_dict: Dict[str, float]) -> Dict[ShardingSpec, float]:
|
||||
'''
|
||||
Get all valid sharding specs from source_spec with single all-gather operation, and
|
||||
accumulate commucation cost on origin cost which will finally be used in auto sharding solver.
|
||||
|
@ -80,7 +76,7 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
|
|||
|
||||
Argument:
|
||||
source_spec(ShardingSpec): the ShardingSpec of the source_spec.
|
||||
orig_cost(float): the original communication cost before this operation.
|
||||
orig_cost(Dict[str, float]): the original communication cost before this operation.
|
||||
|
||||
Return:
|
||||
valid_spec_dict(Dict[ShardingSpec, float]): all valid sharding specs from source_spec with single all-gather operation.
|
||||
|
@ -92,7 +88,7 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
|
|||
# device_mesh_shape: (4, 4)
|
||||
sharding_spec = ShardingSpec(device_mesh, entire_shape, dim_partition_dict)
|
||||
shape_consistency_manager = ShapeConsistencyManager()
|
||||
rst_dict = shape_consistency_manager.get_all_all_gather_spec(sharding_spec, 0)
|
||||
rst_dict = shape_consistency_manager.get_all_all_gather_spec(sharding_spec, {'forward': 0, 'backward': 0, 'total': 0})
|
||||
print(rst_dict)
|
||||
|
||||
Output:
|
||||
|
@ -143,7 +139,8 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
|
|||
pass
|
||||
return valid_spec_dict
|
||||
|
||||
def get_all_all_to_all_spec(self, source_spec, orig_cost_dict):
|
||||
def get_all_all_to_all_spec(self, source_spec: ShardingSpec,
|
||||
orig_cost_dict: Dict[str, float]) -> Dict[ShardingSpec, float]:
|
||||
'''
|
||||
Get all valid sharding specs from source_spec with single all-to-all operation, and
|
||||
accumulate commucation cost on origin cost which will finally be used in auto sharding solver.
|
||||
|
@ -151,7 +148,7 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
|
|||
|
||||
Argument:
|
||||
source_spec(ShardingSpec): the ShardingSpec of the source_spec.
|
||||
orig_cost(float): the original communication cost before this operation.
|
||||
orig_cost(Dict[str, float]): the original communication cost before this operation.
|
||||
|
||||
Return:
|
||||
valid_spec_dict(Dict[ShardingSpec, float]): all valid sharding specs from source_spec with single all-to-all operation.
|
||||
|
@ -163,7 +160,7 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
|
|||
# device_mesh_shape: (4, 4)
|
||||
sharding_spec = ShardingSpec(device_mesh, entire_shape, dim_partition_dict)
|
||||
shape_consistency_manager = ShapeConsistencyManager()
|
||||
rst_dict = shape_consistency_manager.get_all_all_to_all_spec(sharding_spec, 0)
|
||||
rst_dict = shape_consistency_manager.get_all_all_to_all_spec(sharding_spec, {'forward': 0, 'backward': 0, 'total': 0})
|
||||
print(rst_dict)
|
||||
|
||||
Output:
|
||||
|
@ -250,7 +247,7 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
|
|||
|
||||
return valid_spec_dict
|
||||
|
||||
def get_all_shard_spec(self, source_spec, orig_cost_dict):
|
||||
def get_all_shard_spec(self, source_spec: ShardingSpec, orig_cost_dict):
|
||||
'''
|
||||
Get all valid sharding specs from source_spec with single shard operation, and
|
||||
accumulate commucation cost on origin cost which will finally be used in auto sharding solver.
|
||||
|
@ -270,7 +267,7 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
|
|||
# device_mesh_shape: (4, 4)
|
||||
sharding_spec = ShardingSpec(device_mesh, entire_shape, dim_partition_dict)
|
||||
shape_consistency_manager = ShapeConsistencyManager()
|
||||
rst_dict = shape_consistency_manager.get_all_shard_spec(sharding_spec, 0)
|
||||
rst_dict = shape_consistency_manager.get_all_shard_spec(sharding_spec, {'forward': 0, 'backward': 0, 'total': 0})
|
||||
print(rst_dict)
|
||||
|
||||
Output:
|
||||
|
@ -331,7 +328,7 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
|
|||
pass
|
||||
return valid_spec_dict
|
||||
|
||||
def get_all_one_step_transform_spec(self, source_spec, orig_cost_dict):
|
||||
def get_all_one_step_transform_spec(self, source_spec: ShardingSpec, orig_cost_dict) -> Dict[ShardingSpec, float]:
|
||||
'''
|
||||
Get all valid sharding specs from source_spec with one step transform, and
|
||||
accumulate commucation cost on origin cost which will finally be used in auto sharding solver.
|
||||
|
@ -353,7 +350,8 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
|
|||
valid_spec_dict.update(self.get_all_shard_spec(source_spec, orig_cost_dict))
|
||||
return valid_spec_dict
|
||||
|
||||
def shape_consistency(self, source_spec, target_spec):
|
||||
def shape_consistency(self, source_spec: ShardingSpec,
|
||||
target_spec: ShardingSpec) -> Tuple[List[ShardingSpec], List[CommSpec], float]:
|
||||
'''
|
||||
This method will find a path to transform source_spec to target_spec with
|
||||
a greedy algorithm.
|
||||
|
@ -459,7 +457,7 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
|
|||
|
||||
raise RuntimeError(f"Could not find a valid transform path with in {MAX_TRANSFORM_STEPS} steps.")
|
||||
|
||||
def apply(self, tensor_with_sharding_spec, target_spec):
|
||||
def apply(self, tensor_with_sharding_spec: torch.Tensor, target_spec: ShardingSpec) -> torch.Tensor:
|
||||
'''
|
||||
Apply target_spec to tensor with source sharding spec, the transform path is generated by the
|
||||
shape_consistency method.
|
||||
|
|
Loading…
Reference in New Issue