diff --git a/colossalai/auto_parallel/solver/constants.py b/colossalai/auto_parallel/solver/constants.py index 773a5a566..3360f9425 100644 --- a/colossalai/auto_parallel/solver/constants.py +++ b/colossalai/auto_parallel/solver/constants.py @@ -2,16 +2,17 @@ import torch import operator __all__ = [ - 'ELEMENTWISE_MODULE_OP', 'ELEMENTWISE_FUNC_OP', 'CONV_MODULE_OP', 'CONV_FUNC_OP', 'LINEAR_MODULE_OP', - 'LINEAR_FUNC_OP', 'BATCHNORM_MODULE_OP', 'POOL_MODULE_OP' + 'ELEMENTWISE_MODULE_OP', 'ELEMENTWISE_FUNC_OP', 'RESHAPE_FUNC_OP', 'CONV_MODULE_OP', 'CONV_FUNC_OP', + 'LINEAR_MODULE_OP', 'LINEAR_FUNC_OP', 'BATCHNORM_MODULE_OP', 'POOL_MODULE_OP', 'NON_PARAM_FUNC_OP' ] ELEMENTWISE_MODULE_OP = [torch.nn.Dropout, torch.nn.ReLU] # TODO: flatten should not be added into this group ELEMENTWISE_FUNC_OP = [ torch.add, operator.add, torch.abs, torch.cos, torch.exp, torch.mul, operator.mul, operator.floordiv, - operator.truediv, operator.neg, torch.multiply, torch.nn.functional.relu, torch.nn.functional.dropout, torch.flatten + operator.truediv, operator.neg, torch.multiply, torch.nn.functional.relu, torch.nn.functional.dropout ] +RESHAPE_FUNC_OP = [torch.flatten, torch.Tensor.view, torch.reshape] CONV_MODULE_OP = [ torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d, torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d, torch.nn.ConvTranspose3d @@ -23,5 +24,6 @@ LINEAR_MODULE_OP = [torch.nn.Linear] LINEAR_FUNC_OP = [torch.nn.functional.linear, torch.matmul, torch.bmm] BATCHNORM_MODULE_OP = [torch.nn.BatchNorm1d, torch.nn.BatchNorm2d, torch.nn.BatchNorm3d, torch.nn.SyncBatchNorm] POOL_MODULE_OP = [torch.nn.MaxPool1d, torch.nn.MaxPool2d, torch.nn.MaxPool3d, torch.nn.AdaptiveAvgPool2d] +NON_PARAM_FUNC_OP = RESHAPE_FUNC_OP + ELEMENTWISE_FUNC_OP INFINITY_COST = 1e13 diff --git a/colossalai/auto_parallel/solver/op_handler/__init__.py b/colossalai/auto_parallel/solver/op_handler/__init__.py index 012acffe4..1f31ca45a 100644 --- a/colossalai/auto_parallel/solver/op_handler/__init__.py +++ b/colossalai/auto_parallel/solver/op_handler/__init__.py @@ -2,5 +2,6 @@ from .operator_handler import OperatorHandler from .dot_handler import DotHandler from .conv_handler import ConvHandler from .batch_norm_handler import BatchNormHandler +from .reshape_handler import ReshapeHandler -__all__ = ['OperatorHandler', 'DotHandler', 'ConvHandler', 'BatchNormHandler'] \ No newline at end of file +__all__ = ['OperatorHandler', 'DotHandler', 'ConvHandler', 'BatchNormHandler', 'ReshapeHandler'] \ No newline at end of file diff --git a/colossalai/auto_parallel/solver/op_handler/operator_handler.py b/colossalai/auto_parallel/solver/op_handler/operator_handler.py index 44b4d8217..8db91ffef 100644 --- a/colossalai/auto_parallel/solver/op_handler/operator_handler.py +++ b/colossalai/auto_parallel/solver/op_handler/operator_handler.py @@ -8,6 +8,7 @@ from colossalai.device.device_mesh import DeviceMesh from colossalai.tensor.shape_consistency import ShapeConsistencyManager from colossalai.tensor.sharding_spec import ShardingSpec from .._utils import generate_resharding_costs, generate_sharding_spec +from colossalai.auto_parallel.solver.constants import * from ..sharding_strategy import StrategiesVector @@ -44,7 +45,7 @@ class OperatorHandler(ABC): named_parameters = list(module.named_parameters(recurse=False)) # convert named parameters from list to dict named_parameters = {k: v for k, v in named_parameters} - elif self.node.op == 'call_function': + elif self.node.op == 'call_function' and self.node.target not in NON_PARAM_FUNC_OP: module = None parameters = list(self.node.args)[1] named_parameters = {'weight': parameters._meta_data} diff --git a/colossalai/auto_parallel/solver/op_handler/reshape_handler.py b/colossalai/auto_parallel/solver/op_handler/reshape_handler.py new file mode 100644 index 000000000..cb9e4e6ea --- /dev/null +++ b/colossalai/auto_parallel/solver/op_handler/reshape_handler.py @@ -0,0 +1,66 @@ +from .operator_handler import OperatorHandler +from colossalai.tensor.sharding_spec import ShardingSpec +from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector +from colossalai.tensor.shape_consistency import ShapeConsistencyManager +from copy import deepcopy +import math + + +class ReshapeHandler(OperatorHandler): + """ + An OperatorHandler which deals with the sharding strategies of Reshape Operator, such as torch.reshape, torch.flatten, etc. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.input_data = self.predecessor_node[0]._meta_data + self.output_data = self.node._meta_data + + def _generate_compute_cost(self, *args, **kwargs): + return super()._generate_compute_cost(*args, **kwargs) + + def register_strategy(self): + input_node = self.strategies_vector.predecessor_nodes[0] + # For reshape function, to keep the computing correctness we keep the sharding + # spec of input is fully replicated. In addition, we will keep the output in + # replica status and let the successor node choose the way to resharding the + # output node. Therefore, the different strategies of input node with same + # output sharding spec will generate same strategy for reshape function. + sharding_spec_checklist = [] + for strategy in input_node.strategies_vector: + # It looks a little bit confusing, the input of the processing node + # is the output of the input_node. + input_sharding_spec = strategy.output_sharding_spec + assert isinstance(input_sharding_spec, ShardingSpec), f'The input node should NOT be a tuple of tensor.' + if input_sharding_spec in sharding_spec_checklist: + continue + sharding_spec_checklist.append(input_sharding_spec) + dim_partition_dict_for_output = {} + output_sharding_spec = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output) + name = f'{input_sharding_spec.sharding_sequence} -> FULLY REPLICATED' + # TODO: use meta_info_prop to profile memory cost and compute cost + compute_cost = 0 + memory_cost = self.node._meta_data.numel() + + # compute the communication cost, in reshape op, the communication happens during casting the input sharding spec to fully replicating. + dim_partition_dict_for_replicate_input = {} + replicate_input_sharding_spec = self._generate_sharding_spec(self.input_data, + dim_partition_dict_for_replicate_input) + # shape consistency manager is a singleton class + shape_consistency_manager = ShapeConsistencyManager() + _, _, communication_cost = shape_consistency_manager.shape_consistency(input_sharding_spec, + replicate_input_sharding_spec) + + # generate resharding cost + resharding_costs = self._generate_resharding_costs([input_sharding_spec]) + + # to prevent the resharding happening, set their resharding cost to inf. + resharding_costs[input_node] = [0 if cost == 0 else math.inf for cost in resharding_costs[input_node]] + sharding_strategy = ShardingStrategy(name, + output_sharding_spec, + compute_cost=compute_cost, + communication_cost=communication_cost, + memory_cost=memory_cost, + resharding_costs=resharding_costs, + input_shardings=[input_sharding_spec]) + self.strategies_vector.append(sharding_strategy) diff --git a/colossalai/auto_parallel/solver/sharding_strategy.py b/colossalai/auto_parallel/solver/sharding_strategy.py index 9e30bb753..725ef6892 100644 --- a/colossalai/auto_parallel/solver/sharding_strategy.py +++ b/colossalai/auto_parallel/solver/sharding_strategy.py @@ -61,12 +61,17 @@ class StrategiesVector(list): root_module = self.node.graph.owning_module submod = root_module.get_submodule(target) submod_type = type(submod) - # merge elementwise module node into following nodes + # merge elementwise module node into source nodes + # we could merge element-wise op, because the output sharding spec is always same as the input sharding spec. if submod_type in ELEMENTWISE_MODULE_OP: merge_label = True if self.node.op == 'call_function': + # we could merge element-wise op, because the output sharding spec is always same as the input sharding spec. if self.node.target in ELEMENTWISE_FUNC_OP: merge_label = True + # we could merge reshape op, because the output sharding spec of reshape op is always fully replicated. + if self.node.target in RESHAPE_FUNC_OP: + merge_label = True return merge_label diff --git a/colossalai/auto_parallel/solver/strategies_constructor.py b/colossalai/auto_parallel/solver/strategies_constructor.py index 101be664e..3291790da 100644 --- a/colossalai/auto_parallel/solver/strategies_constructor.py +++ b/colossalai/auto_parallel/solver/strategies_constructor.py @@ -157,8 +157,7 @@ class StrategiesConstructor: # print(node, node.op, node.target, node.args) # create sharding strategy for element-wise module # input_node = strategies_vector.predecessor_nodes[0] - norm_handler = BatchNormHandler(node, self.device_mesh, strategies_vector, - self.shape_consistency_manager) + norm_handler = BatchNormHandler(node, self.device_mesh, strategies_vector) norm_handler.register_strategy() # for strategy in norm_handler.strategies_vector: # print(f'{strategy.name}, computation_cost: {strategy.compute_cost}, memory_cost: {strategy.memory_cost}') @@ -214,18 +213,22 @@ class StrategiesConstructor: if target in CONV_FUNC_OP: # use ConvHandler to create sharding strategies for conv node # TODO: the operator_handler does NOT support function node processing now. - conv_handler = ConvHandler(node, self.device_mesh, strategies_vector, - self.shape_consistency_manager) + conv_handler = ConvHandler(node, self.device_mesh, strategies_vector) conv_handler.register_strategy() # linear function elif target in LINEAR_FUNC_OP: # use DotHandler to create sharding strategies for linear node # TODO: the operator_handler does NOT support function node processing now. - linear_handler = DotHandler(node, self.device_mesh, strategies_vector, - self.shape_consistency_manager) + linear_handler = DotHandler(node, self.device_mesh, strategies_vector) linear_handler.register_strategy() + # reshape function + elif target in RESHAPE_FUNC_OP: + # use ReshapeHandler to create sharding strategies for rehsape node + reshape_handler = ReshapeHandler(node, self.device_mesh, strategies_vector) + reshape_handler.register_strategy() + # element-wise function elif target in ELEMENTWISE_FUNC_OP: # TODO: integrate element-wise func and module together diff --git a/tests/test_auto_parallel/test_reshape_handler.py b/tests/test_auto_parallel/test_reshape_handler.py new file mode 100644 index 000000000..ac9cfad6d --- /dev/null +++ b/tests/test_auto_parallel/test_reshape_handler.py @@ -0,0 +1,55 @@ +import torch +from torch.fx import GraphModule +import torch.nn as nn +import pytest + +from colossalai.auto_parallel.solver.options import SolverOptions +from colossalai.auto_parallel.solver.strategies_constructor import StrategiesConstructor +from colossalai.fx.tracer.tracer import ColoTracer +from colossalai.device.device_mesh import DeviceMesh + + +class ConvModel(nn.Module): + + def __init__(self, c_in, c_out): + super().__init__() + self.conv = nn.Conv2d(c_in, c_out, kernel_size=3) + + def forward(self, x): + x = self.conv(x) + x = torch.flatten(x) + return x + + +def test_conv_handler(): + physical_mesh_id = torch.arange(0, 4) + mesh_shape = (2, 2) + # [[0, 1] + # [2, 3]] + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) + + tracer = ColoTracer() + model = ConvModel(16, 32) + input_sample = {'x': torch.rand(4, 16, 64, 64).to('meta')} + # graph(): + # %x : torch.Tensor [#users=1] = placeholder[target=x] + # %conv : [#users=1] = call_module[target=conv](args = (%mul,), kwargs = {}) + # return flatten + graph = tracer.trace(root=model, meta_args=input_sample) + gm = GraphModule(model, graph, model.__class__.__name__) + # [x, conv, flatten, output] + nodes = [node for node in gm.graph.nodes] + solver_options = SolverOptions(fast=True) + strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options) + + strategies_constructor.build_strategies_and_cost() + strategy_map = strategies_constructor.strategy_map + conv_strategies = strategy_map[nodes[1]] + flatten_strategies = strategy_map[nodes[2]] + flatten_strategies_cover_list = [strategy.input_shardings[0].sharding_sequence for strategy in flatten_strategies] + for strategy in conv_strategies: + assert strategy.output_sharding_spec.sharding_sequence in flatten_strategies_cover_list + + +if __name__ == '__main__': + test_conv_handler() diff --git a/tests/test_auto_parallel/test_solver_with_resnet.py b/tests/test_auto_parallel/test_solver_with_resnet.py index 61541b945..8d133886a 100644 --- a/tests/test_auto_parallel/test_solver_with_resnet.py +++ b/tests/test_auto_parallel/test_solver_with_resnet.py @@ -14,6 +14,7 @@ from colossalai.auto_parallel.solver import Solver from torchvision.models import resnet34, resnet50 from colossalai.auto_parallel.solver.constants import * from colossalai.auto_parallel.solver.graph_analysis import GraphAnalyser +from colossalai.auto_parallel.solver.options import SolverOptions class ConvModel(nn.Module): @@ -81,8 +82,8 @@ def test_cost_graph(): liveness_list = graph_analyser.liveness_analysis() # print(len(liveness_dict[0].unique_live_vars)) # assert False - solver_options = {'fast_mode': True} - strategies_constructor = StrategiesConstructor(graph, device_mesh, shape_consistency_manager, solver_options) + solver_options = SolverOptions(fast=True) + strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options) strategies_constructor.build_strategies_and_cost() cost_graph = CostGraph(strategies_constructor.leaf_strategies)