diff --git a/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_function/__init__.py b/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_function/__init__.py index 951c13dde..ef15f0214 100644 --- a/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_function/__init__.py +++ b/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_function/__init__.py @@ -1,2 +1,3 @@ +from .addbmm import Addbmm from .addmm import Addmm -from .bias_addition_function import BiasAdditionFunc, LinearBasedBiasFunc, func_to_func_dict +from .bias_addition_function import BiasAdditionFunc, LinearBasedBiasFunc, func_to_func_dict, method_to_func_dict diff --git a/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_function/addbmm.py b/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_function/addbmm.py new file mode 100644 index 000000000..859a19bf6 --- /dev/null +++ b/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_function/addbmm.py @@ -0,0 +1,75 @@ +import operator + +import torch +import torch.nn.functional as F + +from ...registry import bias_addition_function, bias_addition_method +from .bias_addition_function import LinearBasedBiasFunc + + +@bias_addition_method.register(torch.Tensor.addbmm) +@bias_addition_function.register(torch.addbmm) +class Addbmm(LinearBasedBiasFunc): + + def extract_kwargs_from_origin_func(self): + kwargs = {} + if 'beta' in self.kwargs: + kwargs['beta'] = self.kwargs['beta'] + if 'alpha' in self.kwargs: + kwargs['alpha'] = self.kwargs['alpha'] + return kwargs + + def create_non_bias_func_proxy(self, input_proxy, other_proxy): + """ + This method is used to create the non_bias_func proxy, the node created by this proxy will + compute the main computation, such as convolution, with bias option banned. + """ + assert self.substitute_func == torch.bmm + node_kind = 'call_function' + node_target = self.substitute_func + + node_args = (input_proxy, other_proxy) + # torch.bmm does not have any kwargs + node_kwargs = {} + non_bias_func_proxy = self.tracer.create_proxy(node_kind, node_target, node_args, node_kwargs) + return non_bias_func_proxy + + def insert_sum_node(self, input_proxy, sum_dims=0): + ''' + This method is used to sum the input_proxy through the sum_dims. + ''' + node_kind = 'call_function' + node_target = torch.sum + node_args = (input_proxy, sum_dims) + node_kwargs = {} + sum_proxy = self.tracer.create_proxy(node_kind, node_target, node_args, node_kwargs) + return sum_proxy + + def generate(self): + # The formula for addbmm is output = beta * input + alpha * (torch.bmm(b1, b2)) + + # doing the non-bias computation(temp_0 = torch.bmm(b1, b2)) + non_bias_linear_func_proxy = self.create_non_bias_func_proxy(self.args[1], self.args[2]) + + # doing sum on the batch dimension(temp_1 = torch.sum(temp_0, 0)) + sum_proxy = self.insert_sum_node(non_bias_linear_func_proxy) + kwargs = self.extract_kwargs_from_origin_func() + + if 'beta' in kwargs: + beta = kwargs['beta'] + # doing the multiplication with beta if it exists(temp_2 = beta * input) + beta_proxy = self.create_mul_node(self.args[0], beta) + else: + beta_proxy = self.args[0] + + if 'alpha' in kwargs: + alpha = kwargs['alpha'] + # doing the multiplication with alpha if it exists(temp_3 = alpha * temp_1) + alpha_proxy = self.create_mul_node(alpha, sum_proxy) + else: + alpha_proxy = sum_proxy + + # doing the addition(temp_4 = temp_2 + temp_3) + bias_addition_proxy = self.create_bias_addition_proxy(alpha_proxy, beta_proxy) + + return bias_addition_proxy diff --git a/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_function/addmm.py b/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_function/addmm.py index f02cc80a5..fe7d8d07a 100644 --- a/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_function/addmm.py +++ b/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_function/addmm.py @@ -3,10 +3,11 @@ import operator import torch import torch.nn.functional as F -from ...registry import bias_addition_function +from ...registry import bias_addition_function, bias_addition_method from .bias_addition_function import LinearBasedBiasFunc +@bias_addition_method.register(torch.Tensor.addmm) @bias_addition_function.register(torch.addmm) class Addmm(LinearBasedBiasFunc): @@ -18,23 +19,6 @@ class Addmm(LinearBasedBiasFunc): kwargs['alpha'] = self.kwargs['alpha'] return kwargs - def coefficent_for_addmm(self, input_proxy, coefficent): - """ - This method is used to create a coefficent node for the numerical correctness. - The formula for torch.addmm is out = beta * input + alpha * (m1 @ m2) - Therefore, we need to use this method insert two more operator.mul nodes for - the computation graph to compute the final result. - """ - node_kind = 'call_function' - node_target = operator.mul - node_args = ( - input_proxy, - coefficent, - ) - node_kwargs = {} - mul_proxy = self.tracer.create_proxy(node_kind, node_target, node_args, node_kwargs) - return mul_proxy - def transpose_other_operand_for_linear(self, other_proxy): ''' This method is used to transpose the other operand for linear function. @@ -61,13 +45,13 @@ class Addmm(LinearBasedBiasFunc): if 'beta' in kwargs: beta = kwargs['beta'] - beta_proxy = self.coefficent_for_addmm(self.args[0], beta) + beta_proxy = self.create_mul_node(self.args[0], beta) else: beta_proxy = self.args[0] if 'alpha' in kwargs: alpha = kwargs['alpha'] - alpha_proxy = self.coefficent_for_addmm(alpha, non_bias_linear_func_proxy) + alpha_proxy = self.create_mul_node(alpha, non_bias_linear_func_proxy) else: alpha_proxy = non_bias_linear_func_proxy diff --git a/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_function/bias_addition_function.py b/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_function/bias_addition_function.py index e4ca58429..e53c5fe69 100644 --- a/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_function/bias_addition_function.py +++ b/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_function/bias_addition_function.py @@ -52,6 +52,23 @@ class BiasAdditionFunc(ABC): """ pass + def create_mul_node(self, input_proxy, coefficent): + """ + This method is used to create a coefficent node for the numerical correctness. + The formula for torch.addmm is out = beta * input + alpha * (m1 @ m2) + Therefore, we need to use this method insert two more operator.mul nodes for + the computation graph to compute the final result. + """ + node_kind = 'call_function' + node_target = operator.mul + node_args = ( + input_proxy, + coefficent, + ) + node_kwargs = {} + mul_proxy = self.tracer.create_proxy(node_kind, node_target, node_args, node_kwargs) + return mul_proxy + class LinearBasedBiasFunc(BiasAdditionFunc): """ @@ -88,4 +105,10 @@ class LinearBasedBiasFunc(BiasAdditionFunc): func_to_func_dict = { torch.addmm: F.linear, + torch.addbmm: torch.bmm, +} + +method_to_func_dict = { + torch.Tensor.addmm: F.linear, + torch.Tensor.addbmm: torch.bmm, } diff --git a/colossalai/fx/tracer/registry.py b/colossalai/fx/tracer/registry.py index 01912dd6c..12fc6de73 100644 --- a/colossalai/fx/tracer/registry.py +++ b/colossalai/fx/tracer/registry.py @@ -25,3 +25,4 @@ meta_patched_function = PatchRegistry(name='patched_functions_for_meta_execution meta_patched_module = PatchRegistry(name='patched_modules_for_meta_execution') bias_addition_function = PatchRegistry(name='patched_function_for_bias_addition') bias_addition_module = PatchRegistry(name='patched_module_for_bias_addition') +bias_addition_method = PatchRegistry(name='patched_method_for_bias_addition') diff --git a/colossalai/fx/tracer/tracer.py b/colossalai/fx/tracer/tracer.py index 53e9edd8a..8a4c361b6 100644 --- a/colossalai/fx/tracer/tracer.py +++ b/colossalai/fx/tracer/tracer.py @@ -20,8 +20,14 @@ from torch.fx.proxy import ParameterProxy, Proxy from ..proxy import ColoProxy from ._tracer_utils import compute_meta_data_for_functions_proxy, extract_meta, is_element_in_list -from .bias_addition_patch import func_to_func_dict, module_to_func_dict -from .registry import bias_addition_function, bias_addition_module, meta_patched_function, meta_patched_module +from .bias_addition_patch import func_to_func_dict, method_to_func_dict, module_to_func_dict +from .registry import ( + bias_addition_function, + bias_addition_method, + bias_addition_module, + meta_patched_function, + meta_patched_module, +) __all__ = ['ColoTracer'] @@ -100,12 +106,14 @@ class ColoTracer(Tracer): handle = bias_addition_function.get(target)(self, target, args, kwargs, function_to_substitute) elif bias_addition_function.has(target.__name__): # use name for some builtin op like @ (matmul) - handle = bias_addition_function.get(target.__name__)(self, target, args, kwargs) + function_to_substitute = func_to_func_dict[target] + handle = bias_addition_function.get(target.__name__)(self, target, args, kwargs, function_to_substitute) elif kind == "call_method": method = getattr(args_metas[0].__class__, target) - if bias_addition_function.has(method): - handle = bias_addition_function.get(method)(self, target, args, kwargs) + if bias_addition_method.has(method): + function_to_substitute = method_to_func_dict[method] + handle = bias_addition_method.get(method)(self, target, args, kwargs, function_to_substitute) elif kind == "call_module": if not hasattr(self, "orig_forward"): diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_addbmm_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_addbmm_handler.py index e96de4603..ffc15e403 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_addbmm_handler.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_addbmm_handler.py @@ -5,7 +5,7 @@ import torch import torch.multiprocessing as mp import torch.nn as nn -from colossalai.auto_parallel.tensor_shard.node_handler import AddBMMFunctionHandler +from colossalai.auto_parallel.tensor_shard.node_handler import BMMFunctionHandler from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector from colossalai.device.device_mesh import DeviceMesh from colossalai.fx import ColoGraphModule, ColoTracer @@ -19,20 +19,36 @@ from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import n class AddBMMTensorMethodModule(nn.Module): + def __init__(self, using_kwargs): + super().__init__() + self.using_kwargs = using_kwargs + def forward(self, bias, x1, x2): - return bias.addbmm(x1, x2) + if self.using_kwargs: + output = bias.addbmm(x1, x2, alpha=2, beta=3) + else: + output = bias.addbmm(x1, x2) + return output class AddBMMTorchFunctionModule(nn.Module): + def __init__(self, using_kwargs): + super().__init__() + self.using_kwargs = using_kwargs + def forward(self, bias, x1, x2): - return torch.addbmm(bias, x1, x2) + if self.using_kwargs: + output = torch.addbmm(bias, x1, x2, alpha=2, beta=3) + else: + output = torch.addbmm(bias, x1, x2) + return output -def check_2d_device_mesh(rank, module, bias_shape, world_size, port): +def check_2d_device_mesh(rank, module, bias_shape, using_kwargs, world_size, port): disable_existing_loggers() launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - model = module().cuda() + model = module(using_kwargs).cuda() physical_mesh_id = torch.arange(0, 4) mesh_shape = (2, 2) device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) @@ -54,6 +70,14 @@ def check_2d_device_mesh(rank, module, bias_shape, world_size, port): input_args=input_args, meta_arg_names=meta_arg_names) tracer = ColoTracer() + # graph(): + # %bias : torch.Tensor [#users=1] = placeholder[target=bias] + # %x1 : torch.Tensor [#users=1] = placeholder[target=x1] + # %x2 : torch.Tensor [#users=1] = placeholder[target=x2] + # %bmm : [#users=1] = call_function[target=torch.bmm](args = (%x1, %x2), kwargs = {}) + # %sum_1 : [#users=1] = call_function[target=torch.sum](args = (%bmm, 0), kwargs = {}) + # %add : [#users=1] = call_function[target=operator.add](args = (%sum_1, %bias), kwargs = {}) + # return add graph = tracer.trace(model, meta_args={ 'bias': torch.rand(*bias_shape).to('meta'), @@ -62,11 +86,11 @@ def check_2d_device_mesh(rank, module, bias_shape, world_size, port): }) gm = ColoGraphModule(model, graph) - linear_mod_node = list(graph.nodes)[3] - strategies_vector = StrategiesVector(linear_mod_node) + bmm_mod_node = list(graph.nodes)[3] + strategies_vector = StrategiesVector(bmm_mod_node) # build handler - handler = AddBMMFunctionHandler(node=linear_mod_node, device_mesh=device_mesh, strategies_vector=strategies_vector) + handler = BMMFunctionHandler(node=bmm_mod_node, device_mesh=device_mesh, strategies_vector=strategies_vector) # check operation data mapping mapping = handler.get_operation_data_mapping() @@ -89,19 +113,15 @@ def check_2d_device_mesh(rank, module, bias_shape, world_size, port): assert mapping['other'].type == OperationDataType.ARG assert mapping['other'].logical_shape == torch.Size([4, 16, 8]) - assert mapping['bias'].name == "bias" - assert mapping['bias'].data.is_meta - assert mapping['bias'].data.shape == torch.Size(bias_shape) - assert mapping['bias'].type == OperationDataType.ARG - assert mapping['bias'].logical_shape == torch.Size([8, 8]) - - assert mapping['output'].name == "addbmm" + assert mapping['output'].name == "bmm" assert mapping['output'].data.is_meta - assert mapping['output'].data.shape == torch.Size([8, 8]) + assert mapping['output'].data.shape == torch.Size([4, 8, 8]) assert mapping['output'].type == OperationDataType.OUTPUT strategies_vector = handler.register_strategy(compute_resharding_cost=False) strategy_name_list = [val.name for val in strategies_vector] + for name in strategy_name_list: + print(name) # one batch dim assert 'Sb0 = Sb0 x Sb0' not in strategy_name_list @@ -123,23 +143,21 @@ def check_2d_device_mesh(rank, module, bias_shape, world_size, port): for strategy in strategies_vector: input_sharding_spec = strategy.get_sharding_spec_by_name('x1') other_sharding_spec = strategy.get_sharding_spec_by_name('x2') - bias_sharding_spec = strategy.get_sharding_spec_by_name('bias') - output_sharding_spec = strategy.get_sharding_spec_by_name('addbmm') + output_sharding_spec = strategy.get_sharding_spec_by_name('bmm') # make sure the sharding matches across different operation data - assert input_sharding_spec.sharding_sequence[1] == output_sharding_spec.sharding_sequence[0] + assert input_sharding_spec.sharding_sequence[0] == output_sharding_spec.sharding_sequence[0] assert other_sharding_spec.sharding_sequence[1] == input_sharding_spec.sharding_sequence[-1] assert other_sharding_spec.sharding_sequence[-1] == output_sharding_spec.sharding_sequence[-1] - assert bias_sharding_spec.sharding_sequence[-1] == output_sharding_spec.sharding_sequence[-1] -def check_1d_device_mesh(rank, module, bias_shape, world_size, port): +def check_1d_device_mesh(rank, module, bias_shape, using_kwargs, world_size, port): disable_existing_loggers() launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') physical_mesh_id = torch.arange(0, 4) mesh_shape = (1, 4) device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) - model = module().cuda() + model = module(using_kwargs).cuda() x1 = torch.rand(4, 8, 16).cuda() x2 = torch.rand(4, 16, 8).cuda() bias = torch.rand(bias_shape).cuda() @@ -159,6 +177,14 @@ def check_1d_device_mesh(rank, module, bias_shape, world_size, port): meta_arg_names=meta_arg_names) tracer = ColoTracer() + # graph(): + # %bias : torch.Tensor [#users=1] = placeholder[target=bias] + # %x1 : torch.Tensor [#users=1] = placeholder[target=x1] + # %x2 : torch.Tensor [#users=1] = placeholder[target=x2] + # %bmm : [#users=1] = call_function[target=torch.bmm](args = (%x1, %x2), kwargs = {}) + # %sum_1 : [#users=1] = call_function[target=torch.sum](args = (%bmm, 0), kwargs = {}) + # %add : [#users=1] = call_function[target=operator.add](args = (%sum_1, %bias), kwargs = {}) + # return add graph = tracer.trace(model, meta_args={ 'bias': torch.rand(*bias_shape).to('meta'), @@ -166,11 +192,11 @@ def check_1d_device_mesh(rank, module, bias_shape, world_size, port): 'x2': torch.rand(4, 16, 8).to('meta') }) gm = ColoGraphModule(model, graph) - linear_mod_node = list(graph.nodes)[3] - strategies_vector = StrategiesVector(linear_mod_node) + bmm_mod_node = list(graph.nodes)[3] + strategies_vector = StrategiesVector(bmm_mod_node) # build handler - handler = AddBMMFunctionHandler(node=linear_mod_node, device_mesh=device_mesh, strategies_vector=strategies_vector) + handler = BMMFunctionHandler(node=bmm_mod_node, device_mesh=device_mesh, strategies_vector=strategies_vector) # check operation data mapping mapping = handler.get_operation_data_mapping() @@ -193,15 +219,9 @@ def check_1d_device_mesh(rank, module, bias_shape, world_size, port): assert mapping['other'].type == OperationDataType.ARG assert mapping['other'].logical_shape == torch.Size([4, 16, 8]) - assert mapping['bias'].name == "bias" - assert mapping['bias'].data.is_meta - assert mapping['bias'].data.shape == torch.Size(bias_shape) - assert mapping['bias'].type == OperationDataType.ARG - assert mapping['bias'].logical_shape == torch.Size([8, 8]) - - assert mapping['output'].name == "addbmm" + assert mapping['output'].name == "bmm" assert mapping['output'].data.is_meta - assert mapping['output'].data.shape == torch.Size([8, 8]) + assert mapping['output'].data.shape == torch.Size([4, 8, 8]) assert mapping['output'].type == OperationDataType.OUTPUT strategies_vector = handler.register_strategy(compute_resharding_cost=False) @@ -213,14 +233,12 @@ def check_1d_device_mesh(rank, module, bias_shape, world_size, port): for strategy in strategies_vector: input_sharding_spec = strategy.get_sharding_spec_by_name('x1') other_sharding_spec = strategy.get_sharding_spec_by_name('x2') - bias_sharding_spec = strategy.get_sharding_spec_by_name('bias') - output_sharding_spec = strategy.get_sharding_spec_by_name('addbmm') + output_sharding_spec = strategy.get_sharding_spec_by_name('bmm') # make sure the sharding matches across different operation data - assert input_sharding_spec.sharding_sequence[1] == output_sharding_spec.sharding_sequence[0] + assert input_sharding_spec.sharding_sequence[0] == output_sharding_spec.sharding_sequence[0] assert other_sharding_spec.sharding_sequence[1] == input_sharding_spec.sharding_sequence[-1] assert other_sharding_spec.sharding_sequence[-1] == output_sharding_spec.sharding_sequence[-1] - assert bias_sharding_spec.sharding_sequence[-1] == output_sharding_spec.sharding_sequence[-1] @pytest.mark.skip("skip due to bias cases not ready") @@ -228,13 +246,15 @@ def check_1d_device_mesh(rank, module, bias_shape, world_size, port): @pytest.mark.dist @parameterize('module', [AddBMMTorchFunctionModule, AddBMMTensorMethodModule]) @parameterize('bias_shape', [[8], [1, 8], [8, 8]]) +@parameterize('using_kwargs', [True, False]) @rerun_if_address_is_in_use() -def test_2d_device_mesh(module, bias_shape): +def test_2d_device_mesh(module, bias_shape, using_kwargs): world_size = 4 run_func = partial(check_2d_device_mesh, module=module, bias_shape=bias_shape, world_size=world_size, + using_kwargs=using_kwargs, port=free_port()) mp.spawn(run_func, nprocs=world_size) @@ -244,12 +264,14 @@ def test_2d_device_mesh(module, bias_shape): @pytest.mark.dist @parameterize('module', [AddBMMTorchFunctionModule, AddBMMTensorMethodModule]) @parameterize('bias_shape', [[8], [1, 8], [8, 8]]) +@parameterize('using_kwargs', [True, False]) @rerun_if_address_is_in_use() -def test_1d_device_mesh(module, bias_shape): +def test_1d_device_mesh(module, bias_shape, using_kwargs): world_size = 4 run_func = partial(check_1d_device_mesh, module=module, bias_shape=bias_shape, + using_kwargs=using_kwargs, world_size=world_size, port=free_port()) mp.spawn(run_func, nprocs=world_size)