mirror of https://github.com/hpcaitech/ColossalAI
[autoparallel] refactored shape consistency to remove redundancy (#1591)
* [autoparallel] refactored shape consistency to remove redundancy * polish code * polish code * polish codepull/1587/head^2
parent
d164449d00
commit
27fe8af60c
|
@ -1,8 +1,9 @@
|
||||||
|
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
|
||||||
import torch
|
import torch
|
||||||
from torch.fx.node import Node
|
from torch.fx.node import Node
|
||||||
from colossalai.tensor.sharding_spec import ShardingSpec
|
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||||
from colossalai.device.device_mesh import DeviceMesh
|
from colossalai.device.device_mesh import DeviceMesh
|
||||||
from typing import Union, Dict, List
|
from typing import Union, Dict, List, Optional
|
||||||
|
|
||||||
|
|
||||||
def generate_sharding_spec(input_: Union[Node, torch.Tensor], device_mesh: DeviceMesh,
|
def generate_sharding_spec(input_: Union[Node, torch.Tensor], device_mesh: DeviceMesh,
|
||||||
|
@ -31,3 +32,45 @@ def generate_sharding_spec(input_: Union[Node, torch.Tensor], device_mesh: Devic
|
||||||
|
|
||||||
sharding_spec = ShardingSpec(device_mesh=device_mesh, entire_shape=shape, dim_partition_dict=dim_partition_dict)
|
sharding_spec = ShardingSpec(device_mesh=device_mesh, entire_shape=shape, dim_partition_dict=dim_partition_dict)
|
||||||
return sharding_spec
|
return sharding_spec
|
||||||
|
|
||||||
|
|
||||||
|
def generate_resharding_costs(nodes: List[Node],
|
||||||
|
sharding_specs: List[ShardingSpec],
|
||||||
|
count_backward: Optional[bool] = True,
|
||||||
|
dtype: Optional[torch.dtype] = None):
|
||||||
|
'''
|
||||||
|
Compute the resharding costs with this specific strategy.
|
||||||
|
|
||||||
|
Argument:
|
||||||
|
nodes (List[Node]): a list of nodes
|
||||||
|
sharding_spec_for_input(ShardingSpec): a list of ShardingSpec for the nodes.
|
||||||
|
count_backward (Optional[bool]): whether to include the cost of resharding in the backward pass, default is True. False can be used for inference.
|
||||||
|
dtype (Optional[torch.dtype]): the data type for cost calculation, default is None.
|
||||||
|
'''
|
||||||
|
# The resharding_cost of weight is counted due to sharing weight cases.
|
||||||
|
resharding_costs = {}
|
||||||
|
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
|
||||||
|
|
||||||
|
# shape consistency manager is a singleton class
|
||||||
|
shape_consistency_manager = ShapeConsistencyManager()
|
||||||
|
|
||||||
|
for input_node, input_spec in zip(nodes, sharding_specs):
|
||||||
|
resharding_costs[input_node] = []
|
||||||
|
for strategy in input_node.strategies_vector:
|
||||||
|
input_sharding_spec = strategy.output_sharding_spec
|
||||||
|
assert isinstance(input_sharding_spec, ShardingSpec), f'The input node should NOT be a tuple of tensor.'
|
||||||
|
# compute the resharding cost during forward phase
|
||||||
|
_, _, resharding_cost_forward = shape_consistency_manager.shape_consistency(input_sharding_spec, input_spec)
|
||||||
|
|
||||||
|
if count_backward:
|
||||||
|
# In backward phase, we should convert grad with target_spec into input_sharding_spec
|
||||||
|
_, _, resharding_cost_backward = shape_consistency_manager.shape_consistency(
|
||||||
|
input_spec, input_sharding_spec)
|
||||||
|
total_resharding_cost = resharding_cost_forward + resharding_cost_backward
|
||||||
|
else:
|
||||||
|
total_resharding_cost = resharding_cost_forward
|
||||||
|
|
||||||
|
# we need multiply the size of elem dtype to get correct communication cost
|
||||||
|
resharding_cost = total_resharding_cost * size_per_elem_bytes
|
||||||
|
resharding_costs[input_node].append(resharding_cost)
|
||||||
|
return resharding_costs
|
||||||
|
|
|
@ -4,7 +4,6 @@ import warnings
|
||||||
import torch
|
import torch
|
||||||
from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector
|
from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector
|
||||||
from .operator_handler import OperatorHandler
|
from .operator_handler import OperatorHandler
|
||||||
from .._utils import generate_sharding_spec
|
|
||||||
|
|
||||||
__all__ = ['BatchNormHandler']
|
__all__ = ['BatchNormHandler']
|
||||||
|
|
||||||
|
@ -115,15 +114,13 @@ class BatchNormHandler(OperatorHandler):
|
||||||
name = f'RS{mesh_dim_0} = RS{mesh_dim_0} x S{mesh_dim_0}'
|
name = f'RS{mesh_dim_0} = RS{mesh_dim_0} x S{mesh_dim_0}'
|
||||||
|
|
||||||
dim_partition_dict_for_input = {1: [mesh_dim_0]}
|
dim_partition_dict_for_input = {1: [mesh_dim_0]}
|
||||||
sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh,
|
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
|
||||||
dim_partition_dict_for_input)
|
|
||||||
|
|
||||||
dim_partition_dict_for_weight = {0: [mesh_dim_0]}
|
dim_partition_dict_for_weight = {0: [mesh_dim_0]}
|
||||||
sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight)
|
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
|
||||||
|
|
||||||
dim_partition_dict_for_output = {1: [mesh_dim_0]}
|
dim_partition_dict_for_output = {1: [mesh_dim_0]}
|
||||||
sharding_spec_for_output = generate_sharding_spec(self.output_data, self.device_mesh,
|
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||||
dim_partition_dict_for_output)
|
|
||||||
|
|
||||||
# generate resharding cost for this strategy
|
# generate resharding cost for this strategy
|
||||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
|
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
|
||||||
|
@ -156,8 +153,7 @@ class BatchNormHandler(OperatorHandler):
|
||||||
new_name = f'S{mesh_dim_1}S{mesh_dim_0} = RS{mesh_dim_0} x S{mesh_dim_0}'
|
new_name = f'S{mesh_dim_1}S{mesh_dim_0} = RS{mesh_dim_0} x S{mesh_dim_0}'
|
||||||
|
|
||||||
dim_partition_dict_for_output = {0: [mesh_dim_1], 1: [mesh_dim_0]}
|
dim_partition_dict_for_output = {0: [mesh_dim_1], 1: [mesh_dim_0]}
|
||||||
new_sharding_spec_for_output = generate_sharding_spec(self.output_data, self.device_mesh,
|
new_sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||||
dim_partition_dict_for_output)
|
|
||||||
# the computation cost is all the same
|
# the computation cost is all the same
|
||||||
new_compute_cost = compute_cost
|
new_compute_cost = compute_cost
|
||||||
|
|
||||||
|
@ -192,15 +188,13 @@ class BatchNormHandler(OperatorHandler):
|
||||||
name = f'RS{mesh_dim_0}{mesh_dim_1} = RS{mesh_dim_0}{mesh_dim_1} x S{mesh_dim_0}{mesh_dim_1}'
|
name = f'RS{mesh_dim_0}{mesh_dim_1} = RS{mesh_dim_0}{mesh_dim_1} x S{mesh_dim_0}{mesh_dim_1}'
|
||||||
|
|
||||||
dim_partition_dict_for_input = {1: [mesh_dim_0, mesh_dim_1]}
|
dim_partition_dict_for_input = {1: [mesh_dim_0, mesh_dim_1]}
|
||||||
sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh,
|
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
|
||||||
dim_partition_dict_for_input)
|
|
||||||
|
|
||||||
dim_partition_dict_for_weight = {0: [mesh_dim_0, mesh_dim_1]}
|
dim_partition_dict_for_weight = {0: [mesh_dim_0, mesh_dim_1]}
|
||||||
sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight)
|
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
|
||||||
|
|
||||||
dim_partition_dict_for_output = {1: [mesh_dim_0, mesh_dim_1]}
|
dim_partition_dict_for_output = {1: [mesh_dim_0, mesh_dim_1]}
|
||||||
sharding_spec_for_output = generate_sharding_spec(self.output_data, self.device_mesh,
|
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||||
dim_partition_dict_for_output)
|
|
||||||
|
|
||||||
# generate resharding cost for this strategy
|
# generate resharding cost for this strategy
|
||||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
|
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
|
||||||
|
@ -234,15 +228,13 @@ class BatchNormHandler(OperatorHandler):
|
||||||
name = f'RR = RR x R'
|
name = f'RR = RR x R'
|
||||||
|
|
||||||
dim_partition_dict_for_input = {}
|
dim_partition_dict_for_input = {}
|
||||||
sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh,
|
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
|
||||||
dim_partition_dict_for_input)
|
|
||||||
|
|
||||||
dim_partition_dict_for_weight = {}
|
dim_partition_dict_for_weight = {}
|
||||||
sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight)
|
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
|
||||||
|
|
||||||
dim_partition_dict_for_output = {}
|
dim_partition_dict_for_output = {}
|
||||||
sharding_spec_for_output = generate_sharding_spec(self.output_data, self.device_mesh,
|
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||||
dim_partition_dict_for_output)
|
|
||||||
|
|
||||||
# generate resharding cost for this strategy
|
# generate resharding cost for this strategy
|
||||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
|
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
|
||||||
|
@ -273,8 +265,7 @@ class BatchNormHandler(OperatorHandler):
|
||||||
|
|
||||||
def _construct_batch_sharding_strategies(mesh_dim_list, new_name):
|
def _construct_batch_sharding_strategies(mesh_dim_list, new_name):
|
||||||
dim_partition_dict_for_output = {0: mesh_dim_list}
|
dim_partition_dict_for_output = {0: mesh_dim_list}
|
||||||
new_sharding_spec_for_output = generate_sharding_spec(self.output_data, self.device_mesh,
|
new_sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||||
dim_partition_dict_for_output)
|
|
||||||
|
|
||||||
# the computation cost is all the same
|
# the computation cost is all the same
|
||||||
new_compute_cost = compute_cost
|
new_compute_cost = compute_cost
|
||||||
|
@ -332,15 +323,13 @@ class BatchNormHandler(OperatorHandler):
|
||||||
name = f'S{mesh_dim_0}R = S{mesh_dim_0}R x R WITH SYNC_BN'
|
name = f'S{mesh_dim_0}R = S{mesh_dim_0}R x R WITH SYNC_BN'
|
||||||
|
|
||||||
dim_partition_dict_for_input = {0: [mesh_dim_0]}
|
dim_partition_dict_for_input = {0: [mesh_dim_0]}
|
||||||
sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh,
|
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
|
||||||
dim_partition_dict_for_input)
|
|
||||||
|
|
||||||
dim_partition_dict_for_weight = {}
|
dim_partition_dict_for_weight = {}
|
||||||
sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight)
|
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
|
||||||
|
|
||||||
dim_partition_dict_for_output = {0: [mesh_dim_0]}
|
dim_partition_dict_for_output = {0: [mesh_dim_0]}
|
||||||
sharding_spec_for_output = generate_sharding_spec(self.output_data, self.device_mesh,
|
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||||
dim_partition_dict_for_output)
|
|
||||||
|
|
||||||
# generate resharding cost for this strategy
|
# generate resharding cost for this strategy
|
||||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
|
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
|
||||||
|
@ -374,15 +363,13 @@ class BatchNormHandler(OperatorHandler):
|
||||||
name = f'S{mesh_dim_0}{mesh_dim_1}R = S{mesh_dim_0}{mesh_dim_1}R x R WITH SYNC_BN'
|
name = f'S{mesh_dim_0}{mesh_dim_1}R = S{mesh_dim_0}{mesh_dim_1}R x R WITH SYNC_BN'
|
||||||
|
|
||||||
dim_partition_dict_for_input = {0: [mesh_dim_0, mesh_dim_1]}
|
dim_partition_dict_for_input = {0: [mesh_dim_0, mesh_dim_1]}
|
||||||
sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh,
|
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
|
||||||
dim_partition_dict_for_input)
|
|
||||||
|
|
||||||
dim_partition_dict_for_weight = {}
|
dim_partition_dict_for_weight = {}
|
||||||
sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight)
|
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
|
||||||
|
|
||||||
dim_partition_dict_for_output = {0: [mesh_dim_0, mesh_dim_1]}
|
dim_partition_dict_for_output = {0: [mesh_dim_0, mesh_dim_1]}
|
||||||
sharding_spec_for_output = generate_sharding_spec(self.output_data, self.device_mesh,
|
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||||
dim_partition_dict_for_output)
|
|
||||||
|
|
||||||
# generate resharding cost for this strategy
|
# generate resharding cost for this strategy
|
||||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
|
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
|
||||||
|
@ -416,15 +403,13 @@ class BatchNormHandler(OperatorHandler):
|
||||||
name = f'S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0}S{mesh_dim_1} x S{mesh_dim_1} WITH SYNC_BN'
|
name = f'S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0}S{mesh_dim_1} x S{mesh_dim_1} WITH SYNC_BN'
|
||||||
|
|
||||||
dim_partition_dict_for_input = {0: [mesh_dim_0], 1: [mesh_dim_1]}
|
dim_partition_dict_for_input = {0: [mesh_dim_0], 1: [mesh_dim_1]}
|
||||||
sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh,
|
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
|
||||||
dim_partition_dict_for_input)
|
|
||||||
|
|
||||||
dim_partition_dict_for_weight = {0: [mesh_dim_1]}
|
dim_partition_dict_for_weight = {0: [mesh_dim_1]}
|
||||||
sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight)
|
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
|
||||||
|
|
||||||
dim_partition_dict_for_output = {0: [mesh_dim_0], 1: [mesh_dim_1]}
|
dim_partition_dict_for_output = {0: [mesh_dim_0], 1: [mesh_dim_1]}
|
||||||
sharding_spec_for_output = generate_sharding_spec(self.output_data, self.device_mesh,
|
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||||
dim_partition_dict_for_output)
|
|
||||||
|
|
||||||
# generate resharding cost for this strategy
|
# generate resharding cost for this strategy
|
||||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
|
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
|
||||||
|
@ -459,7 +444,7 @@ class BatchNormHandler(OperatorHandler):
|
||||||
Generate every possible strategies for a BatchNorm node, and record all strategies into the strategies_vector.
|
Generate every possible strategies for a BatchNorm node, and record all strategies into the strategies_vector.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
norm_handler = BatchNormHandler(node, self.device_mesh, strategies_vector,
|
norm_handler = BatchNormHandler(node, strategies_vector,
|
||||||
self.shape_consistency_manager)
|
self.shape_consistency_manager)
|
||||||
norm_handler.register_strategy()
|
norm_handler.register_strategy()
|
||||||
for strategy in norm_handler.strategies_vector:
|
for strategy in norm_handler.strategies_vector:
|
||||||
|
|
|
@ -4,7 +4,6 @@ import warnings
|
||||||
import torch
|
import torch
|
||||||
from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector
|
from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector
|
||||||
from .operator_handler import OperatorHandler
|
from .operator_handler import OperatorHandler
|
||||||
from .._utils import generate_sharding_spec
|
|
||||||
|
|
||||||
__all__ = ['ConvHandler']
|
__all__ = ['ConvHandler']
|
||||||
|
|
||||||
|
@ -109,15 +108,13 @@ class ConvHandler(OperatorHandler):
|
||||||
name = f'S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0}R x RS{mesh_dim_1}'
|
name = f'S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0}R x RS{mesh_dim_1}'
|
||||||
|
|
||||||
dim_partition_dict_for_input = {0: [mesh_dim_0]}
|
dim_partition_dict_for_input = {0: [mesh_dim_0]}
|
||||||
sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh,
|
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
|
||||||
dim_partition_dict_for_input)
|
|
||||||
|
|
||||||
dim_partition_dict_for_weight = {1: [mesh_dim_1]}
|
dim_partition_dict_for_weight = {1: [mesh_dim_1]}
|
||||||
sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight)
|
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
|
||||||
|
|
||||||
dim_partition_dict_for_output = {0: [mesh_dim_0], 1: [mesh_dim_1]}
|
dim_partition_dict_for_output = {0: [mesh_dim_0], 1: [mesh_dim_1]}
|
||||||
sharding_spec_for_output = generate_sharding_spec(self.output_data, self.device_mesh,
|
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||||
dim_partition_dict_for_output)
|
|
||||||
|
|
||||||
# generate resharding cost for this strategy
|
# generate resharding cost for this strategy
|
||||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
|
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
|
||||||
|
@ -158,15 +155,13 @@ class ConvHandler(OperatorHandler):
|
||||||
name = f'S{mesh_dim_0}R = S{mesh_dim_0}R x RR'
|
name = f'S{mesh_dim_0}R = S{mesh_dim_0}R x RR'
|
||||||
|
|
||||||
dim_partition_dict_for_input = {0: [mesh_dim_0]}
|
dim_partition_dict_for_input = {0: [mesh_dim_0]}
|
||||||
sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh,
|
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
|
||||||
dim_partition_dict_for_input)
|
|
||||||
|
|
||||||
dim_partition_dict_for_weight = {}
|
dim_partition_dict_for_weight = {}
|
||||||
sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight)
|
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
|
||||||
|
|
||||||
dim_partition_dict_for_output = {0: [mesh_dim_0]}
|
dim_partition_dict_for_output = {0: [mesh_dim_0]}
|
||||||
sharding_spec_for_output = generate_sharding_spec(self.output_data, self.device_mesh,
|
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||||
dim_partition_dict_for_output)
|
|
||||||
|
|
||||||
# generate resharding cost for this strategy
|
# generate resharding cost for this strategy
|
||||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
|
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
|
||||||
|
@ -205,15 +200,13 @@ class ConvHandler(OperatorHandler):
|
||||||
name = f'S{mesh_dim_0}R = S{mesh_dim_0}S{mesh_dim_1} x S{mesh_dim_1}R'
|
name = f'S{mesh_dim_0}R = S{mesh_dim_0}S{mesh_dim_1} x S{mesh_dim_1}R'
|
||||||
|
|
||||||
dim_partition_dict_for_input = {0: [mesh_dim_0], 1: [mesh_dim_1]}
|
dim_partition_dict_for_input = {0: [mesh_dim_0], 1: [mesh_dim_1]}
|
||||||
sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh,
|
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
|
||||||
dim_partition_dict_for_input)
|
|
||||||
|
|
||||||
dim_partition_dict_for_weight = {0: [mesh_dim_0]}
|
dim_partition_dict_for_weight = {0: [mesh_dim_0]}
|
||||||
sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight)
|
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
|
||||||
|
|
||||||
dim_partition_dict_for_output = {0: [mesh_dim_0]}
|
dim_partition_dict_for_output = {0: [mesh_dim_0]}
|
||||||
sharding_spec_for_output = generate_sharding_spec(self.output_data, self.device_mesh,
|
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||||
dim_partition_dict_for_output)
|
|
||||||
|
|
||||||
# generate resharding cost for this strategy
|
# generate resharding cost for this strategy
|
||||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
|
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
|
||||||
|
@ -252,15 +245,13 @@ class ConvHandler(OperatorHandler):
|
||||||
name = f'RS{mesh_dim_1} = RS{mesh_dim_0} x S{mesh_dim_0}S{mesh_dim_1}'
|
name = f'RS{mesh_dim_1} = RS{mesh_dim_0} x S{mesh_dim_0}S{mesh_dim_1}'
|
||||||
|
|
||||||
dim_partition_dict_for_input = {1: [mesh_dim_0]}
|
dim_partition_dict_for_input = {1: [mesh_dim_0]}
|
||||||
sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh,
|
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
|
||||||
dim_partition_dict_for_input)
|
|
||||||
|
|
||||||
dim_partition_dict_for_weight = {0: [mesh_dim_0], 1: [mesh_dim_1]}
|
dim_partition_dict_for_weight = {0: [mesh_dim_0], 1: [mesh_dim_1]}
|
||||||
sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight)
|
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
|
||||||
|
|
||||||
dim_partition_dict_for_output = {1: [mesh_dim_1]}
|
dim_partition_dict_for_output = {1: [mesh_dim_1]}
|
||||||
sharding_spec_for_output = generate_sharding_spec(self.output_data, self.device_mesh,
|
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||||
dim_partition_dict_for_output)
|
|
||||||
|
|
||||||
# generate resharding cost for this strategy
|
# generate resharding cost for this strategy
|
||||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
|
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
|
||||||
|
@ -296,15 +287,13 @@ class ConvHandler(OperatorHandler):
|
||||||
name = f'RR = RS{mesh_dim_0} x S{mesh_dim_0}R'
|
name = f'RR = RS{mesh_dim_0} x S{mesh_dim_0}R'
|
||||||
|
|
||||||
dim_partition_dict_for_input = {1: [mesh_dim_0]}
|
dim_partition_dict_for_input = {1: [mesh_dim_0]}
|
||||||
sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh,
|
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
|
||||||
dim_partition_dict_for_input)
|
|
||||||
|
|
||||||
dim_partition_dict_for_weight = {0: [mesh_dim_0]}
|
dim_partition_dict_for_weight = {0: [mesh_dim_0]}
|
||||||
sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight)
|
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
|
||||||
|
|
||||||
dim_partition_dict_for_output = {}
|
dim_partition_dict_for_output = {}
|
||||||
sharding_spec_for_output = generate_sharding_spec(self.output_data, self.device_mesh,
|
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||||
dim_partition_dict_for_output)
|
|
||||||
|
|
||||||
# generate resharding cost for this strategy
|
# generate resharding cost for this strategy
|
||||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
|
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
|
||||||
|
@ -340,15 +329,13 @@ class ConvHandler(OperatorHandler):
|
||||||
name = f'RS{mesh_dim_0} = RR x RS{mesh_dim_0}'
|
name = f'RS{mesh_dim_0} = RR x RS{mesh_dim_0}'
|
||||||
|
|
||||||
dim_partition_dict_for_input = {}
|
dim_partition_dict_for_input = {}
|
||||||
sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh,
|
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
|
||||||
dim_partition_dict_for_input)
|
|
||||||
|
|
||||||
dim_partition_dict_for_weight = {1: [mesh_dim_0]}
|
dim_partition_dict_for_weight = {1: [mesh_dim_0]}
|
||||||
sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight)
|
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
|
||||||
|
|
||||||
dim_partition_dict_for_output = {1: [mesh_dim_0]}
|
dim_partition_dict_for_output = {1: [mesh_dim_0]}
|
||||||
sharding_spec_for_output = generate_sharding_spec(self.output_data, self.device_mesh,
|
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||||
dim_partition_dict_for_output)
|
|
||||||
|
|
||||||
# generate resharding cost for this strategy
|
# generate resharding cost for this strategy
|
||||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
|
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
|
||||||
|
@ -384,15 +371,13 @@ class ConvHandler(OperatorHandler):
|
||||||
name = f'RR = RR x RR'
|
name = f'RR = RR x RR'
|
||||||
|
|
||||||
dim_partition_dict_for_input = {}
|
dim_partition_dict_for_input = {}
|
||||||
sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh,
|
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
|
||||||
dim_partition_dict_for_input)
|
|
||||||
|
|
||||||
dim_partition_dict_for_weight = {}
|
dim_partition_dict_for_weight = {}
|
||||||
sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight)
|
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
|
||||||
|
|
||||||
dim_partition_dict_for_output = {}
|
dim_partition_dict_for_output = {}
|
||||||
sharding_spec_for_output = generate_sharding_spec(self.output_data, self.device_mesh,
|
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||||
dim_partition_dict_for_output)
|
|
||||||
|
|
||||||
# generate resharding cost for this strategy
|
# generate resharding cost for this strategy
|
||||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
|
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
|
||||||
|
@ -426,15 +411,13 @@ class ConvHandler(OperatorHandler):
|
||||||
name = f'S{mesh_dim_0}{mesh_dim_1}R = S{mesh_dim_0}{mesh_dim_1}R x RR'
|
name = f'S{mesh_dim_0}{mesh_dim_1}R = S{mesh_dim_0}{mesh_dim_1}R x RR'
|
||||||
|
|
||||||
dim_partition_dict_for_input = {0: [mesh_dim_0, mesh_dim_1]}
|
dim_partition_dict_for_input = {0: [mesh_dim_0, mesh_dim_1]}
|
||||||
sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh,
|
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
|
||||||
dim_partition_dict_for_input)
|
|
||||||
|
|
||||||
dim_partition_dict_for_weight = {}
|
dim_partition_dict_for_weight = {}
|
||||||
sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight)
|
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
|
||||||
|
|
||||||
dim_partition_dict_for_output = {0: [mesh_dim_0, mesh_dim_1]}
|
dim_partition_dict_for_output = {0: [mesh_dim_0, mesh_dim_1]}
|
||||||
sharding_spec_for_output = generate_sharding_spec(self.output_data, self.device_mesh,
|
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||||
dim_partition_dict_for_output)
|
|
||||||
|
|
||||||
# generate resharding cost for this strategy
|
# generate resharding cost for this strategy
|
||||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
|
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
|
||||||
|
@ -475,15 +458,13 @@ class ConvHandler(OperatorHandler):
|
||||||
name = f'RR = RS{mesh_dim_0}{mesh_dim_1} x S{mesh_dim_0}{mesh_dim_1}R'
|
name = f'RR = RS{mesh_dim_0}{mesh_dim_1} x S{mesh_dim_0}{mesh_dim_1}R'
|
||||||
|
|
||||||
dim_partition_dict_for_input = {1: [mesh_dim_0, mesh_dim_1]}
|
dim_partition_dict_for_input = {1: [mesh_dim_0, mesh_dim_1]}
|
||||||
sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh,
|
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
|
||||||
dim_partition_dict_for_input)
|
|
||||||
|
|
||||||
dim_partition_dict_for_weight = {0: [mesh_dim_0, mesh_dim_1]}
|
dim_partition_dict_for_weight = {0: [mesh_dim_0, mesh_dim_1]}
|
||||||
sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight)
|
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
|
||||||
|
|
||||||
dim_partition_dict_for_output = {}
|
dim_partition_dict_for_output = {}
|
||||||
sharding_spec_for_output = generate_sharding_spec(self.output_data, self.device_mesh,
|
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||||
dim_partition_dict_for_output)
|
|
||||||
|
|
||||||
# generate resharding cost for this strategy
|
# generate resharding cost for this strategy
|
||||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
|
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
|
||||||
|
|
|
@ -3,7 +3,6 @@ import torch
|
||||||
from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector
|
from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector
|
||||||
from .operator_handler import OperatorHandler
|
from .operator_handler import OperatorHandler
|
||||||
from functools import reduce
|
from functools import reduce
|
||||||
from .._utils import generate_sharding_spec
|
|
||||||
|
|
||||||
__all__ = ['DotHandler']
|
__all__ = ['DotHandler']
|
||||||
|
|
||||||
|
@ -29,16 +28,14 @@ class DotHandler(OperatorHandler):
|
||||||
name = f'S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0}R x RS{mesh_dim_1}'
|
name = f'S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0}R x RS{mesh_dim_1}'
|
||||||
|
|
||||||
dim_partition_dict_for_input = {0: [mesh_dim_0]}
|
dim_partition_dict_for_input = {0: [mesh_dim_0]}
|
||||||
sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh,
|
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
|
||||||
dim_partition_dict_for_input)
|
|
||||||
|
|
||||||
# linear layer weight is transposed during init
|
# linear layer weight is transposed during init
|
||||||
dim_partition_dict_for_weight = {0: [mesh_dim_1]}
|
dim_partition_dict_for_weight = {0: [mesh_dim_1]}
|
||||||
sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight)
|
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
|
||||||
|
|
||||||
dim_partition_dict_for_output = {0: [mesh_dim_0], 1: [mesh_dim_1]}
|
dim_partition_dict_for_output = {0: [mesh_dim_0], 1: [mesh_dim_1]}
|
||||||
sharding_spec_for_ouput = generate_sharding_spec(self.output_data, self.device_mesh,
|
sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_input)
|
||||||
dim_partition_dict_for_input)
|
|
||||||
|
|
||||||
# generate resharding cost for this strategy
|
# generate resharding cost for this strategy
|
||||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
|
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
|
||||||
|
@ -69,17 +66,15 @@ class DotHandler(OperatorHandler):
|
||||||
name = f'S{mesh_dim_0}R = S{mesh_dim_0}S{mesh_dim_1} x S{mesh_dim_1}R'
|
name = f'S{mesh_dim_0}R = S{mesh_dim_0}S{mesh_dim_1} x S{mesh_dim_1}R'
|
||||||
|
|
||||||
dim_partition_dict_for_input = {0: [mesh_dim_0], 1: [mesh_dim_1]}
|
dim_partition_dict_for_input = {0: [mesh_dim_0], 1: [mesh_dim_1]}
|
||||||
sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh,
|
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
|
||||||
dim_partition_dict_for_input)
|
|
||||||
|
|
||||||
# since weight of the linear layer is transposed
|
# since weight of the linear layer is transposed
|
||||||
# the actual dim to be sharded is 1
|
# the actual dim to be sharded is 1
|
||||||
dim_partition_dict_for_weight = {1: [mesh_dim_0]}
|
dim_partition_dict_for_weight = {1: [mesh_dim_0]}
|
||||||
sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight)
|
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
|
||||||
|
|
||||||
dim_partition_dict_for_output = {0: [mesh_dim_0]}
|
dim_partition_dict_for_output = {0: [mesh_dim_0]}
|
||||||
sharding_spec_for_ouput = generate_sharding_spec(self.output_data, self.device_mesh,
|
sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||||
dim_partition_dict_for_output)
|
|
||||||
|
|
||||||
# generate resharding cost for this strategy
|
# generate resharding cost for this strategy
|
||||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
|
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
|
||||||
|
@ -106,15 +101,13 @@ class DotHandler(OperatorHandler):
|
||||||
name = f'RS{mesh_dim_1} = RS{mesh_dim_0} x S{mesh_dim_0}S{mesh_dim_1}'
|
name = f'RS{mesh_dim_1} = RS{mesh_dim_0} x S{mesh_dim_0}S{mesh_dim_1}'
|
||||||
|
|
||||||
dim_partition_dict_for_input = {1: [mesh_dim_0]}
|
dim_partition_dict_for_input = {1: [mesh_dim_0]}
|
||||||
sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh,
|
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
|
||||||
dim_partition_dict_for_input)
|
|
||||||
|
|
||||||
dim_partition_dict_for_weight = {0: [mesh_dim_0], 1: [mesh_dim_1]}
|
dim_partition_dict_for_weight = {0: [mesh_dim_0], 1: [mesh_dim_1]}
|
||||||
sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight)
|
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
|
||||||
|
|
||||||
dim_partition_dict_for_output = {1: [mesh_dim_1]}
|
dim_partition_dict_for_output = {1: [mesh_dim_1]}
|
||||||
sharding_spec_for_ouput = generate_sharding_spec(self.output_data, self.device_mesh,
|
sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_input)
|
||||||
dim_partition_dict_for_input)
|
|
||||||
|
|
||||||
# generate resharding cost for this strategy
|
# generate resharding cost for this strategy
|
||||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
|
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
|
||||||
|
@ -141,15 +134,13 @@ class DotHandler(OperatorHandler):
|
||||||
name = f'RR = RS{mesh_dim} x S{mesh_dim}R'
|
name = f'RR = RS{mesh_dim} x S{mesh_dim}R'
|
||||||
|
|
||||||
dim_partition_dict_for_input = {1: [mesh_dim]}
|
dim_partition_dict_for_input = {1: [mesh_dim]}
|
||||||
sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh,
|
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
|
||||||
dim_partition_dict_for_input)
|
|
||||||
|
|
||||||
dim_partition_dict_for_weight = {1: [mesh_dim]}
|
dim_partition_dict_for_weight = {1: [mesh_dim]}
|
||||||
sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight)
|
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
|
||||||
|
|
||||||
dim_partition_dict_for_output = {}
|
dim_partition_dict_for_output = {}
|
||||||
sharding_spec_for_ouput = generate_sharding_spec(self.output_data, self.device_mesh,
|
sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||||
dim_partition_dict_for_output)
|
|
||||||
|
|
||||||
# generate resharding cost for this strategy
|
# generate resharding cost for this strategy
|
||||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
|
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
|
||||||
|
@ -176,15 +167,13 @@ class DotHandler(OperatorHandler):
|
||||||
name = f'RS{mesh_dim} = RR x RS{mesh_dim}'
|
name = f'RS{mesh_dim} = RR x RS{mesh_dim}'
|
||||||
|
|
||||||
dim_partition_dict_for_input = {}
|
dim_partition_dict_for_input = {}
|
||||||
sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh,
|
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
|
||||||
dim_partition_dict_for_input)
|
|
||||||
|
|
||||||
dim_partition_dict_for_weight = {0: [mesh_dim]}
|
dim_partition_dict_for_weight = {0: [mesh_dim]}
|
||||||
sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight)
|
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
|
||||||
|
|
||||||
dim_partition_dict_for_output = {1: [mesh_dim]}
|
dim_partition_dict_for_output = {1: [mesh_dim]}
|
||||||
sharding_spec_for_ouput = generate_sharding_spec(self.output_data, self.device_mesh,
|
sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||||
dim_partition_dict_for_output)
|
|
||||||
|
|
||||||
# generate resharding cost for this strategy
|
# generate resharding cost for this strategy
|
||||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
|
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
|
||||||
|
@ -211,15 +200,13 @@ class DotHandler(OperatorHandler):
|
||||||
name = f'S{mesh_dim_0}{mesh_dim_1}R = S{mesh_dim_0}{mesh_dim_1}R x RR'
|
name = f'S{mesh_dim_0}{mesh_dim_1}R = S{mesh_dim_0}{mesh_dim_1}R x RR'
|
||||||
|
|
||||||
dim_partition_dict_for_input = {0: [mesh_dim_0, mesh_dim_1]}
|
dim_partition_dict_for_input = {0: [mesh_dim_0, mesh_dim_1]}
|
||||||
sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh,
|
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
|
||||||
dim_partition_dict_for_input)
|
|
||||||
|
|
||||||
dim_partition_dict_for_weight = {}
|
dim_partition_dict_for_weight = {}
|
||||||
sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight)
|
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
|
||||||
|
|
||||||
dim_partition_dict_for_output = {0: [mesh_dim_0, mesh_dim_1]}
|
dim_partition_dict_for_output = {0: [mesh_dim_0, mesh_dim_1]}
|
||||||
sharding_spec_for_ouput = generate_sharding_spec(self.output_data, self.device_mesh,
|
sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||||
dim_partition_dict_for_output)
|
|
||||||
|
|
||||||
# generate resharding cost for this strategy
|
# generate resharding cost for this strategy
|
||||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
|
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
|
||||||
|
@ -246,15 +233,13 @@ class DotHandler(OperatorHandler):
|
||||||
name = f'RR = RS{mesh_dim_0}{mesh_dim_1} x S{mesh_dim_0}{mesh_dim_1}R'
|
name = f'RR = RS{mesh_dim_0}{mesh_dim_1} x S{mesh_dim_0}{mesh_dim_1}R'
|
||||||
|
|
||||||
dim_partition_dict_for_input = {1: [mesh_dim_0, mesh_dim_1]}
|
dim_partition_dict_for_input = {1: [mesh_dim_0, mesh_dim_1]}
|
||||||
sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh,
|
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
|
||||||
dim_partition_dict_for_input)
|
|
||||||
|
|
||||||
dim_partition_dict_for_weight = {0: [mesh_dim_0, mesh_dim_1]}
|
dim_partition_dict_for_weight = {0: [mesh_dim_0, mesh_dim_1]}
|
||||||
sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight)
|
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
|
||||||
|
|
||||||
dim_partition_dict_for_output = {}
|
dim_partition_dict_for_output = {}
|
||||||
sharding_spec_for_ouput = generate_sharding_spec(self.output_data, self.device_mesh,
|
sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||||
dim_partition_dict_for_output)
|
|
||||||
|
|
||||||
# generate resharding cost for this strategy
|
# generate resharding cost for this strategy
|
||||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
|
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
|
||||||
|
@ -281,15 +266,13 @@ class DotHandler(OperatorHandler):
|
||||||
name = f'RS{mesh_dim_0}{mesh_dim_1} = RR x RS{mesh_dim_0}{mesh_dim_1}'
|
name = f'RS{mesh_dim_0}{mesh_dim_1} = RR x RS{mesh_dim_0}{mesh_dim_1}'
|
||||||
|
|
||||||
dim_partition_dict_for_input = {}
|
dim_partition_dict_for_input = {}
|
||||||
sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh,
|
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
|
||||||
dim_partition_dict_for_input)
|
|
||||||
|
|
||||||
dim_partition_dict_for_weight = {1: [mesh_dim_0, mesh_dim_1]}
|
dim_partition_dict_for_weight = {1: [mesh_dim_0, mesh_dim_1]}
|
||||||
sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight)
|
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
|
||||||
|
|
||||||
dim_partition_dict_for_output = {1: [mesh_dim_0, mesh_dim_1]}
|
dim_partition_dict_for_output = {1: [mesh_dim_0, mesh_dim_1]}
|
||||||
sharding_spec_for_ouput = generate_sharding_spec(self.output_data, self.device_mesh,
|
sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||||
dim_partition_dict_for_output)
|
|
||||||
|
|
||||||
# generate resharding cost for this strategy
|
# generate resharding cost for this strategy
|
||||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
|
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
|
||||||
|
|
|
@ -7,6 +7,7 @@ from typing import Dict, List
|
||||||
from colossalai.device.device_mesh import DeviceMesh
|
from colossalai.device.device_mesh import DeviceMesh
|
||||||
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
|
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
|
||||||
from colossalai.tensor.sharding_spec import ShardingSpec
|
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||||
|
from .._utils import generate_resharding_costs, generate_sharding_spec
|
||||||
|
|
||||||
from ..sharding_strategy import StrategiesVector
|
from ..sharding_strategy import StrategiesVector
|
||||||
|
|
||||||
|
@ -17,24 +18,24 @@ class OperatorHandler(ABC):
|
||||||
'''
|
'''
|
||||||
The OperatorHandler is an abstract class used to generate every possible strategies for an operator node.
|
The OperatorHandler is an abstract class used to generate every possible strategies for an operator node.
|
||||||
|
|
||||||
Argument:
|
Args:
|
||||||
input_node(Node): the input node in node argument list.
|
node (Node): the input node in node argument list.
|
||||||
input_index(int): the index of input node in the node argument list.
|
device_mesh (DeviceMesh): A logical view of a physical mesh.
|
||||||
weight(torch.Tensor): Weight of the node.
|
strategies_vector (StrategiesVector): all the strategies generated in this handler will be recorded into the strategies_vector.
|
||||||
output_node(Node): Output_node is the output of the node.
|
handle_backward (Optional[bool]): whether to consider the backward pass. The default value is True. False can be used for inference.
|
||||||
device_mesh(DeviceMesh): A logical view of a physical mesh.
|
|
||||||
strategies_vector(StrategiesVector): all the strategies generated in this handler will be recorded into the strategies_vector.
|
|
||||||
shape_consistency_manager(ShapeConsistencyManager): ShapeConsistencyManager will give the resharding costs of the different sharding specs.
|
|
||||||
'''
|
'''
|
||||||
|
|
||||||
def __init__(self, node: Node, device_mesh: DeviceMesh, strategies_vector: StrategiesVector,
|
def __init__(self,
|
||||||
shape_consistency_manager: ShapeConsistencyManager):
|
node: Node,
|
||||||
|
device_mesh: DeviceMesh,
|
||||||
|
strategies_vector: StrategiesVector,
|
||||||
|
handle_backward: bool = True):
|
||||||
self.node = node
|
self.node = node
|
||||||
self.predecessor_node = list(node._input_nodes.keys())
|
self.predecessor_node = list(node._input_nodes.keys())
|
||||||
self.successor_node = list(node.users.keys())
|
self.successor_node = list(node.users.keys())
|
||||||
self.device_mesh = device_mesh
|
self.device_mesh = device_mesh
|
||||||
self.strategies_vector = strategies_vector
|
self.strategies_vector = strategies_vector
|
||||||
self.shape_consistency_manager = shape_consistency_manager
|
self.handle_backward = handle_backward
|
||||||
|
|
||||||
# find the module and its parameters associated with this node
|
# find the module and its parameters associated with this node
|
||||||
# this can be used to compute the compute/communication/sharding cost
|
# this can be used to compute the compute/communication/sharding cost
|
||||||
|
@ -102,35 +103,23 @@ class OperatorHandler(ABC):
|
||||||
|
|
||||||
return total_memory_cost, activation_memory_cost, weight_memory_cost
|
return total_memory_cost, activation_memory_cost, weight_memory_cost
|
||||||
|
|
||||||
def _generate_resharding_costs(self, sharding_spec_for_input):
|
def _generate_resharding_costs(self, sharding_specs):
|
||||||
'''
|
|
||||||
Compute the resharding costs with this specific strategy.
|
|
||||||
|
|
||||||
Note: The resharding_cost of weight is NOT counted.
|
|
||||||
|
|
||||||
Argument:
|
|
||||||
resharding_costs(Dict[int, List[float]]): The resharding cost generated in this method will be appended into this dictionary.
|
|
||||||
Resharding_cost[i][j] means the cost of i-th argument in the output node argument list
|
|
||||||
with j-th strategy in its strategies_vector transforms to sharding spec wanted in this
|
|
||||||
strategy.
|
|
||||||
sharding_spec_for_input(ShardingSpec): ShardingSpec of the input node.
|
|
||||||
'''
|
|
||||||
# The resharding_cost of weight is counted due to sharing weight cases.
|
# The resharding_cost of weight is counted due to sharing weight cases.
|
||||||
resharding_costs = {}
|
|
||||||
dtype = self.node._meta_data.dtype
|
dtype = self.node._meta_data.dtype
|
||||||
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
|
nodes = self.predecessor_node
|
||||||
for input_node, input_spec in zip(self.predecessor_node, sharding_spec_for_input):
|
return generate_resharding_costs(nodes=nodes,
|
||||||
resharding_costs[input_node] = []
|
sharding_specs=sharding_specs,
|
||||||
for strategy in input_node.strategies_vector:
|
count_backward=self.handle_backward,
|
||||||
input_sharding_spec = strategy.output_sharding_spec
|
dtype=dtype)
|
||||||
assert isinstance(input_sharding_spec, ShardingSpec), f'The input node should NOT be a tuple of tensor.'
|
|
||||||
# compute the resharding cost during forward phase
|
def _generate_sharding_spec(self, input_: torch.Tensor, dim_partition_dict: Dict[int, List[int]]) -> ShardingSpec:
|
||||||
_, _, resharding_cost_forward = self.shape_consistency_manager.shape_consistency(
|
return generate_sharding_spec(input_=input_,
|
||||||
input_sharding_spec, input_spec)
|
device_mesh=self.device_mesh,
|
||||||
# In backward phase, we should convert grad with target_spec into input_sharding_spec
|
dim_partition_dict=dim_partition_dict)
|
||||||
_, _, resharding_cost_backward = self.shape_consistency_manager.shape_consistency(
|
|
||||||
input_spec, input_sharding_spec)
|
@abstractmethod
|
||||||
# we need multiply the size of elem dtype to get correct communication cost
|
def _generate_compute_cost(self, *args, **kwargs):
|
||||||
resharding_cost = (resharding_cost_forward + resharding_cost_backward) * size_per_elem_bytes
|
"""
|
||||||
resharding_costs[input_node].append(resharding_cost)
|
Compute the flops involved in the node.
|
||||||
return resharding_costs
|
"""
|
||||||
|
pass
|
||||||
|
|
|
@ -11,7 +11,7 @@ import math
|
||||||
import torch
|
import torch
|
||||||
import operator
|
import operator
|
||||||
from typing import Dict, List
|
from typing import Dict, List
|
||||||
from ._utils import generate_sharding_spec
|
from ._utils import generate_sharding_spec, generate_resharding_costs
|
||||||
|
|
||||||
|
|
||||||
class StrategiesConstructor:
|
class StrategiesConstructor:
|
||||||
|
@ -21,12 +21,10 @@ class StrategiesConstructor:
|
||||||
Args:
|
Args:
|
||||||
graph (Graph): a Graph object used for analysis and strategy generation.
|
graph (Graph): a Graph object used for analysis and strategy generation.
|
||||||
device_mesh (DeviceMesh): a DeviceMesh object which contains the meta information about the cluster.
|
device_mesh (DeviceMesh): a DeviceMesh object which contains the meta information about the cluster.
|
||||||
shape_consistency_manager (ShapeConsistencyManager): a ShapeConsistencyManager object to make sure the sharding specs are consistent.
|
|
||||||
solver_options (SolverOptions): a SolverOptions object which specifies the preferences for plan searching.
|
solver_options (SolverOptions): a SolverOptions object which specifies the preferences for plan searching.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, graph: Graph, device_mesh: DeviceMesh, shape_consistency_manager: ShapeConsistencyManager,
|
def __init__(self, graph: Graph, device_mesh: DeviceMesh, solver_options: SolverOptions):
|
||||||
solver_options: SolverOptions):
|
|
||||||
self.graph = graph
|
self.graph = graph
|
||||||
assert graph.owning_module is not None, 'The given graph is not associated with a owning_module'
|
assert graph.owning_module is not None, 'The given graph is not associated with a owning_module'
|
||||||
self.root_module = self.graph.owning_module
|
self.root_module = self.graph.owning_module
|
||||||
|
@ -34,27 +32,8 @@ class StrategiesConstructor:
|
||||||
self.device_mesh = device_mesh
|
self.device_mesh = device_mesh
|
||||||
self.leaf_strategies = []
|
self.leaf_strategies = []
|
||||||
self.strategy_map = {}
|
self.strategy_map = {}
|
||||||
self.shape_consistency_manager = shape_consistency_manager
|
|
||||||
self.solver_options = solver_options
|
self.solver_options = solver_options
|
||||||
|
|
||||||
def _generate_resharding_costs(self, input_nodes, target_sharding_specs):
|
|
||||||
'''
|
|
||||||
Compute the resharding costs with this specific strategy.
|
|
||||||
|
|
||||||
Argument:
|
|
||||||
sharding_spec_for_input(ShardingSpec): ShardingSpec of the input node.
|
|
||||||
'''
|
|
||||||
resharding_costs = {}
|
|
||||||
for input_node, target_sharding_spec in zip(input_nodes, target_sharding_specs):
|
|
||||||
resharding_costs[input_node] = []
|
|
||||||
for strategy in input_node.strategies_vector:
|
|
||||||
input_sharding_spec = strategy.output_sharding_spec
|
|
||||||
assert isinstance(input_sharding_spec, ShardingSpec), f'The input node should NOT be a tuple of tensor.'
|
|
||||||
_, _, resharding_cost = self.shape_consistency_manager.shape_consistency(
|
|
||||||
input_sharding_spec, target_sharding_spec)
|
|
||||||
resharding_costs[input_node].append(resharding_cost)
|
|
||||||
return resharding_costs
|
|
||||||
|
|
||||||
def remove_duplicated_strategy(self, strategies_vector):
|
def remove_duplicated_strategy(self, strategies_vector):
|
||||||
'''
|
'''
|
||||||
In build_strategies_and_cost method, we may produce some duplicated strategies.
|
In build_strategies_and_cost method, we may produce some duplicated strategies.
|
||||||
|
@ -120,14 +99,13 @@ class StrategiesConstructor:
|
||||||
# conv module
|
# conv module
|
||||||
if submod_type in CONV_MODULE_OP:
|
if submod_type in CONV_MODULE_OP:
|
||||||
# use ConvHandler to create sharding strategies for conv module node
|
# use ConvHandler to create sharding strategies for conv module node
|
||||||
conv_handler = ConvHandler(node, self.device_mesh, strategies_vector,
|
conv_handler = ConvHandler(node, self.device_mesh, strategies_vector)
|
||||||
self.shape_consistency_manager)
|
|
||||||
conv_handler.register_strategy()
|
conv_handler.register_strategy()
|
||||||
|
|
||||||
# linear module
|
# linear module
|
||||||
elif submod_type in LINEAR_MODULE_OP:
|
elif submod_type in LINEAR_MODULE_OP:
|
||||||
# use DotHandler to create sharding strategies for linear module node
|
# use DotHandler to create sharding strategies for linear module node
|
||||||
dot_handler = DotHandler(node, self.device_mesh, strategies_vector, self.shape_consistency_manager)
|
dot_handler = DotHandler(node, self.device_mesh, strategies_vector)
|
||||||
dot_handler.register_strategy()
|
dot_handler.register_strategy()
|
||||||
|
|
||||||
# element-wise module
|
# element-wise module
|
||||||
|
@ -158,8 +136,8 @@ class StrategiesConstructor:
|
||||||
# TODO: use meta_info_prop to profile memory cost and compute cost
|
# TODO: use meta_info_prop to profile memory cost and compute cost
|
||||||
compute_cost = node._meta_data.numel()
|
compute_cost = node._meta_data.numel()
|
||||||
memory_cost = 0
|
memory_cost = 0
|
||||||
resharding_costs = self._generate_resharding_costs(strategies_vector.predecessor_nodes,
|
resharding_costs = generate_resharding_costs(strategies_vector.predecessor_nodes,
|
||||||
[input_sharding_spec])
|
[input_sharding_spec])
|
||||||
|
|
||||||
# to prevent the resharding happening, set their resharding cost to inf.
|
# to prevent the resharding happening, set their resharding cost to inf.
|
||||||
resharding_costs[input_node] = [
|
resharding_costs[input_node] = [
|
||||||
|
@ -214,8 +192,8 @@ class StrategiesConstructor:
|
||||||
# TODO: use meta_info_prop to profile memory cost and compute cost
|
# TODO: use meta_info_prop to profile memory cost and compute cost
|
||||||
compute_cost = node._meta_data.numel()
|
compute_cost = node._meta_data.numel()
|
||||||
memory_cost = 0
|
memory_cost = 0
|
||||||
resharding_costs = self._generate_resharding_costs(strategies_vector.predecessor_nodes,
|
resharding_costs = generate_resharding_costs(strategies_vector.predecessor_nodes,
|
||||||
[input_sharding_spec])
|
[input_sharding_spec])
|
||||||
|
|
||||||
sharding_strategy = ShardingStrategy(name,
|
sharding_strategy = ShardingStrategy(name,
|
||||||
output_sharding_spec,
|
output_sharding_spec,
|
||||||
|
@ -275,8 +253,8 @@ class StrategiesConstructor:
|
||||||
compute_cost = node._meta_data.numel()
|
compute_cost = node._meta_data.numel()
|
||||||
memory_cost = 0
|
memory_cost = 0
|
||||||
|
|
||||||
resharding_costs = self._generate_resharding_costs(strategies_vector.predecessor_nodes,
|
resharding_costs = generate_resharding_costs(strategies_vector.predecessor_nodes,
|
||||||
[input_sharding_spec])
|
[input_sharding_spec])
|
||||||
|
|
||||||
# to prevent the resharding happening, set their resharding cost to inf.
|
# to prevent the resharding happening, set their resharding cost to inf.
|
||||||
resharding_costs[input_node] = [
|
resharding_costs[input_node] = [
|
||||||
|
@ -317,8 +295,8 @@ class StrategiesConstructor:
|
||||||
# TODO: use meta_info_prop to profile origin memory cost and compute cost, then divide them depending on sharding spec.
|
# TODO: use meta_info_prop to profile origin memory cost and compute cost, then divide them depending on sharding spec.
|
||||||
compute_cost = 0
|
compute_cost = 0
|
||||||
memory_cost = 0
|
memory_cost = 0
|
||||||
resharding_costs = self._generate_resharding_costs(strategies_vector.predecessor_nodes,
|
resharding_costs = generate_resharding_costs(strategies_vector.predecessor_nodes,
|
||||||
[new_input_sharding_spec])
|
[new_input_sharding_spec])
|
||||||
sharding_strategy = ShardingStrategy(name, (output_sharding_spec, output_sharding_spec),
|
sharding_strategy = ShardingStrategy(name, (output_sharding_spec, output_sharding_spec),
|
||||||
compute_cost=compute_cost,
|
compute_cost=compute_cost,
|
||||||
memory_cost=memory_cost,
|
memory_cost=memory_cost,
|
||||||
|
@ -335,8 +313,8 @@ class StrategiesConstructor:
|
||||||
# TODO: use meta_info_prop to profile origin memory cost and compute cost, then divide them depending on sharding spec.
|
# TODO: use meta_info_prop to profile origin memory cost and compute cost, then divide them depending on sharding spec.
|
||||||
compute_cost = 0
|
compute_cost = 0
|
||||||
memory_cost = 0
|
memory_cost = 0
|
||||||
resharding_costs = self._generate_resharding_costs(strategies_vector.predecessor_nodes,
|
resharding_costs = generate_resharding_costs(strategies_vector.predecessor_nodes,
|
||||||
[input_sharding_spec])
|
[input_sharding_spec])
|
||||||
sharding_strategy = ShardingStrategy(name, (output_sharding_spec, output_sharding_spec),
|
sharding_strategy = ShardingStrategy(name, (output_sharding_spec, output_sharding_spec),
|
||||||
compute_cost=compute_cost,
|
compute_cost=compute_cost,
|
||||||
memory_cost=memory_cost,
|
memory_cost=memory_cost,
|
||||||
|
@ -360,8 +338,8 @@ class StrategiesConstructor:
|
||||||
# TODO: use meta_info_prop to profile origin memory cost and compute cost, then divide them depending on sharding spec.
|
# TODO: use meta_info_prop to profile origin memory cost and compute cost, then divide them depending on sharding spec.
|
||||||
compute_cost = 0
|
compute_cost = 0
|
||||||
memory_cost = 0
|
memory_cost = 0
|
||||||
resharding_costs = self._generate_resharding_costs(strategies_vector.predecessor_nodes,
|
resharding_costs = generate_resharding_costs(strategies_vector.predecessor_nodes,
|
||||||
[input_sharding_spec])
|
[input_sharding_spec])
|
||||||
# to prevent the resharding happening, set their resharding cost to inf.
|
# to prevent the resharding happening, set their resharding cost to inf.
|
||||||
resharding_costs[input_tensor_node] = [
|
resharding_costs[input_tensor_node] = [
|
||||||
cost if cost == 0 else math.inf for cost in resharding_costs[input_tensor_node]
|
cost if cost == 0 else math.inf for cost in resharding_costs[input_tensor_node]
|
||||||
|
@ -397,8 +375,8 @@ class StrategiesConstructor:
|
||||||
output_sharding_spec = input_sharding_specs
|
output_sharding_spec = input_sharding_specs
|
||||||
# TODO: use meta_info_prop to profile memory cost
|
# TODO: use meta_info_prop to profile memory cost
|
||||||
memory_cost = 0
|
memory_cost = 0
|
||||||
resharding_costs = self._generate_resharding_costs(strategies_vector.predecessor_nodes,
|
resharding_costs = generate_resharding_costs(strategies_vector.predecessor_nodes,
|
||||||
input_sharding_specs)
|
input_sharding_specs)
|
||||||
|
|
||||||
# clear the resharding cost for the output node
|
# clear the resharding cost for the output node
|
||||||
# TODO: we may remove this in final version
|
# TODO: we may remove this in final version
|
||||||
|
|
|
@ -15,4 +15,7 @@ class SingletonMeta(type):
|
||||||
if cls not in cls._instances:
|
if cls not in cls._instances:
|
||||||
instance = super().__call__(*args, **kwargs)
|
instance = super().__call__(*args, **kwargs)
|
||||||
cls._instances[cls] = instance
|
cls._instances[cls] = instance
|
||||||
|
else:
|
||||||
|
assert len(args) == 0 and len(
|
||||||
|
kwargs) == 0, f'{cls.__name__} is a singleton class and a instance has been created.'
|
||||||
return cls._instances[cls]
|
return cls._instances[cls]
|
||||||
|
|
|
@ -1,15 +1,22 @@
|
||||||
import torch
|
import torch
|
||||||
|
from dataclasses import dataclass
|
||||||
from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec
|
from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec
|
||||||
from colossalai.tensor.utils import all_gather_simulator, all_to_all_simulator, shard_simulator
|
from colossalai.tensor.utils import all_gather_simulator, all_to_all_simulator, shard_simulator
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from typing import Dict, List, Optional, Tuple, Union
|
from typing import Dict, List, Optional, Tuple, Union
|
||||||
|
from colossalai.context.singleton_meta import SingletonMeta
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
import math
|
import math
|
||||||
from functools import reduce
|
from functools import reduce
|
||||||
import operator
|
import operator
|
||||||
from torch.distributed import ReduceOp
|
from torch.distributed import ReduceOp
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
'CollectiveCommPattern', 'CommSpec', 'ShapeConsistencyManager', 'ShapeConsistencyOptions',
|
||||||
|
'set_shape_consistency_options'
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
class CollectiveCommPattern(Enum):
|
class CollectiveCommPattern(Enum):
|
||||||
ALLGATHER = 'all_gather'
|
ALLGATHER = 'all_gather'
|
||||||
|
@ -152,14 +159,40 @@ class CommSpec:
|
||||||
tensor.data = tensor
|
tensor.data = tensor
|
||||||
|
|
||||||
|
|
||||||
class ShapeConsistencyManager:
|
@dataclass
|
||||||
|
class ShapeConsistencyOptions:
|
||||||
|
"""
|
||||||
|
ShapeConsistencyOptions is a dataclass which specifies the preferences for shape consistency.
|
||||||
|
"""
|
||||||
|
# TODO: shape consistency option is not implemented yet
|
||||||
|
pass
|
||||||
|
|
||||||
def __init__(self, consistency_option=None):
|
|
||||||
self.consistency_option = consistency_option
|
def set_shape_consistency_options(options: ShapeConsistencyOptions):
|
||||||
|
"""
|
||||||
|
Configure the shape consistency manager via function call.
|
||||||
|
"""
|
||||||
|
manager = ShapeConsistencyManager()
|
||||||
|
manager.options = options
|
||||||
|
|
||||||
|
|
||||||
|
class ShapeConsistencyManager(metaclass=SingletonMeta):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self._options = None
|
||||||
self.total_communication_cost = 0
|
self.total_communication_cost = 0
|
||||||
self.total_transform_steps = 0
|
self.total_transform_steps = 0
|
||||||
self.cached_spec_pairs_transform_path = {}
|
self.cached_spec_pairs_transform_path = {}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def options(self):
|
||||||
|
return self._options
|
||||||
|
|
||||||
|
@options.setter
|
||||||
|
def options(self, options_: ShapeConsistencyOptions):
|
||||||
|
assert isinstance(options_, ShapeConsistencyOptions)
|
||||||
|
self._options = options_
|
||||||
|
|
||||||
def get_all_all_gather_spec(self, source_spec, orig_cost):
|
def get_all_all_gather_spec(self, source_spec, orig_cost):
|
||||||
'''
|
'''
|
||||||
Get all valid sharding specs from source_spec with single all-gather operation, and
|
Get all valid sharding specs from source_spec with single all-gather operation, and
|
||||||
|
|
|
@ -8,7 +8,6 @@ from colossalai.fx.tracer.tracer import ColoTracer
|
||||||
from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec
|
from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec
|
||||||
from colossalai.auto_parallel.solver.op_handler.batch_norm_handler import BatchNormHandler
|
from colossalai.auto_parallel.solver.op_handler.batch_norm_handler import BatchNormHandler
|
||||||
from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector
|
from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector
|
||||||
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
|
|
||||||
from colossalai.device.device_mesh import DeviceMesh
|
from colossalai.device.device_mesh import DeviceMesh
|
||||||
|
|
||||||
|
|
||||||
|
@ -31,7 +30,6 @@ def test_bn_handler():
|
||||||
# [2, 3]]
|
# [2, 3]]
|
||||||
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
|
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
|
||||||
entire_shape = torch.Size((4, 16, 64, 64))
|
entire_shape = torch.Size((4, 16, 64, 64))
|
||||||
shape_consistency_manager = ShapeConsistencyManager()
|
|
||||||
|
|
||||||
tracer = ColoTracer()
|
tracer = ColoTracer()
|
||||||
model = BNModel(16)
|
model = BNModel(16)
|
||||||
|
@ -77,10 +75,11 @@ def test_bn_handler():
|
||||||
|
|
||||||
# generate bn strategy
|
# generate bn strategy
|
||||||
strategies_vector = StrategiesVector(node=nodes[2])
|
strategies_vector = StrategiesVector(node=nodes[2])
|
||||||
bn_handler = BatchNormHandler(node=nodes[2],
|
bn_handler = BatchNormHandler(
|
||||||
device_mesh=device_mesh,
|
node=nodes[2],
|
||||||
strategies_vector=strategies_vector,
|
device_mesh=device_mesh,
|
||||||
shape_consistency_manager=shape_consistency_manager)
|
strategies_vector=strategies_vector,
|
||||||
|
)
|
||||||
bn_handler.register_strategy()
|
bn_handler.register_strategy()
|
||||||
# ['RS0 = RS0 x S0', 'S1S0 = RS0 x S0', 'RS1 = RS1 x S1', 'S0S1 = RS1 x S1', 'RR = RR x R', 'S0R = RR x R', 'S1R = RR x R', 'S01R = RR x R', 'RS01 = RS01 x S01',
|
# ['RS0 = RS0 x S0', 'S1S0 = RS0 x S0', 'RS1 = RS1 x S1', 'S0S1 = RS1 x S1', 'RR = RR x R', 'S0R = RR x R', 'S1R = RR x R', 'S01R = RR x R', 'RS01 = RS01 x S01',
|
||||||
# 'S0R = S0R x R WITH SYNC_BN', 'S1R = S1R x R WITH SYNC_BN', 'S0S1 = S0S1 x S1 WITH SYNC_BN', 'S1S0 = S1S0 x S0 WITH SYNC_BN', 'S01R = S01R x R WITH SYNC_BN']
|
# 'S0R = S0R x R WITH SYNC_BN', 'S1R = S1R x R WITH SYNC_BN', 'S0S1 = S0S1 x S1 WITH SYNC_BN', 'S1S0 = S1S0 x S0 WITH SYNC_BN', 'S01R = S01R x R WITH SYNC_BN']
|
||||||
|
|
|
@ -8,7 +8,6 @@ from colossalai.fx.tracer.tracer import ColoTracer
|
||||||
from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec
|
from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec
|
||||||
from colossalai.auto_parallel.solver.op_handler.conv_handler import ConvHandler
|
from colossalai.auto_parallel.solver.op_handler.conv_handler import ConvHandler
|
||||||
from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector
|
from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector
|
||||||
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
|
|
||||||
from colossalai.device.device_mesh import DeviceMesh
|
from colossalai.device.device_mesh import DeviceMesh
|
||||||
|
|
||||||
|
|
||||||
|
@ -31,7 +30,6 @@ def test_conv_handler():
|
||||||
# [2, 3]]
|
# [2, 3]]
|
||||||
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
|
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
|
||||||
entire_shape = torch.Size((4, 16, 64, 64))
|
entire_shape = torch.Size((4, 16, 64, 64))
|
||||||
shape_consistency_manager = ShapeConsistencyManager()
|
|
||||||
|
|
||||||
tracer = ColoTracer()
|
tracer = ColoTracer()
|
||||||
model = ConvModel(16, 32)
|
model = ConvModel(16, 32)
|
||||||
|
@ -77,10 +75,11 @@ def test_conv_handler():
|
||||||
|
|
||||||
# generate conv strategy
|
# generate conv strategy
|
||||||
strategies_vector = StrategiesVector(node=nodes[2])
|
strategies_vector = StrategiesVector(node=nodes[2])
|
||||||
conv_handler = ConvHandler(node=nodes[2],
|
conv_handler = ConvHandler(
|
||||||
device_mesh=device_mesh,
|
node=nodes[2],
|
||||||
strategies_vector=strategies_vector,
|
device_mesh=device_mesh,
|
||||||
shape_consistency_manager=shape_consistency_manager)
|
strategies_vector=strategies_vector,
|
||||||
|
)
|
||||||
conv_handler.register_strategy()
|
conv_handler.register_strategy()
|
||||||
# ['S0S1 = S0R x RS1', 'S1S0 = S1R x RS0', 'S0R = S0R x RR', 'S1R = S1R x RR', 'S0R = S0S1 x S1R', 'S1R = S1S0 x S0R', 'RS1 = RS0 x S0S1', 'RS0 = RS1 x S1S0', 'RR = RS0 x S0R', 'RR = RS1 x S1R', 'RS0 = RR x RS0', 'RS1 = RR x RS1', 'RR = RR x RR', 'S01R = S01R x RR', 'RR = RS01 x S01R']
|
# ['S0S1 = S0R x RS1', 'S1S0 = S1R x RS0', 'S0R = S0R x RR', 'S1R = S1R x RR', 'S0R = S0S1 x S1R', 'S1R = S1S0 x S0R', 'RS1 = RS0 x S0S1', 'RS0 = RS1 x S1S0', 'RR = RS0 x S0R', 'RR = RS1 x S1R', 'RS0 = RR x RS0', 'RS1 = RR x RS1', 'RR = RR x RR', 'S01R = S01R x RR', 'RR = RS01 x S01R']
|
||||||
strategy_name_list = [strategy.name for strategy in conv_handler.strategies_vector]
|
strategy_name_list = [strategy.name for strategy in conv_handler.strategies_vector]
|
||||||
|
|
|
@ -4,10 +4,7 @@ from torch.fx import GraphModule
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from colossalai.fx.proxy import ColoProxy
|
|
||||||
from colossalai.fx.tracer.tracer import ColoTracer
|
from colossalai.fx.tracer.tracer import ColoTracer
|
||||||
from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec
|
|
||||||
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
|
|
||||||
from colossalai.device.device_mesh import DeviceMesh
|
from colossalai.device.device_mesh import DeviceMesh
|
||||||
from colossalai.auto_parallel.solver.strategies_constructor import StrategiesConstructor
|
from colossalai.auto_parallel.solver.strategies_constructor import StrategiesConstructor
|
||||||
from colossalai.auto_parallel.solver.cost_graph import CostGraph
|
from colossalai.auto_parallel.solver.cost_graph import CostGraph
|
||||||
|
@ -37,7 +34,6 @@ def test_cost_graph():
|
||||||
# [2, 3]]
|
# [2, 3]]
|
||||||
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
|
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
|
||||||
entire_shape = torch.Size((4, 16, 64, 64))
|
entire_shape = torch.Size((4, 16, 64, 64))
|
||||||
shape_consistency_manager = ShapeConsistencyManager()
|
|
||||||
|
|
||||||
tracer = ColoTracer()
|
tracer = ColoTracer()
|
||||||
model = ConvModel(16, 32)
|
model = ConvModel(16, 32)
|
||||||
|
@ -55,7 +51,7 @@ def test_cost_graph():
|
||||||
gm.recompile()
|
gm.recompile()
|
||||||
|
|
||||||
solver_options = SolverOptions(fast=True)
|
solver_options = SolverOptions(fast=True)
|
||||||
strategies_constructor = StrategiesConstructor(graph, device_mesh, shape_consistency_manager, solver_options)
|
strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options)
|
||||||
strategies_constructor.build_strategies_and_cost()
|
strategies_constructor.build_strategies_and_cost()
|
||||||
|
|
||||||
# (x, mul):{(0, 0): 0}
|
# (x, mul):{(0, 0): 0}
|
||||||
|
|
|
@ -8,7 +8,6 @@ from colossalai.fx.tracer.tracer import ColoTracer
|
||||||
from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec
|
from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec
|
||||||
from colossalai.auto_parallel.solver.op_handler.dot_handler import DotHandler
|
from colossalai.auto_parallel.solver.op_handler.dot_handler import DotHandler
|
||||||
from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector
|
from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector
|
||||||
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
|
|
||||||
from colossalai.device.device_mesh import DeviceMesh
|
from colossalai.device.device_mesh import DeviceMesh
|
||||||
|
|
||||||
|
|
||||||
|
@ -31,7 +30,6 @@ def test_dot_handler():
|
||||||
# [2, 3]]
|
# [2, 3]]
|
||||||
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
|
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
|
||||||
entire_shape = torch.Size((4, 8))
|
entire_shape = torch.Size((4, 8))
|
||||||
shape_consistency_manager = ShapeConsistencyManager()
|
|
||||||
|
|
||||||
tracer = ColoTracer()
|
tracer = ColoTracer()
|
||||||
model = LinearModel(8, 16)
|
model = LinearModel(8, 16)
|
||||||
|
@ -76,10 +74,11 @@ def test_dot_handler():
|
||||||
|
|
||||||
# generate dot strategy
|
# generate dot strategy
|
||||||
strategies_vector = StrategiesVector(node=nodes[2])
|
strategies_vector = StrategiesVector(node=nodes[2])
|
||||||
dot_handler = DotHandler(node=nodes[2],
|
dot_handler = DotHandler(
|
||||||
device_mesh=device_mesh,
|
node=nodes[2],
|
||||||
strategies_vector=strategies_vector,
|
device_mesh=device_mesh,
|
||||||
shape_consistency_manager=shape_consistency_manager)
|
strategies_vector=strategies_vector,
|
||||||
|
)
|
||||||
strategies_vector = dot_handler.register_strategy()
|
strategies_vector = dot_handler.register_strategy()
|
||||||
|
|
||||||
# ['S0S1 = S0R x RS1', 'S1S0 = S1R x RS0', 'S0R = S0S1 x S1R', 'S1R = S1S0 x S0R', 'RS1 = RS0 x S0S1', 'RS0 = RS1 x S1S0', 'RS0 = RR x RS0', 'RS1 = RR x RS1', 'RR = RR x RR']
|
# ['S0S1 = S0R x RS1', 'S1S0 = S1R x RS0', 'S0R = S0S1 x S1R', 'S1R = S1S0 x S0R', 'RS1 = RS0 x S0S1', 'RS0 = RS1 x S1S0', 'RS0 = RR x RS0', 'RS1 = RR x RS1', 'RR = RR x RR']
|
||||||
|
|
|
@ -8,7 +8,6 @@ from colossalai.fx.tracer.tracer import ColoTracer
|
||||||
from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec
|
from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec
|
||||||
from colossalai.auto_parallel.solver.op_handler.conv_handler import CONV_STRATEGIES_LIST
|
from colossalai.auto_parallel.solver.op_handler.conv_handler import CONV_STRATEGIES_LIST
|
||||||
from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector
|
from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector
|
||||||
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
|
|
||||||
from colossalai.device.device_mesh import DeviceMesh
|
from colossalai.device.device_mesh import DeviceMesh
|
||||||
from colossalai.auto_parallel.solver.strategies_constructor import StrategiesConstructor
|
from colossalai.auto_parallel.solver.strategies_constructor import StrategiesConstructor
|
||||||
from colossalai.auto_parallel.solver.options import SolverOptions
|
from colossalai.auto_parallel.solver.options import SolverOptions
|
||||||
|
@ -34,7 +33,6 @@ def test_strategies_constructor():
|
||||||
# [2, 3]]
|
# [2, 3]]
|
||||||
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
|
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
|
||||||
entire_shape = torch.Size((4, 16, 64, 64))
|
entire_shape = torch.Size((4, 16, 64, 64))
|
||||||
shape_consistency_manager = ShapeConsistencyManager()
|
|
||||||
|
|
||||||
tracer = ColoTracer()
|
tracer = ColoTracer()
|
||||||
model = ConvModel(16, 32)
|
model = ConvModel(16, 32)
|
||||||
|
@ -49,7 +47,7 @@ def test_strategies_constructor():
|
||||||
gm.recompile()
|
gm.recompile()
|
||||||
|
|
||||||
solver_options = SolverOptions(fast=True)
|
solver_options = SolverOptions(fast=True)
|
||||||
strategies_constructor = StrategiesConstructor(graph, device_mesh, shape_consistency_manager, solver_options)
|
strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options)
|
||||||
|
|
||||||
assert strategies_constructor.leaf_strategies == []
|
assert strategies_constructor.leaf_strategies == []
|
||||||
assert strategies_constructor.strategy_map == {}
|
assert strategies_constructor.strategy_map == {}
|
||||||
|
|
Loading…
Reference in New Issue