[autoparallel] add bcast op handler (#1600)

* [autoparallel] add bcast op handler

* polish code

* add more BCAST FUNC OP

* polish code

* add exception handler

* polish
pull/1606/head
YuliangLiu0306 2022-09-16 11:33:01 +08:00 committed by GitHub
parent 3abf98a633
commit eac1b79371
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 322 additions and 49 deletions

View File

@ -4,6 +4,10 @@ from torch.fx.node import Node
from colossalai.tensor.sharding_spec import ShardingSpec
from colossalai.device.device_mesh import DeviceMesh
from typing import Union, Dict, List, Optional
import warnings
from functools import reduce
import functools
import operator
def generate_sharding_spec(input_: Union[Node, torch.Tensor], device_mesh: DeviceMesh,
@ -29,6 +33,11 @@ def generate_sharding_spec(input_: Union[Node, torch.Tensor], device_mesh: Devic
raise TypeError(
f'We cannot generate sharding spec for {type(input_)} type, only torch.fx.Node or torch.Tensor is expected.'
)
for dim_index, sharding_index_list in dim_partition_dict.items():
sharding_list = [device_mesh.mesh_shape[sharding_index] for sharding_index in sharding_index_list]
sharding_size = reduce(operator.mul, sharding_list, 1)
assert shape[
dim_index] % sharding_size == 0, f'we cannot shard the {dim_index} dimension of tensor into {sharding_size} partitions.'
sharding_spec = ShardingSpec(device_mesh=device_mesh, entire_shape=shape, dim_partition_dict=dim_partition_dict)
return sharding_spec
@ -74,3 +83,18 @@ def generate_resharding_costs(nodes: List[Node],
resharding_cost = total_resharding_cost * size_per_elem_bytes
resharding_costs[input_node].append(resharding_cost)
return resharding_costs
def exception_handler(func):
"""
A function wrapper which executes the function with a specified seed.
"""
@functools.wraps(func)
def wrapper(*args, **kwargs):
try:
func(*args, **kwargs)
except Exception as e:
warnings.warn(f'{e}')
return wrapper

View File

@ -3,16 +3,19 @@ import operator
__all__ = [
'ELEMENTWISE_MODULE_OP', 'ELEMENTWISE_FUNC_OP', 'RESHAPE_FUNC_OP', 'CONV_MODULE_OP', 'CONV_FUNC_OP',
'LINEAR_MODULE_OP', 'LINEAR_FUNC_OP', 'BATCHNORM_MODULE_OP', 'POOL_MODULE_OP', 'NON_PARAM_FUNC_OP'
'LINEAR_MODULE_OP', 'LINEAR_FUNC_OP', 'BATCHNORM_MODULE_OP', 'POOL_MODULE_OP', 'NON_PARAM_FUNC_OP', 'BCAST_FUNC_OP'
]
ELEMENTWISE_MODULE_OP = [torch.nn.Dropout, torch.nn.ReLU]
# TODO: flatten should not be added into this group
ELEMENTWISE_FUNC_OP = [
torch.add, operator.add, torch.abs, torch.cos, torch.exp, torch.mul, operator.mul, operator.floordiv,
operator.truediv, operator.neg, torch.multiply, torch.nn.functional.relu, torch.nn.functional.dropout
torch.abs, torch.cos, torch.exp, operator.neg, torch.multiply, torch.nn.functional.relu,
torch.nn.functional.dropout, torch.flatten
]
RESHAPE_FUNC_OP = [torch.flatten, torch.Tensor.view, torch.reshape]
BCAST_FUNC_OP = [
torch.add, torch.sub, torch.mul, torch.div, torch.floor_divide, torch.true_divide, operator.add, operator.sub,
operator.mul, operator.floordiv, operator.truediv
]
CONV_MODULE_OP = [
torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d, torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d,
torch.nn.ConvTranspose3d

View File

@ -3,5 +3,6 @@ from .dot_handler import DotHandler
from .conv_handler import ConvHandler
from .batch_norm_handler import BatchNormHandler
from .reshape_handler import ReshapeHandler
from .bcast_op_handler import BcastOpHandler
__all__ = ['OperatorHandler', 'DotHandler', 'ConvHandler', 'BatchNormHandler', 'ReshapeHandler']
__all__ = ['OperatorHandler', 'DotHandler', 'ConvHandler', 'BatchNormHandler', 'ReshapeHandler', 'BcastOpHandler']

View File

@ -1,9 +1,9 @@
import operator
from functools import reduce
import warnings
import torch
from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector
from .operator_handler import OperatorHandler
from colossalai.auto_parallel.solver._utils import exception_handler
__all__ = ['BatchNormHandler']
@ -110,6 +110,7 @@ class BatchNormHandler(OperatorHandler):
return memory_cost, memory_cost_forward_activation, memory_cost_backward_activation
@exception_handler
def split_input_channel(self, mesh_dim_0, mesh_dim_1):
name = f'RS{mesh_dim_0} = RS{mesh_dim_0} x S{mesh_dim_0}'
@ -184,6 +185,7 @@ class BatchNormHandler(OperatorHandler):
self.strategies_vector.append(sharding_strategies)
@exception_handler
def split_input_channel_1d(self, mesh_dim_0, mesh_dim_1):
name = f'RS{mesh_dim_0}{mesh_dim_1} = RS{mesh_dim_0}{mesh_dim_1} x S{mesh_dim_0}{mesh_dim_1}'
@ -224,6 +226,7 @@ class BatchNormHandler(OperatorHandler):
self.strategies_vector.append(sharding_strategies)
@exception_handler
def non_split(self, mesh_dim_0, mesh_dim_1):
name = f'RR = RR x R'
@ -319,6 +322,7 @@ class BatchNormHandler(OperatorHandler):
new_sharding_strategy = _construct_batch_sharding_strategies(mesh_dim_list, new_name)
self.strategies_vector.append(new_sharding_strategy)
@exception_handler
def split_input_batch(self, mesh_dim_0):
name = f'S{mesh_dim_0}R = S{mesh_dim_0}R x R WITH SYNC_BN'
@ -359,6 +363,7 @@ class BatchNormHandler(OperatorHandler):
self.strategies_vector.append(sharding_strategies)
@exception_handler
def split_input_batch_1d(self, mesh_dim_0, mesh_dim_1):
name = f'S{mesh_dim_0}{mesh_dim_1}R = S{mesh_dim_0}{mesh_dim_1}R x R WITH SYNC_BN'
@ -399,6 +404,7 @@ class BatchNormHandler(OperatorHandler):
self.strategies_vector.append(sharding_strategies)
@exception_handler
def split_input_both_dim(self, mesh_dim_0, mesh_dim_1):
name = f'S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0}S{mesh_dim_1} x S{mesh_dim_1} WITH SYNC_BN'

View File

@ -0,0 +1,164 @@
import operator
from functools import reduce
import warnings
import torch
from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector
from .operator_handler import OperatorHandler
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
from colossalai.tensor.sharding_spec import ShardingSpec
from copy import deepcopy
from typing import Dict, List
from colossalai.auto_parallel.solver._utils import exception_handler
__all__ = ['BcastOpHandler']
class BcastOpHandler(OperatorHandler):
"""
An OperatorHandler which deals with the sharding strategies of broadcast operators(such as operator.add).
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
assert len(self.predecessor_node) == 2
self.lhs_data = self.predecessor_node[0]._meta_data
self.rhs_data = self.predecessor_node[1]._meta_data
self.lhs = self.predecessor_node[0]
self.rhs = self.predecessor_node[1]
self.output_data = self.node._meta_data
def _generate_sharding_spec(self, input_: torch.Tensor, dim_partition_dict: Dict[int, List[int]]) -> ShardingSpec:
shape = list(input_.shape)
# padding the shape to the same length as output_data
while len(shape) < self.output_data.dim():
shape.insert(0, 1)
shape = torch.Size(shape)
# if the sharding happens on a size one dimension, we should record it as R.
processed_dim_partition_dict = deepcopy(dim_partition_dict)
for dim_index, _ in dim_partition_dict.items():
if shape[dim_index] == 1:
processed_dim_partition_dict.pop(dim_index)
sharding_spec = ShardingSpec(device_mesh=self.device_mesh,
entire_shape=shape,
dim_partition_dict=processed_dim_partition_dict)
return sharding_spec
def _generate_resharding_costs(self, sharding_specs):
# The resharding_cost of weight is counted due to sharing weight cases.
dtype = self.node._meta_data.dtype
nodes = self.predecessor_node
resharding_costs = {}
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
# shape consistency manager is a singleton class
shape_consistency_manager = ShapeConsistencyManager()
for input_node, input_spec in zip(nodes, sharding_specs):
resharding_costs[input_node] = []
for strategy in input_node.strategies_vector:
input_sharding_spec = strategy.output_sharding_spec
assert isinstance(input_sharding_spec, ShardingSpec), f'The input node should NOT be a tuple of tensor.'
# if the input shape is smaller than the target input, we will fill the input to the same length as target.
# Then, use the padded input sharding spec to compute the resharding cost.
if len(input_sharding_spec.entire_shape) < len(input_spec.entire_shape):
new_entire_shape = list(input_sharding_spec.entire_shape)
while len(new_entire_shape) < len(input_spec.entire_shape):
new_entire_shape.insert(0, 1)
new_entire_shape = torch.Size(new_entire_shape)
new_device_mesh = input_sharding_spec.device_mesh
new_dim_partition_dict = input_sharding_spec.dim_partition_dict
input_sharding_spec = ShardingSpec(device_mesh=new_device_mesh,
entire_shape=new_entire_shape,
dim_partition_dict=new_dim_partition_dict)
# compute the resharding cost during forward phase
_, _, resharding_cost_forward = shape_consistency_manager.shape_consistency(
input_sharding_spec, input_spec)
_, _, resharding_cost_backward = shape_consistency_manager.shape_consistency(
input_spec, input_sharding_spec)
total_resharding_cost = resharding_cost_forward + resharding_cost_backward
# we need multiply the size of elem dtype to get correct communication cost
resharding_cost = total_resharding_cost * size_per_elem_bytes
resharding_costs[input_node].append(resharding_cost)
return resharding_costs
def _enumerate_all_possible_output(self, mesh_dim_0, mesh_dim_1):
# use mesh_dim_0, mesh_dim_1 instead of constant 0, 1 in here for N-D device mesh scaliablity.
output_sharding_spec_list = []
output_dim_partition_list = []
# enumerate all the 2D sharding cases
for i in range(self.output_data.dim()):
for j in range(i + 1, self.output_data.dim()):
dim_partition_dict_0 = {i: [mesh_dim_0], j: [mesh_dim_1]}
dim_partition_dict_1 = {i: [mesh_dim_1], j: [mesh_dim_0]}
output_dim_partition_list.append(dim_partition_dict_0)
output_dim_partition_list.append(dim_partition_dict_1)
# enumerate all the 1D sharding cases
for i in range(self.output_data.dim()):
dim_partition_dict_0 = {i: [mesh_dim_0]}
dim_partition_dict_1 = {i: [mesh_dim_1]}
dim_partition_dict_flatten = {i: [mesh_dim_0, mesh_dim_1]}
output_dim_partition_list.append(dim_partition_dict_0)
output_dim_partition_list.append(dim_partition_dict_1)
output_dim_partition_list.append(dim_partition_dict_flatten)
# add empty dict for fully replicated case
output_dim_partition_list.append({})
check_duplicated_list = []
for output_dim_partition_dict in output_dim_partition_list:
output_sharding_spec = self._generate_sharding_spec(self.output_data, output_dim_partition_dict)
sharding_seq = output_sharding_spec.sharding_sequence
if sharding_seq not in check_duplicated_list:
check_duplicated_list.append(sharding_seq)
output_sharding_spec_list.append(output_sharding_spec)
return output_sharding_spec_list
def _generate_compute_cost(self, *args, **kwargs):
return super()._generate_compute_cost(*args, **kwargs)
@exception_handler
def _register_strategy(self, output_sharding_spec):
dim_partition_dict_for_input = output_sharding_spec.dim_partition_dict
sharding_spec_for_lhs = self._generate_sharding_spec(self.lhs_data, dim_partition_dict_for_input)
sharding_spec_for_rhs = self._generate_sharding_spec(self.rhs_data, dim_partition_dict_for_input)
name = f'{output_sharding_spec.sharding_sequence} = {sharding_spec_for_lhs.sharding_sequence} x {sharding_spec_for_rhs.sharding_sequence}'
dim_partition_dict_for_output = output_sharding_spec.dim_partition_dict
# generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_lhs, sharding_spec_for_rhs])
# compute the computation cost of this strategy
sharding_dims = []
for mesh_dims in dim_partition_dict_for_output.values():
for mesh_dim in mesh_dims:
sharding_dims.append(self.device_mesh.shape[mesh_dim])
sharding_size = reduce(operator.mul, sharding_dims, 1)
memory_cost = self.output_data.numel() / sharding_size
compute_cost = memory_cost
communication_cost = 0
sharding_strategies = ShardingStrategy(name,
output_sharding_spec=output_sharding_spec,
compute_cost=compute_cost,
communication_cost=communication_cost,
memory_cost=memory_cost,
resharding_costs=resharding_costs,
input_shardings=(sharding_spec_for_lhs, sharding_spec_for_rhs))
self.strategies_vector.append(sharding_strategies)
def register_strategy(self) -> StrategiesVector:
output_sharding_specs = self._enumerate_all_possible_output(0, 1)
for output_sharding_spec in output_sharding_specs:
self._register_strategy(output_sharding_spec)

View File

@ -4,13 +4,14 @@ import warnings
import torch
from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector
from .operator_handler import OperatorHandler
from colossalai.auto_parallel.solver._utils import exception_handler
__all__ = ['ConvHandler']
class ConvHandler(OperatorHandler):
"""
A OperatorHandler which deals with the sharding strategies of Convolution.
An OperatorHandler which deals with the sharding strategies of Convolution.
"""
def __init__(self, *args, **kwargs):
@ -104,6 +105,7 @@ class ConvHandler(OperatorHandler):
return memory_cost, memory_cost_forward_activation, memory_cost_backward_activation, memory_cost_backward_weight
@exception_handler
def split_input_batch_weight_out_channel(self, mesh_dim_0, mesh_dim_1):
name = f'S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0}R x RS{mesh_dim_1}'
@ -151,6 +153,7 @@ class ConvHandler(OperatorHandler):
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
self.strategies_vector.append(sharding_strategies)
@exception_handler
def split_input_batch(self, mesh_dim_0):
name = f'S{mesh_dim_0}R = S{mesh_dim_0}R x RR'
@ -196,6 +199,7 @@ class ConvHandler(OperatorHandler):
self.strategies_vector.append(sharding_strategies)
@exception_handler
def split_input_both_dim_weight_in_channel(self, mesh_dim_0, mesh_dim_1):
name = f'S{mesh_dim_0}R = S{mesh_dim_0}S{mesh_dim_1} x S{mesh_dim_1}R'
@ -241,6 +245,7 @@ class ConvHandler(OperatorHandler):
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
self.strategies_vector.append(sharding_strategies)
@exception_handler
def split_input_in_channel_weight_both_channel(self, mesh_dim_0, mesh_dim_1):
name = f'RS{mesh_dim_1} = RS{mesh_dim_0} x S{mesh_dim_0}S{mesh_dim_1}'
@ -283,6 +288,7 @@ class ConvHandler(OperatorHandler):
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
self.strategies_vector.append(sharding_strategies)
@exception_handler
def split_input_in_channel_weight_in_channel(self, mesh_dim_0):
name = f'RR = RS{mesh_dim_0} x S{mesh_dim_0}R'
@ -325,6 +331,7 @@ class ConvHandler(OperatorHandler):
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
self.strategies_vector.append(sharding_strategies)
@exception_handler
def split_weight_out_channel(self, mesh_dim_0):
name = f'RS{mesh_dim_0} = RR x RS{mesh_dim_0}'
@ -367,6 +374,7 @@ class ConvHandler(OperatorHandler):
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
self.strategies_vector.append(sharding_strategies)
@exception_handler
def non_split(self):
name = f'RR = RR x RR'
@ -407,6 +415,7 @@ class ConvHandler(OperatorHandler):
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
self.strategies_vector.append(sharding_strategies)
@exception_handler
def split_1d_parallel_on_input_batch(self, mesh_dim_0, mesh_dim_1):
name = f'S{mesh_dim_0}{mesh_dim_1}R = S{mesh_dim_0}{mesh_dim_1}R x RR'
@ -454,6 +463,7 @@ class ConvHandler(OperatorHandler):
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
self.strategies_vector.append(sharding_strategies)
@exception_handler
def split_1d_parallel_on_in_channel(self, mesh_dim_0, mesh_dim_1):
name = f'RR = RS{mesh_dim_0}{mesh_dim_1} x S{mesh_dim_0}{mesh_dim_1}R'
@ -554,48 +564,24 @@ class ConvHandler(OperatorHandler):
RR = RS01 x S01R: compute_cost is 8856576, communication_cost is 0, memory_cost is 1968128, resharding_costs is {mul: [0, 0, 262148.4, 65538.002, 196614.402, 262148.4, 65538.2]}
'''
# SS = SR x RS
try:
self.split_input_batch_weight_out_channel(0, 1)
except Exception as e:
warnings.warn(f'{e}')
try:
self.split_input_batch_weight_out_channel(1, 0)
except Exception as e:
warnings.warn(f'{e}')
self.split_input_batch_weight_out_channel(0, 1)
self.split_input_batch_weight_out_channel(1, 0)
# SR = SR x RR
self.split_input_batch(0)
self.split_input_batch(1)
# SR = SS x SR
try:
self.split_input_both_dim_weight_in_channel(0, 1)
except Exception as e:
warnings.warn(f'{e}')
try:
self.split_input_both_dim_weight_in_channel(1, 0)
except Exception as e:
warnings.warn(f'{e}')
self.split_input_both_dim_weight_in_channel(0, 1)
self.split_input_both_dim_weight_in_channel(1, 0)
# RS = RS x SS
try:
self.split_input_in_channel_weight_both_channel(0, 1)
except Exception as e:
warnings.warn(f'{e}')
try:
self.split_input_in_channel_weight_both_channel(1, 0)
except Exception as e:
warnings.warn(f'{e}')
self.split_input_in_channel_weight_both_channel(0, 1)
self.split_input_in_channel_weight_both_channel(1, 0)
# RR = RS x SR
try:
self.split_input_in_channel_weight_in_channel(0)
except Exception as e:
warnings.warn(f'{e}')
try:
self.split_input_in_channel_weight_in_channel(1)
except Exception as e:
warnings.warn(f'{e}')
self.split_input_in_channel_weight_in_channel(0)
self.split_input_in_channel_weight_in_channel(1)
# RS = RR x RS
self.split_weight_out_channel(0)
@ -608,12 +594,7 @@ class ConvHandler(OperatorHandler):
self.split_1d_parallel_on_input_batch(0, 1)
# RR = RS01 x S01R
try:
self.split_1d_parallel_on_in_channel(0, 1)
except Exception as e:
warnings.warn(f'{e}')
# print(f'strategies num is :{len(self.strategies_vector)}')
self.split_1d_parallel_on_in_channel(0, 1)
return self.strategies_vector

View File

@ -6,10 +6,12 @@ from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy,
from .operator_handler import OperatorHandler
from ..constants import LINEAR_FUNC_OP, LINEAR_MODULE_OP
from functools import reduce
from colossalai.auto_parallel.solver._utils import exception_handler
from enum import Enum
from .strategy_generator import StrategyGenerator, IntermediateStrategy
from typing import List
__all__ = ['DotHandler']
@ -414,6 +416,7 @@ class DotHandler(OperatorHandler):
compute_cost = reduce(operator.mul, input_shape) * weight_shape[0] * 2
return compute_cost
@exception_handler
def split_lhs_space_rhs_space(self, mesh_dim_0, mesh_dim_1):
# handle case SS = SR x RS
name = f'S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0}R x RS{mesh_dim_1}'
@ -452,6 +455,7 @@ class DotHandler(OperatorHandler):
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
self.strategies_vector.append(sharding_strategies)
@exception_handler
def split_lhs_space_both_contract(self, mesh_dim_0, mesh_dim_1):
# handle the case SR = SS x SR
name = f'S{mesh_dim_0}R = S{mesh_dim_0}S{mesh_dim_1} x S{mesh_dim_1}R'
@ -488,6 +492,7 @@ class DotHandler(OperatorHandler):
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
self.strategies_vector.append(sharding_strategies)
@exception_handler
def split_rhs_space_both_contract(self, mesh_dim_0, mesh_dim_1):
name = f'RS{mesh_dim_1} = RS{mesh_dim_0} x S{mesh_dim_0}S{mesh_dim_1}'
@ -521,6 +526,7 @@ class DotHandler(OperatorHandler):
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
self.strategies_vector.append(sharding_strategies)
@exception_handler
def recompute_split_both_contract(self, mesh_dim):
name = f'RR = RS{mesh_dim} x S{mesh_dim}R'
@ -554,6 +560,7 @@ class DotHandler(OperatorHandler):
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
self.strategies_vector.append(sharding_strategies)
@exception_handler
def split_rhs_space_only(self, mesh_dim):
name = f'RS{mesh_dim} = RR x RS{mesh_dim}'
@ -587,6 +594,7 @@ class DotHandler(OperatorHandler):
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
self.strategies_vector.append(sharding_strategies)
@exception_handler
def split_lhs_1st_dim_1d(self, mesh_dim_0, mesh_dim_1):
name = f'S{mesh_dim_0}{mesh_dim_1}R = S{mesh_dim_0}{mesh_dim_1}R x RR'
@ -620,6 +628,7 @@ class DotHandler(OperatorHandler):
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
self.strategies_vector.append(sharding_strategies)
@exception_handler
def split_lhs_2nd_dim_1d(self, mesh_dim_0, mesh_dim_1):
name = f'RR = RS{mesh_dim_0}{mesh_dim_1} x S{mesh_dim_0}{mesh_dim_1}R'
@ -653,6 +662,7 @@ class DotHandler(OperatorHandler):
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
self.strategies_vector.append(sharding_strategies)
@exception_handler
def split_rhs_2nd_dim_1d(self, mesh_dim_0, mesh_dim_1):
name = f'RS{mesh_dim_0}{mesh_dim_1} = RR x RS{mesh_dim_0}{mesh_dim_1}'

View File

@ -20,6 +20,7 @@ class ReshapeHandler(OperatorHandler):
return super()._generate_compute_cost(*args, **kwargs)
def register_strategy(self):
# TODO: add strategies with more output sharding specs other than only fully replicated.
input_node = self.strategies_vector.predecessor_nodes[0]
# For reshape function, to keep the computing correctness we keep the sharding
# spec of input is fully replicated. In addition, we will keep the output in

View File

@ -70,6 +70,9 @@ class StrategiesVector(list):
# we could merge element-wise op, because the output sharding spec is always same as the input sharding spec.
if self.node.target in ELEMENTWISE_FUNC_OP:
merge_label = True
# we could merge bcast op if the rhs is a scalar, because it will fall back to the element-wise case.
if self.node.target in BCAST_FUNC_OP and len(self.predecessor_nodes) == 1:
merge_label = True
# we could merge reshape op, because the output sharding spec of reshape op is always fully replicated.
if self.node.target in RESHAPE_FUNC_OP:
merge_label = True

View File

@ -1,4 +1,5 @@
from torch.fx import Graph, Node
from colossalai.auto_parallel.solver.op_handler.bcast_op_handler import BcastOpHandler
from colossalai.tensor.sharding_spec import ShardingSpec
from colossalai.device.device_mesh import DeviceMesh
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
@ -52,6 +53,7 @@ class StrategiesConstructor:
def build_strategies_and_cost(self):
for node in self.nodes:
strategies_vector = StrategiesVector(node)
input_nodes_len = len(strategies_vector.predecessor_nodes)
# placeholder node
if node.op == 'placeholder':
# For placeholder nodes, if solver_options.fast is True, we just let them in
@ -165,6 +167,9 @@ class StrategiesConstructor:
# MaxPool module
elif submod_type in POOL_MODULE_OP:
# TODO: add sharding constraints on image dimension
# e.g.: for a 2D pooling input NCHW, we should promise no sharding happens on H and W dimension
# create sharding strategy for element-wise module
assert len(strategies_vector.predecessor_nodes
) == 1, f'Temporally, we just support single input element-wise op.'
@ -230,7 +235,7 @@ class StrategiesConstructor:
reshape_handler.register_strategy()
# element-wise function
elif target in ELEMENTWISE_FUNC_OP:
elif target in ELEMENTWISE_FUNC_OP or (target in BCAST_FUNC_OP and input_nodes_len == 1):
# TODO: integrate element-wise func and module together
# create sharding strategy for element-wise function
assert len(strategies_vector.predecessor_nodes
@ -271,6 +276,11 @@ class StrategiesConstructor:
input_shardings=[input_sharding_spec])
strategies_vector.append(sharding_strategy)
# bcast op
elif target in BCAST_FUNC_OP:
bcast_op_handler = BcastOpHandler(node, self.device_mesh, strategies_vector)
bcast_op_handler.register_strategy()
# torch.var_mean
elif target == torch.var_mean:
dim = node.kwargs['dim']
@ -383,9 +393,8 @@ class StrategiesConstructor:
# clear the resharding cost for the output node
# TODO: we may remove this in final version
if True:
for prev_node, resharding_cost_list in resharding_costs.items():
resharding_costs[prev_node] = [0] * len(resharding_cost_list)
for prev_node, resharding_cost_list in resharding_costs.items():
resharding_costs[prev_node] = [0] * len(resharding_cost_list)
sharding_strategy_attribute = ShardingStrategy(name,
output_sharding_spec,

View File

@ -0,0 +1,71 @@
import torch
from torch.fx import GraphModule
import torch.nn as nn
import pytest
from colossalai.auto_parallel.solver.options import SolverOptions
from colossalai.auto_parallel.solver.strategies_constructor import StrategiesConstructor
from colossalai.fx.tracer.tracer import ColoTracer
from colossalai.device.device_mesh import DeviceMesh
class ConvModel(nn.Module):
def __init__(self, c_in, c_out):
super().__init__()
self.conv1 = nn.Conv2d(c_in, c_out, kernel_size=3, padding=1)
self.conv2 = nn.Conv2d(c_in, c_out, kernel_size=3, padding=1, stride=2)
def forward(self, x):
x1 = self.conv1(x)
x2 = x1 + 1
x1 = torch.reshape(x1, [1, -1, 64, 1])
x3 = self.conv2(x1)
x3 = torch.reshape(x3, [4, 1, 64, -1])
x = x1 + x3
return x
def test_conv_handler():
physical_mesh_id = torch.arange(0, 4)
mesh_shape = (2, 2)
# [[0, 1]
# [2, 3]]
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
tracer = ColoTracer()
model = ConvModel(16, 32)
input_sample = {'x': torch.rand(4, 16, 64, 64).to('meta')}
# graph():
# %x : torch.Tensor [#users=1] = placeholder[target=x]
# %conv1 : [#users=2] = call_module[target=conv1](args = (%x,), kwargs = {})
# %add : [#users=0] = call_function[target=operator.add](args = (%conv1, 1), kwargs = {})
# %reshape : [#users=2] = call_function[target=torch.reshape](args = (%conv1, [1, -1, 64, 1]), kwargs = {})
# %conv2 : [#users=1] = call_module[target=conv2](args = (%reshape,), kwargs = {})
# %reshape_1 : [#users=1] = call_function[target=torch.reshape](args = (%conv2, [4, 1, 64, -1]), kwargs = {})
# %add_1 : [#users=1] = call_function[target=operator.add](args = (%reshape, %reshape_1), kwargs = {})
# return add_1
graph = tracer.trace(root=model, meta_args=input_sample)
gm = GraphModule(model, graph, model.__class__.__name__)
# [x, conv1, add, reshape, conv2, reshape_1, add_1, output]
nodes = [node for node in gm.graph.nodes]
solver_options = SolverOptions(fast=True)
strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options)
strategies_constructor.build_strategies_and_cost()
strategy_map = strategies_constructor.strategy_map
# check a tensor add with a scalar case
conv1_strategies = strategy_map[nodes[1]]
add_strategies = strategy_map[nodes[2]]
add_strategies_cover_list = [strategy.input_shardings[0].sharding_sequence for strategy in add_strategies]
for strategy in conv1_strategies:
assert strategy.output_sharding_spec.sharding_sequence in add_strategies_cover_list
# check two tensors element-wise add case
add_1_strategies = strategy_map[nodes[6]]
assert len(add_1_strategies) == 25
if __name__ == '__main__':
test_conv_handler()