[autoparallel] support addbmm computation (#2102)

pull/2104/head
YuliangLiu0306 2022-12-08 21:15:11 +08:00 committed by GitHub
parent d3d4630495
commit 0fecbb9e20
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 179 additions and 65 deletions

View File

@ -1,2 +1,3 @@
from .addbmm import Addbmm
from .addmm import Addmm
from .bias_addition_function import BiasAdditionFunc, LinearBasedBiasFunc, func_to_func_dict
from .bias_addition_function import BiasAdditionFunc, LinearBasedBiasFunc, func_to_func_dict, method_to_func_dict

View File

@ -0,0 +1,75 @@
import operator
import torch
import torch.nn.functional as F
from ...registry import bias_addition_function, bias_addition_method
from .bias_addition_function import LinearBasedBiasFunc
@bias_addition_method.register(torch.Tensor.addbmm)
@bias_addition_function.register(torch.addbmm)
class Addbmm(LinearBasedBiasFunc):
def extract_kwargs_from_origin_func(self):
kwargs = {}
if 'beta' in self.kwargs:
kwargs['beta'] = self.kwargs['beta']
if 'alpha' in self.kwargs:
kwargs['alpha'] = self.kwargs['alpha']
return kwargs
def create_non_bias_func_proxy(self, input_proxy, other_proxy):
"""
This method is used to create the non_bias_func proxy, the node created by this proxy will
compute the main computation, such as convolution, with bias option banned.
"""
assert self.substitute_func == torch.bmm
node_kind = 'call_function'
node_target = self.substitute_func
node_args = (input_proxy, other_proxy)
# torch.bmm does not have any kwargs
node_kwargs = {}
non_bias_func_proxy = self.tracer.create_proxy(node_kind, node_target, node_args, node_kwargs)
return non_bias_func_proxy
def insert_sum_node(self, input_proxy, sum_dims=0):
'''
This method is used to sum the input_proxy through the sum_dims.
'''
node_kind = 'call_function'
node_target = torch.sum
node_args = (input_proxy, sum_dims)
node_kwargs = {}
sum_proxy = self.tracer.create_proxy(node_kind, node_target, node_args, node_kwargs)
return sum_proxy
def generate(self):
# The formula for addbmm is output = beta * input + alpha * (torch.bmm(b1, b2))
# doing the non-bias computation(temp_0 = torch.bmm(b1, b2))
non_bias_linear_func_proxy = self.create_non_bias_func_proxy(self.args[1], self.args[2])
# doing sum on the batch dimension(temp_1 = torch.sum(temp_0, 0))
sum_proxy = self.insert_sum_node(non_bias_linear_func_proxy)
kwargs = self.extract_kwargs_from_origin_func()
if 'beta' in kwargs:
beta = kwargs['beta']
# doing the multiplication with beta if it exists(temp_2 = beta * input)
beta_proxy = self.create_mul_node(self.args[0], beta)
else:
beta_proxy = self.args[0]
if 'alpha' in kwargs:
alpha = kwargs['alpha']
# doing the multiplication with alpha if it exists(temp_3 = alpha * temp_1)
alpha_proxy = self.create_mul_node(alpha, sum_proxy)
else:
alpha_proxy = sum_proxy
# doing the addition(temp_4 = temp_2 + temp_3)
bias_addition_proxy = self.create_bias_addition_proxy(alpha_proxy, beta_proxy)
return bias_addition_proxy

View File

@ -3,10 +3,11 @@ import operator
import torch
import torch.nn.functional as F
from ...registry import bias_addition_function
from ...registry import bias_addition_function, bias_addition_method
from .bias_addition_function import LinearBasedBiasFunc
@bias_addition_method.register(torch.Tensor.addmm)
@bias_addition_function.register(torch.addmm)
class Addmm(LinearBasedBiasFunc):
@ -18,23 +19,6 @@ class Addmm(LinearBasedBiasFunc):
kwargs['alpha'] = self.kwargs['alpha']
return kwargs
def coefficent_for_addmm(self, input_proxy, coefficent):
"""
This method is used to create a coefficent node for the numerical correctness.
The formula for torch.addmm is out = beta * input + alpha * (m1 @ m2)
Therefore, we need to use this method insert two more operator.mul nodes for
the computation graph to compute the final result.
"""
node_kind = 'call_function'
node_target = operator.mul
node_args = (
input_proxy,
coefficent,
)
node_kwargs = {}
mul_proxy = self.tracer.create_proxy(node_kind, node_target, node_args, node_kwargs)
return mul_proxy
def transpose_other_operand_for_linear(self, other_proxy):
'''
This method is used to transpose the other operand for linear function.
@ -61,13 +45,13 @@ class Addmm(LinearBasedBiasFunc):
if 'beta' in kwargs:
beta = kwargs['beta']
beta_proxy = self.coefficent_for_addmm(self.args[0], beta)
beta_proxy = self.create_mul_node(self.args[0], beta)
else:
beta_proxy = self.args[0]
if 'alpha' in kwargs:
alpha = kwargs['alpha']
alpha_proxy = self.coefficent_for_addmm(alpha, non_bias_linear_func_proxy)
alpha_proxy = self.create_mul_node(alpha, non_bias_linear_func_proxy)
else:
alpha_proxy = non_bias_linear_func_proxy

View File

@ -52,6 +52,23 @@ class BiasAdditionFunc(ABC):
"""
pass
def create_mul_node(self, input_proxy, coefficent):
"""
This method is used to create a coefficent node for the numerical correctness.
The formula for torch.addmm is out = beta * input + alpha * (m1 @ m2)
Therefore, we need to use this method insert two more operator.mul nodes for
the computation graph to compute the final result.
"""
node_kind = 'call_function'
node_target = operator.mul
node_args = (
input_proxy,
coefficent,
)
node_kwargs = {}
mul_proxy = self.tracer.create_proxy(node_kind, node_target, node_args, node_kwargs)
return mul_proxy
class LinearBasedBiasFunc(BiasAdditionFunc):
"""
@ -88,4 +105,10 @@ class LinearBasedBiasFunc(BiasAdditionFunc):
func_to_func_dict = {
torch.addmm: F.linear,
torch.addbmm: torch.bmm,
}
method_to_func_dict = {
torch.Tensor.addmm: F.linear,
torch.Tensor.addbmm: torch.bmm,
}

View File

@ -25,3 +25,4 @@ meta_patched_function = PatchRegistry(name='patched_functions_for_meta_execution
meta_patched_module = PatchRegistry(name='patched_modules_for_meta_execution')
bias_addition_function = PatchRegistry(name='patched_function_for_bias_addition')
bias_addition_module = PatchRegistry(name='patched_module_for_bias_addition')
bias_addition_method = PatchRegistry(name='patched_method_for_bias_addition')

View File

@ -20,8 +20,14 @@ from torch.fx.proxy import ParameterProxy, Proxy
from ..proxy import ColoProxy
from ._tracer_utils import compute_meta_data_for_functions_proxy, extract_meta, is_element_in_list
from .bias_addition_patch import func_to_func_dict, module_to_func_dict
from .registry import bias_addition_function, bias_addition_module, meta_patched_function, meta_patched_module
from .bias_addition_patch import func_to_func_dict, method_to_func_dict, module_to_func_dict
from .registry import (
bias_addition_function,
bias_addition_method,
bias_addition_module,
meta_patched_function,
meta_patched_module,
)
__all__ = ['ColoTracer']
@ -100,12 +106,14 @@ class ColoTracer(Tracer):
handle = bias_addition_function.get(target)(self, target, args, kwargs, function_to_substitute)
elif bias_addition_function.has(target.__name__):
# use name for some builtin op like @ (matmul)
handle = bias_addition_function.get(target.__name__)(self, target, args, kwargs)
function_to_substitute = func_to_func_dict[target]
handle = bias_addition_function.get(target.__name__)(self, target, args, kwargs, function_to_substitute)
elif kind == "call_method":
method = getattr(args_metas[0].__class__, target)
if bias_addition_function.has(method):
handle = bias_addition_function.get(method)(self, target, args, kwargs)
if bias_addition_method.has(method):
function_to_substitute = method_to_func_dict[method]
handle = bias_addition_method.get(method)(self, target, args, kwargs, function_to_substitute)
elif kind == "call_module":
if not hasattr(self, "orig_forward"):

View File

@ -5,7 +5,7 @@ import torch
import torch.multiprocessing as mp
import torch.nn as nn
from colossalai.auto_parallel.tensor_shard.node_handler import AddBMMFunctionHandler
from colossalai.auto_parallel.tensor_shard.node_handler import BMMFunctionHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx import ColoGraphModule, ColoTracer
@ -19,20 +19,36 @@ from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import n
class AddBMMTensorMethodModule(nn.Module):
def __init__(self, using_kwargs):
super().__init__()
self.using_kwargs = using_kwargs
def forward(self, bias, x1, x2):
return bias.addbmm(x1, x2)
if self.using_kwargs:
output = bias.addbmm(x1, x2, alpha=2, beta=3)
else:
output = bias.addbmm(x1, x2)
return output
class AddBMMTorchFunctionModule(nn.Module):
def __init__(self, using_kwargs):
super().__init__()
self.using_kwargs = using_kwargs
def forward(self, bias, x1, x2):
return torch.addbmm(bias, x1, x2)
if self.using_kwargs:
output = torch.addbmm(bias, x1, x2, alpha=2, beta=3)
else:
output = torch.addbmm(bias, x1, x2)
return output
def check_2d_device_mesh(rank, module, bias_shape, world_size, port):
def check_2d_device_mesh(rank, module, bias_shape, using_kwargs, world_size, port):
disable_existing_loggers()
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
model = module().cuda()
model = module(using_kwargs).cuda()
physical_mesh_id = torch.arange(0, 4)
mesh_shape = (2, 2)
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
@ -54,6 +70,14 @@ def check_2d_device_mesh(rank, module, bias_shape, world_size, port):
input_args=input_args,
meta_arg_names=meta_arg_names)
tracer = ColoTracer()
# graph():
# %bias : torch.Tensor [#users=1] = placeholder[target=bias]
# %x1 : torch.Tensor [#users=1] = placeholder[target=x1]
# %x2 : torch.Tensor [#users=1] = placeholder[target=x2]
# %bmm : [#users=1] = call_function[target=torch.bmm](args = (%x1, %x2), kwargs = {})
# %sum_1 : [#users=1] = call_function[target=torch.sum](args = (%bmm, 0), kwargs = {})
# %add : [#users=1] = call_function[target=operator.add](args = (%sum_1, %bias), kwargs = {})
# return add
graph = tracer.trace(model,
meta_args={
'bias': torch.rand(*bias_shape).to('meta'),
@ -62,11 +86,11 @@ def check_2d_device_mesh(rank, module, bias_shape, world_size, port):
})
gm = ColoGraphModule(model, graph)
linear_mod_node = list(graph.nodes)[3]
strategies_vector = StrategiesVector(linear_mod_node)
bmm_mod_node = list(graph.nodes)[3]
strategies_vector = StrategiesVector(bmm_mod_node)
# build handler
handler = AddBMMFunctionHandler(node=linear_mod_node, device_mesh=device_mesh, strategies_vector=strategies_vector)
handler = BMMFunctionHandler(node=bmm_mod_node, device_mesh=device_mesh, strategies_vector=strategies_vector)
# check operation data mapping
mapping = handler.get_operation_data_mapping()
@ -89,19 +113,15 @@ def check_2d_device_mesh(rank, module, bias_shape, world_size, port):
assert mapping['other'].type == OperationDataType.ARG
assert mapping['other'].logical_shape == torch.Size([4, 16, 8])
assert mapping['bias'].name == "bias"
assert mapping['bias'].data.is_meta
assert mapping['bias'].data.shape == torch.Size(bias_shape)
assert mapping['bias'].type == OperationDataType.ARG
assert mapping['bias'].logical_shape == torch.Size([8, 8])
assert mapping['output'].name == "addbmm"
assert mapping['output'].name == "bmm"
assert mapping['output'].data.is_meta
assert mapping['output'].data.shape == torch.Size([8, 8])
assert mapping['output'].data.shape == torch.Size([4, 8, 8])
assert mapping['output'].type == OperationDataType.OUTPUT
strategies_vector = handler.register_strategy(compute_resharding_cost=False)
strategy_name_list = [val.name for val in strategies_vector]
for name in strategy_name_list:
print(name)
# one batch dim
assert 'Sb0 = Sb0 x Sb0' not in strategy_name_list
@ -123,23 +143,21 @@ def check_2d_device_mesh(rank, module, bias_shape, world_size, port):
for strategy in strategies_vector:
input_sharding_spec = strategy.get_sharding_spec_by_name('x1')
other_sharding_spec = strategy.get_sharding_spec_by_name('x2')
bias_sharding_spec = strategy.get_sharding_spec_by_name('bias')
output_sharding_spec = strategy.get_sharding_spec_by_name('addbmm')
output_sharding_spec = strategy.get_sharding_spec_by_name('bmm')
# make sure the sharding matches across different operation data
assert input_sharding_spec.sharding_sequence[1] == output_sharding_spec.sharding_sequence[0]
assert input_sharding_spec.sharding_sequence[0] == output_sharding_spec.sharding_sequence[0]
assert other_sharding_spec.sharding_sequence[1] == input_sharding_spec.sharding_sequence[-1]
assert other_sharding_spec.sharding_sequence[-1] == output_sharding_spec.sharding_sequence[-1]
assert bias_sharding_spec.sharding_sequence[-1] == output_sharding_spec.sharding_sequence[-1]
def check_1d_device_mesh(rank, module, bias_shape, world_size, port):
def check_1d_device_mesh(rank, module, bias_shape, using_kwargs, world_size, port):
disable_existing_loggers()
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
physical_mesh_id = torch.arange(0, 4)
mesh_shape = (1, 4)
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
model = module().cuda()
model = module(using_kwargs).cuda()
x1 = torch.rand(4, 8, 16).cuda()
x2 = torch.rand(4, 16, 8).cuda()
bias = torch.rand(bias_shape).cuda()
@ -159,6 +177,14 @@ def check_1d_device_mesh(rank, module, bias_shape, world_size, port):
meta_arg_names=meta_arg_names)
tracer = ColoTracer()
# graph():
# %bias : torch.Tensor [#users=1] = placeholder[target=bias]
# %x1 : torch.Tensor [#users=1] = placeholder[target=x1]
# %x2 : torch.Tensor [#users=1] = placeholder[target=x2]
# %bmm : [#users=1] = call_function[target=torch.bmm](args = (%x1, %x2), kwargs = {})
# %sum_1 : [#users=1] = call_function[target=torch.sum](args = (%bmm, 0), kwargs = {})
# %add : [#users=1] = call_function[target=operator.add](args = (%sum_1, %bias), kwargs = {})
# return add
graph = tracer.trace(model,
meta_args={
'bias': torch.rand(*bias_shape).to('meta'),
@ -166,11 +192,11 @@ def check_1d_device_mesh(rank, module, bias_shape, world_size, port):
'x2': torch.rand(4, 16, 8).to('meta')
})
gm = ColoGraphModule(model, graph)
linear_mod_node = list(graph.nodes)[3]
strategies_vector = StrategiesVector(linear_mod_node)
bmm_mod_node = list(graph.nodes)[3]
strategies_vector = StrategiesVector(bmm_mod_node)
# build handler
handler = AddBMMFunctionHandler(node=linear_mod_node, device_mesh=device_mesh, strategies_vector=strategies_vector)
handler = BMMFunctionHandler(node=bmm_mod_node, device_mesh=device_mesh, strategies_vector=strategies_vector)
# check operation data mapping
mapping = handler.get_operation_data_mapping()
@ -193,15 +219,9 @@ def check_1d_device_mesh(rank, module, bias_shape, world_size, port):
assert mapping['other'].type == OperationDataType.ARG
assert mapping['other'].logical_shape == torch.Size([4, 16, 8])
assert mapping['bias'].name == "bias"
assert mapping['bias'].data.is_meta
assert mapping['bias'].data.shape == torch.Size(bias_shape)
assert mapping['bias'].type == OperationDataType.ARG
assert mapping['bias'].logical_shape == torch.Size([8, 8])
assert mapping['output'].name == "addbmm"
assert mapping['output'].name == "bmm"
assert mapping['output'].data.is_meta
assert mapping['output'].data.shape == torch.Size([8, 8])
assert mapping['output'].data.shape == torch.Size([4, 8, 8])
assert mapping['output'].type == OperationDataType.OUTPUT
strategies_vector = handler.register_strategy(compute_resharding_cost=False)
@ -213,14 +233,12 @@ def check_1d_device_mesh(rank, module, bias_shape, world_size, port):
for strategy in strategies_vector:
input_sharding_spec = strategy.get_sharding_spec_by_name('x1')
other_sharding_spec = strategy.get_sharding_spec_by_name('x2')
bias_sharding_spec = strategy.get_sharding_spec_by_name('bias')
output_sharding_spec = strategy.get_sharding_spec_by_name('addbmm')
output_sharding_spec = strategy.get_sharding_spec_by_name('bmm')
# make sure the sharding matches across different operation data
assert input_sharding_spec.sharding_sequence[1] == output_sharding_spec.sharding_sequence[0]
assert input_sharding_spec.sharding_sequence[0] == output_sharding_spec.sharding_sequence[0]
assert other_sharding_spec.sharding_sequence[1] == input_sharding_spec.sharding_sequence[-1]
assert other_sharding_spec.sharding_sequence[-1] == output_sharding_spec.sharding_sequence[-1]
assert bias_sharding_spec.sharding_sequence[-1] == output_sharding_spec.sharding_sequence[-1]
@pytest.mark.skip("skip due to bias cases not ready")
@ -228,13 +246,15 @@ def check_1d_device_mesh(rank, module, bias_shape, world_size, port):
@pytest.mark.dist
@parameterize('module', [AddBMMTorchFunctionModule, AddBMMTensorMethodModule])
@parameterize('bias_shape', [[8], [1, 8], [8, 8]])
@parameterize('using_kwargs', [True, False])
@rerun_if_address_is_in_use()
def test_2d_device_mesh(module, bias_shape):
def test_2d_device_mesh(module, bias_shape, using_kwargs):
world_size = 4
run_func = partial(check_2d_device_mesh,
module=module,
bias_shape=bias_shape,
world_size=world_size,
using_kwargs=using_kwargs,
port=free_port())
mp.spawn(run_func, nprocs=world_size)
@ -244,12 +264,14 @@ def test_2d_device_mesh(module, bias_shape):
@pytest.mark.dist
@parameterize('module', [AddBMMTorchFunctionModule, AddBMMTensorMethodModule])
@parameterize('bias_shape', [[8], [1, 8], [8, 8]])
@parameterize('using_kwargs', [True, False])
@rerun_if_address_is_in_use()
def test_1d_device_mesh(module, bias_shape):
def test_1d_device_mesh(module, bias_shape, using_kwargs):
world_size = 4
run_func = partial(check_1d_device_mesh,
module=module,
bias_shape=bias_shape,
using_kwargs=using_kwargs,
world_size=world_size,
port=free_port())
mp.spawn(run_func, nprocs=world_size)