From b2b2a4af9872bef8d0aed67512160bcb0c893173 Mon Sep 17 00:00:00 2001 From: YuliangLiu0306 <72588413+YuliangLiu0306@users.noreply.github.com> Date: Mon, 26 Sep 2022 15:26:14 +0800 Subject: [PATCH] [autoparallel] adapt solver with mlp (#1638) --- .../solver/op_handler/__init__.py | 3 +- .../solver/op_handler/dot_handler.py | 89 +++++++++++------- .../solver/op_handler/operator_handler.py | 19 +++- .../test_solver_with_mlp.py | 93 +++++++++++++++++++ 4 files changed, 165 insertions(+), 39 deletions(-) create mode 100644 tests/test_auto_parallel/test_solver_with_mlp.py diff --git a/colossalai/auto_parallel/solver/op_handler/__init__.py b/colossalai/auto_parallel/solver/op_handler/__init__.py index 9d7315dd1..a0d570325 100644 --- a/colossalai/auto_parallel/solver/op_handler/__init__.py +++ b/colossalai/auto_parallel/solver/op_handler/__init__.py @@ -4,9 +4,10 @@ from .conv_handler import ConvHandler from .batch_norm_handler import BatchNormHandler from .reshape_handler import ReshapeHandler from .bcast_op_handler import BcastOpHandler +from .embedding_handler import EmbeddingHandler from .unary_elementwise_handler import UnaryElementwiseHandler __all__ = [ 'OperatorHandler', 'DotHandler', 'ConvHandler', 'BatchNormHandler', 'ReshapeHandler', 'BcastOpHandler', - 'UnaryElementwiseHandler' + 'UnaryElementwiseHandler', 'EmbeddingHandler' ] diff --git a/colossalai/auto_parallel/solver/op_handler/dot_handler.py b/colossalai/auto_parallel/solver/op_handler/dot_handler.py index 28f6f29a3..29beb116c 100644 --- a/colossalai/auto_parallel/solver/op_handler/dot_handler.py +++ b/colossalai/auto_parallel/solver/op_handler/dot_handler.py @@ -410,9 +410,9 @@ class DotHandler(OperatorHandler): self.weight = self.module_named_parameters['weight'] self.output_data = self.node._meta_data - def _generate_compute_cost(self, input_shape, weight_shape): + def _generate_compute_cost(self, input_shape, weight_shape, total_sharding_size): # TODO: consider bias addition - compute_cost = reduce(operator.mul, input_shape) * weight_shape[0] * 2 + compute_cost = reduce(operator.mul, input_shape) * weight_shape[0] * 2 // total_sharding_size return compute_cost @exception_handler @@ -434,15 +434,17 @@ class DotHandler(OperatorHandler): resharding_costs = self._generate_resharding_costs([sharding_spec_for_input]) # compute computation cost - compute_cost = self._generate_compute_cost(self.input_data.shape, self.weight.shape) + total_sharding_size = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1] + compute_cost = self._generate_compute_cost(self.input_data.shape, self.weight.shape, total_sharding_size) # compute the memory cost of this strategy - toatl_memory_cost, activation_memory_cost, weight_memory_cost = self._generate_memory_cost( - dim_partition_dict_for_output, dim_partition_dict_for_weight) + toatl_memory_cost, activation_memory_cost, weight_memory_cost, input_grad_memory_cost = self._generate_memory_cost( + dim_partition_dict_for_output, dim_partition_dict_for_weight, dim_partition_dict_for_input) # compute the communication cost - # no all-reduce required for this case - communication_cost = 0 + communication_cost_activation_backward = self.device_mesh.all_reduce_cost(activation_memory_cost, mesh_dim_1) + communication_cost_weight_backward = self.device_mesh.all_reduce_cost(weight_memory_cost, mesh_dim_0) + communication_cost = communication_cost_activation_backward + communication_cost_weight_backward # create and register strategy sharding_strategies = ShardingStrategy(name, @@ -474,14 +476,17 @@ class DotHandler(OperatorHandler): resharding_costs = self._generate_resharding_costs([sharding_spec_for_input]) # compute the computation cost of this strategy - compute_cost = self._generate_compute_cost(self.input_data.shape, self.weight.shape) + total_sharding_size = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1] + compute_cost = self._generate_compute_cost(self.input_data.shape, self.weight.shape, total_sharding_size) # compute the memory cost of this strategy - toatl_memory_cost, activation_memory_cost, weight_memory_cost = self._generate_memory_cost( - dim_partition_dict_for_output, dim_partition_dict_for_weight) + toatl_memory_cost, activation_memory_cost, weight_memory_cost, input_grad_memory_cost = self._generate_memory_cost( + dim_partition_dict_for_output, dim_partition_dict_for_weight, dim_partition_dict_for_input) # compute the communication cost of this strategy - communication_cost = self.device_mesh.all_reduce_cost(activation_memory_cost, mesh_dim_1) + communication_cost_activation_forward = self.device_mesh.all_reduce_cost(activation_memory_cost, mesh_dim_1) + communication_cost_grad_backward = self.device_mesh.all_reduce_cost(weight_memory_cost, mesh_dim_0) + communication_cost = communication_cost_activation_forward + communication_cost_grad_backward sharding_strategies = ShardingStrategy(name, output_sharding_spec=sharding_spec_for_ouput, compute_cost=compute_cost, @@ -508,14 +513,18 @@ class DotHandler(OperatorHandler): resharding_costs = self._generate_resharding_costs([sharding_spec_for_input]) # compute the computation cost of this strategy - compute_cost = self._generate_compute_cost(self.input_data.shape, self.weight.shape) + total_sharding_size = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1] + compute_cost = self._generate_compute_cost(self.input_data.shape, self.weight.shape, total_sharding_size) # compute the memory cost of this strategy - toatl_memory_cost, activation_memory_cost, weight_memory_cost = self._generate_memory_cost( - dim_partition_dict_for_output, dim_partition_dict_for_weight) + toatl_memory_cost, activation_memory_cost, weight_memory_cost, input_grad_memory_cost = self._generate_memory_cost( + dim_partition_dict_for_output, dim_partition_dict_for_weight, dim_partition_dict_for_input) # compute the communication cost of this strategy - communication_cost = self.device_mesh.all_reduce_cost(activation_memory_cost, mesh_dim_1) + communication_cost_activation_forward = self.device_mesh.all_reduce_cost(activation_memory_cost, mesh_dim_0) + communication_cost_activation_backward = self.device_mesh.all_reduce_cost(input_grad_memory_cost, mesh_dim_1) + communication_cost = communication_cost_activation_backward + communication_cost_activation_forward + sharding_strategies = ShardingStrategy(name, output_sharding_spec=sharding_spec_for_ouput, compute_cost=compute_cost, @@ -542,11 +551,12 @@ class DotHandler(OperatorHandler): resharding_costs = self._generate_resharding_costs([sharding_spec_for_input]) # compute the computation cost of this strategy - compute_cost = self._generate_compute_cost(self.input_data.shape, self.weight.shape) + total_sharding_size = self.device_mesh.shape[mesh_dim] + compute_cost = self._generate_compute_cost(self.input_data.shape, self.weight.shape, total_sharding_size) # compute the memory cost of this strategy - toatl_memory_cost, activation_memory_cost, weight_memory_cost = self._generate_memory_cost( - dim_partition_dict_for_output, dim_partition_dict_for_weight) + toatl_memory_cost, activation_memory_cost, weight_memory_cost, input_grad_memory_cost = self._generate_memory_cost( + dim_partition_dict_for_output, dim_partition_dict_for_weight, dim_partition_dict_for_input) # compute the communication cost of this strategy communication_cost = self.device_mesh.all_reduce_cost(activation_memory_cost, mesh_dim) @@ -576,14 +586,16 @@ class DotHandler(OperatorHandler): resharding_costs = self._generate_resharding_costs([sharding_spec_for_input]) # compute the computation cost of this strategy - compute_cost = self._generate_compute_cost(self.input_data.shape, self.weight.shape) + total_sharding_size = self.device_mesh.shape[mesh_dim] + compute_cost = self._generate_compute_cost(self.input_data.shape, self.weight.shape, total_sharding_size) # compute the memory cost of this strategy - toatl_memory_cost, activation_memory_cost, weight_memory_cost = self._generate_memory_cost( - dim_partition_dict_for_output, dim_partition_dict_for_weight) + toatl_memory_cost, activation_memory_cost, weight_memory_cost, input_grad_memory_cost = self._generate_memory_cost( + dim_partition_dict_for_output, dim_partition_dict_for_weight, dim_partition_dict_for_input) # compute the communication cost of this strategy - communication_cost = self.device_mesh.all_reduce_cost(activation_memory_cost, mesh_dim) + communication_cost_activation_backward = self.device_mesh.all_reduce_cost(input_grad_memory_cost, mesh_dim) + communication_cost = communication_cost_activation_backward sharding_strategies = ShardingStrategy(name, output_sharding_spec=sharding_spec_for_ouput, compute_cost=compute_cost, @@ -610,14 +622,16 @@ class DotHandler(OperatorHandler): resharding_costs = self._generate_resharding_costs([sharding_spec_for_input]) # compute the computation cost of this strategy - compute_cost = self._generate_compute_cost(self.input_data.shape, self.weight.shape) + total_sharding_size = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1] + compute_cost = self._generate_compute_cost(self.input_data.shape, self.weight.shape, total_sharding_size) # compute the memory cost of this strategy - toatl_memory_cost, activation_memory_cost, weight_memory_cost = self._generate_memory_cost( - dim_partition_dict_for_output, dim_partition_dict_for_weight) + toatl_memory_cost, activation_memory_cost, weight_memory_cost, input_grad_memory_cost = self._generate_memory_cost( + dim_partition_dict_for_output, dim_partition_dict_for_weight, dim_partition_dict_for_input) # compute the communication cost of this strategy - communication_cost = 0 + communication_cost_weight_backward = self.device_mesh.flatten_device_mesh.all_reduce_cost(weight_memory_cost, 0) + communication_cost = communication_cost_weight_backward sharding_strategies = ShardingStrategy(name, output_sharding_spec=sharding_spec_for_ouput, compute_cost=compute_cost, @@ -644,14 +658,17 @@ class DotHandler(OperatorHandler): resharding_costs = self._generate_resharding_costs([sharding_spec_for_input]) # compute the computation cost of this strategy - compute_cost = self._generate_compute_cost(self.input_data.shape, self.weight.shape) + total_sharding_size = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1] + compute_cost = self._generate_compute_cost(self.input_data.shape, self.weight.shape, total_sharding_size) # compute the memory cost of this strategy - toatl_memory_cost, activation_memory_cost, weight_memory_cost = self._generate_memory_cost( - dim_partition_dict_for_output, dim_partition_dict_for_weight) + toatl_memory_cost, activation_memory_cost, weight_memory_cost, input_grad_memory_cost = self._generate_memory_cost( + dim_partition_dict_for_output, dim_partition_dict_for_weight, dim_partition_dict_for_input) # compute the communication cost of this strategy - communication_cost = self.device_mesh.flatten_device_mesh.all_reduce_cost(activation_memory_cost, 0) + communication_cost_forward_activation = self.device_mesh.flatten_device_mesh.all_reduce_cost( + activation_memory_cost, 0) + communication_cost = communication_cost_forward_activation sharding_strategies = ShardingStrategy(name, output_sharding_spec=sharding_spec_for_ouput, compute_cost=compute_cost, @@ -678,14 +695,16 @@ class DotHandler(OperatorHandler): resharding_costs = self._generate_resharding_costs([sharding_spec_for_input]) # compute the computation cost of this strategy - compute_cost = self._generate_compute_cost(self.input_data.shape, self.weight.shape) + total_sharding_size = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1] + compute_cost = self._generate_compute_cost(self.input_data.shape, self.weight.shape, total_sharding_size) # compute the memory cost of this strategy - toatl_memory_cost, activation_memory_cost, weight_memory_cost = self._generate_memory_cost( - dim_partition_dict_for_output, dim_partition_dict_for_weight) - + toatl_memory_cost, activation_memory_cost, weight_memory_cost, input_grad_memory_cost = self._generate_memory_cost( + dim_partition_dict_for_output, dim_partition_dict_for_weight, dim_partition_dict_for_input) # compute the communication cost of this strategy - communication_cost = 0 + communication_cost_activation_backward = self.device_mesh.flatten_device_mesh.all_reduce_cost( + input_grad_memory_cost, 0) + communication_cost = communication_cost_activation_backward sharding_strategies = ShardingStrategy(name, output_sharding_spec=sharding_spec_for_ouput, compute_cost=compute_cost, diff --git a/colossalai/auto_parallel/solver/op_handler/operator_handler.py b/colossalai/auto_parallel/solver/op_handler/operator_handler.py index a0b70bd6f..1e3234a56 100644 --- a/colossalai/auto_parallel/solver/op_handler/operator_handler.py +++ b/colossalai/auto_parallel/solver/op_handler/operator_handler.py @@ -64,7 +64,8 @@ class OperatorHandler(ABC): """ pass - def _generate_memory_cost(self, dim_partition_dict_for_output, dim_partition_dict_for_weight): + def _generate_memory_cost(self, dim_partition_dict_for_output, dim_partition_dict_for_weight, + sharding_spec_for_input): ''' Compute the memory cost per device with this specific strategy. @@ -102,9 +103,21 @@ class OperatorHandler(ABC): weight_sharding_size *= self.device_mesh.shape[mesh_dim] weight_memory_cost = weight_numel / weight_sharding_size * size_per_elem_bytes - total_memory_cost = activation_memory_cost + weight_memory_cost + # compute the memory cost of input grad + input_grad_numel = self.input_data.numel() + input_grad_sharding_size = 1 + input_grad_mesh_dims = [] + for sharding_dim, mesh_dims in sharding_spec_for_input.items(): + input_grad_mesh_dims.extend(mesh_dims) + for mesh_dim in input_grad_mesh_dims: + input_grad_sharding_size *= self.device_mesh.shape[mesh_dim] + input_grad_memory_cost = input_grad_numel / input_grad_sharding_size * size_per_elem_bytes - return total_memory_cost, activation_memory_cost, weight_memory_cost + memory_cost_forward = activation_memory_cost + weight_memory_cost + memory_cost_backward = input_grad_memory_cost + weight_memory_cost + + return (memory_cost_forward, + memory_cost_backward), activation_memory_cost, weight_memory_cost, input_grad_memory_cost def _generate_resharding_costs(self, sharding_specs): # The resharding_cost of weight is counted due to sharing weight cases. diff --git a/tests/test_auto_parallel/test_solver_with_mlp.py b/tests/test_auto_parallel/test_solver_with_mlp.py new file mode 100644 index 000000000..5a850eee7 --- /dev/null +++ b/tests/test_auto_parallel/test_solver_with_mlp.py @@ -0,0 +1,93 @@ +import torch +from torch.fx import GraphModule +import torch.nn as nn +import pytest + +from colossalai.fx.tracer.tracer import ColoTracer +from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector +from colossalai.tensor.shape_consistency import ShapeConsistencyManager +from colossalai.device.device_mesh import DeviceMesh +from colossalai.auto_parallel.solver.strategies_constructor import StrategiesConstructor +from colossalai.auto_parallel.solver.cost_graph import CostGraph +from copy import deepcopy +from colossalai.auto_parallel.solver import Solver +from torchvision.models import resnet34, resnet50 +from colossalai.auto_parallel.solver.constants import * +from colossalai.auto_parallel.solver.graph_analysis import GraphAnalyser +from colossalai.auto_parallel.solver.options import SolverOptions + + +class MLP(torch.nn.Module): + + def __init__(self, dim: int): + super().__init__() + self.linear1 = torch.nn.Linear(dim, dim * 4) + self.linear2 = torch.nn.Linear(dim * 4, dim) + self.dropout = torch.nn.Dropout(0) + self.relu = torch.nn.ReLU() + + def forward(self, x): + x = self.linear1(x) + x = self.dropout(x) + x = self.relu(x) + x = self.linear2(x) + return x + + +@pytest.mark.skip("for higher testing speed") +def test_cost_graph(): + physical_mesh_id = torch.arange(0, 8) + mesh_shape = (2, 4) + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) + shape_consistency_manager = ShapeConsistencyManager() + + tracer = ColoTracer() + model = MLP(32) + + input_sample = {'x': torch.rand(16, 32).to('meta')} + + # graph(): + # %x : torch.Tensor [#users=1] = placeholder[target=x] + # %linear1 : [#users=1] = call_module[target=linear1](args = (%x,), kwargs = {}) + # %dropout : [#users=1] = call_module[target=dropout](args = (%linear1,), kwargs = {}) + # %relu : [#users=1] = call_module[target=relu](args = (%dropout,), kwargs = {}) + # %linear2 : [#users=1] = call_module[target=linear2](args = (%relu,), kwargs = {}) + # return linear2 + graph = tracer.trace(root=model, meta_args=input_sample) + + gm = GraphModule(model, graph, model.__class__.__name__) + gm.recompile() + graph_analyser = GraphAnalyser(gm) + liveness_list = graph_analyser.liveness_analysis() + solver_options = SolverOptions(fast=True) + strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options) + strategies_constructor.build_strategies_and_cost() + + cost_graph = CostGraph(strategies_constructor.leaf_strategies) + cost_graph.simplify_graph() + # # megatron mode if no memory constraints + # solver = Solver(gm.graph, strategies_constructor, cost_graph, graph_analyser) + # all sharding on out feature dim if memory budget is not sufficient for megatron mode + solver = Solver(gm.graph, strategies_constructor, cost_graph, graph_analyser, memory_budget=5500.0) + + ret = solver.call_solver_serialized_args() + strategies_list = list(ret[0]) + computation_cost = 0 + communication_cost = 0 + memory_cost = 0 + for index, node in enumerate(graph.nodes): + print(node.name, node.strategies_vector[strategies_list[index]].name) + computation_cost += node.strategies_vector[strategies_list[index]].compute_cost + communication_cost += node.strategies_vector[strategies_list[index]].communication_cost + node_memory_cost = node.strategies_vector[strategies_list[index]].memory_cost + if isinstance(node_memory_cost, tuple): + node_memory_cost = node_memory_cost[0] + memory_cost += node_memory_cost + + print(f'computation cost is {computation_cost}') + print(f'communication cost is {communication_cost}') + print(f'memory cost is {memory_cost}') + + +if __name__ == '__main__': + test_cost_graph()