mirror of https://github.com/hpcaitech/ColossalAI
[autoparallel] refactored shape consistency to remove redundancy (#1591)
* [autoparallel] refactored shape consistency to remove redundancy * polish code * polish code * polish codepull/1587/head^2
parent
d164449d00
commit
27fe8af60c
|
@ -1,8 +1,9 @@
|
|||
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
|
||||
import torch
|
||||
from torch.fx.node import Node
|
||||
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from typing import Union, Dict, List
|
||||
from typing import Union, Dict, List, Optional
|
||||
|
||||
|
||||
def generate_sharding_spec(input_: Union[Node, torch.Tensor], device_mesh: DeviceMesh,
|
||||
|
@ -31,3 +32,45 @@ def generate_sharding_spec(input_: Union[Node, torch.Tensor], device_mesh: Devic
|
|||
|
||||
sharding_spec = ShardingSpec(device_mesh=device_mesh, entire_shape=shape, dim_partition_dict=dim_partition_dict)
|
||||
return sharding_spec
|
||||
|
||||
|
||||
def generate_resharding_costs(nodes: List[Node],
|
||||
sharding_specs: List[ShardingSpec],
|
||||
count_backward: Optional[bool] = True,
|
||||
dtype: Optional[torch.dtype] = None):
|
||||
'''
|
||||
Compute the resharding costs with this specific strategy.
|
||||
|
||||
Argument:
|
||||
nodes (List[Node]): a list of nodes
|
||||
sharding_spec_for_input(ShardingSpec): a list of ShardingSpec for the nodes.
|
||||
count_backward (Optional[bool]): whether to include the cost of resharding in the backward pass, default is True. False can be used for inference.
|
||||
dtype (Optional[torch.dtype]): the data type for cost calculation, default is None.
|
||||
'''
|
||||
# The resharding_cost of weight is counted due to sharing weight cases.
|
||||
resharding_costs = {}
|
||||
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
|
||||
|
||||
# shape consistency manager is a singleton class
|
||||
shape_consistency_manager = ShapeConsistencyManager()
|
||||
|
||||
for input_node, input_spec in zip(nodes, sharding_specs):
|
||||
resharding_costs[input_node] = []
|
||||
for strategy in input_node.strategies_vector:
|
||||
input_sharding_spec = strategy.output_sharding_spec
|
||||
assert isinstance(input_sharding_spec, ShardingSpec), f'The input node should NOT be a tuple of tensor.'
|
||||
# compute the resharding cost during forward phase
|
||||
_, _, resharding_cost_forward = shape_consistency_manager.shape_consistency(input_sharding_spec, input_spec)
|
||||
|
||||
if count_backward:
|
||||
# In backward phase, we should convert grad with target_spec into input_sharding_spec
|
||||
_, _, resharding_cost_backward = shape_consistency_manager.shape_consistency(
|
||||
input_spec, input_sharding_spec)
|
||||
total_resharding_cost = resharding_cost_forward + resharding_cost_backward
|
||||
else:
|
||||
total_resharding_cost = resharding_cost_forward
|
||||
|
||||
# we need multiply the size of elem dtype to get correct communication cost
|
||||
resharding_cost = total_resharding_cost * size_per_elem_bytes
|
||||
resharding_costs[input_node].append(resharding_cost)
|
||||
return resharding_costs
|
||||
|
|
|
@ -4,7 +4,6 @@ import warnings
|
|||
import torch
|
||||
from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector
|
||||
from .operator_handler import OperatorHandler
|
||||
from .._utils import generate_sharding_spec
|
||||
|
||||
__all__ = ['BatchNormHandler']
|
||||
|
||||
|
@ -115,15 +114,13 @@ class BatchNormHandler(OperatorHandler):
|
|||
name = f'RS{mesh_dim_0} = RS{mesh_dim_0} x S{mesh_dim_0}'
|
||||
|
||||
dim_partition_dict_for_input = {1: [mesh_dim_0]}
|
||||
sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh,
|
||||
dim_partition_dict_for_input)
|
||||
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
|
||||
|
||||
dim_partition_dict_for_weight = {0: [mesh_dim_0]}
|
||||
sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight)
|
||||
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
|
||||
|
||||
dim_partition_dict_for_output = {1: [mesh_dim_0]}
|
||||
sharding_spec_for_output = generate_sharding_spec(self.output_data, self.device_mesh,
|
||||
dim_partition_dict_for_output)
|
||||
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||
|
||||
# generate resharding cost for this strategy
|
||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
|
||||
|
@ -156,8 +153,7 @@ class BatchNormHandler(OperatorHandler):
|
|||
new_name = f'S{mesh_dim_1}S{mesh_dim_0} = RS{mesh_dim_0} x S{mesh_dim_0}'
|
||||
|
||||
dim_partition_dict_for_output = {0: [mesh_dim_1], 1: [mesh_dim_0]}
|
||||
new_sharding_spec_for_output = generate_sharding_spec(self.output_data, self.device_mesh,
|
||||
dim_partition_dict_for_output)
|
||||
new_sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||
# the computation cost is all the same
|
||||
new_compute_cost = compute_cost
|
||||
|
||||
|
@ -192,15 +188,13 @@ class BatchNormHandler(OperatorHandler):
|
|||
name = f'RS{mesh_dim_0}{mesh_dim_1} = RS{mesh_dim_0}{mesh_dim_1} x S{mesh_dim_0}{mesh_dim_1}'
|
||||
|
||||
dim_partition_dict_for_input = {1: [mesh_dim_0, mesh_dim_1]}
|
||||
sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh,
|
||||
dim_partition_dict_for_input)
|
||||
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
|
||||
|
||||
dim_partition_dict_for_weight = {0: [mesh_dim_0, mesh_dim_1]}
|
||||
sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight)
|
||||
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
|
||||
|
||||
dim_partition_dict_for_output = {1: [mesh_dim_0, mesh_dim_1]}
|
||||
sharding_spec_for_output = generate_sharding_spec(self.output_data, self.device_mesh,
|
||||
dim_partition_dict_for_output)
|
||||
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||
|
||||
# generate resharding cost for this strategy
|
||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
|
||||
|
@ -234,15 +228,13 @@ class BatchNormHandler(OperatorHandler):
|
|||
name = f'RR = RR x R'
|
||||
|
||||
dim_partition_dict_for_input = {}
|
||||
sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh,
|
||||
dim_partition_dict_for_input)
|
||||
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
|
||||
|
||||
dim_partition_dict_for_weight = {}
|
||||
sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight)
|
||||
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
|
||||
|
||||
dim_partition_dict_for_output = {}
|
||||
sharding_spec_for_output = generate_sharding_spec(self.output_data, self.device_mesh,
|
||||
dim_partition_dict_for_output)
|
||||
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||
|
||||
# generate resharding cost for this strategy
|
||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
|
||||
|
@ -273,8 +265,7 @@ class BatchNormHandler(OperatorHandler):
|
|||
|
||||
def _construct_batch_sharding_strategies(mesh_dim_list, new_name):
|
||||
dim_partition_dict_for_output = {0: mesh_dim_list}
|
||||
new_sharding_spec_for_output = generate_sharding_spec(self.output_data, self.device_mesh,
|
||||
dim_partition_dict_for_output)
|
||||
new_sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||
|
||||
# the computation cost is all the same
|
||||
new_compute_cost = compute_cost
|
||||
|
@ -332,15 +323,13 @@ class BatchNormHandler(OperatorHandler):
|
|||
name = f'S{mesh_dim_0}R = S{mesh_dim_0}R x R WITH SYNC_BN'
|
||||
|
||||
dim_partition_dict_for_input = {0: [mesh_dim_0]}
|
||||
sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh,
|
||||
dim_partition_dict_for_input)
|
||||
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
|
||||
|
||||
dim_partition_dict_for_weight = {}
|
||||
sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight)
|
||||
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
|
||||
|
||||
dim_partition_dict_for_output = {0: [mesh_dim_0]}
|
||||
sharding_spec_for_output = generate_sharding_spec(self.output_data, self.device_mesh,
|
||||
dim_partition_dict_for_output)
|
||||
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||
|
||||
# generate resharding cost for this strategy
|
||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
|
||||
|
@ -374,15 +363,13 @@ class BatchNormHandler(OperatorHandler):
|
|||
name = f'S{mesh_dim_0}{mesh_dim_1}R = S{mesh_dim_0}{mesh_dim_1}R x R WITH SYNC_BN'
|
||||
|
||||
dim_partition_dict_for_input = {0: [mesh_dim_0, mesh_dim_1]}
|
||||
sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh,
|
||||
dim_partition_dict_for_input)
|
||||
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
|
||||
|
||||
dim_partition_dict_for_weight = {}
|
||||
sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight)
|
||||
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
|
||||
|
||||
dim_partition_dict_for_output = {0: [mesh_dim_0, mesh_dim_1]}
|
||||
sharding_spec_for_output = generate_sharding_spec(self.output_data, self.device_mesh,
|
||||
dim_partition_dict_for_output)
|
||||
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||
|
||||
# generate resharding cost for this strategy
|
||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
|
||||
|
@ -416,15 +403,13 @@ class BatchNormHandler(OperatorHandler):
|
|||
name = f'S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0}S{mesh_dim_1} x S{mesh_dim_1} WITH SYNC_BN'
|
||||
|
||||
dim_partition_dict_for_input = {0: [mesh_dim_0], 1: [mesh_dim_1]}
|
||||
sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh,
|
||||
dim_partition_dict_for_input)
|
||||
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
|
||||
|
||||
dim_partition_dict_for_weight = {0: [mesh_dim_1]}
|
||||
sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight)
|
||||
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
|
||||
|
||||
dim_partition_dict_for_output = {0: [mesh_dim_0], 1: [mesh_dim_1]}
|
||||
sharding_spec_for_output = generate_sharding_spec(self.output_data, self.device_mesh,
|
||||
dim_partition_dict_for_output)
|
||||
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||
|
||||
# generate resharding cost for this strategy
|
||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
|
||||
|
@ -459,7 +444,7 @@ class BatchNormHandler(OperatorHandler):
|
|||
Generate every possible strategies for a BatchNorm node, and record all strategies into the strategies_vector.
|
||||
|
||||
Example:
|
||||
norm_handler = BatchNormHandler(node, self.device_mesh, strategies_vector,
|
||||
norm_handler = BatchNormHandler(node, strategies_vector,
|
||||
self.shape_consistency_manager)
|
||||
norm_handler.register_strategy()
|
||||
for strategy in norm_handler.strategies_vector:
|
||||
|
|
|
@ -4,7 +4,6 @@ import warnings
|
|||
import torch
|
||||
from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector
|
||||
from .operator_handler import OperatorHandler
|
||||
from .._utils import generate_sharding_spec
|
||||
|
||||
__all__ = ['ConvHandler']
|
||||
|
||||
|
@ -109,15 +108,13 @@ class ConvHandler(OperatorHandler):
|
|||
name = f'S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0}R x RS{mesh_dim_1}'
|
||||
|
||||
dim_partition_dict_for_input = {0: [mesh_dim_0]}
|
||||
sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh,
|
||||
dim_partition_dict_for_input)
|
||||
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
|
||||
|
||||
dim_partition_dict_for_weight = {1: [mesh_dim_1]}
|
||||
sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight)
|
||||
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
|
||||
|
||||
dim_partition_dict_for_output = {0: [mesh_dim_0], 1: [mesh_dim_1]}
|
||||
sharding_spec_for_output = generate_sharding_spec(self.output_data, self.device_mesh,
|
||||
dim_partition_dict_for_output)
|
||||
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||
|
||||
# generate resharding cost for this strategy
|
||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
|
||||
|
@ -158,15 +155,13 @@ class ConvHandler(OperatorHandler):
|
|||
name = f'S{mesh_dim_0}R = S{mesh_dim_0}R x RR'
|
||||
|
||||
dim_partition_dict_for_input = {0: [mesh_dim_0]}
|
||||
sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh,
|
||||
dim_partition_dict_for_input)
|
||||
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
|
||||
|
||||
dim_partition_dict_for_weight = {}
|
||||
sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight)
|
||||
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
|
||||
|
||||
dim_partition_dict_for_output = {0: [mesh_dim_0]}
|
||||
sharding_spec_for_output = generate_sharding_spec(self.output_data, self.device_mesh,
|
||||
dim_partition_dict_for_output)
|
||||
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||
|
||||
# generate resharding cost for this strategy
|
||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
|
||||
|
@ -205,15 +200,13 @@ class ConvHandler(OperatorHandler):
|
|||
name = f'S{mesh_dim_0}R = S{mesh_dim_0}S{mesh_dim_1} x S{mesh_dim_1}R'
|
||||
|
||||
dim_partition_dict_for_input = {0: [mesh_dim_0], 1: [mesh_dim_1]}
|
||||
sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh,
|
||||
dim_partition_dict_for_input)
|
||||
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
|
||||
|
||||
dim_partition_dict_for_weight = {0: [mesh_dim_0]}
|
||||
sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight)
|
||||
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
|
||||
|
||||
dim_partition_dict_for_output = {0: [mesh_dim_0]}
|
||||
sharding_spec_for_output = generate_sharding_spec(self.output_data, self.device_mesh,
|
||||
dim_partition_dict_for_output)
|
||||
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||
|
||||
# generate resharding cost for this strategy
|
||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
|
||||
|
@ -252,15 +245,13 @@ class ConvHandler(OperatorHandler):
|
|||
name = f'RS{mesh_dim_1} = RS{mesh_dim_0} x S{mesh_dim_0}S{mesh_dim_1}'
|
||||
|
||||
dim_partition_dict_for_input = {1: [mesh_dim_0]}
|
||||
sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh,
|
||||
dim_partition_dict_for_input)
|
||||
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
|
||||
|
||||
dim_partition_dict_for_weight = {0: [mesh_dim_0], 1: [mesh_dim_1]}
|
||||
sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight)
|
||||
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
|
||||
|
||||
dim_partition_dict_for_output = {1: [mesh_dim_1]}
|
||||
sharding_spec_for_output = generate_sharding_spec(self.output_data, self.device_mesh,
|
||||
dim_partition_dict_for_output)
|
||||
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||
|
||||
# generate resharding cost for this strategy
|
||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
|
||||
|
@ -296,15 +287,13 @@ class ConvHandler(OperatorHandler):
|
|||
name = f'RR = RS{mesh_dim_0} x S{mesh_dim_0}R'
|
||||
|
||||
dim_partition_dict_for_input = {1: [mesh_dim_0]}
|
||||
sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh,
|
||||
dim_partition_dict_for_input)
|
||||
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
|
||||
|
||||
dim_partition_dict_for_weight = {0: [mesh_dim_0]}
|
||||
sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight)
|
||||
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
|
||||
|
||||
dim_partition_dict_for_output = {}
|
||||
sharding_spec_for_output = generate_sharding_spec(self.output_data, self.device_mesh,
|
||||
dim_partition_dict_for_output)
|
||||
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||
|
||||
# generate resharding cost for this strategy
|
||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
|
||||
|
@ -340,15 +329,13 @@ class ConvHandler(OperatorHandler):
|
|||
name = f'RS{mesh_dim_0} = RR x RS{mesh_dim_0}'
|
||||
|
||||
dim_partition_dict_for_input = {}
|
||||
sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh,
|
||||
dim_partition_dict_for_input)
|
||||
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
|
||||
|
||||
dim_partition_dict_for_weight = {1: [mesh_dim_0]}
|
||||
sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight)
|
||||
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
|
||||
|
||||
dim_partition_dict_for_output = {1: [mesh_dim_0]}
|
||||
sharding_spec_for_output = generate_sharding_spec(self.output_data, self.device_mesh,
|
||||
dim_partition_dict_for_output)
|
||||
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||
|
||||
# generate resharding cost for this strategy
|
||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
|
||||
|
@ -384,15 +371,13 @@ class ConvHandler(OperatorHandler):
|
|||
name = f'RR = RR x RR'
|
||||
|
||||
dim_partition_dict_for_input = {}
|
||||
sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh,
|
||||
dim_partition_dict_for_input)
|
||||
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
|
||||
|
||||
dim_partition_dict_for_weight = {}
|
||||
sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight)
|
||||
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
|
||||
|
||||
dim_partition_dict_for_output = {}
|
||||
sharding_spec_for_output = generate_sharding_spec(self.output_data, self.device_mesh,
|
||||
dim_partition_dict_for_output)
|
||||
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||
|
||||
# generate resharding cost for this strategy
|
||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
|
||||
|
@ -426,15 +411,13 @@ class ConvHandler(OperatorHandler):
|
|||
name = f'S{mesh_dim_0}{mesh_dim_1}R = S{mesh_dim_0}{mesh_dim_1}R x RR'
|
||||
|
||||
dim_partition_dict_for_input = {0: [mesh_dim_0, mesh_dim_1]}
|
||||
sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh,
|
||||
dim_partition_dict_for_input)
|
||||
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
|
||||
|
||||
dim_partition_dict_for_weight = {}
|
||||
sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight)
|
||||
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
|
||||
|
||||
dim_partition_dict_for_output = {0: [mesh_dim_0, mesh_dim_1]}
|
||||
sharding_spec_for_output = generate_sharding_spec(self.output_data, self.device_mesh,
|
||||
dim_partition_dict_for_output)
|
||||
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||
|
||||
# generate resharding cost for this strategy
|
||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
|
||||
|
@ -475,15 +458,13 @@ class ConvHandler(OperatorHandler):
|
|||
name = f'RR = RS{mesh_dim_0}{mesh_dim_1} x S{mesh_dim_0}{mesh_dim_1}R'
|
||||
|
||||
dim_partition_dict_for_input = {1: [mesh_dim_0, mesh_dim_1]}
|
||||
sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh,
|
||||
dim_partition_dict_for_input)
|
||||
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
|
||||
|
||||
dim_partition_dict_for_weight = {0: [mesh_dim_0, mesh_dim_1]}
|
||||
sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight)
|
||||
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
|
||||
|
||||
dim_partition_dict_for_output = {}
|
||||
sharding_spec_for_output = generate_sharding_spec(self.output_data, self.device_mesh,
|
||||
dim_partition_dict_for_output)
|
||||
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||
|
||||
# generate resharding cost for this strategy
|
||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
|
||||
|
|
|
@ -3,7 +3,6 @@ import torch
|
|||
from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector
|
||||
from .operator_handler import OperatorHandler
|
||||
from functools import reduce
|
||||
from .._utils import generate_sharding_spec
|
||||
|
||||
__all__ = ['DotHandler']
|
||||
|
||||
|
@ -29,16 +28,14 @@ class DotHandler(OperatorHandler):
|
|||
name = f'S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0}R x RS{mesh_dim_1}'
|
||||
|
||||
dim_partition_dict_for_input = {0: [mesh_dim_0]}
|
||||
sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh,
|
||||
dim_partition_dict_for_input)
|
||||
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
|
||||
|
||||
# linear layer weight is transposed during init
|
||||
dim_partition_dict_for_weight = {0: [mesh_dim_1]}
|
||||
sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight)
|
||||
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
|
||||
|
||||
dim_partition_dict_for_output = {0: [mesh_dim_0], 1: [mesh_dim_1]}
|
||||
sharding_spec_for_ouput = generate_sharding_spec(self.output_data, self.device_mesh,
|
||||
dim_partition_dict_for_input)
|
||||
sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_input)
|
||||
|
||||
# generate resharding cost for this strategy
|
||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
|
||||
|
@ -69,17 +66,15 @@ class DotHandler(OperatorHandler):
|
|||
name = f'S{mesh_dim_0}R = S{mesh_dim_0}S{mesh_dim_1} x S{mesh_dim_1}R'
|
||||
|
||||
dim_partition_dict_for_input = {0: [mesh_dim_0], 1: [mesh_dim_1]}
|
||||
sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh,
|
||||
dim_partition_dict_for_input)
|
||||
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
|
||||
|
||||
# since weight of the linear layer is transposed
|
||||
# the actual dim to be sharded is 1
|
||||
dim_partition_dict_for_weight = {1: [mesh_dim_0]}
|
||||
sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight)
|
||||
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
|
||||
|
||||
dim_partition_dict_for_output = {0: [mesh_dim_0]}
|
||||
sharding_spec_for_ouput = generate_sharding_spec(self.output_data, self.device_mesh,
|
||||
dim_partition_dict_for_output)
|
||||
sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||
|
||||
# generate resharding cost for this strategy
|
||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
|
||||
|
@ -106,15 +101,13 @@ class DotHandler(OperatorHandler):
|
|||
name = f'RS{mesh_dim_1} = RS{mesh_dim_0} x S{mesh_dim_0}S{mesh_dim_1}'
|
||||
|
||||
dim_partition_dict_for_input = {1: [mesh_dim_0]}
|
||||
sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh,
|
||||
dim_partition_dict_for_input)
|
||||
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
|
||||
|
||||
dim_partition_dict_for_weight = {0: [mesh_dim_0], 1: [mesh_dim_1]}
|
||||
sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight)
|
||||
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
|
||||
|
||||
dim_partition_dict_for_output = {1: [mesh_dim_1]}
|
||||
sharding_spec_for_ouput = generate_sharding_spec(self.output_data, self.device_mesh,
|
||||
dim_partition_dict_for_input)
|
||||
sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_input)
|
||||
|
||||
# generate resharding cost for this strategy
|
||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
|
||||
|
@ -141,15 +134,13 @@ class DotHandler(OperatorHandler):
|
|||
name = f'RR = RS{mesh_dim} x S{mesh_dim}R'
|
||||
|
||||
dim_partition_dict_for_input = {1: [mesh_dim]}
|
||||
sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh,
|
||||
dim_partition_dict_for_input)
|
||||
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
|
||||
|
||||
dim_partition_dict_for_weight = {1: [mesh_dim]}
|
||||
sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight)
|
||||
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
|
||||
|
||||
dim_partition_dict_for_output = {}
|
||||
sharding_spec_for_ouput = generate_sharding_spec(self.output_data, self.device_mesh,
|
||||
dim_partition_dict_for_output)
|
||||
sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||
|
||||
# generate resharding cost for this strategy
|
||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
|
||||
|
@ -176,15 +167,13 @@ class DotHandler(OperatorHandler):
|
|||
name = f'RS{mesh_dim} = RR x RS{mesh_dim}'
|
||||
|
||||
dim_partition_dict_for_input = {}
|
||||
sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh,
|
||||
dim_partition_dict_for_input)
|
||||
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
|
||||
|
||||
dim_partition_dict_for_weight = {0: [mesh_dim]}
|
||||
sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight)
|
||||
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
|
||||
|
||||
dim_partition_dict_for_output = {1: [mesh_dim]}
|
||||
sharding_spec_for_ouput = generate_sharding_spec(self.output_data, self.device_mesh,
|
||||
dim_partition_dict_for_output)
|
||||
sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||
|
||||
# generate resharding cost for this strategy
|
||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
|
||||
|
@ -211,15 +200,13 @@ class DotHandler(OperatorHandler):
|
|||
name = f'S{mesh_dim_0}{mesh_dim_1}R = S{mesh_dim_0}{mesh_dim_1}R x RR'
|
||||
|
||||
dim_partition_dict_for_input = {0: [mesh_dim_0, mesh_dim_1]}
|
||||
sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh,
|
||||
dim_partition_dict_for_input)
|
||||
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
|
||||
|
||||
dim_partition_dict_for_weight = {}
|
||||
sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight)
|
||||
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
|
||||
|
||||
dim_partition_dict_for_output = {0: [mesh_dim_0, mesh_dim_1]}
|
||||
sharding_spec_for_ouput = generate_sharding_spec(self.output_data, self.device_mesh,
|
||||
dim_partition_dict_for_output)
|
||||
sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||
|
||||
# generate resharding cost for this strategy
|
||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
|
||||
|
@ -246,15 +233,13 @@ class DotHandler(OperatorHandler):
|
|||
name = f'RR = RS{mesh_dim_0}{mesh_dim_1} x S{mesh_dim_0}{mesh_dim_1}R'
|
||||
|
||||
dim_partition_dict_for_input = {1: [mesh_dim_0, mesh_dim_1]}
|
||||
sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh,
|
||||
dim_partition_dict_for_input)
|
||||
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
|
||||
|
||||
dim_partition_dict_for_weight = {0: [mesh_dim_0, mesh_dim_1]}
|
||||
sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight)
|
||||
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
|
||||
|
||||
dim_partition_dict_for_output = {}
|
||||
sharding_spec_for_ouput = generate_sharding_spec(self.output_data, self.device_mesh,
|
||||
dim_partition_dict_for_output)
|
||||
sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||
|
||||
# generate resharding cost for this strategy
|
||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
|
||||
|
@ -281,15 +266,13 @@ class DotHandler(OperatorHandler):
|
|||
name = f'RS{mesh_dim_0}{mesh_dim_1} = RR x RS{mesh_dim_0}{mesh_dim_1}'
|
||||
|
||||
dim_partition_dict_for_input = {}
|
||||
sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh,
|
||||
dim_partition_dict_for_input)
|
||||
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
|
||||
|
||||
dim_partition_dict_for_weight = {1: [mesh_dim_0, mesh_dim_1]}
|
||||
sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight)
|
||||
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
|
||||
|
||||
dim_partition_dict_for_output = {1: [mesh_dim_0, mesh_dim_1]}
|
||||
sharding_spec_for_ouput = generate_sharding_spec(self.output_data, self.device_mesh,
|
||||
dim_partition_dict_for_output)
|
||||
sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||
|
||||
# generate resharding cost for this strategy
|
||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
|
||||
|
|
|
@ -7,6 +7,7 @@ from typing import Dict, List
|
|||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
|
||||
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||
from .._utils import generate_resharding_costs, generate_sharding_spec
|
||||
|
||||
from ..sharding_strategy import StrategiesVector
|
||||
|
||||
|
@ -17,24 +18,24 @@ class OperatorHandler(ABC):
|
|||
'''
|
||||
The OperatorHandler is an abstract class used to generate every possible strategies for an operator node.
|
||||
|
||||
Argument:
|
||||
input_node(Node): the input node in node argument list.
|
||||
input_index(int): the index of input node in the node argument list.
|
||||
weight(torch.Tensor): Weight of the node.
|
||||
output_node(Node): Output_node is the output of the node.
|
||||
device_mesh(DeviceMesh): A logical view of a physical mesh.
|
||||
strategies_vector(StrategiesVector): all the strategies generated in this handler will be recorded into the strategies_vector.
|
||||
shape_consistency_manager(ShapeConsistencyManager): ShapeConsistencyManager will give the resharding costs of the different sharding specs.
|
||||
Args:
|
||||
node (Node): the input node in node argument list.
|
||||
device_mesh (DeviceMesh): A logical view of a physical mesh.
|
||||
strategies_vector (StrategiesVector): all the strategies generated in this handler will be recorded into the strategies_vector.
|
||||
handle_backward (Optional[bool]): whether to consider the backward pass. The default value is True. False can be used for inference.
|
||||
'''
|
||||
|
||||
def __init__(self, node: Node, device_mesh: DeviceMesh, strategies_vector: StrategiesVector,
|
||||
shape_consistency_manager: ShapeConsistencyManager):
|
||||
def __init__(self,
|
||||
node: Node,
|
||||
device_mesh: DeviceMesh,
|
||||
strategies_vector: StrategiesVector,
|
||||
handle_backward: bool = True):
|
||||
self.node = node
|
||||
self.predecessor_node = list(node._input_nodes.keys())
|
||||
self.successor_node = list(node.users.keys())
|
||||
self.device_mesh = device_mesh
|
||||
self.strategies_vector = strategies_vector
|
||||
self.shape_consistency_manager = shape_consistency_manager
|
||||
self.handle_backward = handle_backward
|
||||
|
||||
# find the module and its parameters associated with this node
|
||||
# this can be used to compute the compute/communication/sharding cost
|
||||
|
@ -102,35 +103,23 @@ class OperatorHandler(ABC):
|
|||
|
||||
return total_memory_cost, activation_memory_cost, weight_memory_cost
|
||||
|
||||
def _generate_resharding_costs(self, sharding_spec_for_input):
|
||||
'''
|
||||
Compute the resharding costs with this specific strategy.
|
||||
|
||||
Note: The resharding_cost of weight is NOT counted.
|
||||
|
||||
Argument:
|
||||
resharding_costs(Dict[int, List[float]]): The resharding cost generated in this method will be appended into this dictionary.
|
||||
Resharding_cost[i][j] means the cost of i-th argument in the output node argument list
|
||||
with j-th strategy in its strategies_vector transforms to sharding spec wanted in this
|
||||
strategy.
|
||||
sharding_spec_for_input(ShardingSpec): ShardingSpec of the input node.
|
||||
'''
|
||||
def _generate_resharding_costs(self, sharding_specs):
|
||||
# The resharding_cost of weight is counted due to sharing weight cases.
|
||||
resharding_costs = {}
|
||||
dtype = self.node._meta_data.dtype
|
||||
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
|
||||
for input_node, input_spec in zip(self.predecessor_node, sharding_spec_for_input):
|
||||
resharding_costs[input_node] = []
|
||||
for strategy in input_node.strategies_vector:
|
||||
input_sharding_spec = strategy.output_sharding_spec
|
||||
assert isinstance(input_sharding_spec, ShardingSpec), f'The input node should NOT be a tuple of tensor.'
|
||||
# compute the resharding cost during forward phase
|
||||
_, _, resharding_cost_forward = self.shape_consistency_manager.shape_consistency(
|
||||
input_sharding_spec, input_spec)
|
||||
# In backward phase, we should convert grad with target_spec into input_sharding_spec
|
||||
_, _, resharding_cost_backward = self.shape_consistency_manager.shape_consistency(
|
||||
input_spec, input_sharding_spec)
|
||||
# we need multiply the size of elem dtype to get correct communication cost
|
||||
resharding_cost = (resharding_cost_forward + resharding_cost_backward) * size_per_elem_bytes
|
||||
resharding_costs[input_node].append(resharding_cost)
|
||||
return resharding_costs
|
||||
nodes = self.predecessor_node
|
||||
return generate_resharding_costs(nodes=nodes,
|
||||
sharding_specs=sharding_specs,
|
||||
count_backward=self.handle_backward,
|
||||
dtype=dtype)
|
||||
|
||||
def _generate_sharding_spec(self, input_: torch.Tensor, dim_partition_dict: Dict[int, List[int]]) -> ShardingSpec:
|
||||
return generate_sharding_spec(input_=input_,
|
||||
device_mesh=self.device_mesh,
|
||||
dim_partition_dict=dim_partition_dict)
|
||||
|
||||
@abstractmethod
|
||||
def _generate_compute_cost(self, *args, **kwargs):
|
||||
"""
|
||||
Compute the flops involved in the node.
|
||||
"""
|
||||
pass
|
||||
|
|
|
@ -11,7 +11,7 @@ import math
|
|||
import torch
|
||||
import operator
|
||||
from typing import Dict, List
|
||||
from ._utils import generate_sharding_spec
|
||||
from ._utils import generate_sharding_spec, generate_resharding_costs
|
||||
|
||||
|
||||
class StrategiesConstructor:
|
||||
|
@ -21,12 +21,10 @@ class StrategiesConstructor:
|
|||
Args:
|
||||
graph (Graph): a Graph object used for analysis and strategy generation.
|
||||
device_mesh (DeviceMesh): a DeviceMesh object which contains the meta information about the cluster.
|
||||
shape_consistency_manager (ShapeConsistencyManager): a ShapeConsistencyManager object to make sure the sharding specs are consistent.
|
||||
solver_options (SolverOptions): a SolverOptions object which specifies the preferences for plan searching.
|
||||
"""
|
||||
|
||||
def __init__(self, graph: Graph, device_mesh: DeviceMesh, shape_consistency_manager: ShapeConsistencyManager,
|
||||
solver_options: SolverOptions):
|
||||
def __init__(self, graph: Graph, device_mesh: DeviceMesh, solver_options: SolverOptions):
|
||||
self.graph = graph
|
||||
assert graph.owning_module is not None, 'The given graph is not associated with a owning_module'
|
||||
self.root_module = self.graph.owning_module
|
||||
|
@ -34,27 +32,8 @@ class StrategiesConstructor:
|
|||
self.device_mesh = device_mesh
|
||||
self.leaf_strategies = []
|
||||
self.strategy_map = {}
|
||||
self.shape_consistency_manager = shape_consistency_manager
|
||||
self.solver_options = solver_options
|
||||
|
||||
def _generate_resharding_costs(self, input_nodes, target_sharding_specs):
|
||||
'''
|
||||
Compute the resharding costs with this specific strategy.
|
||||
|
||||
Argument:
|
||||
sharding_spec_for_input(ShardingSpec): ShardingSpec of the input node.
|
||||
'''
|
||||
resharding_costs = {}
|
||||
for input_node, target_sharding_spec in zip(input_nodes, target_sharding_specs):
|
||||
resharding_costs[input_node] = []
|
||||
for strategy in input_node.strategies_vector:
|
||||
input_sharding_spec = strategy.output_sharding_spec
|
||||
assert isinstance(input_sharding_spec, ShardingSpec), f'The input node should NOT be a tuple of tensor.'
|
||||
_, _, resharding_cost = self.shape_consistency_manager.shape_consistency(
|
||||
input_sharding_spec, target_sharding_spec)
|
||||
resharding_costs[input_node].append(resharding_cost)
|
||||
return resharding_costs
|
||||
|
||||
def remove_duplicated_strategy(self, strategies_vector):
|
||||
'''
|
||||
In build_strategies_and_cost method, we may produce some duplicated strategies.
|
||||
|
@ -120,14 +99,13 @@ class StrategiesConstructor:
|
|||
# conv module
|
||||
if submod_type in CONV_MODULE_OP:
|
||||
# use ConvHandler to create sharding strategies for conv module node
|
||||
conv_handler = ConvHandler(node, self.device_mesh, strategies_vector,
|
||||
self.shape_consistency_manager)
|
||||
conv_handler = ConvHandler(node, self.device_mesh, strategies_vector)
|
||||
conv_handler.register_strategy()
|
||||
|
||||
# linear module
|
||||
elif submod_type in LINEAR_MODULE_OP:
|
||||
# use DotHandler to create sharding strategies for linear module node
|
||||
dot_handler = DotHandler(node, self.device_mesh, strategies_vector, self.shape_consistency_manager)
|
||||
dot_handler = DotHandler(node, self.device_mesh, strategies_vector)
|
||||
dot_handler.register_strategy()
|
||||
|
||||
# element-wise module
|
||||
|
@ -158,7 +136,7 @@ class StrategiesConstructor:
|
|||
# TODO: use meta_info_prop to profile memory cost and compute cost
|
||||
compute_cost = node._meta_data.numel()
|
||||
memory_cost = 0
|
||||
resharding_costs = self._generate_resharding_costs(strategies_vector.predecessor_nodes,
|
||||
resharding_costs = generate_resharding_costs(strategies_vector.predecessor_nodes,
|
||||
[input_sharding_spec])
|
||||
|
||||
# to prevent the resharding happening, set their resharding cost to inf.
|
||||
|
@ -214,7 +192,7 @@ class StrategiesConstructor:
|
|||
# TODO: use meta_info_prop to profile memory cost and compute cost
|
||||
compute_cost = node._meta_data.numel()
|
||||
memory_cost = 0
|
||||
resharding_costs = self._generate_resharding_costs(strategies_vector.predecessor_nodes,
|
||||
resharding_costs = generate_resharding_costs(strategies_vector.predecessor_nodes,
|
||||
[input_sharding_spec])
|
||||
|
||||
sharding_strategy = ShardingStrategy(name,
|
||||
|
@ -275,7 +253,7 @@ class StrategiesConstructor:
|
|||
compute_cost = node._meta_data.numel()
|
||||
memory_cost = 0
|
||||
|
||||
resharding_costs = self._generate_resharding_costs(strategies_vector.predecessor_nodes,
|
||||
resharding_costs = generate_resharding_costs(strategies_vector.predecessor_nodes,
|
||||
[input_sharding_spec])
|
||||
|
||||
# to prevent the resharding happening, set their resharding cost to inf.
|
||||
|
@ -317,7 +295,7 @@ class StrategiesConstructor:
|
|||
# TODO: use meta_info_prop to profile origin memory cost and compute cost, then divide them depending on sharding spec.
|
||||
compute_cost = 0
|
||||
memory_cost = 0
|
||||
resharding_costs = self._generate_resharding_costs(strategies_vector.predecessor_nodes,
|
||||
resharding_costs = generate_resharding_costs(strategies_vector.predecessor_nodes,
|
||||
[new_input_sharding_spec])
|
||||
sharding_strategy = ShardingStrategy(name, (output_sharding_spec, output_sharding_spec),
|
||||
compute_cost=compute_cost,
|
||||
|
@ -335,7 +313,7 @@ class StrategiesConstructor:
|
|||
# TODO: use meta_info_prop to profile origin memory cost and compute cost, then divide them depending on sharding spec.
|
||||
compute_cost = 0
|
||||
memory_cost = 0
|
||||
resharding_costs = self._generate_resharding_costs(strategies_vector.predecessor_nodes,
|
||||
resharding_costs = generate_resharding_costs(strategies_vector.predecessor_nodes,
|
||||
[input_sharding_spec])
|
||||
sharding_strategy = ShardingStrategy(name, (output_sharding_spec, output_sharding_spec),
|
||||
compute_cost=compute_cost,
|
||||
|
@ -360,7 +338,7 @@ class StrategiesConstructor:
|
|||
# TODO: use meta_info_prop to profile origin memory cost and compute cost, then divide them depending on sharding spec.
|
||||
compute_cost = 0
|
||||
memory_cost = 0
|
||||
resharding_costs = self._generate_resharding_costs(strategies_vector.predecessor_nodes,
|
||||
resharding_costs = generate_resharding_costs(strategies_vector.predecessor_nodes,
|
||||
[input_sharding_spec])
|
||||
# to prevent the resharding happening, set their resharding cost to inf.
|
||||
resharding_costs[input_tensor_node] = [
|
||||
|
@ -397,7 +375,7 @@ class StrategiesConstructor:
|
|||
output_sharding_spec = input_sharding_specs
|
||||
# TODO: use meta_info_prop to profile memory cost
|
||||
memory_cost = 0
|
||||
resharding_costs = self._generate_resharding_costs(strategies_vector.predecessor_nodes,
|
||||
resharding_costs = generate_resharding_costs(strategies_vector.predecessor_nodes,
|
||||
input_sharding_specs)
|
||||
|
||||
# clear the resharding cost for the output node
|
||||
|
|
|
@ -15,4 +15,7 @@ class SingletonMeta(type):
|
|||
if cls not in cls._instances:
|
||||
instance = super().__call__(*args, **kwargs)
|
||||
cls._instances[cls] = instance
|
||||
else:
|
||||
assert len(args) == 0 and len(
|
||||
kwargs) == 0, f'{cls.__name__} is a singleton class and a instance has been created.'
|
||||
return cls._instances[cls]
|
||||
|
|
|
@ -1,15 +1,22 @@
|
|||
import torch
|
||||
from dataclasses import dataclass
|
||||
from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec
|
||||
from colossalai.tensor.utils import all_gather_simulator, all_to_all_simulator, shard_simulator
|
||||
from enum import Enum
|
||||
from copy import deepcopy
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
from colossalai.context.singleton_meta import SingletonMeta
|
||||
import torch.distributed as dist
|
||||
import math
|
||||
from functools import reduce
|
||||
import operator
|
||||
from torch.distributed import ReduceOp
|
||||
|
||||
__all__ = [
|
||||
'CollectiveCommPattern', 'CommSpec', 'ShapeConsistencyManager', 'ShapeConsistencyOptions',
|
||||
'set_shape_consistency_options'
|
||||
]
|
||||
|
||||
|
||||
class CollectiveCommPattern(Enum):
|
||||
ALLGATHER = 'all_gather'
|
||||
|
@ -152,14 +159,40 @@ class CommSpec:
|
|||
tensor.data = tensor
|
||||
|
||||
|
||||
class ShapeConsistencyManager:
|
||||
@dataclass
|
||||
class ShapeConsistencyOptions:
|
||||
"""
|
||||
ShapeConsistencyOptions is a dataclass which specifies the preferences for shape consistency.
|
||||
"""
|
||||
# TODO: shape consistency option is not implemented yet
|
||||
pass
|
||||
|
||||
def __init__(self, consistency_option=None):
|
||||
self.consistency_option = consistency_option
|
||||
|
||||
def set_shape_consistency_options(options: ShapeConsistencyOptions):
|
||||
"""
|
||||
Configure the shape consistency manager via function call.
|
||||
"""
|
||||
manager = ShapeConsistencyManager()
|
||||
manager.options = options
|
||||
|
||||
|
||||
class ShapeConsistencyManager(metaclass=SingletonMeta):
|
||||
|
||||
def __init__(self):
|
||||
self._options = None
|
||||
self.total_communication_cost = 0
|
||||
self.total_transform_steps = 0
|
||||
self.cached_spec_pairs_transform_path = {}
|
||||
|
||||
@property
|
||||
def options(self):
|
||||
return self._options
|
||||
|
||||
@options.setter
|
||||
def options(self, options_: ShapeConsistencyOptions):
|
||||
assert isinstance(options_, ShapeConsistencyOptions)
|
||||
self._options = options_
|
||||
|
||||
def get_all_all_gather_spec(self, source_spec, orig_cost):
|
||||
'''
|
||||
Get all valid sharding specs from source_spec with single all-gather operation, and
|
||||
|
|
|
@ -8,7 +8,6 @@ from colossalai.fx.tracer.tracer import ColoTracer
|
|||
from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec
|
||||
from colossalai.auto_parallel.solver.op_handler.batch_norm_handler import BatchNormHandler
|
||||
from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector
|
||||
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
|
||||
|
||||
|
@ -31,7 +30,6 @@ def test_bn_handler():
|
|||
# [2, 3]]
|
||||
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
|
||||
entire_shape = torch.Size((4, 16, 64, 64))
|
||||
shape_consistency_manager = ShapeConsistencyManager()
|
||||
|
||||
tracer = ColoTracer()
|
||||
model = BNModel(16)
|
||||
|
@ -77,10 +75,11 @@ def test_bn_handler():
|
|||
|
||||
# generate bn strategy
|
||||
strategies_vector = StrategiesVector(node=nodes[2])
|
||||
bn_handler = BatchNormHandler(node=nodes[2],
|
||||
bn_handler = BatchNormHandler(
|
||||
node=nodes[2],
|
||||
device_mesh=device_mesh,
|
||||
strategies_vector=strategies_vector,
|
||||
shape_consistency_manager=shape_consistency_manager)
|
||||
)
|
||||
bn_handler.register_strategy()
|
||||
# ['RS0 = RS0 x S0', 'S1S0 = RS0 x S0', 'RS1 = RS1 x S1', 'S0S1 = RS1 x S1', 'RR = RR x R', 'S0R = RR x R', 'S1R = RR x R', 'S01R = RR x R', 'RS01 = RS01 x S01',
|
||||
# 'S0R = S0R x R WITH SYNC_BN', 'S1R = S1R x R WITH SYNC_BN', 'S0S1 = S0S1 x S1 WITH SYNC_BN', 'S1S0 = S1S0 x S0 WITH SYNC_BN', 'S01R = S01R x R WITH SYNC_BN']
|
||||
|
|
|
@ -8,7 +8,6 @@ from colossalai.fx.tracer.tracer import ColoTracer
|
|||
from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec
|
||||
from colossalai.auto_parallel.solver.op_handler.conv_handler import ConvHandler
|
||||
from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector
|
||||
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
|
||||
|
||||
|
@ -31,7 +30,6 @@ def test_conv_handler():
|
|||
# [2, 3]]
|
||||
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
|
||||
entire_shape = torch.Size((4, 16, 64, 64))
|
||||
shape_consistency_manager = ShapeConsistencyManager()
|
||||
|
||||
tracer = ColoTracer()
|
||||
model = ConvModel(16, 32)
|
||||
|
@ -77,10 +75,11 @@ def test_conv_handler():
|
|||
|
||||
# generate conv strategy
|
||||
strategies_vector = StrategiesVector(node=nodes[2])
|
||||
conv_handler = ConvHandler(node=nodes[2],
|
||||
conv_handler = ConvHandler(
|
||||
node=nodes[2],
|
||||
device_mesh=device_mesh,
|
||||
strategies_vector=strategies_vector,
|
||||
shape_consistency_manager=shape_consistency_manager)
|
||||
)
|
||||
conv_handler.register_strategy()
|
||||
# ['S0S1 = S0R x RS1', 'S1S0 = S1R x RS0', 'S0R = S0R x RR', 'S1R = S1R x RR', 'S0R = S0S1 x S1R', 'S1R = S1S0 x S0R', 'RS1 = RS0 x S0S1', 'RS0 = RS1 x S1S0', 'RR = RS0 x S0R', 'RR = RS1 x S1R', 'RS0 = RR x RS0', 'RS1 = RR x RS1', 'RR = RR x RR', 'S01R = S01R x RR', 'RR = RS01 x S01R']
|
||||
strategy_name_list = [strategy.name for strategy in conv_handler.strategies_vector]
|
||||
|
|
|
@ -4,10 +4,7 @@ from torch.fx import GraphModule
|
|||
import torch.nn as nn
|
||||
import pytest
|
||||
|
||||
from colossalai.fx.proxy import ColoProxy
|
||||
from colossalai.fx.tracer.tracer import ColoTracer
|
||||
from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec
|
||||
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.auto_parallel.solver.strategies_constructor import StrategiesConstructor
|
||||
from colossalai.auto_parallel.solver.cost_graph import CostGraph
|
||||
|
@ -37,7 +34,6 @@ def test_cost_graph():
|
|||
# [2, 3]]
|
||||
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
|
||||
entire_shape = torch.Size((4, 16, 64, 64))
|
||||
shape_consistency_manager = ShapeConsistencyManager()
|
||||
|
||||
tracer = ColoTracer()
|
||||
model = ConvModel(16, 32)
|
||||
|
@ -55,7 +51,7 @@ def test_cost_graph():
|
|||
gm.recompile()
|
||||
|
||||
solver_options = SolverOptions(fast=True)
|
||||
strategies_constructor = StrategiesConstructor(graph, device_mesh, shape_consistency_manager, solver_options)
|
||||
strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options)
|
||||
strategies_constructor.build_strategies_and_cost()
|
||||
|
||||
# (x, mul):{(0, 0): 0}
|
||||
|
|
|
@ -8,7 +8,6 @@ from colossalai.fx.tracer.tracer import ColoTracer
|
|||
from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec
|
||||
from colossalai.auto_parallel.solver.op_handler.dot_handler import DotHandler
|
||||
from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector
|
||||
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
|
||||
|
||||
|
@ -31,7 +30,6 @@ def test_dot_handler():
|
|||
# [2, 3]]
|
||||
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
|
||||
entire_shape = torch.Size((4, 8))
|
||||
shape_consistency_manager = ShapeConsistencyManager()
|
||||
|
||||
tracer = ColoTracer()
|
||||
model = LinearModel(8, 16)
|
||||
|
@ -76,10 +74,11 @@ def test_dot_handler():
|
|||
|
||||
# generate dot strategy
|
||||
strategies_vector = StrategiesVector(node=nodes[2])
|
||||
dot_handler = DotHandler(node=nodes[2],
|
||||
dot_handler = DotHandler(
|
||||
node=nodes[2],
|
||||
device_mesh=device_mesh,
|
||||
strategies_vector=strategies_vector,
|
||||
shape_consistency_manager=shape_consistency_manager)
|
||||
)
|
||||
strategies_vector = dot_handler.register_strategy()
|
||||
|
||||
# ['S0S1 = S0R x RS1', 'S1S0 = S1R x RS0', 'S0R = S0S1 x S1R', 'S1R = S1S0 x S0R', 'RS1 = RS0 x S0S1', 'RS0 = RS1 x S1S0', 'RS0 = RR x RS0', 'RS1 = RR x RS1', 'RR = RR x RR']
|
||||
|
|
|
@ -8,7 +8,6 @@ from colossalai.fx.tracer.tracer import ColoTracer
|
|||
from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec
|
||||
from colossalai.auto_parallel.solver.op_handler.conv_handler import CONV_STRATEGIES_LIST
|
||||
from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector
|
||||
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.auto_parallel.solver.strategies_constructor import StrategiesConstructor
|
||||
from colossalai.auto_parallel.solver.options import SolverOptions
|
||||
|
@ -34,7 +33,6 @@ def test_strategies_constructor():
|
|||
# [2, 3]]
|
||||
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
|
||||
entire_shape = torch.Size((4, 16, 64, 64))
|
||||
shape_consistency_manager = ShapeConsistencyManager()
|
||||
|
||||
tracer = ColoTracer()
|
||||
model = ConvModel(16, 32)
|
||||
|
@ -49,7 +47,7 @@ def test_strategies_constructor():
|
|||
gm.recompile()
|
||||
|
||||
solver_options = SolverOptions(fast=True)
|
||||
strategies_constructor = StrategiesConstructor(graph, device_mesh, shape_consistency_manager, solver_options)
|
||||
strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options)
|
||||
|
||||
assert strategies_constructor.leaf_strategies == []
|
||||
assert strategies_constructor.strategy_map == {}
|
||||
|
|
Loading…
Reference in New Issue