mirror of https://github.com/hpcaitech/ColossalAI
[autoparallel] add bcast op handler (#1600)
* [autoparallel] add bcast op handler * polish code * add more BCAST FUNC OP * polish code * add exception handler * polishpull/1606/head
parent
3abf98a633
commit
eac1b79371
|
@ -4,6 +4,10 @@ from torch.fx.node import Node
|
||||||
from colossalai.tensor.sharding_spec import ShardingSpec
|
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||||
from colossalai.device.device_mesh import DeviceMesh
|
from colossalai.device.device_mesh import DeviceMesh
|
||||||
from typing import Union, Dict, List, Optional
|
from typing import Union, Dict, List, Optional
|
||||||
|
import warnings
|
||||||
|
from functools import reduce
|
||||||
|
import functools
|
||||||
|
import operator
|
||||||
|
|
||||||
|
|
||||||
def generate_sharding_spec(input_: Union[Node, torch.Tensor], device_mesh: DeviceMesh,
|
def generate_sharding_spec(input_: Union[Node, torch.Tensor], device_mesh: DeviceMesh,
|
||||||
|
@ -29,6 +33,11 @@ def generate_sharding_spec(input_: Union[Node, torch.Tensor], device_mesh: Devic
|
||||||
raise TypeError(
|
raise TypeError(
|
||||||
f'We cannot generate sharding spec for {type(input_)} type, only torch.fx.Node or torch.Tensor is expected.'
|
f'We cannot generate sharding spec for {type(input_)} type, only torch.fx.Node or torch.Tensor is expected.'
|
||||||
)
|
)
|
||||||
|
for dim_index, sharding_index_list in dim_partition_dict.items():
|
||||||
|
sharding_list = [device_mesh.mesh_shape[sharding_index] for sharding_index in sharding_index_list]
|
||||||
|
sharding_size = reduce(operator.mul, sharding_list, 1)
|
||||||
|
assert shape[
|
||||||
|
dim_index] % sharding_size == 0, f'we cannot shard the {dim_index} dimension of tensor into {sharding_size} partitions.'
|
||||||
|
|
||||||
sharding_spec = ShardingSpec(device_mesh=device_mesh, entire_shape=shape, dim_partition_dict=dim_partition_dict)
|
sharding_spec = ShardingSpec(device_mesh=device_mesh, entire_shape=shape, dim_partition_dict=dim_partition_dict)
|
||||||
return sharding_spec
|
return sharding_spec
|
||||||
|
@ -74,3 +83,18 @@ def generate_resharding_costs(nodes: List[Node],
|
||||||
resharding_cost = total_resharding_cost * size_per_elem_bytes
|
resharding_cost = total_resharding_cost * size_per_elem_bytes
|
||||||
resharding_costs[input_node].append(resharding_cost)
|
resharding_costs[input_node].append(resharding_cost)
|
||||||
return resharding_costs
|
return resharding_costs
|
||||||
|
|
||||||
|
|
||||||
|
def exception_handler(func):
|
||||||
|
"""
|
||||||
|
A function wrapper which executes the function with a specified seed.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@functools.wraps(func)
|
||||||
|
def wrapper(*args, **kwargs):
|
||||||
|
try:
|
||||||
|
func(*args, **kwargs)
|
||||||
|
except Exception as e:
|
||||||
|
warnings.warn(f'{e}')
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
|
@ -3,16 +3,19 @@ import operator
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'ELEMENTWISE_MODULE_OP', 'ELEMENTWISE_FUNC_OP', 'RESHAPE_FUNC_OP', 'CONV_MODULE_OP', 'CONV_FUNC_OP',
|
'ELEMENTWISE_MODULE_OP', 'ELEMENTWISE_FUNC_OP', 'RESHAPE_FUNC_OP', 'CONV_MODULE_OP', 'CONV_FUNC_OP',
|
||||||
'LINEAR_MODULE_OP', 'LINEAR_FUNC_OP', 'BATCHNORM_MODULE_OP', 'POOL_MODULE_OP', 'NON_PARAM_FUNC_OP'
|
'LINEAR_MODULE_OP', 'LINEAR_FUNC_OP', 'BATCHNORM_MODULE_OP', 'POOL_MODULE_OP', 'NON_PARAM_FUNC_OP', 'BCAST_FUNC_OP'
|
||||||
]
|
]
|
||||||
|
|
||||||
ELEMENTWISE_MODULE_OP = [torch.nn.Dropout, torch.nn.ReLU]
|
ELEMENTWISE_MODULE_OP = [torch.nn.Dropout, torch.nn.ReLU]
|
||||||
# TODO: flatten should not be added into this group
|
|
||||||
ELEMENTWISE_FUNC_OP = [
|
ELEMENTWISE_FUNC_OP = [
|
||||||
torch.add, operator.add, torch.abs, torch.cos, torch.exp, torch.mul, operator.mul, operator.floordiv,
|
torch.abs, torch.cos, torch.exp, operator.neg, torch.multiply, torch.nn.functional.relu,
|
||||||
operator.truediv, operator.neg, torch.multiply, torch.nn.functional.relu, torch.nn.functional.dropout
|
torch.nn.functional.dropout, torch.flatten
|
||||||
]
|
]
|
||||||
RESHAPE_FUNC_OP = [torch.flatten, torch.Tensor.view, torch.reshape]
|
RESHAPE_FUNC_OP = [torch.flatten, torch.Tensor.view, torch.reshape]
|
||||||
|
BCAST_FUNC_OP = [
|
||||||
|
torch.add, torch.sub, torch.mul, torch.div, torch.floor_divide, torch.true_divide, operator.add, operator.sub,
|
||||||
|
operator.mul, operator.floordiv, operator.truediv
|
||||||
|
]
|
||||||
CONV_MODULE_OP = [
|
CONV_MODULE_OP = [
|
||||||
torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d, torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d,
|
torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d, torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d,
|
||||||
torch.nn.ConvTranspose3d
|
torch.nn.ConvTranspose3d
|
||||||
|
|
|
@ -3,5 +3,6 @@ from .dot_handler import DotHandler
|
||||||
from .conv_handler import ConvHandler
|
from .conv_handler import ConvHandler
|
||||||
from .batch_norm_handler import BatchNormHandler
|
from .batch_norm_handler import BatchNormHandler
|
||||||
from .reshape_handler import ReshapeHandler
|
from .reshape_handler import ReshapeHandler
|
||||||
|
from .bcast_op_handler import BcastOpHandler
|
||||||
|
|
||||||
__all__ = ['OperatorHandler', 'DotHandler', 'ConvHandler', 'BatchNormHandler', 'ReshapeHandler']
|
__all__ = ['OperatorHandler', 'DotHandler', 'ConvHandler', 'BatchNormHandler', 'ReshapeHandler', 'BcastOpHandler']
|
|
@ -1,9 +1,9 @@
|
||||||
import operator
|
import operator
|
||||||
from functools import reduce
|
from functools import reduce
|
||||||
import warnings
|
|
||||||
import torch
|
import torch
|
||||||
from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector
|
from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector
|
||||||
from .operator_handler import OperatorHandler
|
from .operator_handler import OperatorHandler
|
||||||
|
from colossalai.auto_parallel.solver._utils import exception_handler
|
||||||
|
|
||||||
__all__ = ['BatchNormHandler']
|
__all__ = ['BatchNormHandler']
|
||||||
|
|
||||||
|
@ -110,6 +110,7 @@ class BatchNormHandler(OperatorHandler):
|
||||||
|
|
||||||
return memory_cost, memory_cost_forward_activation, memory_cost_backward_activation
|
return memory_cost, memory_cost_forward_activation, memory_cost_backward_activation
|
||||||
|
|
||||||
|
@exception_handler
|
||||||
def split_input_channel(self, mesh_dim_0, mesh_dim_1):
|
def split_input_channel(self, mesh_dim_0, mesh_dim_1):
|
||||||
name = f'RS{mesh_dim_0} = RS{mesh_dim_0} x S{mesh_dim_0}'
|
name = f'RS{mesh_dim_0} = RS{mesh_dim_0} x S{mesh_dim_0}'
|
||||||
|
|
||||||
|
@ -184,6 +185,7 @@ class BatchNormHandler(OperatorHandler):
|
||||||
|
|
||||||
self.strategies_vector.append(sharding_strategies)
|
self.strategies_vector.append(sharding_strategies)
|
||||||
|
|
||||||
|
@exception_handler
|
||||||
def split_input_channel_1d(self, mesh_dim_0, mesh_dim_1):
|
def split_input_channel_1d(self, mesh_dim_0, mesh_dim_1):
|
||||||
name = f'RS{mesh_dim_0}{mesh_dim_1} = RS{mesh_dim_0}{mesh_dim_1} x S{mesh_dim_0}{mesh_dim_1}'
|
name = f'RS{mesh_dim_0}{mesh_dim_1} = RS{mesh_dim_0}{mesh_dim_1} x S{mesh_dim_0}{mesh_dim_1}'
|
||||||
|
|
||||||
|
@ -224,6 +226,7 @@ class BatchNormHandler(OperatorHandler):
|
||||||
|
|
||||||
self.strategies_vector.append(sharding_strategies)
|
self.strategies_vector.append(sharding_strategies)
|
||||||
|
|
||||||
|
@exception_handler
|
||||||
def non_split(self, mesh_dim_0, mesh_dim_1):
|
def non_split(self, mesh_dim_0, mesh_dim_1):
|
||||||
name = f'RR = RR x R'
|
name = f'RR = RR x R'
|
||||||
|
|
||||||
|
@ -319,6 +322,7 @@ class BatchNormHandler(OperatorHandler):
|
||||||
new_sharding_strategy = _construct_batch_sharding_strategies(mesh_dim_list, new_name)
|
new_sharding_strategy = _construct_batch_sharding_strategies(mesh_dim_list, new_name)
|
||||||
self.strategies_vector.append(new_sharding_strategy)
|
self.strategies_vector.append(new_sharding_strategy)
|
||||||
|
|
||||||
|
@exception_handler
|
||||||
def split_input_batch(self, mesh_dim_0):
|
def split_input_batch(self, mesh_dim_0):
|
||||||
name = f'S{mesh_dim_0}R = S{mesh_dim_0}R x R WITH SYNC_BN'
|
name = f'S{mesh_dim_0}R = S{mesh_dim_0}R x R WITH SYNC_BN'
|
||||||
|
|
||||||
|
@ -359,6 +363,7 @@ class BatchNormHandler(OperatorHandler):
|
||||||
|
|
||||||
self.strategies_vector.append(sharding_strategies)
|
self.strategies_vector.append(sharding_strategies)
|
||||||
|
|
||||||
|
@exception_handler
|
||||||
def split_input_batch_1d(self, mesh_dim_0, mesh_dim_1):
|
def split_input_batch_1d(self, mesh_dim_0, mesh_dim_1):
|
||||||
name = f'S{mesh_dim_0}{mesh_dim_1}R = S{mesh_dim_0}{mesh_dim_1}R x R WITH SYNC_BN'
|
name = f'S{mesh_dim_0}{mesh_dim_1}R = S{mesh_dim_0}{mesh_dim_1}R x R WITH SYNC_BN'
|
||||||
|
|
||||||
|
@ -399,6 +404,7 @@ class BatchNormHandler(OperatorHandler):
|
||||||
|
|
||||||
self.strategies_vector.append(sharding_strategies)
|
self.strategies_vector.append(sharding_strategies)
|
||||||
|
|
||||||
|
@exception_handler
|
||||||
def split_input_both_dim(self, mesh_dim_0, mesh_dim_1):
|
def split_input_both_dim(self, mesh_dim_0, mesh_dim_1):
|
||||||
name = f'S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0}S{mesh_dim_1} x S{mesh_dim_1} WITH SYNC_BN'
|
name = f'S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0}S{mesh_dim_1} x S{mesh_dim_1} WITH SYNC_BN'
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,164 @@
|
||||||
|
import operator
|
||||||
|
from functools import reduce
|
||||||
|
import warnings
|
||||||
|
import torch
|
||||||
|
from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector
|
||||||
|
from .operator_handler import OperatorHandler
|
||||||
|
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
|
||||||
|
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||||
|
from copy import deepcopy
|
||||||
|
from typing import Dict, List
|
||||||
|
from colossalai.auto_parallel.solver._utils import exception_handler
|
||||||
|
|
||||||
|
__all__ = ['BcastOpHandler']
|
||||||
|
|
||||||
|
|
||||||
|
class BcastOpHandler(OperatorHandler):
|
||||||
|
"""
|
||||||
|
An OperatorHandler which deals with the sharding strategies of broadcast operators(such as operator.add).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
assert len(self.predecessor_node) == 2
|
||||||
|
self.lhs_data = self.predecessor_node[0]._meta_data
|
||||||
|
self.rhs_data = self.predecessor_node[1]._meta_data
|
||||||
|
self.lhs = self.predecessor_node[0]
|
||||||
|
self.rhs = self.predecessor_node[1]
|
||||||
|
self.output_data = self.node._meta_data
|
||||||
|
|
||||||
|
def _generate_sharding_spec(self, input_: torch.Tensor, dim_partition_dict: Dict[int, List[int]]) -> ShardingSpec:
|
||||||
|
shape = list(input_.shape)
|
||||||
|
|
||||||
|
# padding the shape to the same length as output_data
|
||||||
|
while len(shape) < self.output_data.dim():
|
||||||
|
shape.insert(0, 1)
|
||||||
|
shape = torch.Size(shape)
|
||||||
|
|
||||||
|
# if the sharding happens on a size one dimension, we should record it as R.
|
||||||
|
processed_dim_partition_dict = deepcopy(dim_partition_dict)
|
||||||
|
for dim_index, _ in dim_partition_dict.items():
|
||||||
|
if shape[dim_index] == 1:
|
||||||
|
processed_dim_partition_dict.pop(dim_index)
|
||||||
|
sharding_spec = ShardingSpec(device_mesh=self.device_mesh,
|
||||||
|
entire_shape=shape,
|
||||||
|
dim_partition_dict=processed_dim_partition_dict)
|
||||||
|
|
||||||
|
return sharding_spec
|
||||||
|
|
||||||
|
def _generate_resharding_costs(self, sharding_specs):
|
||||||
|
# The resharding_cost of weight is counted due to sharing weight cases.
|
||||||
|
dtype = self.node._meta_data.dtype
|
||||||
|
nodes = self.predecessor_node
|
||||||
|
resharding_costs = {}
|
||||||
|
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
|
||||||
|
|
||||||
|
# shape consistency manager is a singleton class
|
||||||
|
shape_consistency_manager = ShapeConsistencyManager()
|
||||||
|
|
||||||
|
for input_node, input_spec in zip(nodes, sharding_specs):
|
||||||
|
resharding_costs[input_node] = []
|
||||||
|
for strategy in input_node.strategies_vector:
|
||||||
|
input_sharding_spec = strategy.output_sharding_spec
|
||||||
|
assert isinstance(input_sharding_spec, ShardingSpec), f'The input node should NOT be a tuple of tensor.'
|
||||||
|
# if the input shape is smaller than the target input, we will fill the input to the same length as target.
|
||||||
|
# Then, use the padded input sharding spec to compute the resharding cost.
|
||||||
|
if len(input_sharding_spec.entire_shape) < len(input_spec.entire_shape):
|
||||||
|
new_entire_shape = list(input_sharding_spec.entire_shape)
|
||||||
|
while len(new_entire_shape) < len(input_spec.entire_shape):
|
||||||
|
new_entire_shape.insert(0, 1)
|
||||||
|
new_entire_shape = torch.Size(new_entire_shape)
|
||||||
|
new_device_mesh = input_sharding_spec.device_mesh
|
||||||
|
new_dim_partition_dict = input_sharding_spec.dim_partition_dict
|
||||||
|
input_sharding_spec = ShardingSpec(device_mesh=new_device_mesh,
|
||||||
|
entire_shape=new_entire_shape,
|
||||||
|
dim_partition_dict=new_dim_partition_dict)
|
||||||
|
|
||||||
|
# compute the resharding cost during forward phase
|
||||||
|
_, _, resharding_cost_forward = shape_consistency_manager.shape_consistency(
|
||||||
|
input_sharding_spec, input_spec)
|
||||||
|
|
||||||
|
_, _, resharding_cost_backward = shape_consistency_manager.shape_consistency(
|
||||||
|
input_spec, input_sharding_spec)
|
||||||
|
total_resharding_cost = resharding_cost_forward + resharding_cost_backward
|
||||||
|
|
||||||
|
# we need multiply the size of elem dtype to get correct communication cost
|
||||||
|
resharding_cost = total_resharding_cost * size_per_elem_bytes
|
||||||
|
resharding_costs[input_node].append(resharding_cost)
|
||||||
|
|
||||||
|
return resharding_costs
|
||||||
|
|
||||||
|
def _enumerate_all_possible_output(self, mesh_dim_0, mesh_dim_1):
|
||||||
|
# use mesh_dim_0, mesh_dim_1 instead of constant 0, 1 in here for N-D device mesh scaliablity.
|
||||||
|
|
||||||
|
output_sharding_spec_list = []
|
||||||
|
output_dim_partition_list = []
|
||||||
|
|
||||||
|
# enumerate all the 2D sharding cases
|
||||||
|
for i in range(self.output_data.dim()):
|
||||||
|
for j in range(i + 1, self.output_data.dim()):
|
||||||
|
dim_partition_dict_0 = {i: [mesh_dim_0], j: [mesh_dim_1]}
|
||||||
|
dim_partition_dict_1 = {i: [mesh_dim_1], j: [mesh_dim_0]}
|
||||||
|
output_dim_partition_list.append(dim_partition_dict_0)
|
||||||
|
output_dim_partition_list.append(dim_partition_dict_1)
|
||||||
|
|
||||||
|
# enumerate all the 1D sharding cases
|
||||||
|
for i in range(self.output_data.dim()):
|
||||||
|
dim_partition_dict_0 = {i: [mesh_dim_0]}
|
||||||
|
dim_partition_dict_1 = {i: [mesh_dim_1]}
|
||||||
|
dim_partition_dict_flatten = {i: [mesh_dim_0, mesh_dim_1]}
|
||||||
|
output_dim_partition_list.append(dim_partition_dict_0)
|
||||||
|
output_dim_partition_list.append(dim_partition_dict_1)
|
||||||
|
output_dim_partition_list.append(dim_partition_dict_flatten)
|
||||||
|
|
||||||
|
# add empty dict for fully replicated case
|
||||||
|
output_dim_partition_list.append({})
|
||||||
|
check_duplicated_list = []
|
||||||
|
for output_dim_partition_dict in output_dim_partition_list:
|
||||||
|
output_sharding_spec = self._generate_sharding_spec(self.output_data, output_dim_partition_dict)
|
||||||
|
sharding_seq = output_sharding_spec.sharding_sequence
|
||||||
|
if sharding_seq not in check_duplicated_list:
|
||||||
|
check_duplicated_list.append(sharding_seq)
|
||||||
|
output_sharding_spec_list.append(output_sharding_spec)
|
||||||
|
|
||||||
|
return output_sharding_spec_list
|
||||||
|
|
||||||
|
def _generate_compute_cost(self, *args, **kwargs):
|
||||||
|
return super()._generate_compute_cost(*args, **kwargs)
|
||||||
|
|
||||||
|
@exception_handler
|
||||||
|
def _register_strategy(self, output_sharding_spec):
|
||||||
|
dim_partition_dict_for_input = output_sharding_spec.dim_partition_dict
|
||||||
|
sharding_spec_for_lhs = self._generate_sharding_spec(self.lhs_data, dim_partition_dict_for_input)
|
||||||
|
sharding_spec_for_rhs = self._generate_sharding_spec(self.rhs_data, dim_partition_dict_for_input)
|
||||||
|
|
||||||
|
name = f'{output_sharding_spec.sharding_sequence} = {sharding_spec_for_lhs.sharding_sequence} x {sharding_spec_for_rhs.sharding_sequence}'
|
||||||
|
dim_partition_dict_for_output = output_sharding_spec.dim_partition_dict
|
||||||
|
|
||||||
|
# generate resharding cost for this strategy
|
||||||
|
resharding_costs = self._generate_resharding_costs([sharding_spec_for_lhs, sharding_spec_for_rhs])
|
||||||
|
|
||||||
|
# compute the computation cost of this strategy
|
||||||
|
sharding_dims = []
|
||||||
|
for mesh_dims in dim_partition_dict_for_output.values():
|
||||||
|
for mesh_dim in mesh_dims:
|
||||||
|
sharding_dims.append(self.device_mesh.shape[mesh_dim])
|
||||||
|
sharding_size = reduce(operator.mul, sharding_dims, 1)
|
||||||
|
memory_cost = self.output_data.numel() / sharding_size
|
||||||
|
compute_cost = memory_cost
|
||||||
|
communication_cost = 0
|
||||||
|
|
||||||
|
sharding_strategies = ShardingStrategy(name,
|
||||||
|
output_sharding_spec=output_sharding_spec,
|
||||||
|
compute_cost=compute_cost,
|
||||||
|
communication_cost=communication_cost,
|
||||||
|
memory_cost=memory_cost,
|
||||||
|
resharding_costs=resharding_costs,
|
||||||
|
input_shardings=(sharding_spec_for_lhs, sharding_spec_for_rhs))
|
||||||
|
|
||||||
|
self.strategies_vector.append(sharding_strategies)
|
||||||
|
|
||||||
|
def register_strategy(self) -> StrategiesVector:
|
||||||
|
output_sharding_specs = self._enumerate_all_possible_output(0, 1)
|
||||||
|
for output_sharding_spec in output_sharding_specs:
|
||||||
|
self._register_strategy(output_sharding_spec)
|
|
@ -4,13 +4,14 @@ import warnings
|
||||||
import torch
|
import torch
|
||||||
from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector
|
from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector
|
||||||
from .operator_handler import OperatorHandler
|
from .operator_handler import OperatorHandler
|
||||||
|
from colossalai.auto_parallel.solver._utils import exception_handler
|
||||||
|
|
||||||
__all__ = ['ConvHandler']
|
__all__ = ['ConvHandler']
|
||||||
|
|
||||||
|
|
||||||
class ConvHandler(OperatorHandler):
|
class ConvHandler(OperatorHandler):
|
||||||
"""
|
"""
|
||||||
A OperatorHandler which deals with the sharding strategies of Convolution.
|
An OperatorHandler which deals with the sharding strategies of Convolution.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
|
@ -104,6 +105,7 @@ class ConvHandler(OperatorHandler):
|
||||||
|
|
||||||
return memory_cost, memory_cost_forward_activation, memory_cost_backward_activation, memory_cost_backward_weight
|
return memory_cost, memory_cost_forward_activation, memory_cost_backward_activation, memory_cost_backward_weight
|
||||||
|
|
||||||
|
@exception_handler
|
||||||
def split_input_batch_weight_out_channel(self, mesh_dim_0, mesh_dim_1):
|
def split_input_batch_weight_out_channel(self, mesh_dim_0, mesh_dim_1):
|
||||||
name = f'S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0}R x RS{mesh_dim_1}'
|
name = f'S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0}R x RS{mesh_dim_1}'
|
||||||
|
|
||||||
|
@ -151,6 +153,7 @@ class ConvHandler(OperatorHandler):
|
||||||
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
|
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
|
||||||
self.strategies_vector.append(sharding_strategies)
|
self.strategies_vector.append(sharding_strategies)
|
||||||
|
|
||||||
|
@exception_handler
|
||||||
def split_input_batch(self, mesh_dim_0):
|
def split_input_batch(self, mesh_dim_0):
|
||||||
name = f'S{mesh_dim_0}R = S{mesh_dim_0}R x RR'
|
name = f'S{mesh_dim_0}R = S{mesh_dim_0}R x RR'
|
||||||
|
|
||||||
|
@ -196,6 +199,7 @@ class ConvHandler(OperatorHandler):
|
||||||
|
|
||||||
self.strategies_vector.append(sharding_strategies)
|
self.strategies_vector.append(sharding_strategies)
|
||||||
|
|
||||||
|
@exception_handler
|
||||||
def split_input_both_dim_weight_in_channel(self, mesh_dim_0, mesh_dim_1):
|
def split_input_both_dim_weight_in_channel(self, mesh_dim_0, mesh_dim_1):
|
||||||
name = f'S{mesh_dim_0}R = S{mesh_dim_0}S{mesh_dim_1} x S{mesh_dim_1}R'
|
name = f'S{mesh_dim_0}R = S{mesh_dim_0}S{mesh_dim_1} x S{mesh_dim_1}R'
|
||||||
|
|
||||||
|
@ -241,6 +245,7 @@ class ConvHandler(OperatorHandler):
|
||||||
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
|
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
|
||||||
self.strategies_vector.append(sharding_strategies)
|
self.strategies_vector.append(sharding_strategies)
|
||||||
|
|
||||||
|
@exception_handler
|
||||||
def split_input_in_channel_weight_both_channel(self, mesh_dim_0, mesh_dim_1):
|
def split_input_in_channel_weight_both_channel(self, mesh_dim_0, mesh_dim_1):
|
||||||
name = f'RS{mesh_dim_1} = RS{mesh_dim_0} x S{mesh_dim_0}S{mesh_dim_1}'
|
name = f'RS{mesh_dim_1} = RS{mesh_dim_0} x S{mesh_dim_0}S{mesh_dim_1}'
|
||||||
|
|
||||||
|
@ -283,6 +288,7 @@ class ConvHandler(OperatorHandler):
|
||||||
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
|
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
|
||||||
self.strategies_vector.append(sharding_strategies)
|
self.strategies_vector.append(sharding_strategies)
|
||||||
|
|
||||||
|
@exception_handler
|
||||||
def split_input_in_channel_weight_in_channel(self, mesh_dim_0):
|
def split_input_in_channel_weight_in_channel(self, mesh_dim_0):
|
||||||
name = f'RR = RS{mesh_dim_0} x S{mesh_dim_0}R'
|
name = f'RR = RS{mesh_dim_0} x S{mesh_dim_0}R'
|
||||||
|
|
||||||
|
@ -325,6 +331,7 @@ class ConvHandler(OperatorHandler):
|
||||||
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
|
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
|
||||||
self.strategies_vector.append(sharding_strategies)
|
self.strategies_vector.append(sharding_strategies)
|
||||||
|
|
||||||
|
@exception_handler
|
||||||
def split_weight_out_channel(self, mesh_dim_0):
|
def split_weight_out_channel(self, mesh_dim_0):
|
||||||
name = f'RS{mesh_dim_0} = RR x RS{mesh_dim_0}'
|
name = f'RS{mesh_dim_0} = RR x RS{mesh_dim_0}'
|
||||||
|
|
||||||
|
@ -367,6 +374,7 @@ class ConvHandler(OperatorHandler):
|
||||||
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
|
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
|
||||||
self.strategies_vector.append(sharding_strategies)
|
self.strategies_vector.append(sharding_strategies)
|
||||||
|
|
||||||
|
@exception_handler
|
||||||
def non_split(self):
|
def non_split(self):
|
||||||
name = f'RR = RR x RR'
|
name = f'RR = RR x RR'
|
||||||
|
|
||||||
|
@ -407,6 +415,7 @@ class ConvHandler(OperatorHandler):
|
||||||
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
|
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
|
||||||
self.strategies_vector.append(sharding_strategies)
|
self.strategies_vector.append(sharding_strategies)
|
||||||
|
|
||||||
|
@exception_handler
|
||||||
def split_1d_parallel_on_input_batch(self, mesh_dim_0, mesh_dim_1):
|
def split_1d_parallel_on_input_batch(self, mesh_dim_0, mesh_dim_1):
|
||||||
name = f'S{mesh_dim_0}{mesh_dim_1}R = S{mesh_dim_0}{mesh_dim_1}R x RR'
|
name = f'S{mesh_dim_0}{mesh_dim_1}R = S{mesh_dim_0}{mesh_dim_1}R x RR'
|
||||||
|
|
||||||
|
@ -454,6 +463,7 @@ class ConvHandler(OperatorHandler):
|
||||||
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
|
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
|
||||||
self.strategies_vector.append(sharding_strategies)
|
self.strategies_vector.append(sharding_strategies)
|
||||||
|
|
||||||
|
@exception_handler
|
||||||
def split_1d_parallel_on_in_channel(self, mesh_dim_0, mesh_dim_1):
|
def split_1d_parallel_on_in_channel(self, mesh_dim_0, mesh_dim_1):
|
||||||
name = f'RR = RS{mesh_dim_0}{mesh_dim_1} x S{mesh_dim_0}{mesh_dim_1}R'
|
name = f'RR = RS{mesh_dim_0}{mesh_dim_1} x S{mesh_dim_0}{mesh_dim_1}R'
|
||||||
|
|
||||||
|
@ -554,48 +564,24 @@ class ConvHandler(OperatorHandler):
|
||||||
RR = RS01 x S01R: compute_cost is 8856576, communication_cost is 0, memory_cost is 1968128, resharding_costs is {mul: [0, 0, 262148.4, 65538.002, 196614.402, 262148.4, 65538.2]}
|
RR = RS01 x S01R: compute_cost is 8856576, communication_cost is 0, memory_cost is 1968128, resharding_costs is {mul: [0, 0, 262148.4, 65538.002, 196614.402, 262148.4, 65538.2]}
|
||||||
'''
|
'''
|
||||||
# SS = SR x RS
|
# SS = SR x RS
|
||||||
try:
|
self.split_input_batch_weight_out_channel(0, 1)
|
||||||
self.split_input_batch_weight_out_channel(0, 1)
|
self.split_input_batch_weight_out_channel(1, 0)
|
||||||
except Exception as e:
|
|
||||||
warnings.warn(f'{e}')
|
|
||||||
try:
|
|
||||||
self.split_input_batch_weight_out_channel(1, 0)
|
|
||||||
except Exception as e:
|
|
||||||
warnings.warn(f'{e}')
|
|
||||||
|
|
||||||
# SR = SR x RR
|
# SR = SR x RR
|
||||||
self.split_input_batch(0)
|
self.split_input_batch(0)
|
||||||
self.split_input_batch(1)
|
self.split_input_batch(1)
|
||||||
|
|
||||||
# SR = SS x SR
|
# SR = SS x SR
|
||||||
try:
|
self.split_input_both_dim_weight_in_channel(0, 1)
|
||||||
self.split_input_both_dim_weight_in_channel(0, 1)
|
self.split_input_both_dim_weight_in_channel(1, 0)
|
||||||
except Exception as e:
|
|
||||||
warnings.warn(f'{e}')
|
|
||||||
try:
|
|
||||||
self.split_input_both_dim_weight_in_channel(1, 0)
|
|
||||||
except Exception as e:
|
|
||||||
warnings.warn(f'{e}')
|
|
||||||
|
|
||||||
# RS = RS x SS
|
# RS = RS x SS
|
||||||
try:
|
self.split_input_in_channel_weight_both_channel(0, 1)
|
||||||
self.split_input_in_channel_weight_both_channel(0, 1)
|
self.split_input_in_channel_weight_both_channel(1, 0)
|
||||||
except Exception as e:
|
|
||||||
warnings.warn(f'{e}')
|
|
||||||
try:
|
|
||||||
self.split_input_in_channel_weight_both_channel(1, 0)
|
|
||||||
except Exception as e:
|
|
||||||
warnings.warn(f'{e}')
|
|
||||||
|
|
||||||
# RR = RS x SR
|
# RR = RS x SR
|
||||||
try:
|
self.split_input_in_channel_weight_in_channel(0)
|
||||||
self.split_input_in_channel_weight_in_channel(0)
|
self.split_input_in_channel_weight_in_channel(1)
|
||||||
except Exception as e:
|
|
||||||
warnings.warn(f'{e}')
|
|
||||||
try:
|
|
||||||
self.split_input_in_channel_weight_in_channel(1)
|
|
||||||
except Exception as e:
|
|
||||||
warnings.warn(f'{e}')
|
|
||||||
|
|
||||||
# RS = RR x RS
|
# RS = RR x RS
|
||||||
self.split_weight_out_channel(0)
|
self.split_weight_out_channel(0)
|
||||||
|
@ -608,12 +594,7 @@ class ConvHandler(OperatorHandler):
|
||||||
self.split_1d_parallel_on_input_batch(0, 1)
|
self.split_1d_parallel_on_input_batch(0, 1)
|
||||||
|
|
||||||
# RR = RS01 x S01R
|
# RR = RS01 x S01R
|
||||||
try:
|
self.split_1d_parallel_on_in_channel(0, 1)
|
||||||
self.split_1d_parallel_on_in_channel(0, 1)
|
|
||||||
except Exception as e:
|
|
||||||
warnings.warn(f'{e}')
|
|
||||||
|
|
||||||
# print(f'strategies num is :{len(self.strategies_vector)}')
|
|
||||||
|
|
||||||
return self.strategies_vector
|
return self.strategies_vector
|
||||||
|
|
||||||
|
|
|
@ -6,10 +6,12 @@ from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy,
|
||||||
from .operator_handler import OperatorHandler
|
from .operator_handler import OperatorHandler
|
||||||
from ..constants import LINEAR_FUNC_OP, LINEAR_MODULE_OP
|
from ..constants import LINEAR_FUNC_OP, LINEAR_MODULE_OP
|
||||||
from functools import reduce
|
from functools import reduce
|
||||||
|
from colossalai.auto_parallel.solver._utils import exception_handler
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from .strategy_generator import StrategyGenerator, IntermediateStrategy
|
from .strategy_generator import StrategyGenerator, IntermediateStrategy
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
|
|
||||||
__all__ = ['DotHandler']
|
__all__ = ['DotHandler']
|
||||||
|
|
||||||
|
|
||||||
|
@ -414,6 +416,7 @@ class DotHandler(OperatorHandler):
|
||||||
compute_cost = reduce(operator.mul, input_shape) * weight_shape[0] * 2
|
compute_cost = reduce(operator.mul, input_shape) * weight_shape[0] * 2
|
||||||
return compute_cost
|
return compute_cost
|
||||||
|
|
||||||
|
@exception_handler
|
||||||
def split_lhs_space_rhs_space(self, mesh_dim_0, mesh_dim_1):
|
def split_lhs_space_rhs_space(self, mesh_dim_0, mesh_dim_1):
|
||||||
# handle case SS = SR x RS
|
# handle case SS = SR x RS
|
||||||
name = f'S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0}R x RS{mesh_dim_1}'
|
name = f'S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0}R x RS{mesh_dim_1}'
|
||||||
|
@ -452,6 +455,7 @@ class DotHandler(OperatorHandler):
|
||||||
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
|
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
|
||||||
self.strategies_vector.append(sharding_strategies)
|
self.strategies_vector.append(sharding_strategies)
|
||||||
|
|
||||||
|
@exception_handler
|
||||||
def split_lhs_space_both_contract(self, mesh_dim_0, mesh_dim_1):
|
def split_lhs_space_both_contract(self, mesh_dim_0, mesh_dim_1):
|
||||||
# handle the case SR = SS x SR
|
# handle the case SR = SS x SR
|
||||||
name = f'S{mesh_dim_0}R = S{mesh_dim_0}S{mesh_dim_1} x S{mesh_dim_1}R'
|
name = f'S{mesh_dim_0}R = S{mesh_dim_0}S{mesh_dim_1} x S{mesh_dim_1}R'
|
||||||
|
@ -488,6 +492,7 @@ class DotHandler(OperatorHandler):
|
||||||
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
|
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
|
||||||
self.strategies_vector.append(sharding_strategies)
|
self.strategies_vector.append(sharding_strategies)
|
||||||
|
|
||||||
|
@exception_handler
|
||||||
def split_rhs_space_both_contract(self, mesh_dim_0, mesh_dim_1):
|
def split_rhs_space_both_contract(self, mesh_dim_0, mesh_dim_1):
|
||||||
name = f'RS{mesh_dim_1} = RS{mesh_dim_0} x S{mesh_dim_0}S{mesh_dim_1}'
|
name = f'RS{mesh_dim_1} = RS{mesh_dim_0} x S{mesh_dim_0}S{mesh_dim_1}'
|
||||||
|
|
||||||
|
@ -521,6 +526,7 @@ class DotHandler(OperatorHandler):
|
||||||
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
|
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
|
||||||
self.strategies_vector.append(sharding_strategies)
|
self.strategies_vector.append(sharding_strategies)
|
||||||
|
|
||||||
|
@exception_handler
|
||||||
def recompute_split_both_contract(self, mesh_dim):
|
def recompute_split_both_contract(self, mesh_dim):
|
||||||
name = f'RR = RS{mesh_dim} x S{mesh_dim}R'
|
name = f'RR = RS{mesh_dim} x S{mesh_dim}R'
|
||||||
|
|
||||||
|
@ -554,6 +560,7 @@ class DotHandler(OperatorHandler):
|
||||||
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
|
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
|
||||||
self.strategies_vector.append(sharding_strategies)
|
self.strategies_vector.append(sharding_strategies)
|
||||||
|
|
||||||
|
@exception_handler
|
||||||
def split_rhs_space_only(self, mesh_dim):
|
def split_rhs_space_only(self, mesh_dim):
|
||||||
name = f'RS{mesh_dim} = RR x RS{mesh_dim}'
|
name = f'RS{mesh_dim} = RR x RS{mesh_dim}'
|
||||||
|
|
||||||
|
@ -587,6 +594,7 @@ class DotHandler(OperatorHandler):
|
||||||
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
|
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
|
||||||
self.strategies_vector.append(sharding_strategies)
|
self.strategies_vector.append(sharding_strategies)
|
||||||
|
|
||||||
|
@exception_handler
|
||||||
def split_lhs_1st_dim_1d(self, mesh_dim_0, mesh_dim_1):
|
def split_lhs_1st_dim_1d(self, mesh_dim_0, mesh_dim_1):
|
||||||
name = f'S{mesh_dim_0}{mesh_dim_1}R = S{mesh_dim_0}{mesh_dim_1}R x RR'
|
name = f'S{mesh_dim_0}{mesh_dim_1}R = S{mesh_dim_0}{mesh_dim_1}R x RR'
|
||||||
|
|
||||||
|
@ -620,6 +628,7 @@ class DotHandler(OperatorHandler):
|
||||||
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
|
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
|
||||||
self.strategies_vector.append(sharding_strategies)
|
self.strategies_vector.append(sharding_strategies)
|
||||||
|
|
||||||
|
@exception_handler
|
||||||
def split_lhs_2nd_dim_1d(self, mesh_dim_0, mesh_dim_1):
|
def split_lhs_2nd_dim_1d(self, mesh_dim_0, mesh_dim_1):
|
||||||
name = f'RR = RS{mesh_dim_0}{mesh_dim_1} x S{mesh_dim_0}{mesh_dim_1}R'
|
name = f'RR = RS{mesh_dim_0}{mesh_dim_1} x S{mesh_dim_0}{mesh_dim_1}R'
|
||||||
|
|
||||||
|
@ -653,6 +662,7 @@ class DotHandler(OperatorHandler):
|
||||||
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
|
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
|
||||||
self.strategies_vector.append(sharding_strategies)
|
self.strategies_vector.append(sharding_strategies)
|
||||||
|
|
||||||
|
@exception_handler
|
||||||
def split_rhs_2nd_dim_1d(self, mesh_dim_0, mesh_dim_1):
|
def split_rhs_2nd_dim_1d(self, mesh_dim_0, mesh_dim_1):
|
||||||
name = f'RS{mesh_dim_0}{mesh_dim_1} = RR x RS{mesh_dim_0}{mesh_dim_1}'
|
name = f'RS{mesh_dim_0}{mesh_dim_1} = RR x RS{mesh_dim_0}{mesh_dim_1}'
|
||||||
|
|
||||||
|
|
|
@ -20,6 +20,7 @@ class ReshapeHandler(OperatorHandler):
|
||||||
return super()._generate_compute_cost(*args, **kwargs)
|
return super()._generate_compute_cost(*args, **kwargs)
|
||||||
|
|
||||||
def register_strategy(self):
|
def register_strategy(self):
|
||||||
|
# TODO: add strategies with more output sharding specs other than only fully replicated.
|
||||||
input_node = self.strategies_vector.predecessor_nodes[0]
|
input_node = self.strategies_vector.predecessor_nodes[0]
|
||||||
# For reshape function, to keep the computing correctness we keep the sharding
|
# For reshape function, to keep the computing correctness we keep the sharding
|
||||||
# spec of input is fully replicated. In addition, we will keep the output in
|
# spec of input is fully replicated. In addition, we will keep the output in
|
||||||
|
|
|
@ -70,6 +70,9 @@ class StrategiesVector(list):
|
||||||
# we could merge element-wise op, because the output sharding spec is always same as the input sharding spec.
|
# we could merge element-wise op, because the output sharding spec is always same as the input sharding spec.
|
||||||
if self.node.target in ELEMENTWISE_FUNC_OP:
|
if self.node.target in ELEMENTWISE_FUNC_OP:
|
||||||
merge_label = True
|
merge_label = True
|
||||||
|
# we could merge bcast op if the rhs is a scalar, because it will fall back to the element-wise case.
|
||||||
|
if self.node.target in BCAST_FUNC_OP and len(self.predecessor_nodes) == 1:
|
||||||
|
merge_label = True
|
||||||
# we could merge reshape op, because the output sharding spec of reshape op is always fully replicated.
|
# we could merge reshape op, because the output sharding spec of reshape op is always fully replicated.
|
||||||
if self.node.target in RESHAPE_FUNC_OP:
|
if self.node.target in RESHAPE_FUNC_OP:
|
||||||
merge_label = True
|
merge_label = True
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
from torch.fx import Graph, Node
|
from torch.fx import Graph, Node
|
||||||
|
from colossalai.auto_parallel.solver.op_handler.bcast_op_handler import BcastOpHandler
|
||||||
from colossalai.tensor.sharding_spec import ShardingSpec
|
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||||
from colossalai.device.device_mesh import DeviceMesh
|
from colossalai.device.device_mesh import DeviceMesh
|
||||||
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
|
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
|
||||||
|
@ -52,6 +53,7 @@ class StrategiesConstructor:
|
||||||
def build_strategies_and_cost(self):
|
def build_strategies_and_cost(self):
|
||||||
for node in self.nodes:
|
for node in self.nodes:
|
||||||
strategies_vector = StrategiesVector(node)
|
strategies_vector = StrategiesVector(node)
|
||||||
|
input_nodes_len = len(strategies_vector.predecessor_nodes)
|
||||||
# placeholder node
|
# placeholder node
|
||||||
if node.op == 'placeholder':
|
if node.op == 'placeholder':
|
||||||
# For placeholder nodes, if solver_options.fast is True, we just let them in
|
# For placeholder nodes, if solver_options.fast is True, we just let them in
|
||||||
|
@ -165,6 +167,9 @@ class StrategiesConstructor:
|
||||||
|
|
||||||
# MaxPool module
|
# MaxPool module
|
||||||
elif submod_type in POOL_MODULE_OP:
|
elif submod_type in POOL_MODULE_OP:
|
||||||
|
# TODO: add sharding constraints on image dimension
|
||||||
|
# e.g.: for a 2D pooling input NCHW, we should promise no sharding happens on H and W dimension
|
||||||
|
|
||||||
# create sharding strategy for element-wise module
|
# create sharding strategy for element-wise module
|
||||||
assert len(strategies_vector.predecessor_nodes
|
assert len(strategies_vector.predecessor_nodes
|
||||||
) == 1, f'Temporally, we just support single input element-wise op.'
|
) == 1, f'Temporally, we just support single input element-wise op.'
|
||||||
|
@ -230,7 +235,7 @@ class StrategiesConstructor:
|
||||||
reshape_handler.register_strategy()
|
reshape_handler.register_strategy()
|
||||||
|
|
||||||
# element-wise function
|
# element-wise function
|
||||||
elif target in ELEMENTWISE_FUNC_OP:
|
elif target in ELEMENTWISE_FUNC_OP or (target in BCAST_FUNC_OP and input_nodes_len == 1):
|
||||||
# TODO: integrate element-wise func and module together
|
# TODO: integrate element-wise func and module together
|
||||||
# create sharding strategy for element-wise function
|
# create sharding strategy for element-wise function
|
||||||
assert len(strategies_vector.predecessor_nodes
|
assert len(strategies_vector.predecessor_nodes
|
||||||
|
@ -271,6 +276,11 @@ class StrategiesConstructor:
|
||||||
input_shardings=[input_sharding_spec])
|
input_shardings=[input_sharding_spec])
|
||||||
strategies_vector.append(sharding_strategy)
|
strategies_vector.append(sharding_strategy)
|
||||||
|
|
||||||
|
# bcast op
|
||||||
|
elif target in BCAST_FUNC_OP:
|
||||||
|
bcast_op_handler = BcastOpHandler(node, self.device_mesh, strategies_vector)
|
||||||
|
bcast_op_handler.register_strategy()
|
||||||
|
|
||||||
# torch.var_mean
|
# torch.var_mean
|
||||||
elif target == torch.var_mean:
|
elif target == torch.var_mean:
|
||||||
dim = node.kwargs['dim']
|
dim = node.kwargs['dim']
|
||||||
|
@ -383,9 +393,8 @@ class StrategiesConstructor:
|
||||||
|
|
||||||
# clear the resharding cost for the output node
|
# clear the resharding cost for the output node
|
||||||
# TODO: we may remove this in final version
|
# TODO: we may remove this in final version
|
||||||
if True:
|
for prev_node, resharding_cost_list in resharding_costs.items():
|
||||||
for prev_node, resharding_cost_list in resharding_costs.items():
|
resharding_costs[prev_node] = [0] * len(resharding_cost_list)
|
||||||
resharding_costs[prev_node] = [0] * len(resharding_cost_list)
|
|
||||||
|
|
||||||
sharding_strategy_attribute = ShardingStrategy(name,
|
sharding_strategy_attribute = ShardingStrategy(name,
|
||||||
output_sharding_spec,
|
output_sharding_spec,
|
||||||
|
|
|
@ -0,0 +1,71 @@
|
||||||
|
import torch
|
||||||
|
from torch.fx import GraphModule
|
||||||
|
import torch.nn as nn
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from colossalai.auto_parallel.solver.options import SolverOptions
|
||||||
|
from colossalai.auto_parallel.solver.strategies_constructor import StrategiesConstructor
|
||||||
|
from colossalai.fx.tracer.tracer import ColoTracer
|
||||||
|
from colossalai.device.device_mesh import DeviceMesh
|
||||||
|
|
||||||
|
|
||||||
|
class ConvModel(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, c_in, c_out):
|
||||||
|
super().__init__()
|
||||||
|
self.conv1 = nn.Conv2d(c_in, c_out, kernel_size=3, padding=1)
|
||||||
|
self.conv2 = nn.Conv2d(c_in, c_out, kernel_size=3, padding=1, stride=2)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x1 = self.conv1(x)
|
||||||
|
x2 = x1 + 1
|
||||||
|
x1 = torch.reshape(x1, [1, -1, 64, 1])
|
||||||
|
x3 = self.conv2(x1)
|
||||||
|
x3 = torch.reshape(x3, [4, 1, 64, -1])
|
||||||
|
x = x1 + x3
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def test_conv_handler():
|
||||||
|
physical_mesh_id = torch.arange(0, 4)
|
||||||
|
mesh_shape = (2, 2)
|
||||||
|
# [[0, 1]
|
||||||
|
# [2, 3]]
|
||||||
|
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
|
||||||
|
|
||||||
|
tracer = ColoTracer()
|
||||||
|
model = ConvModel(16, 32)
|
||||||
|
input_sample = {'x': torch.rand(4, 16, 64, 64).to('meta')}
|
||||||
|
# graph():
|
||||||
|
# %x : torch.Tensor [#users=1] = placeholder[target=x]
|
||||||
|
# %conv1 : [#users=2] = call_module[target=conv1](args = (%x,), kwargs = {})
|
||||||
|
# %add : [#users=0] = call_function[target=operator.add](args = (%conv1, 1), kwargs = {})
|
||||||
|
# %reshape : [#users=2] = call_function[target=torch.reshape](args = (%conv1, [1, -1, 64, 1]), kwargs = {})
|
||||||
|
# %conv2 : [#users=1] = call_module[target=conv2](args = (%reshape,), kwargs = {})
|
||||||
|
# %reshape_1 : [#users=1] = call_function[target=torch.reshape](args = (%conv2, [4, 1, 64, -1]), kwargs = {})
|
||||||
|
# %add_1 : [#users=1] = call_function[target=operator.add](args = (%reshape, %reshape_1), kwargs = {})
|
||||||
|
# return add_1
|
||||||
|
graph = tracer.trace(root=model, meta_args=input_sample)
|
||||||
|
gm = GraphModule(model, graph, model.__class__.__name__)
|
||||||
|
# [x, conv1, add, reshape, conv2, reshape_1, add_1, output]
|
||||||
|
nodes = [node for node in gm.graph.nodes]
|
||||||
|
solver_options = SolverOptions(fast=True)
|
||||||
|
strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options)
|
||||||
|
|
||||||
|
strategies_constructor.build_strategies_and_cost()
|
||||||
|
strategy_map = strategies_constructor.strategy_map
|
||||||
|
# check a tensor add with a scalar case
|
||||||
|
conv1_strategies = strategy_map[nodes[1]]
|
||||||
|
add_strategies = strategy_map[nodes[2]]
|
||||||
|
add_strategies_cover_list = [strategy.input_shardings[0].sharding_sequence for strategy in add_strategies]
|
||||||
|
for strategy in conv1_strategies:
|
||||||
|
assert strategy.output_sharding_spec.sharding_sequence in add_strategies_cover_list
|
||||||
|
|
||||||
|
# check two tensors element-wise add case
|
||||||
|
add_1_strategies = strategy_map[nodes[6]]
|
||||||
|
assert len(add_1_strategies) == 25
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
test_conv_handler()
|
Loading…
Reference in New Issue