[Auto Parallel]: Speed up intra-op plan generation by 44% (#5446)

* Remove unnecessary calls to deepcopy

* Build DimSpec's difference dict only once

This change considerably speeds up construction speed of DimSpec objects. The difference_dict is the same for each DimSpec object, so a single copy of it is enough.

* Fix documentation of DimSpec's difference method
pull/5902/head
Stephan Kö 2024-07-15 12:05:06 +08:00 committed by GitHub
parent c068ef0fa0
commit 45c49dde96
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 103 additions and 73 deletions

View File

@ -1,4 +1,3 @@
from copy import deepcopy
from typing import Dict, List
from ..utils import merge_same_dim_mesh_list
@ -23,10 +22,11 @@ class DimSpec:
Otherwise, the element in shard_list means the data will be sharded in that dimension.
"""
_DIFFERENCE_DICT = None
def __init__(self, shard_list):
self.is_replica = len(shard_list) == 0
self.shard_list = shard_list
self.build_difference_2d_dict()
def __eq__(self, other):
return str(self) == str(other)
@ -39,24 +39,43 @@ class DimSpec:
target += str(dim)
return target
def _convert_str_to_shard_list(self, str_spec):
@property
def difference_dict(self):
"""
Convert str_spec into shard_list.
Returns the difference dict, and lazily initializes it when needed
Return:
difference_dict(Dict[Tuple[int, int], Union[int, float, str]]):
difference dict
"""
if self._DIFFERENCE_DICT is None:
self._DIFFERENCE_DICT = self._build_difference_2d_dict()
return self._DIFFERENCE_DICT
def dim_diff(self, other):
"""
The difference between two DimSpec.
Argument:
str_spec(str): dim spec in str type.
other(DimSpec): the dim spec to compare with.
Return:
difference(int): the difference between two DimSpec.
Example:
dim_spec = DimSpec([0])
other_dim_spec = DimSpec([0, 1])
print(dim_spec.dim_diff(other_dim_spec))
Output:
5
"""
difference = self.difference_dict[(str(self), str(other))]
return difference
if str_spec == "R":
return []
if str_spec == "S0":
return [0]
if str_spec == "S1":
return [1]
if str_spec == "S01":
return [0, 1]
def build_difference_2d_dict(self):
@classmethod
def _build_difference_2d_dict(cls):
"""
Build a difference mapping for 2D device mesh case. It will be used to
compute the difference between DimSpec pairs.
@ -67,9 +86,8 @@ class DimSpec:
difference_dict = {}
for source_spec in source_spec_list:
for target_spec in target_spec_list:
spec_pair = (deepcopy(source_spec), deepcopy(target_spec))
source_shard_list = self._convert_str_to_shard_list(source_spec)
target_shard_list = self._convert_str_to_shard_list(target_spec)
source_shard_list = cls._convert_str_to_shard_list(source_spec)
target_shard_list = cls._convert_str_to_shard_list(target_spec)
# source same as target
if source_shard_list == target_shard_list:
@ -112,30 +130,27 @@ class DimSpec:
else:
difference = NAN
difference_dict[spec_pair] = difference
difference_dict[(source_spec, target_spec)] = difference
self.difference_dict = difference_dict
return difference_dict
def dim_diff(self, other):
@staticmethod
def _convert_str_to_shard_list(str_spec):
"""
The difference between two _DimSpec.
Convert str_spec into shard_list.
Argument:
other(_DimSpec): the dim spec to compare with.
Return:
difference(int): the difference between two _DimSpec.
Example:
dim_spec = _DimSpec([0])
other_dim_spec = _DimSpec([0, 1])
print(dim_spec.difference(other_dim_spec))
Output:
5
str_spec(str): dim spec in str type.
"""
difference = self.difference_dict[(str(self), str(other))]
return difference
if str_spec == "R":
return []
if str_spec == "S0":
return [0]
if str_spec == "S1":
return [1]
if str_spec == "S01":
return [0, 1]
class ShardingSpec:

View File

@ -1,5 +1,4 @@
import operator
from copy import deepcopy
from functools import reduce
import torch
@ -27,10 +26,11 @@ class _DimSpec:
Otherwise, the element in shard_list means the data will be sharded in that dimension.
"""
_DIFFERENCE_DICT = None
def __init__(self, shard_list):
self.is_replica = len(shard_list) == 0
self.shard_list = shard_list
self.build_difference_2d_dict()
def __eq__(self, other):
return str(self) == str(other)
@ -43,27 +43,46 @@ class _DimSpec:
target += str(dim)
return target
def _convert_str_to_shard_list(self, str_spec):
@property
def difference_dict(self):
"""
Convert str_spec into shard_list.
Returns the difference dict, and lazily initializes it when needed
Return:
difference_dict(Dict[Tuple[int, int], Union[int, float, str]]):
difference dict
"""
if self._DIFFERENCE_DICT is None:
self._DIFFERENCE_DICT = self._build_difference_2d_dict()
return self._DIFFERENCE_DICT
def difference(self, other):
"""
The difference between two _DimSpec.
Argument:
str_spec(str): dim spec in str type.
other(_DimSpec): the dim spec to compare with.
Return:
difference(int): the difference between two _DimSpec.
Example:
dim_spec = _DimSpec([0])
other_dim_spec = _DimSpec([0, 1])
print(dim_spec.difference(other_dim_spec))
Output:
5
"""
difference = self.difference_dict[(str(self), str(other))]
return difference
if str_spec == "R":
return []
if str_spec == "S0":
return [0]
if str_spec == "S1":
return [1]
if str_spec == "S01":
return [0, 1]
def build_difference_2d_dict(self):
@classmethod
def _build_difference_2d_dict(cls):
"""
Build a difference mapping for 2D device mesh case. It will be used to
compute the difference between DimSpec pairs.
compute the difference between _DimSpec pairs.
"""
source_spec_list = ["R", "S0", "S1", "S01"]
@ -71,9 +90,8 @@ class _DimSpec:
difference_dict = {}
for source_spec in source_spec_list:
for target_spec in target_spec_list:
spec_pair = (deepcopy(source_spec), deepcopy(target_spec))
source_shard_list = self._convert_str_to_shard_list(source_spec)
target_shard_list = self._convert_str_to_shard_list(target_spec)
source_shard_list = cls._convert_str_to_shard_list(source_spec)
target_shard_list = cls._convert_str_to_shard_list(target_spec)
# source same as target
if source_shard_list == target_shard_list:
@ -116,30 +134,27 @@ class _DimSpec:
else:
difference = NAN
difference_dict[spec_pair] = difference
difference_dict[(source_spec, target_spec)] = difference
self.difference_dict = difference_dict
return difference_dict
def difference(self, other):
@staticmethod
def _convert_str_to_shard_list(str_spec):
"""
The difference between two _DimSpec.
Convert str_spec into shard_list.
Argument:
other(_DimSpec): the dim spec to compare with.
Return:
difference(int): the difference between two _DimSpec.
Example:
dim_spec = _DimSpec([0])
other_dim_spec = _DimSpec([0, 1])
print(dim_spec.difference(other_dim_spec))
Output:
5
str_spec(str): dim spec in str type.
"""
difference = self.difference_dict[(str(self), str(other))]
return difference
if str_spec == "R":
return []
if str_spec == "S0":
return [0]
if str_spec == "S1":
return [1]
if str_spec == "S01":
return [0, 1]
class ShardingSpecException(Exception):