mirror of https://github.com/hpcaitech/ColossalAI
[Auto Parallel]: Speed up intra-op plan generation by 44% (#5446)
* Remove unnecessary calls to deepcopy * Build DimSpec's difference dict only once This change considerably speeds up construction speed of DimSpec objects. The difference_dict is the same for each DimSpec object, so a single copy of it is enough. * Fix documentation of DimSpec's difference methodpull/5902/head
parent
c068ef0fa0
commit
45c49dde96
|
@ -1,4 +1,3 @@
|
|||
from copy import deepcopy
|
||||
from typing import Dict, List
|
||||
|
||||
from ..utils import merge_same_dim_mesh_list
|
||||
|
@ -23,10 +22,11 @@ class DimSpec:
|
|||
Otherwise, the element in shard_list means the data will be sharded in that dimension.
|
||||
"""
|
||||
|
||||
_DIFFERENCE_DICT = None
|
||||
|
||||
def __init__(self, shard_list):
|
||||
self.is_replica = len(shard_list) == 0
|
||||
self.shard_list = shard_list
|
||||
self.build_difference_2d_dict()
|
||||
|
||||
def __eq__(self, other):
|
||||
return str(self) == str(other)
|
||||
|
@ -39,24 +39,43 @@ class DimSpec:
|
|||
target += str(dim)
|
||||
return target
|
||||
|
||||
def _convert_str_to_shard_list(self, str_spec):
|
||||
@property
|
||||
def difference_dict(self):
|
||||
"""
|
||||
Convert str_spec into shard_list.
|
||||
Returns the difference dict, and lazily initializes it when needed
|
||||
|
||||
Return:
|
||||
difference_dict(Dict[Tuple[int, int], Union[int, float, str]]):
|
||||
difference dict
|
||||
"""
|
||||
if self._DIFFERENCE_DICT is None:
|
||||
self._DIFFERENCE_DICT = self._build_difference_2d_dict()
|
||||
|
||||
return self._DIFFERENCE_DICT
|
||||
|
||||
def dim_diff(self, other):
|
||||
"""
|
||||
The difference between two DimSpec.
|
||||
|
||||
Argument:
|
||||
str_spec(str): dim spec in str type.
|
||||
other(DimSpec): the dim spec to compare with.
|
||||
|
||||
Return:
|
||||
difference(int): the difference between two DimSpec.
|
||||
|
||||
Example:
|
||||
dim_spec = DimSpec([0])
|
||||
other_dim_spec = DimSpec([0, 1])
|
||||
print(dim_spec.dim_diff(other_dim_spec))
|
||||
|
||||
Output:
|
||||
5
|
||||
"""
|
||||
difference = self.difference_dict[(str(self), str(other))]
|
||||
return difference
|
||||
|
||||
if str_spec == "R":
|
||||
return []
|
||||
if str_spec == "S0":
|
||||
return [0]
|
||||
if str_spec == "S1":
|
||||
return [1]
|
||||
if str_spec == "S01":
|
||||
return [0, 1]
|
||||
|
||||
def build_difference_2d_dict(self):
|
||||
@classmethod
|
||||
def _build_difference_2d_dict(cls):
|
||||
"""
|
||||
Build a difference mapping for 2D device mesh case. It will be used to
|
||||
compute the difference between DimSpec pairs.
|
||||
|
@ -67,9 +86,8 @@ class DimSpec:
|
|||
difference_dict = {}
|
||||
for source_spec in source_spec_list:
|
||||
for target_spec in target_spec_list:
|
||||
spec_pair = (deepcopy(source_spec), deepcopy(target_spec))
|
||||
source_shard_list = self._convert_str_to_shard_list(source_spec)
|
||||
target_shard_list = self._convert_str_to_shard_list(target_spec)
|
||||
source_shard_list = cls._convert_str_to_shard_list(source_spec)
|
||||
target_shard_list = cls._convert_str_to_shard_list(target_spec)
|
||||
|
||||
# source same as target
|
||||
if source_shard_list == target_shard_list:
|
||||
|
@ -112,30 +130,27 @@ class DimSpec:
|
|||
|
||||
else:
|
||||
difference = NAN
|
||||
difference_dict[spec_pair] = difference
|
||||
difference_dict[(source_spec, target_spec)] = difference
|
||||
|
||||
self.difference_dict = difference_dict
|
||||
return difference_dict
|
||||
|
||||
def dim_diff(self, other):
|
||||
@staticmethod
|
||||
def _convert_str_to_shard_list(str_spec):
|
||||
"""
|
||||
The difference between two _DimSpec.
|
||||
Convert str_spec into shard_list.
|
||||
|
||||
Argument:
|
||||
other(_DimSpec): the dim spec to compare with.
|
||||
|
||||
Return:
|
||||
difference(int): the difference between two _DimSpec.
|
||||
|
||||
Example:
|
||||
dim_spec = _DimSpec([0])
|
||||
other_dim_spec = _DimSpec([0, 1])
|
||||
print(dim_spec.difference(other_dim_spec))
|
||||
|
||||
Output:
|
||||
5
|
||||
str_spec(str): dim spec in str type.
|
||||
"""
|
||||
difference = self.difference_dict[(str(self), str(other))]
|
||||
return difference
|
||||
|
||||
if str_spec == "R":
|
||||
return []
|
||||
if str_spec == "S0":
|
||||
return [0]
|
||||
if str_spec == "S1":
|
||||
return [1]
|
||||
if str_spec == "S01":
|
||||
return [0, 1]
|
||||
|
||||
|
||||
class ShardingSpec:
|
||||
|
|
|
@ -1,5 +1,4 @@
|
|||
import operator
|
||||
from copy import deepcopy
|
||||
from functools import reduce
|
||||
|
||||
import torch
|
||||
|
@ -27,10 +26,11 @@ class _DimSpec:
|
|||
Otherwise, the element in shard_list means the data will be sharded in that dimension.
|
||||
"""
|
||||
|
||||
_DIFFERENCE_DICT = None
|
||||
|
||||
def __init__(self, shard_list):
|
||||
self.is_replica = len(shard_list) == 0
|
||||
self.shard_list = shard_list
|
||||
self.build_difference_2d_dict()
|
||||
|
||||
def __eq__(self, other):
|
||||
return str(self) == str(other)
|
||||
|
@ -43,27 +43,46 @@ class _DimSpec:
|
|||
target += str(dim)
|
||||
return target
|
||||
|
||||
def _convert_str_to_shard_list(self, str_spec):
|
||||
@property
|
||||
def difference_dict(self):
|
||||
"""
|
||||
Convert str_spec into shard_list.
|
||||
Returns the difference dict, and lazily initializes it when needed
|
||||
|
||||
Return:
|
||||
difference_dict(Dict[Tuple[int, int], Union[int, float, str]]):
|
||||
difference dict
|
||||
"""
|
||||
if self._DIFFERENCE_DICT is None:
|
||||
self._DIFFERENCE_DICT = self._build_difference_2d_dict()
|
||||
|
||||
return self._DIFFERENCE_DICT
|
||||
|
||||
def difference(self, other):
|
||||
"""
|
||||
The difference between two _DimSpec.
|
||||
|
||||
Argument:
|
||||
str_spec(str): dim spec in str type.
|
||||
other(_DimSpec): the dim spec to compare with.
|
||||
|
||||
Return:
|
||||
difference(int): the difference between two _DimSpec.
|
||||
|
||||
Example:
|
||||
dim_spec = _DimSpec([0])
|
||||
other_dim_spec = _DimSpec([0, 1])
|
||||
print(dim_spec.difference(other_dim_spec))
|
||||
|
||||
Output:
|
||||
5
|
||||
"""
|
||||
difference = self.difference_dict[(str(self), str(other))]
|
||||
return difference
|
||||
|
||||
if str_spec == "R":
|
||||
return []
|
||||
if str_spec == "S0":
|
||||
return [0]
|
||||
if str_spec == "S1":
|
||||
return [1]
|
||||
if str_spec == "S01":
|
||||
return [0, 1]
|
||||
|
||||
def build_difference_2d_dict(self):
|
||||
@classmethod
|
||||
def _build_difference_2d_dict(cls):
|
||||
"""
|
||||
Build a difference mapping for 2D device mesh case. It will be used to
|
||||
compute the difference between DimSpec pairs.
|
||||
compute the difference between _DimSpec pairs.
|
||||
"""
|
||||
|
||||
source_spec_list = ["R", "S0", "S1", "S01"]
|
||||
|
@ -71,9 +90,8 @@ class _DimSpec:
|
|||
difference_dict = {}
|
||||
for source_spec in source_spec_list:
|
||||
for target_spec in target_spec_list:
|
||||
spec_pair = (deepcopy(source_spec), deepcopy(target_spec))
|
||||
source_shard_list = self._convert_str_to_shard_list(source_spec)
|
||||
target_shard_list = self._convert_str_to_shard_list(target_spec)
|
||||
source_shard_list = cls._convert_str_to_shard_list(source_spec)
|
||||
target_shard_list = cls._convert_str_to_shard_list(target_spec)
|
||||
|
||||
# source same as target
|
||||
if source_shard_list == target_shard_list:
|
||||
|
@ -116,30 +134,27 @@ class _DimSpec:
|
|||
|
||||
else:
|
||||
difference = NAN
|
||||
difference_dict[spec_pair] = difference
|
||||
difference_dict[(source_spec, target_spec)] = difference
|
||||
|
||||
self.difference_dict = difference_dict
|
||||
return difference_dict
|
||||
|
||||
def difference(self, other):
|
||||
@staticmethod
|
||||
def _convert_str_to_shard_list(str_spec):
|
||||
"""
|
||||
The difference between two _DimSpec.
|
||||
Convert str_spec into shard_list.
|
||||
|
||||
Argument:
|
||||
other(_DimSpec): the dim spec to compare with.
|
||||
|
||||
Return:
|
||||
difference(int): the difference between two _DimSpec.
|
||||
|
||||
Example:
|
||||
dim_spec = _DimSpec([0])
|
||||
other_dim_spec = _DimSpec([0, 1])
|
||||
print(dim_spec.difference(other_dim_spec))
|
||||
|
||||
Output:
|
||||
5
|
||||
str_spec(str): dim spec in str type.
|
||||
"""
|
||||
difference = self.difference_dict[(str(self), str(other))]
|
||||
return difference
|
||||
|
||||
if str_spec == "R":
|
||||
return []
|
||||
if str_spec == "S0":
|
||||
return [0]
|
||||
if str_spec == "S1":
|
||||
return [1]
|
||||
if str_spec == "S01":
|
||||
return [0, 1]
|
||||
|
||||
|
||||
class ShardingSpecException(Exception):
|
||||
|
|
Loading…
Reference in New Issue