[autoparallel] refactored shape consistency to remove redundancy (#1591)

* [autoparallel] refactored shape consistency to remove redundancy

* polish code

* polish code

* polish code
pull/1587/head^2
Frank Lee 2022-09-13 18:30:18 +08:00 committed by GitHub
parent d164449d00
commit 27fe8af60c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 220 additions and 234 deletions

View File

@ -1,8 +1,9 @@
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
import torch import torch
from torch.fx.node import Node from torch.fx.node import Node
from colossalai.tensor.sharding_spec import ShardingSpec from colossalai.tensor.sharding_spec import ShardingSpec
from colossalai.device.device_mesh import DeviceMesh from colossalai.device.device_mesh import DeviceMesh
from typing import Union, Dict, List from typing import Union, Dict, List, Optional
def generate_sharding_spec(input_: Union[Node, torch.Tensor], device_mesh: DeviceMesh, def generate_sharding_spec(input_: Union[Node, torch.Tensor], device_mesh: DeviceMesh,
@ -31,3 +32,45 @@ def generate_sharding_spec(input_: Union[Node, torch.Tensor], device_mesh: Devic
sharding_spec = ShardingSpec(device_mesh=device_mesh, entire_shape=shape, dim_partition_dict=dim_partition_dict) sharding_spec = ShardingSpec(device_mesh=device_mesh, entire_shape=shape, dim_partition_dict=dim_partition_dict)
return sharding_spec return sharding_spec
def generate_resharding_costs(nodes: List[Node],
sharding_specs: List[ShardingSpec],
count_backward: Optional[bool] = True,
dtype: Optional[torch.dtype] = None):
'''
Compute the resharding costs with this specific strategy.
Argument:
nodes (List[Node]): a list of nodes
sharding_spec_for_input(ShardingSpec): a list of ShardingSpec for the nodes.
count_backward (Optional[bool]): whether to include the cost of resharding in the backward pass, default is True. False can be used for inference.
dtype (Optional[torch.dtype]): the data type for cost calculation, default is None.
'''
# The resharding_cost of weight is counted due to sharing weight cases.
resharding_costs = {}
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
# shape consistency manager is a singleton class
shape_consistency_manager = ShapeConsistencyManager()
for input_node, input_spec in zip(nodes, sharding_specs):
resharding_costs[input_node] = []
for strategy in input_node.strategies_vector:
input_sharding_spec = strategy.output_sharding_spec
assert isinstance(input_sharding_spec, ShardingSpec), f'The input node should NOT be a tuple of tensor.'
# compute the resharding cost during forward phase
_, _, resharding_cost_forward = shape_consistency_manager.shape_consistency(input_sharding_spec, input_spec)
if count_backward:
# In backward phase, we should convert grad with target_spec into input_sharding_spec
_, _, resharding_cost_backward = shape_consistency_manager.shape_consistency(
input_spec, input_sharding_spec)
total_resharding_cost = resharding_cost_forward + resharding_cost_backward
else:
total_resharding_cost = resharding_cost_forward
# we need multiply the size of elem dtype to get correct communication cost
resharding_cost = total_resharding_cost * size_per_elem_bytes
resharding_costs[input_node].append(resharding_cost)
return resharding_costs

View File

@ -4,7 +4,6 @@ import warnings
import torch import torch
from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector
from .operator_handler import OperatorHandler from .operator_handler import OperatorHandler
from .._utils import generate_sharding_spec
__all__ = ['BatchNormHandler'] __all__ = ['BatchNormHandler']
@ -115,15 +114,13 @@ class BatchNormHandler(OperatorHandler):
name = f'RS{mesh_dim_0} = RS{mesh_dim_0} x S{mesh_dim_0}' name = f'RS{mesh_dim_0} = RS{mesh_dim_0} x S{mesh_dim_0}'
dim_partition_dict_for_input = {1: [mesh_dim_0]} dim_partition_dict_for_input = {1: [mesh_dim_0]}
sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh, sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
dim_partition_dict_for_input)
dim_partition_dict_for_weight = {0: [mesh_dim_0]} dim_partition_dict_for_weight = {0: [mesh_dim_0]}
sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight) sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
dim_partition_dict_for_output = {1: [mesh_dim_0]} dim_partition_dict_for_output = {1: [mesh_dim_0]}
sharding_spec_for_output = generate_sharding_spec(self.output_data, self.device_mesh, sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
dim_partition_dict_for_output)
# generate resharding cost for this strategy # generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input]) resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
@ -156,8 +153,7 @@ class BatchNormHandler(OperatorHandler):
new_name = f'S{mesh_dim_1}S{mesh_dim_0} = RS{mesh_dim_0} x S{mesh_dim_0}' new_name = f'S{mesh_dim_1}S{mesh_dim_0} = RS{mesh_dim_0} x S{mesh_dim_0}'
dim_partition_dict_for_output = {0: [mesh_dim_1], 1: [mesh_dim_0]} dim_partition_dict_for_output = {0: [mesh_dim_1], 1: [mesh_dim_0]}
new_sharding_spec_for_output = generate_sharding_spec(self.output_data, self.device_mesh, new_sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
dim_partition_dict_for_output)
# the computation cost is all the same # the computation cost is all the same
new_compute_cost = compute_cost new_compute_cost = compute_cost
@ -192,15 +188,13 @@ class BatchNormHandler(OperatorHandler):
name = f'RS{mesh_dim_0}{mesh_dim_1} = RS{mesh_dim_0}{mesh_dim_1} x S{mesh_dim_0}{mesh_dim_1}' name = f'RS{mesh_dim_0}{mesh_dim_1} = RS{mesh_dim_0}{mesh_dim_1} x S{mesh_dim_0}{mesh_dim_1}'
dim_partition_dict_for_input = {1: [mesh_dim_0, mesh_dim_1]} dim_partition_dict_for_input = {1: [mesh_dim_0, mesh_dim_1]}
sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh, sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
dim_partition_dict_for_input)
dim_partition_dict_for_weight = {0: [mesh_dim_0, mesh_dim_1]} dim_partition_dict_for_weight = {0: [mesh_dim_0, mesh_dim_1]}
sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight) sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
dim_partition_dict_for_output = {1: [mesh_dim_0, mesh_dim_1]} dim_partition_dict_for_output = {1: [mesh_dim_0, mesh_dim_1]}
sharding_spec_for_output = generate_sharding_spec(self.output_data, self.device_mesh, sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
dim_partition_dict_for_output)
# generate resharding cost for this strategy # generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input]) resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
@ -234,15 +228,13 @@ class BatchNormHandler(OperatorHandler):
name = f'RR = RR x R' name = f'RR = RR x R'
dim_partition_dict_for_input = {} dim_partition_dict_for_input = {}
sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh, sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
dim_partition_dict_for_input)
dim_partition_dict_for_weight = {} dim_partition_dict_for_weight = {}
sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight) sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
dim_partition_dict_for_output = {} dim_partition_dict_for_output = {}
sharding_spec_for_output = generate_sharding_spec(self.output_data, self.device_mesh, sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
dim_partition_dict_for_output)
# generate resharding cost for this strategy # generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input]) resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
@ -273,8 +265,7 @@ class BatchNormHandler(OperatorHandler):
def _construct_batch_sharding_strategies(mesh_dim_list, new_name): def _construct_batch_sharding_strategies(mesh_dim_list, new_name):
dim_partition_dict_for_output = {0: mesh_dim_list} dim_partition_dict_for_output = {0: mesh_dim_list}
new_sharding_spec_for_output = generate_sharding_spec(self.output_data, self.device_mesh, new_sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
dim_partition_dict_for_output)
# the computation cost is all the same # the computation cost is all the same
new_compute_cost = compute_cost new_compute_cost = compute_cost
@ -332,15 +323,13 @@ class BatchNormHandler(OperatorHandler):
name = f'S{mesh_dim_0}R = S{mesh_dim_0}R x R WITH SYNC_BN' name = f'S{mesh_dim_0}R = S{mesh_dim_0}R x R WITH SYNC_BN'
dim_partition_dict_for_input = {0: [mesh_dim_0]} dim_partition_dict_for_input = {0: [mesh_dim_0]}
sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh, sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
dim_partition_dict_for_input)
dim_partition_dict_for_weight = {} dim_partition_dict_for_weight = {}
sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight) sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
dim_partition_dict_for_output = {0: [mesh_dim_0]} dim_partition_dict_for_output = {0: [mesh_dim_0]}
sharding_spec_for_output = generate_sharding_spec(self.output_data, self.device_mesh, sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
dim_partition_dict_for_output)
# generate resharding cost for this strategy # generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input]) resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
@ -374,15 +363,13 @@ class BatchNormHandler(OperatorHandler):
name = f'S{mesh_dim_0}{mesh_dim_1}R = S{mesh_dim_0}{mesh_dim_1}R x R WITH SYNC_BN' name = f'S{mesh_dim_0}{mesh_dim_1}R = S{mesh_dim_0}{mesh_dim_1}R x R WITH SYNC_BN'
dim_partition_dict_for_input = {0: [mesh_dim_0, mesh_dim_1]} dim_partition_dict_for_input = {0: [mesh_dim_0, mesh_dim_1]}
sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh, sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
dim_partition_dict_for_input)
dim_partition_dict_for_weight = {} dim_partition_dict_for_weight = {}
sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight) sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
dim_partition_dict_for_output = {0: [mesh_dim_0, mesh_dim_1]} dim_partition_dict_for_output = {0: [mesh_dim_0, mesh_dim_1]}
sharding_spec_for_output = generate_sharding_spec(self.output_data, self.device_mesh, sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
dim_partition_dict_for_output)
# generate resharding cost for this strategy # generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input]) resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
@ -416,15 +403,13 @@ class BatchNormHandler(OperatorHandler):
name = f'S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0}S{mesh_dim_1} x S{mesh_dim_1} WITH SYNC_BN' name = f'S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0}S{mesh_dim_1} x S{mesh_dim_1} WITH SYNC_BN'
dim_partition_dict_for_input = {0: [mesh_dim_0], 1: [mesh_dim_1]} dim_partition_dict_for_input = {0: [mesh_dim_0], 1: [mesh_dim_1]}
sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh, sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
dim_partition_dict_for_input)
dim_partition_dict_for_weight = {0: [mesh_dim_1]} dim_partition_dict_for_weight = {0: [mesh_dim_1]}
sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight) sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
dim_partition_dict_for_output = {0: [mesh_dim_0], 1: [mesh_dim_1]} dim_partition_dict_for_output = {0: [mesh_dim_0], 1: [mesh_dim_1]}
sharding_spec_for_output = generate_sharding_spec(self.output_data, self.device_mesh, sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
dim_partition_dict_for_output)
# generate resharding cost for this strategy # generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input]) resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
@ -459,7 +444,7 @@ class BatchNormHandler(OperatorHandler):
Generate every possible strategies for a BatchNorm node, and record all strategies into the strategies_vector. Generate every possible strategies for a BatchNorm node, and record all strategies into the strategies_vector.
Example: Example:
norm_handler = BatchNormHandler(node, self.device_mesh, strategies_vector, norm_handler = BatchNormHandler(node, strategies_vector,
self.shape_consistency_manager) self.shape_consistency_manager)
norm_handler.register_strategy() norm_handler.register_strategy()
for strategy in norm_handler.strategies_vector: for strategy in norm_handler.strategies_vector:

