mirror of https://github.com/hpcaitech/ColossalAI
[autoparallel] add bcast op handler (#1600)
* [autoparallel] add bcast op handler * polish code * add more BCAST FUNC OP * polish code * add exception handler * polishpull/1606/head
parent
3abf98a633
commit
eac1b79371
|
@ -4,6 +4,10 @@ from torch.fx.node import Node
|
|||
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from typing import Union, Dict, List, Optional
|
||||
import warnings
|
||||
from functools import reduce
|
||||
import functools
|
||||
import operator
|
||||
|
||||
|
||||
def generate_sharding_spec(input_: Union[Node, torch.Tensor], device_mesh: DeviceMesh,
|
||||
|
@ -29,6 +33,11 @@ def generate_sharding_spec(input_: Union[Node, torch.Tensor], device_mesh: Devic
|
|||
raise TypeError(
|
||||
f'We cannot generate sharding spec for {type(input_)} type, only torch.fx.Node or torch.Tensor is expected.'
|
||||
)
|
||||
for dim_index, sharding_index_list in dim_partition_dict.items():
|
||||
sharding_list = [device_mesh.mesh_shape[sharding_index] for sharding_index in sharding_index_list]
|
||||
sharding_size = reduce(operator.mul, sharding_list, 1)
|
||||
assert shape[
|
||||
dim_index] % sharding_size == 0, f'we cannot shard the {dim_index} dimension of tensor into {sharding_size} partitions.'
|
||||
|
||||
sharding_spec = ShardingSpec(device_mesh=device_mesh, entire_shape=shape, dim_partition_dict=dim_partition_dict)
|
||||
return sharding_spec
|
||||
|
@ -74,3 +83,18 @@ def generate_resharding_costs(nodes: List[Node],
|
|||
resharding_cost = total_resharding_cost * size_per_elem_bytes
|
||||
resharding_costs[input_node].append(resharding_cost)
|
||||
return resharding_costs
|
||||
|
||||
|
||||
def exception_handler(func):
|
||||
"""
|
||||
A function wrapper which executes the function with a specified seed.
|
||||
"""
|
||||
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
try:
|
||||
func(*args, **kwargs)
|
||||
except Exception as e:
|
||||
warnings.warn(f'{e}')
|
||||
|
||||
return wrapper
|
||||
|
|
|
@ -3,16 +3,19 @@ import operator
|
|||
|
||||
__all__ = [
|
||||
'ELEMENTWISE_MODULE_OP', 'ELEMENTWISE_FUNC_OP', 'RESHAPE_FUNC_OP', 'CONV_MODULE_OP', 'CONV_FUNC_OP',
|
||||
'LINEAR_MODULE_OP', 'LINEAR_FUNC_OP', 'BATCHNORM_MODULE_OP', 'POOL_MODULE_OP', 'NON_PARAM_FUNC_OP'
|
||||
'LINEAR_MODULE_OP', 'LINEAR_FUNC_OP', 'BATCHNORM_MODULE_OP', 'POOL_MODULE_OP', 'NON_PARAM_FUNC_OP', 'BCAST_FUNC_OP'
|
||||
]
|
||||
|
||||
ELEMENTWISE_MODULE_OP = [torch.nn.Dropout, torch.nn.ReLU]
|
||||
# TODO: flatten should not be added into this group
|
||||
ELEMENTWISE_FUNC_OP = [
|
||||
torch.add, operator.add, torch.abs, torch.cos, torch.exp, torch.mul, operator.mul, operator.floordiv,
|
||||
operator.truediv, operator.neg, torch.multiply, torch.nn.functional.relu, torch.nn.functional.dropout
|
||||
torch.abs, torch.cos, torch.exp, operator.neg, torch.multiply, torch.nn.functional.relu,
|
||||
torch.nn.functional.dropout, torch.flatten
|
||||
]
|
||||
RESHAPE_FUNC_OP = [torch.flatten, torch.Tensor.view, torch.reshape]
|
||||
BCAST_FUNC_OP = [
|
||||
torch.add, torch.sub, torch.mul, torch.div, torch.floor_divide, torch.true_divide, operator.add, operator.sub,
|
||||
operator.mul, operator.floordiv, operator.truediv
|
||||
]
|
||||
CONV_MODULE_OP = [
|
||||
torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d, torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d,
|
||||
torch.nn.ConvTranspose3d
|
||||
|
|
|
@ -3,5 +3,6 @@ from .dot_handler import DotHandler
|
|||
from .conv_handler import ConvHandler
|
||||
from .batch_norm_handler import BatchNormHandler
|
||||
from .reshape_handler import ReshapeHandler
|
||||
from .bcast_op_handler import BcastOpHandler
|
||||
|
||||
__all__ = ['OperatorHandler', 'DotHandler', 'ConvHandler', 'BatchNormHandler', 'ReshapeHandler']
|
||||
__all__ = ['OperatorHandler', 'DotHandler', 'ConvHandler', 'BatchNormHandler', 'ReshapeHandler', 'BcastOpHandler']
|
|
@ -1,9 +1,9 @@
|
|||
import operator
|
||||
from functools import reduce
|
||||
import warnings
|
||||
import torch
|
||||
from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector
|
||||
from .operator_handler import OperatorHandler
|
||||
from colossalai.auto_parallel.solver._utils import exception_handler
|
||||
|
||||
__all__ = ['BatchNormHandler']
|
||||
|
||||
|
@ -110,6 +110,7 @@ class BatchNormHandler(OperatorHandler):
|
|||
|
||||
return memory_cost, memory_cost_forward_activation, memory_cost_backward_activation
|
||||
|
||||
@exception_handler
|
||||
def split_input_channel(self, mesh_dim_0, mesh_dim_1):
|
||||
name = f'RS{mesh_dim_0} = RS{mesh_dim_0} x S{mesh_dim_0}'
|
||||
|
||||
|
@ -184,6 +185,7 @@ class BatchNormHandler(OperatorHandler):
|
|||
|
||||
self.strategies_vector.append(sharding_strategies)
|
||||
|
||||
@exception_handler
|
||||
def split_input_channel_1d(self, mesh_dim_0, mesh_dim_1):
|
||||
name = f'RS{mesh_dim_0}{mesh_dim_1} = RS{mesh_dim_0}{mesh_dim_1} x S{mesh_dim_0}{mesh_dim_1}'
|
||||
|
||||
|
@ -224,6 +226,7 @@ class BatchNormHandler(OperatorHandler):
|
|||
|
||||
self.strategies_vector.append(sharding_strategies)
|
||||
|
||||
@exception_handler
|
||||
def non_split(self, mesh_dim_0, mesh_dim_1):
|
||||
name = f'RR = RR x R'
|
||||
|
||||
|
@ -319,6 +322,7 @@ class BatchNormHandler(OperatorHandler):
|
|||
new_sharding_strategy = _construct_batch_sharding_strategies(mesh_dim_list, new_name)
|
||||
self.strategies_vector.append(new_sharding_strategy)
|
||||
|
||||
@exception_handler
|
||||
def split_input_batch(self, mesh_dim_0):
|
||||
name = f'S{mesh_dim_0}R = S{mesh_dim_0}R x R WITH SYNC_BN'
|
||||
|
||||
|
@ -359,6 +363,7 @@ class BatchNormHandler(OperatorHandler):
|
|||
|
||||
self.strategies_vector.append(sharding_strategies)
|
||||
|
||||
@exception_handler
|
||||
def split_input_batch_1d(self, mesh_dim_0, mesh_dim_1):
|
||||
name = f'S{mesh_dim_0}{mesh_dim_1}R = S{mesh_dim_0}{mesh_dim_1}R x R WITH SYNC_BN'
|
||||
|
||||
|
@ -399,6 +404,7 @@ class BatchNormHandler(OperatorHandler):
|
|||
|
||||
self.strategies_vector.append(sharding_strategies)
|
||||
|
||||
@exception_handler
|
||||
def split_input_both_dim(self, mesh_dim_0, mesh_dim_1):
|
||||
name = f'S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0}S{mesh_dim_1} x S{mesh_dim_1} WITH SYNC_BN'
|
||||
|
||||
|
|
|
@ -0,0 +1,164 @@
|
|||
import operator
|
||||
from functools import reduce
|
||||
import warnings
|
||||
import torch
|
||||
from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector
|
||||
from .operator_handler import OperatorHandler
|
||||
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
|
||||
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||
from copy import deepcopy
|
||||
from typing import Dict, List
|
||||
from colossalai.auto_parallel.solver._utils import exception_handler
|
||||
|
||||
__all__ = ['BcastOpHandler']
|
||||
|
||||
|
||||
class BcastOpHandler(OperatorHandler):
|
||||
"""
|
||||
An OperatorHandler which deals with the sharding strategies of broadcast operators(such as operator.add).
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
assert len(self.predecessor_node) == 2
|
||||
self.lhs_data = self.predecessor_node[0]._meta_data
|
||||
self.rhs_data = self.predecessor_node[1]._meta_data
|
||||
self.lhs = self.predecessor_node[0]
|
||||
self.rhs = self.predecessor_node[1]
|
||||
self.output_data = self.node._meta_data
|
||||
|
||||
def _generate_sharding_spec(self, input_: torch.Tensor, dim_partition_dict: Dict[int, List[int]]) -> ShardingSpec:
|
||||
shape = list(input_.shape)
|
||||
|
||||
# padding the shape to the same length as output_data
|
||||
while len(shape) < self.output_data.dim():
|
||||
shape.insert(0, 1)
|
||||
shape = torch.Size(shape)
|
||||
|
||||
# if the sharding happens on a size one dimension, we should record it as R.
|
||||
processed_dim_partition_dict = deepcopy(dim_partition_dict)
|
||||
for dim_index, _ in dim_partition_dict.items():
|
||||
if shape[dim_index] == 1:
|
||||
processed_dim_partition_dict.pop(dim_index)
|
||||
sharding_spec = ShardingSpec(device_mesh=self.device_mesh,
|
||||
entire_shape=shape,
|
||||
dim_partition_dict=processed_dim_partition_dict)
|
||||
|
||||
return sharding_spec
|
||||
|
||||
def _generate_resharding_costs(self, sharding_specs):
|
||||
# The resharding_cost of weight is counted due to sharing weight cases.
|
||||
dtype = self.node._meta_data.dtype
|
||||
nodes = self.predecessor_node
|
||||
resharding_costs = {}
|
||||
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
|
||||
|
||||
# shape consistency manager is a singleton class
|
||||
shape_consistency_manager = ShapeConsistencyManager()
|
||||
|
||||
for input_node, input_spec in zip(nodes, sharding_specs):
|
||||
resharding_costs[input_node] = []
|
||||
for strategy in input_node.strategies_vector:
|
||||
input_sharding_spec = strategy.output_sharding_spec
|
||||
assert isinstance(input_sharding_spec, ShardingSpec), f'The input node should NOT be a tuple of tensor.'
|
||||
# if the input shape is smaller than the target input, we will fill the input to the same length as target.
|
||||
# Then, use the padded input sharding spec to compute the resharding cost.
|
||||
if len(input_sharding_spec.entire_shape) < len(input_spec.entire_shape):
|
||||
new_entire_shape = list(input_sharding_spec.entire_shape)
|
||||
while len(new_entire_shape) < len(input_spec.entire_shape):
|
||||
new_entire_shape.insert(0, 1)
|
||||
new_entire_shape = torch.Size(new_entire_shape)
|
||||
new_device_mesh = input_sharding_spec.device_mesh
|
||||
new_dim_partition_dict = input_sharding_spec.dim_partition_dict
|
||||
input_sharding_spec = ShardingSpec(device_mesh=new_device_mesh,
|
||||
entire_shape=new_entire_shape,
|
||||
dim_partition_dict=new_dim_partition_dict)
|
||||
|
||||
# compute the resharding cost during forward phase
|
||||
_, _, resharding_cost_forward = shape_consistency_manager.shape_consistency(
|
||||
input_sharding_spec, input_spec)
|
||||
|
||||
_, _, resharding_cost_backward = shape_consistency_manager.shape_consistency(
|
||||
input_spec, input_sharding_spec)
|
||||
total_resharding_cost = resharding_cost_forward + resharding_cost_backward
|
||||
|
||||
# we need multiply the size of elem dtype to get correct communication cost
|
||||
resharding_cost = total_resharding_cost * size_per_elem_bytes
|
||||
resharding_costs[input_node].append(resharding_cost)
|
||||
|
||||
return resharding_costs
|
||||
|
||||
def _enumerate_all_possible_output(self, mesh_dim_0, mesh_dim_1):
|
||||
# use mesh_dim_0, mesh_dim_1 instead of constant 0, 1 in here for N-D device mesh scaliablity.
|
||||
|
||||
output_sharding_spec_list = []
|
||||
output_dim_partition_list = []
|
||||
|
||||
# enumerate all the 2D sharding cases
|
||||
for i in range(self.output_data.dim()):
|
||||
for j in range(i + 1, self.output_data.dim()):
|
||||
dim_partition_dict_0 = {i: [mesh_dim_0], j: [mesh_dim_1]}
|
||||
dim_partition_dict_1 = {i: [mesh_dim_1], j: [mesh_dim_0]}
|
||||
output_dim_partition_list.append(dim_partition_dict_0)
|
||||
output_dim_partition_list.append(dim_partition_dict_1)
|
||||
|
||||
# enumerate all the 1D sharding cases
|
||||
for i in range(self.output_data.dim()):
|
||||
dim_partition_dict_0 = {i: [mesh_dim_0]}
|
||||
dim_partition_dict_1 = {i: [mesh_dim_1]}
|
||||
dim_partition_dict_flatten = {i: [mesh_dim_0, mesh_dim_1]}
|
||||
output_dim_partition_list.append(dim_partition_dict_0)
|
||||
output_dim_partition_list.append(dim_partition_dict_1)
|
||||
output_dim_partition_list.append(dim_partition_dict_flatten)
|
||||
|
||||
# add empty dict for fully replicated case
|
||||
output_dim_partition_list.append({})
|
||||
check_duplicated_list = []
|
||||
for output_dim_partition_dict in output_dim_partition_list:
|
||||
output_sharding_spec = self._generate_sharding_spec(self.output_data, output_dim_partition_dict)
|
||||
sharding_seq = output_sharding_spec.sharding_sequence
|
||||
if sharding_seq not in check_duplicated_list:
|
||||
check_duplicated_list.append(sharding_seq)
|
||||
output_sharding_spec_list.append(output_sharding_spec)
|
||||
|
||||
return output_sharding_spec_list
|
||||
|
||||
def _generate_compute_cost(self, *args, **kwargs):
|
||||
return super()._generate_compute_cost(*args, **kwargs)
|
||||
|
||||
@exception_handler
|
||||
def _register_strategy(self, output_sharding_spec):
|
||||
dim_partition_dict_for_input = output_sharding_spec.dim_partition_dict
|
||||
sharding_spec_for_lhs = self._generate_sharding_spec(self.lhs_data, dim_partition_dict_for_input)
|
||||
sharding_spec_for_rhs = self._generate_sharding_spec(self.rhs_data, dim_partition_dict_for_input)
|
||||
|
||||
name = f'{output_sharding_spec.sharding_sequence} = {sharding_spec_for_lhs.sharding_sequence} x {sharding_spec_for_rhs.sharding_sequence}'
|
||||
dim_partition_dict_for_output = output_sharding_spec.dim_partition_dict
|
||||
|
||||
# generate resharding cost for this strategy
|
||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_lhs, sharding_spec_for_rhs])
|
||||
|
||||
# compute the computation cost of this strategy
|
||||
sharding_dims = []
|
||||
for mesh_dims in dim_partition_dict_for_output.values():
|
||||
for mesh_dim in mesh_dims:
|
||||
sharding_dims.append(self.device_mesh.shape[mesh_dim])
|
||||
sharding_size = reduce(operator.mul, sharding_dims, 1)
|
||||
memory_cost = self.output_data.numel() / sharding_size
|
||||
compute_cost = memory_cost
|
||||
communication_cost = 0
|
||||
|
||||
sharding_strategies = ShardingStrategy(name,
|
||||
output_sharding_spec=output_sharding_spec,
|
||||
compute_cost=compute_cost,
|
||||
communication_cost=communication_cost,
|
||||
memory_cost=memory_cost,
|
||||
resharding_costs=resharding_costs,
|
||||
input_shardings=(sharding_spec_for_lhs, sharding_spec_for_rhs))
|
||||
|
||||
self.strategies_vector.append(sharding_strategies)
|
||||
|
||||
def register_strategy(self) -> StrategiesVector:
|
||||
output_sharding_specs = self._enumerate_all_possible_output(0, 1)
|
||||
for output_sharding_spec in output_sharding_specs:
|
||||
self._register_strategy(output_sharding_spec)
|
|
@ -4,13 +4,14 @@ import warnings
|
|||
import torch
|
||||
from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector
|
||||
from .operator_handler import OperatorHandler
|
||||
from colossalai.auto_parallel.solver._utils import exception_handler
|
||||
|
||||
__all__ = ['ConvHandler']
|
||||
|
||||
|
||||
class ConvHandler(OperatorHandler):
|
||||
"""
|
||||
A OperatorHandler which deals with the sharding strategies of Convolution.
|
||||
An OperatorHandler which deals with the sharding strategies of Convolution.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
|
@ -104,6 +105,7 @@ class ConvHandler(OperatorHandler):
|
|||
|
||||
return memory_cost, memory_cost_forward_activation, memory_cost_backward_activation, memory_cost_backward_weight
|
||||
|
||||
@exception_handler
|
||||
def split_input_batch_weight_out_channel(self, mesh_dim_0, mesh_dim_1):
|
||||
name = f'S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0}R x RS{mesh_dim_1}'
|
||||
|
||||
|
@ -151,6 +153,7 @@ class ConvHandler(OperatorHandler):
|
|||
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
|
||||
self.strategies_vector.append(sharding_strategies)
|
||||
|
||||
@exception_handler
|
||||
def split_input_batch(self, mesh_dim_0):
|
||||
name = f'S{mesh_dim_0}R = S{mesh_dim_0}R x RR'
|
||||
|
||||
|
@ -196,6 +199,7 @@ class ConvHandler(OperatorHandler):
|
|||
|
||||
self.strategies_vector.append(sharding_strategies)
|
||||
|
||||
@exception_handler
|
||||
def split_input_both_dim_weight_in_channel(self, mesh_dim_0, mesh_dim_1):
|
||||
name = f'S{mesh_dim_0}R = S{mesh_dim_0}S{mesh_dim_1} x S{mesh_dim_1}R'
|
||||
|
||||
|
@ -241,6 +245,7 @@ class ConvHandler(OperatorHandler):
|
|||
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
|
||||
self.strategies_vector.append(sharding_strategies)
|
||||
|
||||
@exception_handler
|
||||
def split_input_in_channel_weight_both_channel(self, mesh_dim_0, mesh_dim_1):
|
||||
name = f'RS{mesh_dim_1} = RS{mesh_dim_0} x S{mesh_dim_0}S{mesh_dim_1}'
|
||||
|
||||
|
@ -283,6 +288,7 @@ class ConvHandler(OperatorHandler):
|
|||
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
|
||||
self.strategies_vector.append(sharding_strategies)
|
||||
|
||||
@exception_handler
|
||||
def split_input_in_channel_weight_in_channel(self, mesh_dim_0):
|
||||
name = f'RR = RS{mesh_dim_0} x S{mesh_dim_0}R'
|
||||
|
||||
|
@ -325,6 +331,7 @@ class ConvHandler(OperatorHandler):
|
|||
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
|
||||
self.strategies_vector.append(sharding_strategies)
|
||||
|
||||
@exception_handler
|
||||
def split_weight_out_channel(self, mesh_dim_0):
|
||||
name = f'RS{mesh_dim_0} = RR x RS{mesh_dim_0}'
|
||||
|
||||
|
@ -367,6 +374,7 @@ class ConvHandler(OperatorHandler):
|
|||
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
|
||||
self.strategies_vector.append(sharding_strategies)
|
||||
|
||||
@exception_handler
|
||||
def non_split(self):
|
||||
name = f'RR = RR x RR'
|
||||
|
||||
|
@ -407,6 +415,7 @@ class ConvHandler(OperatorHandler):
|
|||
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
|
||||
self.strategies_vector.append(sharding_strategies)
|
||||
|
||||
@exception_handler
|
||||
def split_1d_parallel_on_input_batch(self, mesh_dim_0, mesh_dim_1):
|
||||
name = f'S{mesh_dim_0}{mesh_dim_1}R = S{mesh_dim_0}{mesh_dim_1}R x RR'
|
||||
|
||||
|
@ -454,6 +463,7 @@ class ConvHandler(OperatorHandler):
|
|||
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
|
||||
self.strategies_vector.append(sharding_strategies)
|
||||
|
||||
@exception_handler
|
||||
def split_1d_parallel_on_in_channel(self, mesh_dim_0, mesh_dim_1):
|
||||
name = f'RR = RS{mesh_dim_0}{mesh_dim_1} x S{mesh_dim_0}{mesh_dim_1}R'
|
||||
|
||||
|
@ -554,48 +564,24 @@ class ConvHandler(OperatorHandler):
|
|||
RR = RS01 x S01R: compute_cost is 8856576, communication_cost is 0, memory_cost is 1968128, resharding_costs is {mul: [0, 0, 262148.4, 65538.002, 196614.402, 262148.4, 65538.2]}
|
||||
'''
|
||||
# SS = SR x RS
|
||||
try:
|
||||
self.split_input_batch_weight_out_channel(0, 1)
|
||||
except Exception as e:
|
||||
warnings.warn(f'{e}')
|
||||
try:
|
||||
self.split_input_batch_weight_out_channel(1, 0)
|
||||
except Exception as e:
|
||||
warnings.warn(f'{e}')
|
||||
self.split_input_batch_weight_out_channel(0, 1)
|
||||
self.split_input_batch_weight_out_channel(1, 0)
|
||||
|
||||
# SR = SR x RR
|
||||
self.split_input_batch(0)
|
||||
self.split_input_batch(1)
|
||||
|
||||
# SR = SS x SR
|
||||
try:
|
||||
self.split_input_both_dim_weight_in_channel(0, 1)
|
||||
except Exception as e:
|
||||
warnings.warn(f'{e}')
|
||||
try:
|
||||
self.split_input_both_dim_weight_in_channel(1, 0)
|
||||
except Exception as e:
|
||||
warnings.warn(f'{e}')
|
||||
self.split_input_both_dim_weight_in_channel(0, 1)
|
||||
self.split_input_both_dim_weight_in_channel(1, 0)
|
||||
|
||||
# RS = RS x SS
|
||||
try:
|
||||
self.split_input_in_channel_weight_both_channel(0, 1)
|
||||
except Exception as e:
|
||||
warnings.warn(f'{e}')
|
||||
try:
|
||||
self.split_input_in_channel_weight_both_channel(1, 0)
|
||||
except Exception as e:
|
||||
warnings.warn(f'{e}')
|
||||
self.split_input_in_channel_weight_both_channel(0, 1)
|
||||
self.split_input_in_channel_weight_both_channel(1, 0)
|
||||
|
||||
# RR = RS x SR
|
||||
try:
|
||||
self.split_input_in_channel_weight_in_channel(0)
|
||||
except Exception as e:
|
||||
warnings.warn(f'{e}')
|
||||
try:
|
||||
self.split_input_in_channel_weight_in_channel(1)
|
||||
except Exception as e:
|
||||
warnings.warn(f'{e}')
|
||||
self.split_input_in_channel_weight_in_channel(0)
|
||||
self.split_input_in_channel_weight_in_channel(1)
|
||||
|
||||
# RS = RR x RS
|
||||
self.split_weight_out_channel(0)
|
||||
|
@ -608,12 +594,7 @@ class ConvHandler(OperatorHandler):
|
|||
self.split_1d_parallel_on_input_batch(0, 1)
|
||||
|
||||
# RR = RS01 x S01R
|
||||
try:
|
||||
self.split_1d_parallel_on_in_channel(0, 1)
|
||||
except Exception as e:
|
||||
warnings.warn(f'{e}')
|
||||
|
||||
# print(f'strategies num is :{len(self.strategies_vector)}')
|
||||
self.split_1d_parallel_on_in_channel(0, 1)
|
||||
|
||||
return self.strategies_vector
|
||||
|
||||
|
|
|
@ -6,10 +6,12 @@ from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy,
|
|||
from .operator_handler import OperatorHandler
|
||||
from ..constants import LINEAR_FUNC_OP, LINEAR_MODULE_OP
|
||||
from functools import reduce
|
||||
from colossalai.auto_parallel.solver._utils import exception_handler
|
||||
from enum import Enum
|
||||
from .strategy_generator import StrategyGenerator, IntermediateStrategy
|
||||
from typing import List
|
||||
|
||||
|
||||
__all__ = ['DotHandler']
|
||||
|
||||
|
||||
|
@ -414,6 +416,7 @@ class DotHandler(OperatorHandler):
|
|||
compute_cost = reduce(operator.mul, input_shape) * weight_shape[0] * 2
|
||||
return compute_cost
|
||||
|
||||
@exception_handler
|
||||
def split_lhs_space_rhs_space(self, mesh_dim_0, mesh_dim_1):
|
||||
# handle case SS = SR x RS
|
||||
name = f'S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0}R x RS{mesh_dim_1}'
|
||||
|
@ -452,6 +455,7 @@ class DotHandler(OperatorHandler):
|
|||
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
|
||||
self.strategies_vector.append(sharding_strategies)
|
||||
|
||||
@exception_handler
|
||||
def split_lhs_space_both_contract(self, mesh_dim_0, mesh_dim_1):
|
||||
# handle the case SR = SS x SR
|
||||
name = f'S{mesh_dim_0}R = S{mesh_dim_0}S{mesh_dim_1} x S{mesh_dim_1}R'
|
||||
|
@ -488,6 +492,7 @@ class DotHandler(OperatorHandler):
|
|||
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
|
||||
self.strategies_vector.append(sharding_strategies)
|
||||
|
||||
@exception_handler
|
||||
def split_rhs_space_both_contract(self, mesh_dim_0, mesh_dim_1):
|
||||
name = f'RS{mesh_dim_1} = RS{mesh_dim_0} x S{mesh_dim_0}S{mesh_dim_1}'
|
||||
|
||||
|
@ -521,6 +526,7 @@ class DotHandler(OperatorHandler):
|
|||
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
|
||||
self.strategies_vector.append(sharding_strategies)
|
||||
|
||||
@exception_handler
|
||||
def recompute_split_both_contract(self, mesh_dim):
|
||||
name = f'RR = RS{mesh_dim} x S{mesh_dim}R'
|
||||
|
||||
|
@ -554,6 +560,7 @@ class DotHandler(OperatorHandler):
|
|||
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
|
||||
self.strategies_vector.append(sharding_strategies)
|
||||
|
||||
@exception_handler
|
||||
def split_rhs_space_only(self, mesh_dim):
|
||||
name = f'RS{mesh_dim} = RR x RS{mesh_dim}'
|
||||
|
||||
|
@ -587,6 +594,7 @@ class DotHandler(OperatorHandler):
|
|||
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
|
||||
self.strategies_vector.append(sharding_strategies)
|
||||
|
||||
@exception_handler
|
||||
def split_lhs_1st_dim_1d(self, mesh_dim_0, mesh_dim_1):
|
||||
name = f'S{mesh_dim_0}{mesh_dim_1}R = S{mesh_dim_0}{mesh_dim_1}R x RR'
|
||||
|
||||
|
@ -620,6 +628,7 @@ class DotHandler(OperatorHandler):
|
|||
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
|
||||
self.strategies_vector.append(sharding_strategies)
|
||||
|
||||
@exception_handler
|
||||
def split_lhs_2nd_dim_1d(self, mesh_dim_0, mesh_dim_1):
|
||||
name = f'RR = RS{mesh_dim_0}{mesh_dim_1} x S{mesh_dim_0}{mesh_dim_1}R'
|
||||
|
||||
|
@ -653,6 +662,7 @@ class DotHandler(OperatorHandler):
|
|||
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
|
||||
self.strategies_vector.append(sharding_strategies)
|
||||
|
||||
@exception_handler
|
||||
def split_rhs_2nd_dim_1d(self, mesh_dim_0, mesh_dim_1):
|
||||
name = f'RS{mesh_dim_0}{mesh_dim_1} = RR x RS{mesh_dim_0}{mesh_dim_1}'
|
||||
|
||||
|
|
|
@ -20,6 +20,7 @@ class ReshapeHandler(OperatorHandler):
|
|||
return super()._generate_compute_cost(*args, **kwargs)
|
||||
|
||||
def register_strategy(self):
|
||||
# TODO: add strategies with more output sharding specs other than only fully replicated.
|
||||
input_node = self.strategies_vector.predecessor_nodes[0]
|
||||
# For reshape function, to keep the computing correctness we keep the sharding
|
||||
# spec of input is fully replicated. In addition, we will keep the output in
|
||||
|
|
|
@ -70,6 +70,9 @@ class StrategiesVector(list):
|
|||
# we could merge element-wise op, because the output sharding spec is always same as the input sharding spec.
|
||||
if self.node.target in ELEMENTWISE_FUNC_OP:
|
||||
merge_label = True
|
||||
# we could merge bcast op if the rhs is a scalar, because it will fall back to the element-wise case.
|
||||
if self.node.target in BCAST_FUNC_OP and len(self.predecessor_nodes) == 1:
|
||||
merge_label = True
|
||||
# we could merge reshape op, because the output sharding spec of reshape op is always fully replicated.
|
||||
if self.node.target in RESHAPE_FUNC_OP:
|
||||
merge_label = True
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
from torch.fx import Graph, Node
|
||||
from colossalai.auto_parallel.solver.op_handler.bcast_op_handler import BcastOpHandler
|
||||
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
|
||||
|
@ -52,6 +53,7 @@ class StrategiesConstructor:
|
|||
def build_strategies_and_cost(self):
|
||||
for node in self.nodes:
|
||||
strategies_vector = StrategiesVector(node)
|
||||
input_nodes_len = len(strategies_vector.predecessor_nodes)
|
||||
# placeholder node
|
||||
if node.op == 'placeholder':
|
||||
# For placeholder nodes, if solver_options.fast is True, we just let them in
|
||||
|
@ -165,6 +167,9 @@ class StrategiesConstructor:
|
|||
|
||||
# MaxPool module
|
||||
elif submod_type in POOL_MODULE_OP:
|
||||
# TODO: add sharding constraints on image dimension
|
||||
# e.g.: for a 2D pooling input NCHW, we should promise no sharding happens on H and W dimension
|
||||
|
||||
# create sharding strategy for element-wise module
|
||||
assert len(strategies_vector.predecessor_nodes
|
||||
) == 1, f'Temporally, we just support single input element-wise op.'
|
||||
|
@ -230,7 +235,7 @@ class StrategiesConstructor:
|
|||
reshape_handler.register_strategy()
|
||||
|
||||
# element-wise function
|
||||
elif target in ELEMENTWISE_FUNC_OP:
|
||||
elif target in ELEMENTWISE_FUNC_OP or (target in BCAST_FUNC_OP and input_nodes_len == 1):
|
||||
# TODO: integrate element-wise func and module together
|
||||
# create sharding strategy for element-wise function
|
||||
assert len(strategies_vector.predecessor_nodes
|
||||
|
@ -271,6 +276,11 @@ class StrategiesConstructor:
|
|||
input_shardings=[input_sharding_spec])
|
||||
strategies_vector.append(sharding_strategy)
|
||||
|
||||
# bcast op
|
||||
elif target in BCAST_FUNC_OP:
|
||||
bcast_op_handler = BcastOpHandler(node, self.device_mesh, strategies_vector)
|
||||
bcast_op_handler.register_strategy()
|
||||
|
||||
# torch.var_mean
|
||||
elif target == torch.var_mean:
|
||||
dim = node.kwargs['dim']
|
||||
|
@ -383,9 +393,8 @@ class StrategiesConstructor:
|
|||
|
||||
# clear the resharding cost for the output node
|
||||
# TODO: we may remove this in final version
|
||||
if True:
|
||||
for prev_node, resharding_cost_list in resharding_costs.items():
|
||||
resharding_costs[prev_node] = [0] * len(resharding_cost_list)
|
||||
for prev_node, resharding_cost_list in resharding_costs.items():
|
||||
resharding_costs[prev_node] = [0] * len(resharding_cost_list)
|
||||
|
||||
sharding_strategy_attribute = ShardingStrategy(name,
|
||||
output_sharding_spec,
|
||||
|
|
|
@ -0,0 +1,71 @@
|
|||
import torch
|
||||
from torch.fx import GraphModule
|
||||
import torch.nn as nn
|
||||
import pytest
|
||||
|
||||
from colossalai.auto_parallel.solver.options import SolverOptions
|
||||
from colossalai.auto_parallel.solver.strategies_constructor import StrategiesConstructor
|
||||
from colossalai.fx.tracer.tracer import ColoTracer
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
|
||||
|
||||
class ConvModel(nn.Module):
|
||||
|
||||
def __init__(self, c_in, c_out):
|
||||
super().__init__()
|
||||
self.conv1 = nn.Conv2d(c_in, c_out, kernel_size=3, padding=1)
|
||||
self.conv2 = nn.Conv2d(c_in, c_out, kernel_size=3, padding=1, stride=2)
|
||||
|
||||
def forward(self, x):
|
||||
x1 = self.conv1(x)
|
||||
x2 = x1 + 1
|
||||
x1 = torch.reshape(x1, [1, -1, 64, 1])
|
||||
x3 = self.conv2(x1)
|
||||
x3 = torch.reshape(x3, [4, 1, 64, -1])
|
||||
x = x1 + x3
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def test_conv_handler():
|
||||
physical_mesh_id = torch.arange(0, 4)
|
||||
mesh_shape = (2, 2)
|
||||
# [[0, 1]
|
||||
# [2, 3]]
|
||||
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
|
||||
|
||||
tracer = ColoTracer()
|
||||
model = ConvModel(16, 32)
|
||||
input_sample = {'x': torch.rand(4, 16, 64, 64).to('meta')}
|
||||
# graph():
|
||||
# %x : torch.Tensor [#users=1] = placeholder[target=x]
|
||||
# %conv1 : [#users=2] = call_module[target=conv1](args = (%x,), kwargs = {})
|
||||
# %add : [#users=0] = call_function[target=operator.add](args = (%conv1, 1), kwargs = {})
|
||||
# %reshape : [#users=2] = call_function[target=torch.reshape](args = (%conv1, [1, -1, 64, 1]), kwargs = {})
|
||||
# %conv2 : [#users=1] = call_module[target=conv2](args = (%reshape,), kwargs = {})
|
||||
# %reshape_1 : [#users=1] = call_function[target=torch.reshape](args = (%conv2, [4, 1, 64, -1]), kwargs = {})
|
||||
# %add_1 : [#users=1] = call_function[target=operator.add](args = (%reshape, %reshape_1), kwargs = {})
|
||||
# return add_1
|
||||
graph = tracer.trace(root=model, meta_args=input_sample)
|
||||
gm = GraphModule(model, graph, model.__class__.__name__)
|
||||
# [x, conv1, add, reshape, conv2, reshape_1, add_1, output]
|
||||
nodes = [node for node in gm.graph.nodes]
|
||||
solver_options = SolverOptions(fast=True)
|
||||
strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options)
|
||||
|
||||
strategies_constructor.build_strategies_and_cost()
|
||||
strategy_map = strategies_constructor.strategy_map
|
||||
# check a tensor add with a scalar case
|
||||
conv1_strategies = strategy_map[nodes[1]]
|
||||
add_strategies = strategy_map[nodes[2]]
|
||||
add_strategies_cover_list = [strategy.input_shardings[0].sharding_sequence for strategy in add_strategies]
|
||||
for strategy in conv1_strategies:
|
||||
assert strategy.output_sharding_spec.sharding_sequence in add_strategies_cover_list
|
||||
|
||||
# check two tensors element-wise add case
|
||||
add_1_strategies = strategy_map[nodes[6]]
|
||||
assert len(add_1_strategies) == 25
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_conv_handler()
|
Loading…
Reference in New Issue