From d164449d004bca01f0e7bbe21de1e32adfd577e6 Mon Sep 17 00:00:00 2001 From: YuliangLiu0306 <72588413+YuliangLiu0306@users.noreply.github.com> Date: Tue, 13 Sep 2022 18:05:05 +0800 Subject: [PATCH] [autoparallel] add resnet autoparallel unit test and add backward weight communication cost (#1589) --- .../solver/op_handler/conv_handler.py | 72 ++++++---- .../test_solver_with_resnet.py | 125 ++++++++++++++++++ 2 files changed, 168 insertions(+), 29 deletions(-) create mode 100644 tests/test_auto_parallel/test_solver_with_resnet.py diff --git a/colossalai/auto_parallel/solver/op_handler/conv_handler.py b/colossalai/auto_parallel/solver/op_handler/conv_handler.py index 6c1b92d4a..d41817652 100644 --- a/colossalai/auto_parallel/solver/op_handler/conv_handler.py +++ b/colossalai/auto_parallel/solver/op_handler/conv_handler.py @@ -103,7 +103,7 @@ class ConvHandler(OperatorHandler): # memory_cost pair memory_cost = (memory_cost_forward, memory_cost_backward) - return memory_cost, memory_cost_forward_activation, memory_cost_backward_activation + return memory_cost, memory_cost_forward_activation, memory_cost_backward_activation, memory_cost_backward_weight def split_input_batch_weight_out_channel(self, mesh_dim_0, mesh_dim_1): name = f'S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0}R x RS{mesh_dim_1}' @@ -132,15 +132,18 @@ class ConvHandler(OperatorHandler): sharding_size_forward = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1] sharding_size_backward_activation = self.device_mesh.shape[mesh_dim_0] sharding_size_weight = self.device_mesh.shape[mesh_dim_1] - memory_cost, _, memory_cost_backward_activation = self._generate_memory_cost( + memory_cost, _, memory_cost_backward_activation, memory_cost_backward_weight = self._generate_memory_cost( sharding_size_forward, sharding_size_backward_activation, sharding_size_weight) # This strategy do not need to do all_reduce operation during forward communication_cost_forward = 0 - # compute the backward communication cost of this strategy - communication_cost_backward = self.device_mesh.all_reduce_cost(memory_cost_backward_activation, mesh_dim_1) + # compute the backward communication cost to all reduce the input activation grad + communication_cost_backward_activation = self.device_mesh.all_reduce_cost(memory_cost_backward_activation, + mesh_dim_1) + # compute the backward communication cost to all reduce the weight due to data parallel + communication_cost_backward_weight = self.device_mesh.all_reduce_cost(memory_cost_backward_weight, mesh_dim_0) # total communication cost - communication_cost = communication_cost_forward + communication_cost_backward + communication_cost = communication_cost_forward + communication_cost_backward_activation + communication_cost_backward_weight sharding_strategies = ShardingStrategy(name, output_sharding_spec=sharding_spec_for_output, @@ -178,11 +181,16 @@ class ConvHandler(OperatorHandler): sharding_size_forward = self.device_mesh.shape[mesh_dim_0] sharding_size_backward_activation = self.device_mesh.shape[mesh_dim_0] sharding_size_weight = 1 - memory_cost, _, _ = self._generate_memory_cost(sharding_size_forward, sharding_size_backward_activation, - sharding_size_weight) + memory_cost, _, _, memory_cost_backward_weight = self._generate_memory_cost(sharding_size_forward, + sharding_size_backward_activation, + sharding_size_weight) - # This strategy do not need to do all_reduce operation in both forward and backward phase. - communication_cost = 0 + # This strategy do not need to do all_reduce operation in forward phase. + communication_cost_forward = 0 + # compute the backward communication cost to all reduce the weight due to data parallel + communication_cost_backward_weight = self.device_mesh.all_reduce_cost(memory_cost_backward_weight, mesh_dim_0) + # compute the total cost + communication_cost = communication_cost_forward + communication_cost_backward_weight sharding_strategies = ShardingStrategy(name, output_sharding_spec=sharding_spec_for_output, compute_cost=compute_cost, @@ -220,15 +228,17 @@ class ConvHandler(OperatorHandler): sharding_size_forward = self.device_mesh.shape[mesh_dim_0] sharding_size_backward_activation = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1] sharding_size_weight = self.device_mesh.shape[mesh_dim_1] - memory_cost, memory_cost_forward_activation, _ = self._generate_memory_cost(sharding_size_forward, - sharding_size_backward_activation, - sharding_size_weight) + memory_cost, memory_cost_forward_activation, _, memory_cost_backward_weight = self._generate_memory_cost( + sharding_size_forward, sharding_size_backward_activation, sharding_size_weight) # compute the communication cost of this strategy during forward phase communication_cost_forward = self.device_mesh.all_reduce_cost(memory_cost_forward_activation, mesh_dim_1) - # This strategy do not need to do all_reduce operation during backward phase - communication_cost_backward = 0 - communication_cost = communication_cost_forward + communication_cost_backward + # This strategy do not need to do all_reduce operation to compute the input activation grad + communication_cost_backward_activation = 0 + # compute the backward communication cost to all reduce the weight due to data parallel + communication_cost_backward_weight = self.device_mesh.all_reduce_cost(memory_cost_backward_weight, mesh_dim_0) + # compute total cost + communication_cost = communication_cost_forward + communication_cost_backward_activation + communication_cost_backward_weight sharding_strategies = ShardingStrategy(name, output_sharding_spec=sharding_spec_for_output, compute_cost=compute_cost, @@ -265,7 +275,7 @@ class ConvHandler(OperatorHandler): sharding_size_forward = self.device_mesh.shape[mesh_dim_1] sharding_size_backward_activation = self.device_mesh.shape[mesh_dim_0] sharding_size_weight = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1] - memory_cost, memory_cost_forward_activation, memory_cost_backward_activation = self._generate_memory_cost( + memory_cost, memory_cost_forward_activation, memory_cost_backward_activation, _ = self._generate_memory_cost( sharding_size_forward, sharding_size_backward_activation, sharding_size_weight) # compute the communication cost of this strategy during forward phase @@ -309,9 +319,8 @@ class ConvHandler(OperatorHandler): sharding_size_forward = 1 sharding_size_backward_activation = self.device_mesh.shape[mesh_dim_0] sharding_size_weight = self.device_mesh.shape[mesh_dim_0] - memory_cost, memory_cost_forward_activation, _ = self._generate_memory_cost(sharding_size_forward, - sharding_size_backward_activation, - sharding_size_weight) + memory_cost, memory_cost_forward_activation, _, _ = self._generate_memory_cost( + sharding_size_forward, sharding_size_backward_activation, sharding_size_weight) # compute the communication cost of this strategy during forward phase communication_cost_forward = self.device_mesh.all_reduce_cost(memory_cost_forward_activation, mesh_dim_0) @@ -354,7 +363,7 @@ class ConvHandler(OperatorHandler): sharding_size_forward = self.device_mesh.shape[mesh_dim_0] sharding_size_backward_activation = 1 sharding_size_weight = self.device_mesh.shape[mesh_dim_0] - memory_cost, _, memory_cost_backward_activation = self._generate_memory_cost( + memory_cost, _, memory_cost_backward_activation, _ = self._generate_memory_cost( sharding_size_forward, sharding_size_backward_activation, sharding_size_weight) # This strategy do not need to do all_reduce during forward phase @@ -398,8 +407,8 @@ class ConvHandler(OperatorHandler): sharding_size_forward = 1 sharding_size_backward_activation = 1 sharding_size_weight = 1 - memory_cost, _, _ = self._generate_memory_cost(sharding_size_forward, sharding_size_backward_activation, - sharding_size_weight) + memory_cost, _, _, _ = self._generate_memory_cost(sharding_size_forward, sharding_size_backward_activation, + sharding_size_weight) # This strategy do not need to do all_reduce in both forward and backward phase communication_cost = 0 @@ -441,11 +450,17 @@ class ConvHandler(OperatorHandler): sharding_size_backward_activation = self.device_mesh.mesh_shape[mesh_dim_0] * self.device_mesh.mesh_shape[ mesh_dim_1] sharding_size_weight = 1 - memory_cost, _, _ = self._generate_memory_cost(sharding_size_forward, sharding_size_backward_activation, - sharding_size_weight) + memory_cost, _, _, memory_cost_backward_weight = self._generate_memory_cost(sharding_size_forward, + sharding_size_backward_activation, + sharding_size_weight) - # This strategy do not need to do all_reduce in both forward and backward phase - communication_cost = 0 + # This strategy do not need to do all_reduce in forward phase + communication_cost_forward = 0 + # compute the backward communication cost to all reduce the weight due to data parallel + communication_cost_backward_weight = self.device_mesh.flatten_device_mesh.all_reduce_cost( + memory_cost_backward_weight, 0) + # compute the total communication cost + communication_cost = communication_cost_backward_weight + communication_cost_forward sharding_strategies = ShardingStrategy(name, output_sharding_spec=sharding_spec_for_output, @@ -485,9 +500,8 @@ class ConvHandler(OperatorHandler): sharding_size_backward_activation = self.device_mesh.mesh_shape[mesh_dim_0] * self.device_mesh.mesh_shape[ mesh_dim_1] sharding_size_weight = self.device_mesh.mesh_shape[mesh_dim_0] * self.device_mesh.mesh_shape[mesh_dim_1] - memory_cost, memory_cost_forward_activation, _ = self._generate_memory_cost(sharding_size_forward, - sharding_size_backward_activation, - sharding_size_weight) + memory_cost, memory_cost_forward_activation, _, _ = self._generate_memory_cost( + sharding_size_forward, sharding_size_backward_activation, sharding_size_weight) # compute communication cost during forward phase communication_cost_forward = self.device_mesh.flatten_device_mesh.all_reduce_cost( diff --git a/tests/test_auto_parallel/test_solver_with_resnet.py b/tests/test_auto_parallel/test_solver_with_resnet.py new file mode 100644 index 000000000..61541b945 --- /dev/null +++ b/tests/test_auto_parallel/test_solver_with_resnet.py @@ -0,0 +1,125 @@ +import torch +from torch.fx import GraphModule +import torch.nn as nn +import pytest + +from colossalai.fx.tracer.tracer import ColoTracer +from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector +from colossalai.tensor.shape_consistency import ShapeConsistencyManager +from colossalai.device.device_mesh import DeviceMesh +from colossalai.auto_parallel.solver.strategies_constructor import StrategiesConstructor +from colossalai.auto_parallel.solver.cost_graph import CostGraph +from copy import deepcopy +from colossalai.auto_parallel.solver import Solver +from torchvision.models import resnet34, resnet50 +from colossalai.auto_parallel.solver.constants import * +from colossalai.auto_parallel.solver.graph_analysis import GraphAnalyser + + +class ConvModel(nn.Module): + + def __init__(self, c_in, c_out): + super().__init__() + self.conv1 = nn.Conv2d(c_in, c_out, kernel_size=3) + self.conv2 = nn.Conv2d(c_out, c_out, kernel_size=3) + self.conv3 = nn.Conv2d(c_out, c_out, kernel_size=3) + self.relu = nn.ReLU() + + def forward(self, x): + x = x * 2 + x = self.conv1(x) + x = self.conv2(x) + x = x / 2 + x = self.conv3(x) + x = self.relu(x) + return x + + +@pytest.mark.skip("for higher testing speed") +def test_cost_graph(): + physical_mesh_id = torch.arange(0, 8) + mesh_shape = (2, 4) + # [[0, 1] + # [2, 3]] + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) + shape_consistency_manager = ShapeConsistencyManager() + + tracer = ColoTracer() + # model = ConvModel(16, 32) + # input_sample = {'x': torch.rand(4, 16, 64, 64).to('meta')} + model = resnet50(num_classes=100000) + input_sample = {'x': torch.rand(128, 3, 224, 224).to('meta')} + + graph = tracer.trace(root=model, meta_args=input_sample) + # graph(): + # %x : torch.Tensor [#users=1] = placeholder[target=x] + # %conv1 : [#users=1] = call_module[target=conv1](args = (%x,), kwargs = {}) + # %bn1 : [#users=1] = call_module[target=bn1](args = (%conv1,), kwargs = {}) + # %relu : [#users=1] = call_module[target=relu](args = (%bn1,), kwargs = {}) + # %maxpool : [#users=2] = call_module[target=maxpool](args = (%relu,), kwargs = {}) + # %layer1_0_conv1 : [#users=1] = call_module[target=layer1.0.conv1](args = (%maxpool,), kwargs = {}) + # %layer1_0_bn1 : [#users=1] = call_module[target=layer1.0.bn1](args = (%layer1_0_conv1,), kwargs = {}) + # %layer1_0_relu : [#users=1] = call_module[target=layer1.0.relu](args = (%layer1_0_bn1,), kwargs = {}) + # %layer1_0_conv2 : [#users=1] = call_module[target=layer1.0.conv2](args = (%layer1_0_relu,), kwargs = {}) + # %layer1_0_bn2 : [#users=1] = call_module[target=layer1.0.bn2](args = (%layer1_0_conv2,), kwargs = {}) + # %add : [#users=1] = call_function[target=operator.add](args = (%layer1_0_bn2, %maxpool), kwargs = {}) + # %layer1_0_relu_1 : [#users=2] = call_module[target=layer1.0.relu](args = (%add,), kwargs = {}) + # %layer1_1_conv1 : [#users=1] = call_module[target=layer1.1.conv1](args = (%layer1_0_relu_1,), kwargs = {}) + # %layer1_1_bn1 : [#users=1] = call_module[target=layer1.1.bn1](args = (%layer1_1_conv1,), kwargs = {}) + # %layer1_1_relu : [#users=1] = call_module[target=layer1.1.relu](args = (%layer1_1_bn1,), kwargs = {}) + # %layer1_1_conv2 : [#users=1] = call_module[target=layer1.1.conv2](args = (%layer1_1_relu,), kwargs = {}) + # %layer1_1_bn2 : [#users=1] = call_module[target=layer1.1.bn2](args = (%layer1_1_conv2,), kwargs = {}) + # %add_1 : [#users=1] = call_function[target=operator.add](args = (%layer1_1_bn2, %layer1_0_relu_1), kwargs = {}) + # ... + # %avgpool : [#users=1] = call_module[target=avgpool](args = (%layer4_2_relu_1,), kwargs = {}) + # %flatten : [#users=1] = call_function[target=torch.flatten](args = (%avgpool, 1), kwargs = {}) + # %fc : [#users=1] = call_module[target=fc](args = (%flatten,), kwargs = {}) + # return fc + gm = GraphModule(model, graph, model.__class__.__name__) + gm.recompile() + graph_analyser = GraphAnalyser(gm) + liveness_list = graph_analyser.liveness_analysis() + # print(len(liveness_dict[0].unique_live_vars)) + # assert False + solver_options = {'fast_mode': True} + strategies_constructor = StrategiesConstructor(graph, device_mesh, shape_consistency_manager, solver_options) + strategies_constructor.build_strategies_and_cost() + + cost_graph = CostGraph(strategies_constructor.leaf_strategies) + cost_graph.simplify_graph() + solver = Solver(gm.graph, strategies_constructor, cost_graph, graph_analyser, memory_budget=1620017824.0) + # solver = Solver(gm.graph, strategies_constructor, cost_graph, graph_analyser) + + ret = solver.call_solver_serialized_args() + print(ret) + strategies_list = list(ret[0]) + print(strategies_list) + computation_cost = 0 + communication_cost = 0 + communication_cost_bn = 0 + memory_cost = 0 + for index, node in enumerate(graph.nodes): + if node.op == 'call_module': + submod = node.graph.owning_module.get_submodule(node.target) + if type(submod) in ELEMENTWISE_MODULE_OP: + input_spec = node.args[0].strategies_vector[strategies_list[index]].output_sharding_spec + print(node.name, input_spec) + continue + if type(submod) in BATCHNORM_MODULE_OP: + communication_cost_bn += node.strategies_vector[strategies_list[index]].communication_cost + print(node.name, node.strategies_vector[strategies_list[index]].name) + computation_cost += node.strategies_vector[strategies_list[index]].compute_cost + communication_cost += node.strategies_vector[strategies_list[index]].communication_cost + node_memory_cost = node.strategies_vector[strategies_list[index]].memory_cost + if isinstance(node_memory_cost, tuple): + node_memory_cost = node_memory_cost[0] + memory_cost += node_memory_cost + + print(f'computation cost is {computation_cost}') + print(f'communication cost is {communication_cost}') + print(f'memory cost is {memory_cost}') + print(f'bn communication cost is {communication_cost_bn}') + + +if __name__ == '__main__': + test_cost_graph()