View File

@ -4,7 +4,6 @@ import warnings
import torch import torch
from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector
from .operator_handler import OperatorHandler from .operator_handler import OperatorHandler
from .._utils import generate_sharding_spec
__all__ = ['ConvHandler'] __all__ = ['ConvHandler']
@ -109,15 +108,13 @@ class ConvHandler(OperatorHandler):
name = f'S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0}R x RS{mesh_dim_1}' name = f'S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0}R x RS{mesh_dim_1}'
dim_partition_dict_for_input = {0: [mesh_dim_0]} dim_partition_dict_for_input = {0: [mesh_dim_0]}
sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh, sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
dim_partition_dict_for_input)
dim_partition_dict_for_weight = {1: [mesh_dim_1]} dim_partition_dict_for_weight = {1: [mesh_dim_1]}
sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight) sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
dim_partition_dict_for_output = {0: [mesh_dim_0], 1: [mesh_dim_1]} dim_partition_dict_for_output = {0: [mesh_dim_0], 1: [mesh_dim_1]}
sharding_spec_for_output = generate_sharding_spec(self.output_data, self.device_mesh, sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
dim_partition_dict_for_output)
# generate resharding cost for this strategy # generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight]) resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
@ -158,15 +155,13 @@ class ConvHandler(OperatorHandler):
name = f'S{mesh_dim_0}R = S{mesh_dim_0}R x RR' name = f'S{mesh_dim_0}R = S{mesh_dim_0}R x RR'
dim_partition_dict_for_input = {0: [mesh_dim_0]} dim_partition_dict_for_input = {0: [mesh_dim_0]}
sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh, sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
dim_partition_dict_for_input)
dim_partition_dict_for_weight = {} dim_partition_dict_for_weight = {}
sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight) sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
dim_partition_dict_for_output = {0: [mesh_dim_0]} dim_partition_dict_for_output = {0: [mesh_dim_0]}
sharding_spec_for_output = generate_sharding_spec(self.output_data, self.device_mesh, sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
dim_partition_dict_for_output)
# generate resharding cost for this strategy # generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight]) resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
@ -205,15 +200,13 @@ class ConvHandler(OperatorHandler):
name = f'S{mesh_dim_0}R = S{mesh_dim_0}S{mesh_dim_1} x S{mesh_dim_1}R' name = f'S{mesh_dim_0}R = S{mesh_dim_0}S{mesh_dim_1} x S{mesh_dim_1}R'
dim_partition_dict_for_input = {0: [mesh_dim_0], 1: [mesh_dim_1]} dim_partition_dict_for_input = {0: [mesh_dim_0], 1: [mesh_dim_1]}
sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh, sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
dim_partition_dict_for_input)
dim_partition_dict_for_weight = {0: [mesh_dim_0]} dim_partition_dict_for_weight = {0: [mesh_dim_0]}
sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight) sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
dim_partition_dict_for_output = {0: [mesh_dim_0]} dim_partition_dict_for_output = {0: [mesh_dim_0]}
sharding_spec_for_output = generate_sharding_spec(self.output_data, self.device_mesh, sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
dim_partition_dict_for_output)
# generate resharding cost for this strategy # generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight]) resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
@ -252,15 +245,13 @@ class ConvHandler(OperatorHandler):
name = f'RS{mesh_dim_1} = RS{mesh_dim_0} x S{mesh_dim_0}S{mesh_dim_1}' name = f'RS{mesh_dim_1} = RS{mesh_dim_0} x S{mesh_dim_0}S{mesh_dim_1}'
dim_partition_dict_for_input = {1: [mesh_dim_0]} dim_partition_dict_for_input = {1: [mesh_dim_0]}
sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh, sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
dim_partition_dict_for_input)
dim_partition_dict_for_weight = {0: [mesh_dim_0], 1: [mesh_dim_1]} dim_partition_dict_for_weight = {0: [mesh_dim_0], 1: [mesh_dim_1]}
sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight) sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
dim_partition_dict_for_output = {1: [mesh_dim_1]} dim_partition_dict_for_output = {1: [mesh_dim_1]}
sharding_spec_for_output = generate_sharding_spec(self.output_data, self.device_mesh, sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
dim_partition_dict_for_output)
# generate resharding cost for this strategy # generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight]) resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
@ -296,15 +287,13 @@ class ConvHandler(OperatorHandler):
name = f'RR = RS{mesh_dim_0} x S{mesh_dim_0}R' name = f'RR = RS{mesh_dim_0} x S{mesh_dim_0}R'
dim_partition_dict_for_input = {1: [mesh_dim_0]} dim_partition_dict_for_input = {1: [mesh_dim_0]}
sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh, sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
dim_partition_dict_for_input)
dim_partition_dict_for_weight = {0: [mesh_dim_0]} dim_partition_dict_for_weight = {0: [mesh_dim_0]}
sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight) sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
dim_partition_dict_for_output = {} dim_partition_dict_for_output = {}
sharding_spec_for_output = generate_sharding_spec(self.output_data, self.device_mesh, sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
dim_partition_dict_for_output)
# generate resharding cost for this strategy # generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight]) resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
@ -340,15 +329,13 @@ class ConvHandler(OperatorHandler):
name = f'RS{mesh_dim_0} = RR x RS{mesh_dim_0}' name = f'RS{mesh_dim_0} = RR x RS{mesh_dim_0}'
dim_partition_dict_for_input = {} dim_partition_dict_for_input = {}
sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh, sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
dim_partition_dict_for_input)
dim_partition_dict_for_weight = {1: [mesh_dim_0]} dim_partition_dict_for_weight = {1: [mesh_dim_0]}
sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight) sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
dim_partition_dict_for_output = {1: [mesh_dim_0]} dim_partition_dict_for_output = {1: [mesh_dim_0]}
sharding_spec_for_output = generate_sharding_spec(self.output_data, self.device_mesh, sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
dim_partition_dict_for_output)
# generate resharding cost for this strategy # generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight]) resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
@ -384,15 +371,13 @@ class ConvHandler(OperatorHandler):
name = f'RR = RR x RR' name = f'RR = RR x RR'
dim_partition_dict_for_input = {} dim_partition_dict_for_input = {}
sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh, sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
dim_partition_dict_for_input)
dim_partition_dict_for_weight = {} dim_partition_dict_for_weight = {}
sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight) sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
dim_partition_dict_for_output = {} dim_partition_dict_for_output = {}
sharding_spec_for_output = generate_sharding_spec(self.output_data, self.device_mesh, sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
dim_partition_dict_for_output)
# generate resharding cost for this strategy # generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight]) resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
@ -426,15 +411,13 @@ class ConvHandler(OperatorHandler):
name = f'S{mesh_dim_0}{mesh_dim_1}R = S{mesh_dim_0}{mesh_dim_1}R x RR' name = f'S{mesh_dim_0}{mesh_dim_1}R = S{mesh_dim_0}{mesh_dim_1}R x RR'
dim_partition_dict_for_input = {0: [mesh_dim_0, mesh_dim_1]} dim_partition_dict_for_input = {0: [mesh_dim_0, mesh_dim_1]}
sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh, sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
dim_partition_dict_for_input)
dim_partition_dict_for_weight = {} dim_partition_dict_for_weight = {}
sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight) sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
dim_partition_dict_for_output = {0: [mesh_dim_0, mesh_dim_1]} dim_partition_dict_for_output = {0: [mesh_dim_0, mesh_dim_1]}
sharding_spec_for_output = generate_sharding_spec(self.output_data, self.device_mesh, sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
dim_partition_dict_for_output)
# generate resharding cost for this strategy # generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight]) resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
@ -475,15 +458,13 @@ class ConvHandler(OperatorHandler):
name = f'RR = RS{mesh_dim_0}{mesh_dim_1} x S{mesh_dim_0}{mesh_dim_1}R' name = f'RR = RS{mesh_dim_0}{mesh_dim_1} x S{mesh_dim_0}{mesh_dim_1}R'
dim_partition_dict_for_input = {1: [mesh_dim_0, mesh_dim_1]} dim_partition_dict_for_input = {1: [mesh_dim_0, mesh_dim_1]}
sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh, sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
dim_partition_dict_for_input)
dim_partition_dict_for_weight = {0: [mesh_dim_0, mesh_dim_1]} dim_partition_dict_for_weight = {0: [mesh_dim_0, mesh_dim_1]}
sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight) sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
dim_partition_dict_for_output = {} dim_partition_dict_for_output = {}
sharding_spec_for_output = generate_sharding_spec(self.output_data, self.device_mesh, sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
dim_partition_dict_for_output)
# generate resharding cost for this strategy # generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight]) resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])

