Browse Source

[NFC] polish type hint for shape consistency (#1801)

* [NFC] polish type hint for shape consistency

* polish code

* polish code
pull/1803/head
Jiarui Fang 2 years ago committed by GitHub
parent
commit
218c75fd9d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 36
      colossalai/tensor/shape_consistency.py

36
colossalai/tensor/shape_consistency.py

@ -1,17 +1,12 @@
import math import math
import operator
from copy import deepcopy from copy import deepcopy
from dataclasses import dataclass from dataclasses import dataclass
from enum import Enum from typing import Dict, List, Tuple
from functools import reduce
from typing import Dict, List, Optional, Tuple, Union
import torch import torch
import torch.distributed as dist
from torch.distributed import ReduceOp
from colossalai.context.singleton_meta import SingletonMeta from colossalai.context.singleton_meta import SingletonMeta
from colossalai.tensor.sharding_spec import ShardingSpec, ShardingSpecException, _DimSpec from colossalai.tensor.sharding_spec import ShardingSpec, ShardingSpecException
from colossalai.tensor.utils import all_gather_simulator, all_to_all_simulator, shard_simulator from colossalai.tensor.utils import all_gather_simulator, all_to_all_simulator, shard_simulator
from .comm_spec import * from .comm_spec import *
@ -28,7 +23,7 @@ class ShapeConsistencyOptions:
pass pass
def to_global(distributed_tensor: torch.Tensor, sharding_spec: ShardingSpec): def to_global(distributed_tensor: torch.Tensor, sharding_spec: ShardingSpec) -> torch.Tensor:
shape_consistency_manager = ShapeConsistencyManager() shape_consistency_manager = ShapeConsistencyManager()
global_sharding_spec = ShardingSpec(sharding_spec.device_mesh, sharding_spec.entire_shape, {}) global_sharding_spec = ShardingSpec(sharding_spec.device_mesh, sharding_spec.entire_shape, {})
with torch.no_grad(): with torch.no_grad():
@ -72,7 +67,8 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
assert isinstance(value, bool) assert isinstance(value, bool)
self._forward_only = value self._forward_only = value
def get_all_all_gather_spec(self, source_spec, orig_cost_dict): def get_all_all_gather_spec(self, source_spec: ShardingSpec,
orig_cost_dict: Dict[str, float]) -> Dict[ShardingSpec, float]:
''' '''
Get all valid sharding specs from source_spec with single all-gather operation, and Get all valid sharding specs from source_spec with single all-gather operation, and
accumulate commucation cost on origin cost which will finally be used in auto sharding solver. accumulate commucation cost on origin cost which will finally be used in auto sharding solver.
@ -80,7 +76,7 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
Argument: Argument:
source_spec(ShardingSpec): the ShardingSpec of the source_spec. source_spec(ShardingSpec): the ShardingSpec of the source_spec.
orig_cost(float): the original communication cost before this operation. orig_cost(Dict[str, float]): the original communication cost before this operation.
Return: Return:
valid_spec_dict(Dict[ShardingSpec, float]): all valid sharding specs from source_spec with single all-gather operation. valid_spec_dict(Dict[ShardingSpec, float]): all valid sharding specs from source_spec with single all-gather operation.
@ -92,7 +88,7 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
# device_mesh_shape: (4, 4) # device_mesh_shape: (4, 4)
sharding_spec = ShardingSpec(device_mesh, entire_shape, dim_partition_dict) sharding_spec = ShardingSpec(device_mesh, entire_shape, dim_partition_dict)
shape_consistency_manager = ShapeConsistencyManager() shape_consistency_manager = ShapeConsistencyManager()
rst_dict = shape_consistency_manager.get_all_all_gather_spec(sharding_spec, 0) rst_dict = shape_consistency_manager.get_all_all_gather_spec(sharding_spec, {'forward': 0, 'backward': 0, 'total': 0})
print(rst_dict) print(rst_dict)
Output: Output:
@ -143,7 +139,8 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
pass pass
return valid_spec_dict return valid_spec_dict
def get_all_all_to_all_spec(self, source_spec, orig_cost_dict): def get_all_all_to_all_spec(self, source_spec: ShardingSpec,
orig_cost_dict: Dict[str, float]) -> Dict[ShardingSpec, float]:
''' '''
Get all valid sharding specs from source_spec with single all-to-all operation, and Get all valid sharding specs from source_spec with single all-to-all operation, and
accumulate commucation cost on origin cost which will finally be used in auto sharding solver. accumulate commucation cost on origin cost which will finally be used in auto sharding solver.
@ -151,7 +148,7 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
Argument: Argument:
source_spec(ShardingSpec): the ShardingSpec of the source_spec. source_spec(ShardingSpec): the ShardingSpec of the source_spec.
orig_cost(float): the original communication cost before this operation. orig_cost(Dict[str, float]): the original communication cost before this operation.
Return: Return:
valid_spec_dict(Dict[ShardingSpec, float]): all valid sharding specs from source_spec with single all-to-all operation. valid_spec_dict(Dict[ShardingSpec, float]): all valid sharding specs from source_spec with single all-to-all operation.
@ -163,7 +160,7 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
# device_mesh_shape: (4, 4) # device_mesh_shape: (4, 4)
sharding_spec = ShardingSpec(device_mesh, entire_shape, dim_partition_dict) sharding_spec = ShardingSpec(device_mesh, entire_shape, dim_partition_dict)
shape_consistency_manager = ShapeConsistencyManager() shape_consistency_manager = ShapeConsistencyManager()
rst_dict = shape_consistency_manager.get_all_all_to_all_spec(sharding_spec, 0) rst_dict = shape_consistency_manager.get_all_all_to_all_spec(sharding_spec, {'forward': 0, 'backward': 0, 'total': 0})
print(rst_dict) print(rst_dict)
Output: Output:
@ -250,7 +247,7 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
return valid_spec_dict return valid_spec_dict
def get_all_shard_spec(self, source_spec, orig_cost_dict): def get_all_shard_spec(self, source_spec: ShardingSpec, orig_cost_dict):
''' '''
Get all valid sharding specs from source_spec with single shard operation, and Get all valid sharding specs from source_spec with single shard operation, and
accumulate commucation cost on origin cost which will finally be used in auto sharding solver. accumulate commucation cost on origin cost which will finally be used in auto sharding solver.
@ -270,7 +267,7 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
# device_mesh_shape: (4, 4) # device_mesh_shape: (4, 4)
sharding_spec = ShardingSpec(device_mesh, entire_shape, dim_partition_dict) sharding_spec = ShardingSpec(device_mesh, entire_shape, dim_partition_dict)
shape_consistency_manager = ShapeConsistencyManager() shape_consistency_manager = ShapeConsistencyManager()
rst_dict = shape_consistency_manager.get_all_shard_spec(sharding_spec, 0) rst_dict = shape_consistency_manager.get_all_shard_spec(sharding_spec, {'forward': 0, 'backward': 0, 'total': 0})
print(rst_dict) print(rst_dict)
Output: Output:
@ -331,7 +328,7 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
pass pass
return valid_spec_dict return valid_spec_dict
def get_all_one_step_transform_spec(self, source_spec, orig_cost_dict): def get_all_one_step_transform_spec(self, source_spec: ShardingSpec, orig_cost_dict) -> Dict[ShardingSpec, float]:
''' '''
Get all valid sharding specs from source_spec with one step transform, and Get all valid sharding specs from source_spec with one step transform, and
accumulate commucation cost on origin cost which will finally be used in auto sharding solver. accumulate commucation cost on origin cost which will finally be used in auto sharding solver.
@ -353,7 +350,8 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
valid_spec_dict.update(self.get_all_shard_spec(source_spec, orig_cost_dict)) valid_spec_dict.update(self.get_all_shard_spec(source_spec, orig_cost_dict))
return valid_spec_dict return valid_spec_dict
def shape_consistency(self, source_spec, target_spec): def shape_consistency(self, source_spec: ShardingSpec,
target_spec: ShardingSpec) -> Tuple[List[ShardingSpec], List[CommSpec], float]:
''' '''
This method will find a path to transform source_spec to target_spec with This method will find a path to transform source_spec to target_spec with
a greedy algorithm. a greedy algorithm.
@ -459,7 +457,7 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
raise RuntimeError(f"Could not find a valid transform path with in {MAX_TRANSFORM_STEPS} steps.") raise RuntimeError(f"Could not find a valid transform path with in {MAX_TRANSFORM_STEPS} steps.")
def apply(self, tensor_with_sharding_spec, target_spec): def apply(self, tensor_with_sharding_spec: torch.Tensor, target_spec: ShardingSpec) -> torch.Tensor:
''' '''
Apply target_spec to tensor with source sharding spec, the transform path is generated by the Apply target_spec to tensor with source sharding spec, the transform path is generated by the
shape_consistency method. shape_consistency method.

Loading…
Cancel
Save