From b175e6d58e9b7ee07ca0058e2e014608c46c7ffa Mon Sep 17 00:00:00 2001 From: YuliangLiu0306 <72588413+YuliangLiu0306@users.noreply.github.com> Date: Thu, 8 Dec 2022 11:31:51 +0800 Subject: [PATCH] [autoparallel] add bias addtion function class (#2098) * [autoparallel] add bias addtion function class * polish code * polish --- .../__init__.py | 2 + .../patched_bias_addition_function/addmm.py | 76 ++++++++++++++++ .../bias_addition_function.py | 91 +++++++++++++++++++ colossalai/fx/tracer/tracer.py | 5 +- .../test_node_handler/test_addmm_handler.py | 75 ++++++++------- 5 files changed, 216 insertions(+), 33 deletions(-) create mode 100644 colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_function/addmm.py create mode 100644 colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_function/bias_addition_function.py diff --git a/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_function/__init__.py b/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_function/__init__.py index e69de29bb..951c13dde 100644 --- a/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_function/__init__.py +++ b/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_function/__init__.py @@ -0,0 +1,2 @@ +from .addmm import Addmm +from .bias_addition_function import BiasAdditionFunc, LinearBasedBiasFunc, func_to_func_dict diff --git a/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_function/addmm.py b/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_function/addmm.py new file mode 100644 index 000000000..f02cc80a5 --- /dev/null +++ b/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_function/addmm.py @@ -0,0 +1,76 @@ +import operator + +import torch +import torch.nn.functional as F + +from ...registry import bias_addition_function +from .bias_addition_function import LinearBasedBiasFunc + + +@bias_addition_function.register(torch.addmm) +class Addmm(LinearBasedBiasFunc): + + def extract_kwargs_from_origin_func(self): + kwargs = {} + if 'beta' in self.kwargs: + kwargs['beta'] = self.kwargs['beta'] + if 'alpha' in self.kwargs: + kwargs['alpha'] = self.kwargs['alpha'] + return kwargs + + def coefficent_for_addmm(self, input_proxy, coefficent): + """ + This method is used to create a coefficent node for the numerical correctness. + The formula for torch.addmm is out = beta * input + alpha * (m1 @ m2) + Therefore, we need to use this method insert two more operator.mul nodes for + the computation graph to compute the final result. + """ + node_kind = 'call_function' + node_target = operator.mul + node_args = ( + input_proxy, + coefficent, + ) + node_kwargs = {} + mul_proxy = self.tracer.create_proxy(node_kind, node_target, node_args, node_kwargs) + return mul_proxy + + def transpose_other_operand_for_linear(self, other_proxy): + ''' + This method is used to transpose the other operand for linear function. + For example: + input = torch.rand(3, 4) + m1 = torch.rand(3, 5) + m2 = torch.rand(5, 4) + original_output = torch.addmm(input, m1, m2) + # To keep the computation graph consistent with the origin computation graph, we need to transpose the m2 + # before we call the linear function. + new_output = torch.linear(m1, m2.transpose(0, 1)) + input + ''' + node_kind = 'call_function' + node_target = torch.transpose + node_args = (other_proxy, 0, 1) + node_kwargs = {} + transpose_proxy = self.tracer.create_proxy(node_kind, node_target, node_args, node_kwargs) + return transpose_proxy + + def generate(self): + transpose_proxy = self.transpose_other_operand_for_linear(self.args[2]) + non_bias_linear_func_proxy = self.create_non_bias_func_proxy(self.args[1], transpose_proxy) + kwargs = self.extract_kwargs_from_origin_func() + + if 'beta' in kwargs: + beta = kwargs['beta'] + beta_proxy = self.coefficent_for_addmm(self.args[0], beta) + else: + beta_proxy = self.args[0] + + if 'alpha' in kwargs: + alpha = kwargs['alpha'] + alpha_proxy = self.coefficent_for_addmm(alpha, non_bias_linear_func_proxy) + else: + alpha_proxy = non_bias_linear_func_proxy + + bias_addition_proxy = self.create_bias_addition_proxy(alpha_proxy, beta_proxy) + + return bias_addition_proxy diff --git a/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_function/bias_addition_function.py b/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_function/bias_addition_function.py new file mode 100644 index 000000000..e4ca58429 --- /dev/null +++ b/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_function/bias_addition_function.py @@ -0,0 +1,91 @@ +import operator +from abc import ABC, abstractmethod + +import torch +import torch.nn.functional as F + + +class BiasAdditionFunc(ABC): + """ + This class is used to construct the restructure computation graph for + call_func node with bias addition inside. + """ + + def __init__(self, tracer, target, args, kwargs, substitute_func): + self.tracer = tracer + self.target = target + self.args = args + self.kwargs = kwargs + self.substitute_func = substitute_func + + @abstractmethod + def extract_kwargs_from_origin_func(self): + """ + This method is used to extract the kwargs for further graph transform. + + For example: + The formula for torch.addmm is out = beta * input + alpha * (m1 @ m2) + The kwargs for addmm function is {beta=1, alpha=1, output=None}, then we need + to insert two more operator.mul nodes for the computation graph to compute the + final result. + """ + pass + + @abstractmethod + def generate(self): + """ + This method is used to construct the whole restructure computation graph for call_func node with bias + addition inside. + + A whole restructure computation graph will contain a weight node, a bias node, a non-bias addition computation node, + a bias reshape node if needed and a bias addition node. + + Use torch.addmm as an example: + The origin node is: + %addmm: call_func[target=torch.addmm](args = (%input_1, m1, m2), kwargs = {beta=1, alpha=1}) + Restructured graph is: + %transpose : [#users=1] = call_function[target=torch.transpose](args = (%m2, 0, 1), kwargs = {}) + %linear : [#users=1] = call_function[target=torch._C._nn.linear](args = (%m1, %transpose), kwargs = {}) + %mul : [#users=1] = call_function[target=operator.mul](args = (%input_1, 3), kwargs = {}) + %mul_1 : [#users=1] = call_function[target=operator.mul](args = (2, %linear), kwargs = {}) + %add : [#users=1] = call_function[target=operator.add](args = (%mul_1, %mul), kwargs = {}) + """ + pass + + +class LinearBasedBiasFunc(BiasAdditionFunc): + """ + This class is used to construct the restructure computation graph for + call_func node based on F.linear. + """ + + def create_non_bias_func_proxy(self, input_proxy, other_proxy): + """ + This method is used to create the non_bias_func proxy, the node created by this proxy will + compute the main computation, such as convolution, with bias option banned. + """ + assert self.substitute_func == torch.nn.functional.linear + node_kind = 'call_function' + node_target = self.substitute_func + + node_args = (input_proxy, other_proxy) + # non-bias linear does not have any kwargs + node_kwargs = {} + non_bias_func_proxy = self.tracer.create_proxy(node_kind, node_target, node_args, node_kwargs) + return non_bias_func_proxy + + def create_bias_addition_proxy(self, non_bias_func_proxy, bias_proxy): + """ + This method is used to create the bias_addition_proxy, the node created by this proxy will + compute the sum of non_bias_func result and bias with some reshape operation if needed. + """ + bias_add_node_kind = 'call_function' + bias_add_node_target = operator.add + bias_add_args = (non_bias_func_proxy, bias_proxy) + bias_add_proxy = self.tracer.create_proxy(bias_add_node_kind, bias_add_node_target, tuple(bias_add_args), {}) + return bias_add_proxy + + +func_to_func_dict = { + torch.addmm: F.linear, +} diff --git a/colossalai/fx/tracer/tracer.py b/colossalai/fx/tracer/tracer.py index 6295523b8..53e9edd8a 100644 --- a/colossalai/fx/tracer/tracer.py +++ b/colossalai/fx/tracer/tracer.py @@ -20,7 +20,7 @@ from torch.fx.proxy import ParameterProxy, Proxy from ..proxy import ColoProxy from ._tracer_utils import compute_meta_data_for_functions_proxy, extract_meta, is_element_in_list -from .bias_addition_patch import module_to_func_dict +from .bias_addition_patch import func_to_func_dict, module_to_func_dict from .registry import bias_addition_function, bias_addition_module, meta_patched_function, meta_patched_module __all__ = ['ColoTracer'] @@ -96,7 +96,8 @@ class ColoTracer(Tracer): handle = None if kind == "call_function": if bias_addition_function.has(target): - handle = bias_addition_function.get(target)(self, target, args, kwargs) + function_to_substitute = func_to_func_dict[target] + handle = bias_addition_function.get(target)(self, target, args, kwargs, function_to_substitute) elif bias_addition_function.has(target.__name__): # use name for some builtin op like @ (matmul) handle = bias_addition_function.get(target.__name__)(self, target, args, kwargs) diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_addmm_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_addmm_handler.py index e8d3a95a7..767864296 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_addmm_handler.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_addmm_handler.py @@ -8,7 +8,7 @@ import torch.multiprocessing as mp import torch.nn as nn from typing_extensions import Self -from colossalai.auto_parallel.tensor_shard.node_handler import ADDMMFunctionHandler +from colossalai.auto_parallel.tensor_shard.node_handler import LinearFunctionHandler from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( OperationData, OperationDataType, @@ -19,7 +19,7 @@ from colossalai.device.device_mesh import DeviceMesh from colossalai.fx import ColoGraphModule, ColoTracer from colossalai.initialize import launch from colossalai.logging import disable_existing_loggers -from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use +from colossalai.testing import parameterize, rerun_if_address_is_in_use from colossalai.testing.pytest_wrapper import run_on_environment_flag from colossalai.utils import free_port from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy @@ -31,7 +31,7 @@ class AddmmModel(nn.Module): super().__init__() def forward(self, input, m1, m2): - x = torch.addmm(input, m1, m2) + x = torch.addmm(input, m1, m2, beta=3, alpha=2) return x @@ -47,9 +47,9 @@ def check_linear_function_handler(rank, input_shape, world_size, port): m1 = torch.rand(4, 8).cuda() m2 = torch.rand(8, 16).cuda() # the index of addmm node in computation graph - node_index = 3 + node_index = 4 # strategy number of linear node - strategy_number = 10 + strategy_number = 14 # construct input args input_args = [input, m1, m2] # construct meta arg names @@ -59,9 +59,20 @@ def check_linear_function_handler(rank, input_shape, world_size, port): node_index=node_index, strategy_number=strategy_number, input_args=input_args, - meta_arg_names=meta_arg_names) + meta_arg_names=meta_arg_names, + node_type='bias_module') tracer = ColoTracer() + # graph(): + # %input_1 : torch.Tensor [#users=1] = placeholder[target=input] + # %m1 : torch.Tensor [#users=1] = placeholder[target=m1] + # %m2 : torch.Tensor [#users=1] = placeholder[target=m2] + # %transpose : [#users=1] = call_function[target=torch.transpose](args = (%m2, 0, 1), kwargs = {}) + # %linear : [#users=1] = call_function[target=torch._C._nn.linear](args = (%m1, %transpose), kwargs = {}) + # %mul : [#users=1] = call_function[target=operator.mul](args = (%input_1, 3), kwargs = {}) + # %mul_1 : [#users=1] = call_function[target=operator.mul](args = (2, %linear), kwargs = {}) + # %add : [#users=1] = call_function[target=operator.add](args = (%mul_1, %mul), kwargs = {}) + # return add graph = tracer.trace(model, meta_args={ "input": torch.rand(input_shape).to('meta'), @@ -71,11 +82,11 @@ def check_linear_function_handler(rank, input_shape, world_size, port): gm = ColoGraphModule(model, graph) # [input_1, m1, m2, addmm, output] node_list = list(graph.nodes) - addmm_node = node_list[3] - strategies_vector = StrategiesVector(addmm_node) + linear_node = node_list[4] + strategies_vector = StrategiesVector(linear_node) # build handler - handler = ADDMMFunctionHandler(node=addmm_node, device_mesh=device_mesh, strategies_vector=strategies_vector) + handler = LinearFunctionHandler(node=linear_node, device_mesh=device_mesh, strategies_vector=strategies_vector) handler.register_strategy(compute_resharding_cost=False) strategy_name_list = [val.name for val in strategies_vector] @@ -88,30 +99,22 @@ def check_linear_function_handler(rank, input_shape, world_size, port): assert mapping['input'].type == OperationDataType.ARG assert mapping['input'].logical_shape == torch.Size([4, 8]) - assert mapping['other'].name == "m2" - assert mapping['other'].data.shape == torch.Size([8, 16]) + assert mapping['other'].name == "transpose" + assert mapping['other'].data.shape == torch.Size([16, 8]) assert mapping['other'].type == OperationDataType.ARG assert mapping['other'].logical_shape == torch.Size([8, 16]) - assert mapping['bias'].name == "input_1" - assert mapping['bias'].data.shape == torch.Size(input_shape) - assert mapping['bias'].type == OperationDataType.ARG - assert mapping['bias'].logical_shape == torch.Size([4, 16]) - - assert mapping['output'].name == "addmm" + assert mapping['output'].name == "linear" assert mapping['output'].data.shape == torch.Size([4, 16]) assert mapping['output'].type == OperationDataType.OUTPUT - # one strategy will be converted to different physical sharding spec - assert len(strategy_name_list) > 8 - # SS = SR x RS - assert 'S0S1 = S0R x RS1' in strategy_name_list - assert 'S1S0 = S1R x RS0' in strategy_name_list + assert 'S0S1 = S0R x RS1_0' in strategy_name_list + assert 'S1S0 = S1R x RS0_0' in strategy_name_list # SR = SS x SR - assert 'S0R = S0S1 x S1R' in strategy_name_list - assert 'S1R = S1S0 x S0R' in strategy_name_list + assert 'S0R = S0S1 x S1R_0' in strategy_name_list + assert 'S1R = S1S0 x S0R_0' in strategy_name_list # RS = RS x SS assert 'RS0 = RS1 x S1S0' in strategy_name_list @@ -125,23 +128,33 @@ def check_linear_function_handler(rank, input_shape, world_size, port): assert 'RS0 = RR x RS0' in strategy_name_list assert 'RS1 = RR x RS1' in strategy_name_list + # S01R = S01R x RR + assert 'S01R = S01R x RR_0' in strategy_name_list + + # RR = RS01 x S01R + assert 'RR = RS01 x S01R' in strategy_name_list + + # RS01 = RR x RS01 + assert 'RS01 = RR x RS01' in strategy_name_list + + # RR = RR x RR + assert 'RR = RR x RR' in strategy_name_list + for strategy in strategies_vector: strategy: ShardingStrategy input_sharding_spec = strategy.get_sharding_spec_by_name('m1') - weight_sharding_spec = strategy.get_sharding_spec_by_name('m2') - output_sharding_spec = strategy.get_sharding_spec_by_name('addmm') - bias_sharding_spec = strategy.get_sharding_spec_by_name('input_1') + weight_sharding_spec = strategy.get_sharding_spec_by_name('transpose') + output_sharding_spec = strategy.get_sharding_spec_by_name('linear') # make sure the sharding matches across different operation data assert input_sharding_spec.sharding_sequence[:-1] == output_sharding_spec.sharding_sequence[:-1] - assert weight_sharding_spec.sharding_sequence[0] == input_sharding_spec.sharding_sequence[1] - assert weight_sharding_spec.sharding_sequence[1] == output_sharding_spec.sharding_sequence[1] - assert bias_sharding_spec.sharding_sequence[-1] == output_sharding_spec.sharding_sequence[-1] + assert weight_sharding_spec.sharding_sequence[1] == input_sharding_spec.sharding_sequence[1] + assert weight_sharding_spec.sharding_sequence[0] == output_sharding_spec.sharding_sequence[1] -@parameterize('input_shape', [(16,), (4, 16)]) @run_on_environment_flag(name='AUTO_PARALLEL') @pytest.mark.dist +@parameterize('input_shape', [(16,), (4, 16)]) @rerun_if_address_is_in_use() def test_addmm_handler(input_shape): world_size = 4