View File

@ -3,7 +3,6 @@ import torch
from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector
from .operator_handler import OperatorHandler from .operator_handler import OperatorHandler
from functools import reduce from functools import reduce
from .._utils import generate_sharding_spec
__all__ = ['DotHandler'] __all__ = ['DotHandler']
@ -29,16 +28,14 @@ class DotHandler(OperatorHandler):
name = f'S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0}R x RS{mesh_dim_1}' name = f'S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0}R x RS{mesh_dim_1}'
dim_partition_dict_for_input = {0: [mesh_dim_0]} dim_partition_dict_for_input = {0: [mesh_dim_0]}
sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh, sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
dim_partition_dict_for_input)
# linear layer weight is transposed during init # linear layer weight is transposed during init
dim_partition_dict_for_weight = {0: [mesh_dim_1]} dim_partition_dict_for_weight = {0: [mesh_dim_1]}
sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight) sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
dim_partition_dict_for_output = {0: [mesh_dim_0], 1: [mesh_dim_1]} dim_partition_dict_for_output = {0: [mesh_dim_0], 1: [mesh_dim_1]}
sharding_spec_for_ouput = generate_sharding_spec(self.output_data, self.device_mesh, sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_input)
dim_partition_dict_for_input)
# generate resharding cost for this strategy # generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input]) resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
@ -69,17 +66,15 @@ class DotHandler(OperatorHandler):
name = f'S{mesh_dim_0}R = S{mesh_dim_0}S{mesh_dim_1} x S{mesh_dim_1}R' name = f'S{mesh_dim_0}R = S{mesh_dim_0}S{mesh_dim_1} x S{mesh_dim_1}R'
dim_partition_dict_for_input = {0: [mesh_dim_0], 1: [mesh_dim_1]} dim_partition_dict_for_input = {0: [mesh_dim_0], 1: [mesh_dim_1]}
sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh, sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
dim_partition_dict_for_input)
# since weight of the linear layer is transposed # since weight of the linear layer is transposed
# the actual dim to be sharded is 1 # the actual dim to be sharded is 1
dim_partition_dict_for_weight = {1: [mesh_dim_0]} dim_partition_dict_for_weight = {1: [mesh_dim_0]}
sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight) sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
dim_partition_dict_for_output = {0: [mesh_dim_0]} dim_partition_dict_for_output = {0: [mesh_dim_0]}
sharding_spec_for_ouput = generate_sharding_spec(self.output_data, self.device_mesh, sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
dim_partition_dict_for_output)
# generate resharding cost for this strategy # generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input]) resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
@ -106,15 +101,13 @@ class DotHandler(OperatorHandler):
name = f'RS{mesh_dim_1} = RS{mesh_dim_0} x S{mesh_dim_0}S{mesh_dim_1}' name = f'RS{mesh_dim_1} = RS{mesh_dim_0} x S{mesh_dim_0}S{mesh_dim_1}'
dim_partition_dict_for_input = {1: [mesh_dim_0]} dim_partition_dict_for_input = {1: [mesh_dim_0]}
sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh, sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
dim_partition_dict_for_input)
dim_partition_dict_for_weight = {0: [mesh_dim_0], 1: [mesh_dim_1]} dim_partition_dict_for_weight = {0: [mesh_dim_0], 1: [mesh_dim_1]}
sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight) sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
dim_partition_dict_for_output = {1: [mesh_dim_1]} dim_partition_dict_for_output = {1: [mesh_dim_1]}
sharding_spec_for_ouput = generate_sharding_spec(self.output_data, self.device_mesh, sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_input)
dim_partition_dict_for_input)
# generate resharding cost for this strategy # generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input]) resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
@ -141,15 +134,13 @@ class DotHandler(OperatorHandler):
name = f'RR = RS{mesh_dim} x S{mesh_dim}R' name = f'RR = RS{mesh_dim} x S{mesh_dim}R'
dim_partition_dict_for_input = {1: [mesh_dim]} dim_partition_dict_for_input = {1: [mesh_dim]}
sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh, sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
dim_partition_dict_for_input)
dim_partition_dict_for_weight = {1: [mesh_dim]} dim_partition_dict_for_weight = {1: [mesh_dim]}
sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight) sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
dim_partition_dict_for_output = {} dim_partition_dict_for_output = {}
sharding_spec_for_ouput = generate_sharding_spec(self.output_data, self.device_mesh, sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
dim_partition_dict_for_output)
# generate resharding cost for this strategy # generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input]) resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
@ -176,15 +167,13 @@ class DotHandler(OperatorHandler):
name = f'RS{mesh_dim} = RR x RS{mesh_dim}' name = f'RS{mesh_dim} = RR x RS{mesh_dim}'
dim_partition_dict_for_input = {} dim_partition_dict_for_input = {}
sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh, sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
dim_partition_dict_for_input)
dim_partition_dict_for_weight = {0: [mesh_dim]} dim_partition_dict_for_weight = {0: [mesh_dim]}
sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight) sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
dim_partition_dict_for_output = {1: [mesh_dim]} dim_partition_dict_for_output = {1: [mesh_dim]}
sharding_spec_for_ouput = generate_sharding_spec(self.output_data, self.device_mesh, sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
dim_partition_dict_for_output)
# generate resharding cost for this strategy # generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input]) resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
@ -211,15 +200,13 @@ class DotHandler(OperatorHandler):
name = f'S{mesh_dim_0}{mesh_dim_1}R = S{mesh_dim_0}{mesh_dim_1}R x RR' name = f'S{mesh_dim_0}{mesh_dim_1}R = S{mesh_dim_0}{mesh_dim_1}R x RR'
dim_partition_dict_for_input = {0: [mesh_dim_0, mesh_dim_1]} dim_partition_dict_for_input = {0: [mesh_dim_0, mesh_dim_1]}
sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh, sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
dim_partition_dict_for_input)
dim_partition_dict_for_weight = {} dim_partition_dict_for_weight = {}
sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight) sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
dim_partition_dict_for_output = {0: [mesh_dim_0, mesh_dim_1]} dim_partition_dict_for_output = {0: [mesh_dim_0, mesh_dim_1]}
sharding_spec_for_ouput = generate_sharding_spec(self.output_data, self.device_mesh, sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
dim_partition_dict_for_output)
# generate resharding cost for this strategy # generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input]) resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
@ -246,15 +233,13 @@ class DotHandler(OperatorHandler):
name = f'RR = RS{mesh_dim_0}{mesh_dim_1} x S{mesh_dim_0}{mesh_dim_1}R' name = f'RR = RS{mesh_dim_0}{mesh_dim_1} x S{mesh_dim_0}{mesh_dim_1}R'
dim_partition_dict_for_input = {1: [mesh_dim_0, mesh_dim_1]} dim_partition_dict_for_input = {1: [mesh_dim_0, mesh_dim_1]}
sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh, sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
dim_partition_dict_for_input)
dim_partition_dict_for_weight = {0: [mesh_dim_0, mesh_dim_1]} dim_partition_dict_for_weight = {0: [mesh_dim_0, mesh_dim_1]}
sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight) sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
dim_partition_dict_for_output = {} dim_partition_dict_for_output = {}
sharding_spec_for_ouput = generate_sharding_spec(self.output_data, self.device_mesh, sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
dim_partition_dict_for_output)
# generate resharding cost for this strategy # generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input]) resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
@ -281,15 +266,13 @@ class DotHandler(OperatorHandler):
name = f'RS{mesh_dim_0}{mesh_dim_1} = RR x RS{mesh_dim_0}{mesh_dim_1}' name = f'RS{mesh_dim_0}{mesh_dim_1} = RR x RS{mesh_dim_0}{mesh_dim_1}'
dim_partition_dict_for_input = {} dim_partition_dict_for_input = {}
sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh, sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
dim_partition_dict_for_input)
dim_partition_dict_for_weight = {1: [mesh_dim_0, mesh_dim_1]} dim_partition_dict_for_weight = {1: [mesh_dim_0, mesh_dim_1]}
sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight) sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
dim_partition_dict_for_output = {1: [mesh_dim_0, mesh_dim_1]} dim_partition_dict_for_output = {1: [mesh_dim_0, mesh_dim_1]}
sharding_spec_for_ouput = generate_sharding_spec(self.output_data, self.device_mesh, sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
dim_partition_dict_for_output)
# generate resharding cost for this strategy # generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input]) resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])

View File

@ -7,6 +7,7 @@ from typing import Dict, List
from colossalai.device.device_mesh import DeviceMesh from colossalai.device.device_mesh import DeviceMesh
from colossalai.tensor.shape_consistency import ShapeConsistencyManager from colossalai.tensor.shape_consistency import ShapeConsistencyManager
from colossalai.tensor.sharding_spec import ShardingSpec from colossalai.tensor.sharding_spec import ShardingSpec
from .._utils import generate_resharding_costs, generate_sharding_spec
from ..sharding_strategy import StrategiesVector from ..sharding_strategy import StrategiesVector
@ -17,24 +18,24 @@ class OperatorHandler(ABC):
''' '''
The OperatorHandler is an abstract class used to generate every possible strategies for an operator node. The OperatorHandler is an abstract class used to generate every possible strategies for an operator node.
Argument: Args:
input_node(Node): the input node in node argument list. node (Node): the input node in node argument list.
input_index(int): the index of input node in the node argument list. device_mesh (DeviceMesh): A logical view of a physical mesh.
weight(torch.Tensor): Weight of the node. strategies_vector (StrategiesVector): all the strategies generated in this handler will be recorded into the strategies_vector.
output_node(Node): Output_node is the output of the node. handle_backward (Optional[bool]): whether to consider the backward pass. The default value is True. False can be used for inference.
device_mesh(DeviceMesh): A logical view of a physical mesh.
strategies_vector(StrategiesVector): all the strategies generated in this handler will be recorded into the strategies_vector.
shape_consistency_manager(ShapeConsistencyManager): ShapeConsistencyManager will give the resharding costs of the different sharding specs.
''' '''
def __init__(self, node: Node, device_mesh: DeviceMesh, strategies_vector: StrategiesVector, def __init__(self,
shape_consistency_manager: ShapeConsistencyManager): node: Node,
device_mesh: DeviceMesh,
strategies_vector: StrategiesVector,
handle_backward: bool = True):
self.node = node self.node = node
self.predecessor_node = list(node._input_nodes.keys()) self.predecessor_node = list(node._input_nodes.keys())
self.successor_node = list(node.users.keys()) self.successor_node = list(node.users.keys())
self.device_mesh = device_mesh self.device_mesh = device_mesh
self.strategies_vector = strategies_vector self.strategies_vector = strategies_vector
self.shape_consistency_manager = shape_consistency_manager self.handle_backward = handle_backward
# find the module and its parameters associated with this node # find the module and its parameters associated with this node
# this can be used to compute the compute/communication/sharding cost # this can be used to compute the compute/communication/sharding cost
@ -102,35 +103,23 @@ class OperatorHandler(ABC):
return total_memory_cost, activation_memory_cost, weight_memory_cost return total_memory_cost, activation_memory_cost, weight_memory_cost
def _generate_resharding_costs(self, sharding_spec_for_input): def _generate_resharding_costs(self, sharding_specs):
'''
Compute the resharding costs with this specific strategy.
Note: The resharding_cost of weight is NOT counted.
Argument:
resharding_costs(Dict[int, List[float]]): The resharding cost generated in this method will be appended into this dictionary.
Resharding_cost[i][j] means the cost of i-th argument in the output node argument list
with j-th strategy in its strategies_vector transforms to sharding spec wanted in this
strategy.
sharding_spec_for_input(ShardingSpec): ShardingSpec of the input node.
'''
# The resharding_cost of weight is counted due to sharing weight cases. # The resharding_cost of weight is counted due to sharing weight cases.
resharding_costs = {}
dtype = self.node._meta_data.dtype dtype = self.node._meta_data.dtype
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size() nodes = self.predecessor_node
for input_node, input_spec in zip(self.predecessor_node, sharding_spec_for_input): return generate_resharding_costs(nodes=nodes,
resharding_costs[input_node] = [] sharding_specs=sharding_specs,
for strategy in input_node.strategies_vector: count_backward=self.handle_backward,
input_sharding_spec = strategy.output_sharding_spec dtype=dtype)
assert isinstance(input_sharding_spec, ShardingSpec), f'The input node should NOT be a tuple of tensor.'
# compute the resharding cost during forward phase def _generate_sharding_spec(self, input_: torch.Tensor, dim_partition_dict: Dict[int, List[int]]) -> ShardingSpec:
_, _, resharding_cost_forward = self.shape_consistency_manager.shape_consistency( return generate_sharding_spec(input_=input_,
input_sharding_spec, input_spec) device_mesh=self.device_mesh,
# In backward phase, we should convert grad with target_spec into input_sharding_spec dim_partition_dict=dim_partition_dict)
_, _, resharding_cost_backward = self.shape_consistency_manager.shape_consistency(
input_spec, input_sharding_spec) @abstractmethod
# we need multiply the size of elem dtype to get correct communication cost def _generate_compute_cost(self, *args, **kwargs):
resharding_cost = (resharding_cost_forward + resharding_cost_backward) * size_per_elem_bytes """
resharding_costs[input_node].append(resharding_cost) Compute the flops involved in the node.
return resharding_costs """
pass

View File

@ -11,7 +11,7 @@ import math
import torch import torch
import operator import operator
from typing import Dict, List from typing import Dict, List
from ._utils import generate_sharding_spec from ._utils import generate_sharding_spec, generate_resharding_costs
class StrategiesConstructor: class StrategiesConstructor:
@ -21,12 +21,10 @@ class StrategiesConstructor:
Args: Args:
graph (Graph): a Graph object used for analysis and strategy generation. graph (Graph): a Graph object used for analysis and strategy generation.
device_mesh (DeviceMesh): a DeviceMesh object which contains the meta information about the cluster. device_mesh (DeviceMesh): a DeviceMesh object which contains the meta information about the cluster.
shape_consistency_manager (ShapeConsistencyManager): a ShapeConsistencyManager object to make sure the sharding specs are consistent.
solver_options (SolverOptions): a SolverOptions object which specifies the preferences for plan searching. solver_options (SolverOptions): a SolverOptions object which specifies the preferences for plan searching.
""" """
def __init__(self, graph: Graph, device_mesh: DeviceMesh, shape_consistency_manager: ShapeConsistencyManager, def __init__(self, graph: Graph, device_mesh: DeviceMesh, solver_options: SolverOptions):
solver_options: SolverOptions):
self.graph = graph self.graph = graph
assert graph.owning_module is not None, 'The given graph is not associated with a owning_module' assert graph.owning_module is not None, 'The given graph is not associated with a owning_module'
self.root_module = self.graph.owning_module self.root_module = self.graph.owning_module
@ -34,27 +32,8 @@ class StrategiesConstructor:
self.device_mesh = device_mesh self.device_mesh = device_mesh
self.leaf_strategies = [] self.leaf_strategies = []
self.strategy_map = {} self.strategy_map = {}
self.shape_consistency_manager = shape_consistency_manager
self.solver_options = solver_options self.solver_options = solver_options
def _generate_resharding_costs(self, input_nodes, target_sharding_specs):
'''
Compute the resharding costs with this specific strategy.
Argument:
sharding_spec_for_input(ShardingSpec): ShardingSpec of the input node.
'''
resharding_costs = {}
for input_node, target_sharding_spec in zip(input_nodes, target_sharding_specs):
resharding_costs[input_node] = []
for strategy in input_node.strategies_vector:
input_sharding_spec = strategy.output_sharding_spec
assert isinstance(input_sharding_spec, ShardingSpec), f'The input node should NOT be a tuple of tensor.'
_, _, resharding_cost = self.shape_consistency_manager.shape_consistency(
input_sharding_spec, target_sharding_spec)
resharding_costs[input_node].append(resharding_cost)
return resharding_costs
def remove_duplicated_strategy(self, strategies_vector): def remove_duplicated_strategy(self, strategies_vector):
''' '''
In build_strategies_and_cost method, we may produce some duplicated strategies. In build_strategies_and_cost method, we may produce some duplicated strategies.
@ -120,14 +99,13 @@ class StrategiesConstructor:
# conv module # conv module
if submod_type in CONV_MODULE_OP: if submod_type in CONV_MODULE_OP:
# use ConvHandler to create sharding strategies for conv module node # use ConvHandler to create sharding strategies for conv module node
conv_handler = ConvHandler(node, self.device_mesh, strategies_vector, conv_handler = ConvHandler(node, self.device_mesh, strategies_vector)
self.shape_consistency_manager)
conv_handler.register_strategy() conv_handler.register_strategy()
# linear module # linear module
elif submod_type in LINEAR_MODULE_OP: elif submod_type in LINEAR_MODULE_OP:
# use DotHandler to create sharding strategies for linear module node # use DotHandler to create sharding strategies for linear module node
dot_handler = DotHandler(node, self.device_mesh, strategies_vector, self.shape_consistency_manager) dot_handler = DotHandler(node, self.device_mesh, strategies_vector)
dot_handler.register_strategy() dot_handler.register_strategy()
# element-wise module # element-wise module
@ -158,8 +136,8 @@ class StrategiesConstructor:
# TODO: use meta_info_prop to profile memory cost and compute cost # TODO: use meta_info_prop to profile memory cost and compute cost
compute_cost = node._meta_data.numel() compute_cost = node._meta_data.numel()
memory_cost = 0 memory_cost = 0
resharding_costs = self._generate_resharding_costs(strategies_vector.predecessor_nodes, resharding_costs = generate_resharding_costs(strategies_vector.predecessor_nodes,
[input_sharding_spec]) [input_sharding_spec])
# to prevent the resharding happening, set their resharding cost to inf. # to prevent the resharding happening, set their resharding cost to inf.
resharding_costs[input_node] = [ resharding_costs[input_node] = [
@ -214,8 +192,8 @@ class StrategiesConstructor:
# TODO: use meta_info_prop to profile memory cost and compute cost # TODO: use meta_info_prop to profile memory cost and compute cost
compute_cost = node._meta_data.numel() compute_cost = node._meta_data.numel()
memory_cost = 0 memory_cost = 0
resharding_costs = self._generate_resharding_costs(strategies_vector.predecessor_nodes, resharding_costs = generate_resharding_costs(strategies_vector.predecessor_nodes,
[input_sharding_spec]) [input_sharding_spec])
sharding_strategy = ShardingStrategy(name, sharding_strategy = ShardingStrategy(name,
output_sharding_spec, output_sharding_spec,
@ -275,8 +253,8 @@ class StrategiesConstructor:
compute_cost = node._meta_data.numel() compute_cost = node._meta_data.numel()
memory_cost = 0 memory_cost = 0
resharding_costs = self._generate_resharding_costs(strategies_vector.predecessor_nodes, resharding_costs = generate_resharding_costs(strategies_vector.predecessor_nodes,
[input_sharding_spec]) [input_sharding_spec])
# to prevent the resharding happening, set their resharding cost to inf. # to prevent the resharding happening, set their resharding cost to inf.
resharding_costs[input_node] = [ resharding_costs[input_node] = [
@ -317,8 +295,8 @@ class StrategiesConstructor:
# TODO: use meta_info_prop to profile origin memory cost and compute cost, then divide them depending on sharding spec. # TODO: use meta_info_prop to profile origin memory cost and compute cost, then divide them depending on sharding spec.
compute_cost = 0 compute_cost = 0
memory_cost = 0 memory_cost = 0
resharding_costs = self._generate_resharding_costs(strategies_vector.predecessor_nodes, resharding_costs = generate_resharding_costs(strategies_vector.predecessor_nodes,
[new_input_sharding_spec]) [new_input_sharding_spec])
sharding_strategy = ShardingStrategy(name, (output_sharding_spec, output_sharding_spec), sharding_strategy = ShardingStrategy(name, (output_sharding_spec, output_sharding_spec),
compute_cost=compute_cost, compute_cost=compute_cost,
memory_cost=memory_cost, memory_cost=memory_cost,
@ -335,8 +313,8 @@ class StrategiesConstructor:
# TODO: use meta_info_prop to profile origin memory cost and compute cost, then divide them depending on sharding spec. # TODO: use meta_info_prop to profile origin memory cost and compute cost, then divide them depending on sharding spec.
compute_cost = 0 compute_cost = 0
memory_cost = 0 memory_cost = 0
resharding_costs = self._generate_resharding_costs(strategies_vector.predecessor_nodes, resharding_costs = generate_resharding_costs(strategies_vector.predecessor_nodes,
[input_sharding_spec]) [input_sharding_spec])
sharding_strategy = ShardingStrategy(name, (output_sharding_spec, output_sharding_spec), sharding_strategy = ShardingStrategy(name, (output_sharding_spec, output_sharding_spec),
compute_cost=compute_cost, compute_cost=compute_cost,
memory_cost=memory_cost, memory_cost=memory_cost,
@ -360,8 +338,8 @@ class StrategiesConstructor:
# TODO: use meta_info_prop to profile origin memory cost and compute cost, then divide them depending on sharding spec. # TODO: use meta_info_prop to profile origin memory cost and compute cost, then divide them depending on sharding spec.
compute_cost = 0 compute_cost = 0
memory_cost = 0 memory_cost = 0
resharding_costs = self._generate_resharding_costs(strategies_vector.predecessor_nodes, resharding_costs = generate_resharding_costs(strategies_vector.predecessor_nodes,
[input_sharding_spec]) [input_sharding_spec])
# to prevent the resharding happening, set their resharding cost to inf. # to prevent the resharding happening, set their resharding cost to inf.
resharding_costs[input_tensor_node] = [ resharding_costs[input_tensor_node] = [
cost if cost == 0 else math.inf for cost in resharding_costs[input_tensor_node] cost if cost == 0 else math.inf for cost in resharding_costs[input_tensor_node]
@ -397,8 +375,8 @@ class StrategiesConstructor:
output_sharding_spec = input_sharding_specs output_sharding_spec = input_sharding_specs
# TODO: use meta_info_prop to profile memory cost # TODO: use meta_info_prop to profile memory cost
memory_cost = 0 memory_cost = 0
resharding_costs = self._generate_resharding_costs(strategies_vector.predecessor_nodes, resharding_costs = generate_resharding_costs(strategies_vector.predecessor_nodes,
input_sharding_specs) input_sharding_specs)
# clear the resharding cost for the output node # clear the resharding cost for the output node
# TODO: we may remove this in final version # TODO: we may remove this in final version

View File

@ -15,4 +15,7 @@ class SingletonMeta(type):
if cls not in cls._instances: if cls not in cls._instances:
instance = super().__call__(*args, **kwargs) instance = super().__call__(*args, **kwargs)
cls._instances[cls] = instance cls._instances[cls] = instance
else:
assert len(args) == 0 and len(
kwargs) == 0, f'{cls.__name__} is a singleton class and a instance has been created.'
return cls._instances[cls] return cls._instances[cls]

View File

@ -1,15 +1,22 @@
import torch import torch
from dataclasses import dataclass
from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec
from colossalai.tensor.utils import all_gather_simulator, all_to_all_simulator, shard_simulator from colossalai.tensor.utils import all_gather_simulator, all_to_all_simulator, shard_simulator
from enum import Enum from enum import Enum
from copy import deepcopy from copy import deepcopy
from typing import Dict, List, Optional, Tuple, Union from typing import Dict, List, Optional, Tuple, Union
from colossalai.context.singleton_meta import SingletonMeta
import torch.distributed as dist import torch.distributed as dist
import math import math
from functools import reduce from functools import reduce
import operator import operator
from torch.distributed import ReduceOp from torch.distributed import ReduceOp
__all__ = [
'CollectiveCommPattern', 'CommSpec', 'ShapeConsistencyManager', 'ShapeConsistencyOptions',
'set_shape_consistency_options'
]
class CollectiveCommPattern(Enum): class CollectiveCommPattern(Enum):
ALLGATHER = 'all_gather' ALLGATHER = 'all_gather'
@ -152,14 +159,40 @@ class CommSpec:
tensor.data = tensor tensor.data = tensor
class ShapeConsistencyManager: @dataclass
class ShapeConsistencyOptions:
"""
ShapeConsistencyOptions is a dataclass which specifies the preferences for shape consistency.
"""
# TODO: shape consistency option is not implemented yet
pass
def __init__(self, consistency_option=None):
self.consistency_option = consistency_option def set_shape_consistency_options(options: ShapeConsistencyOptions):
"""
Configure the shape consistency manager via function call.
"""
manager = ShapeConsistencyManager()
manager.options = options
class ShapeConsistencyManager(metaclass=SingletonMeta):
def __init__(self):
self._options = None
self.total_communication_cost = 0 self.total_communication_cost = 0
self.total_transform_steps = 0 self.total_transform_steps = 0
self.cached_spec_pairs_transform_path = {} self.cached_spec_pairs_transform_path = {}
@property
def options(self):
return self._options
@options.setter
def options(self, options_: ShapeConsistencyOptions):
assert isinstance(options_, ShapeConsistencyOptions)
self._options = options_
def get_all_all_gather_spec(self, source_spec, orig_cost): def get_all_all_gather_spec(self, source_spec, orig_cost):
''' '''
Get all valid sharding specs from source_spec with single all-gather operation, and Get all valid sharding specs from source_spec with single all-gather operation, and

View File

@ -8,7 +8,6 @@ from colossalai.fx.tracer.tracer import ColoTracer
from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec
from colossalai.auto_parallel.solver.op_handler.batch_norm_handler import BatchNormHandler from colossalai.auto_parallel.solver.op_handler.batch_norm_handler import BatchNormHandler
from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
from colossalai.device.device_mesh import DeviceMesh from colossalai.device.device_mesh import DeviceMesh
@ -31,7 +30,6 @@ def test_bn_handler():
# [2, 3]] # [2, 3]]
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
entire_shape = torch.Size((4, 16, 64, 64)) entire_shape = torch.Size((4, 16, 64, 64))
shape_consistency_manager = ShapeConsistencyManager()
tracer = ColoTracer() tracer = ColoTracer()
model = BNModel(16) model = BNModel(16)
@ -77,10 +75,11 @@ def test_bn_handler():
# generate bn strategy # generate bn strategy
strategies_vector = StrategiesVector(node=nodes[2]) strategies_vector = StrategiesVector(node=nodes[2])
bn_handler = BatchNormHandler(node=nodes[2], bn_handler = BatchNormHandler(
device_mesh=device_mesh, node=nodes[2],
strategies_vector=strategies_vector, device_mesh=device_mesh,
shape_consistency_manager=shape_consistency_manager) strategies_vector=strategies_vector,
)
bn_handler.register_strategy() bn_handler.register_strategy()
# ['RS0 = RS0 x S0', 'S1S0 = RS0 x S0', 'RS1 = RS1 x S1', 'S0S1 = RS1 x S1', 'RR = RR x R', 'S0R = RR x R', 'S1R = RR x R', 'S01R = RR x R', 'RS01 = RS01 x S01', # ['RS0 = RS0 x S0', 'S1S0 = RS0 x S0', 'RS1 = RS1 x S1', 'S0S1 = RS1 x S1', 'RR = RR x R', 'S0R = RR x R', 'S1R = RR x R', 'S01R = RR x R', 'RS01 = RS01 x S01',
# 'S0R = S0R x R WITH SYNC_BN', 'S1R = S1R x R WITH SYNC_BN', 'S0S1 = S0S1 x S1 WITH SYNC_BN', 'S1S0 = S1S0 x S0 WITH SYNC_BN', 'S01R = S01R x R WITH SYNC_BN'] # 'S0R = S0R x R WITH SYNC_BN', 'S1R = S1R x R WITH SYNC_BN', 'S0S1 = S0S1 x S1 WITH SYNC_BN', 'S1S0 = S1S0 x S0 WITH SYNC_BN', 'S01R = S01R x R WITH SYNC_BN']

View File

@ -8,7 +8,6 @@ from colossalai.fx.tracer.tracer import ColoTracer
from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec
from colossalai.auto_parallel.solver.op_handler.conv_handler import ConvHandler from colossalai.auto_parallel.solver.op_handler.conv_handler import ConvHandler
from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
from colossalai.device.device_mesh import DeviceMesh from colossalai.device.device_mesh import DeviceMesh
@ -31,7 +30,6 @@ def test_conv_handler():
# [2, 3]] # [2, 3]]
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
entire_shape = torch.Size((4, 16, 64, 64)) entire_shape = torch.Size((4, 16, 64, 64))
shape_consistency_manager = ShapeConsistencyManager()
tracer = ColoTracer() tracer = ColoTracer()
model = ConvModel(16, 32) model = ConvModel(16, 32)
@ -77,10 +75,11 @@ def test_conv_handler():
# generate conv strategy # generate conv strategy
strategies_vector = StrategiesVector(node=nodes[2]) strategies_vector = StrategiesVector(node=nodes[2])
conv_handler = ConvHandler(node=nodes[2], conv_handler = ConvHandler(
device_mesh=device_mesh, node=nodes[2],
strategies_vector=strategies_vector, device_mesh=device_mesh,
shape_consistency_manager=shape_consistency_manager) strategies_vector=strategies_vector,
)
conv_handler.register_strategy() conv_handler.register_strategy()
# ['S0S1 = S0R x RS1', 'S1S0 = S1R x RS0', 'S0R = S0R x RR', 'S1R = S1R x RR', 'S0R = S0S1 x S1R', 'S1R = S1S0 x S0R', 'RS1 = RS0 x S0S1', 'RS0 = RS1 x S1S0', 'RR = RS0 x S0R', 'RR = RS1 x S1R', 'RS0 = RR x RS0', 'RS1 = RR x RS1', 'RR = RR x RR', 'S01R = S01R x RR', 'RR = RS01 x S01R'] # ['S0S1 = S0R x RS1', 'S1S0 = S1R x RS0', 'S0R = S0R x RR', 'S1R = S1R x RR', 'S0R = S0S1 x S1R', 'S1R = S1S0 x S0R', 'RS1 = RS0 x S0S1', 'RS0 = RS1 x S1S0', 'RR = RS0 x S0R', 'RR = RS1 x S1R', 'RS0 = RR x RS0', 'RS1 = RR x RS1', 'RR = RR x RR', 'S01R = S01R x RR', 'RR = RS01 x S01R']
strategy_name_list = [strategy.name for strategy in conv_handler.strategies_vector] strategy_name_list = [strategy.name for strategy in conv_handler.strategies_vector]

View File

@ -4,10 +4,7 @@ from torch.fx import GraphModule
import torch.nn as nn import torch.nn as nn
import pytest import pytest
from colossalai.fx.proxy import ColoProxy
from colossalai.fx.tracer.tracer import ColoTracer from colossalai.fx.tracer.tracer import ColoTracer
from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
from colossalai.device.device_mesh import DeviceMesh from colossalai.device.device_mesh import DeviceMesh
from colossalai.auto_parallel.solver.strategies_constructor import StrategiesConstructor from colossalai.auto_parallel.solver.strategies_constructor import StrategiesConstructor
from colossalai.auto_parallel.solver.cost_graph import CostGraph from colossalai.auto_parallel.solver.cost_graph import CostGraph
@ -37,7 +34,6 @@ def test_cost_graph():
# [2, 3]] # [2, 3]]
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
entire_shape = torch.Size((4, 16, 64, 64)) entire_shape = torch.Size((4, 16, 64, 64))
shape_consistency_manager = ShapeConsistencyManager()
tracer = ColoTracer() tracer = ColoTracer()
model = ConvModel(16, 32) model = ConvModel(16, 32)
@ -55,7 +51,7 @@ def test_cost_graph():
gm.recompile() gm.recompile()
solver_options = SolverOptions(fast=True) solver_options = SolverOptions(fast=True)
strategies_constructor = StrategiesConstructor(graph, device_mesh, shape_consistency_manager, solver_options) strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options)
strategies_constructor.build_strategies_and_cost() strategies_constructor.build_strategies_and_cost()
# (x, mul):{(0, 0): 0} # (x, mul):{(0, 0): 0}

View File

@ -8,7 +8,6 @@ from colossalai.fx.tracer.tracer import ColoTracer
from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec
from colossalai.auto_parallel.solver.op_handler.dot_handler import DotHandler from colossalai.auto_parallel.solver.op_handler.dot_handler import DotHandler
from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
from colossalai.device.device_mesh import DeviceMesh from colossalai.device.device_mesh import DeviceMesh
@ -31,7 +30,6 @@ def test_dot_handler():
# [2, 3]] # [2, 3]]
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
entire_shape = torch.Size((4, 8)) entire_shape = torch.Size((4, 8))
shape_consistency_manager = ShapeConsistencyManager()
tracer = ColoTracer() tracer = ColoTracer()
model = LinearModel(8, 16) model = LinearModel(8, 16)
@ -76,10 +74,11 @@ def test_dot_handler():
# generate dot strategy # generate dot strategy
strategies_vector = StrategiesVector(node=nodes[2]) strategies_vector = StrategiesVector(node=nodes[2])
dot_handler = DotHandler(node=nodes[2], dot_handler = DotHandler(
device_mesh=device_mesh, node=nodes[2],
strategies_vector=strategies_vector, device_mesh=device_mesh,
shape_consistency_manager=shape_consistency_manager) strategies_vector=strategies_vector,
)
strategies_vector = dot_handler.register_strategy() strategies_vector = dot_handler.register_strategy()
# ['S0S1 = S0R x RS1', 'S1S0 = S1R x RS0', 'S0R = S0S1 x S1R', 'S1R = S1S0 x S0R', 'RS1 = RS0 x S0S1', 'RS0 = RS1 x S1S0', 'RS0 = RR x RS0', 'RS1 = RR x RS1', 'RR = RR x RR'] # ['S0S1 = S0R x RS1', 'S1S0 = S1R x RS0', 'S0R = S0S1 x S1R', 'S1R = S1S0 x S0R', 'RS1 = RS0 x S0S1', 'RS0 = RS1 x S1S0', 'RS0 = RR x RS0', 'RS1 = RR x RS1', 'RR = RR x RR']

View File

@ -8,7 +8,6 @@ from colossalai.fx.tracer.tracer import ColoTracer
from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec
from colossalai.auto_parallel.solver.op_handler.conv_handler import CONV_STRATEGIES_LIST from colossalai.auto_parallel.solver.op_handler.conv_handler import CONV_STRATEGIES_LIST
from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
from colossalai.device.device_mesh import DeviceMesh from colossalai.device.device_mesh import DeviceMesh
from colossalai.auto_parallel.solver.strategies_constructor import StrategiesConstructor from colossalai.auto_parallel.solver.strategies_constructor import StrategiesConstructor
from colossalai.auto_parallel.solver.options import SolverOptions from colossalai.auto_parallel.solver.options import SolverOptions
@ -34,7 +33,6 @@ def test_strategies_constructor():
# [2, 3]] # [2, 3]]
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
entire_shape = torch.Size((4, 16, 64, 64)) entire_shape = torch.Size((4, 16, 64, 64))
shape_consistency_manager = ShapeConsistencyManager()
tracer = ColoTracer() tracer = ColoTracer()
model = ConvModel(16, 32) model = ConvModel(16, 32)
@ -49,7 +47,7 @@ def test_strategies_constructor():
gm.recompile() gm.recompile()
solver_options = SolverOptions(fast=True) solver_options = SolverOptions(fast=True)
strategies_constructor = StrategiesConstructor(graph, device_mesh, shape_consistency_manager, solver_options) strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options)
assert strategies_constructor.leaf_strategies == [] assert strategies_constructor.leaf_strategies == []
assert strategies_constructor.strategy_map == {} assert strategies_constructor.strategy_map == {}