mirror of https://github.com/hpcaitech/ColossalAI
[autoparallel] support addbmm computation (#2102)
parent
d3d4630495
commit
0fecbb9e20
|
@ -1,2 +1,3 @@
|
||||||
|
from .addbmm import Addbmm
|
||||||
from .addmm import Addmm
|
from .addmm import Addmm
|
||||||
from .bias_addition_function import BiasAdditionFunc, LinearBasedBiasFunc, func_to_func_dict
|
from .bias_addition_function import BiasAdditionFunc, LinearBasedBiasFunc, func_to_func_dict, method_to_func_dict
|
||||||
|
|
|
@ -0,0 +1,75 @@
|
||||||
|
import operator
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
from ...registry import bias_addition_function, bias_addition_method
|
||||||
|
from .bias_addition_function import LinearBasedBiasFunc
|
||||||
|
|
||||||
|
|
||||||
|
@bias_addition_method.register(torch.Tensor.addbmm)
|
||||||
|
@bias_addition_function.register(torch.addbmm)
|
||||||
|
class Addbmm(LinearBasedBiasFunc):
|
||||||
|
|
||||||
|
def extract_kwargs_from_origin_func(self):
|
||||||
|
kwargs = {}
|
||||||
|
if 'beta' in self.kwargs:
|
||||||
|
kwargs['beta'] = self.kwargs['beta']
|
||||||
|
if 'alpha' in self.kwargs:
|
||||||
|
kwargs['alpha'] = self.kwargs['alpha']
|
||||||
|
return kwargs
|
||||||
|
|
||||||
|
def create_non_bias_func_proxy(self, input_proxy, other_proxy):
|
||||||
|
"""
|
||||||
|
This method is used to create the non_bias_func proxy, the node created by this proxy will
|
||||||
|
compute the main computation, such as convolution, with bias option banned.
|
||||||
|
"""
|
||||||
|
assert self.substitute_func == torch.bmm
|
||||||
|
node_kind = 'call_function'
|
||||||
|
node_target = self.substitute_func
|
||||||
|
|
||||||
|
node_args = (input_proxy, other_proxy)
|
||||||
|
# torch.bmm does not have any kwargs
|
||||||
|
node_kwargs = {}
|
||||||
|
non_bias_func_proxy = self.tracer.create_proxy(node_kind, node_target, node_args, node_kwargs)
|
||||||
|
return non_bias_func_proxy
|
||||||
|
|
||||||
|
def insert_sum_node(self, input_proxy, sum_dims=0):
|
||||||
|
'''
|
||||||
|
This method is used to sum the input_proxy through the sum_dims.
|
||||||
|
'''
|
||||||
|
node_kind = 'call_function'
|
||||||
|
node_target = torch.sum
|
||||||
|
node_args = (input_proxy, sum_dims)
|
||||||
|
node_kwargs = {}
|
||||||
|
sum_proxy = self.tracer.create_proxy(node_kind, node_target, node_args, node_kwargs)
|
||||||
|
return sum_proxy
|
||||||
|
|
||||||
|
def generate(self):
|
||||||
|
# The formula for addbmm is output = beta * input + alpha * (torch.bmm(b1, b2))
|
||||||
|
|
||||||
|
# doing the non-bias computation(temp_0 = torch.bmm(b1, b2))
|
||||||
|
non_bias_linear_func_proxy = self.create_non_bias_func_proxy(self.args[1], self.args[2])
|
||||||
|
|
||||||
|
# doing sum on the batch dimension(temp_1 = torch.sum(temp_0, 0))
|
||||||
|
sum_proxy = self.insert_sum_node(non_bias_linear_func_proxy)
|
||||||
|
kwargs = self.extract_kwargs_from_origin_func()
|
||||||
|
|
||||||
|
if 'beta' in kwargs:
|
||||||
|
beta = kwargs['beta']
|
||||||
|
# doing the multiplication with beta if it exists(temp_2 = beta * input)
|
||||||
|
beta_proxy = self.create_mul_node(self.args[0], beta)
|
||||||
|
else:
|
||||||
|
beta_proxy = self.args[0]
|
||||||
|
|
||||||
|
if 'alpha' in kwargs:
|
||||||
|
alpha = kwargs['alpha']
|
||||||
|
# doing the multiplication with alpha if it exists(temp_3 = alpha * temp_1)
|
||||||
|
alpha_proxy = self.create_mul_node(alpha, sum_proxy)
|
||||||
|
else:
|
||||||
|
alpha_proxy = sum_proxy
|
||||||
|
|
||||||
|
# doing the addition(temp_4 = temp_2 + temp_3)
|
||||||
|
bias_addition_proxy = self.create_bias_addition_proxy(alpha_proxy, beta_proxy)
|
||||||
|
|
||||||
|
return bias_addition_proxy
|
|
@ -3,10 +3,11 @@ import operator
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
from ...registry import bias_addition_function
|
from ...registry import bias_addition_function, bias_addition_method
|
||||||
from .bias_addition_function import LinearBasedBiasFunc
|
from .bias_addition_function import LinearBasedBiasFunc
|
||||||
|
|
||||||
|
|
||||||
|
@bias_addition_method.register(torch.Tensor.addmm)
|
||||||
@bias_addition_function.register(torch.addmm)
|
@bias_addition_function.register(torch.addmm)
|
||||||
class Addmm(LinearBasedBiasFunc):
|
class Addmm(LinearBasedBiasFunc):
|
||||||
|
|
||||||
|
@ -18,23 +19,6 @@ class Addmm(LinearBasedBiasFunc):
|
||||||
kwargs['alpha'] = self.kwargs['alpha']
|
kwargs['alpha'] = self.kwargs['alpha']
|
||||||
return kwargs
|
return kwargs
|
||||||
|
|
||||||
def coefficent_for_addmm(self, input_proxy, coefficent):
|
|
||||||
"""
|
|
||||||
This method is used to create a coefficent node for the numerical correctness.
|
|
||||||
The formula for torch.addmm is out = beta * input + alpha * (m1 @ m2)
|
|
||||||
Therefore, we need to use this method insert two more operator.mul nodes for
|
|
||||||
the computation graph to compute the final result.
|
|
||||||
"""
|
|
||||||
node_kind = 'call_function'
|
|
||||||
node_target = operator.mul
|
|
||||||
node_args = (
|
|
||||||
input_proxy,
|
|
||||||
coefficent,
|
|
||||||
)
|
|
||||||
node_kwargs = {}
|
|
||||||
mul_proxy = self.tracer.create_proxy(node_kind, node_target, node_args, node_kwargs)
|
|
||||||
return mul_proxy
|
|
||||||
|
|
||||||
def transpose_other_operand_for_linear(self, other_proxy):
|
def transpose_other_operand_for_linear(self, other_proxy):
|
||||||
'''
|
'''
|
||||||
This method is used to transpose the other operand for linear function.
|
This method is used to transpose the other operand for linear function.
|
||||||
|
@ -61,13 +45,13 @@ class Addmm(LinearBasedBiasFunc):
|
||||||
|
|
||||||
if 'beta' in kwargs:
|
if 'beta' in kwargs:
|
||||||
beta = kwargs['beta']
|
beta = kwargs['beta']
|
||||||
beta_proxy = self.coefficent_for_addmm(self.args[0], beta)
|
beta_proxy = self.create_mul_node(self.args[0], beta)
|
||||||
else:
|
else:
|
||||||
beta_proxy = self.args[0]
|
beta_proxy = self.args[0]
|
||||||
|
|
||||||
if 'alpha' in kwargs:
|
if 'alpha' in kwargs:
|
||||||
alpha = kwargs['alpha']
|
alpha = kwargs['alpha']
|
||||||
alpha_proxy = self.coefficent_for_addmm(alpha, non_bias_linear_func_proxy)
|
alpha_proxy = self.create_mul_node(alpha, non_bias_linear_func_proxy)
|
||||||
else:
|
else:
|
||||||
alpha_proxy = non_bias_linear_func_proxy
|
alpha_proxy = non_bias_linear_func_proxy
|
||||||
|
|
||||||
|
|
|
@ -52,6 +52,23 @@ class BiasAdditionFunc(ABC):
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def create_mul_node(self, input_proxy, coefficent):
|
||||||
|
"""
|
||||||
|
This method is used to create a coefficent node for the numerical correctness.
|
||||||
|
The formula for torch.addmm is out = beta * input + alpha * (m1 @ m2)
|
||||||
|
Therefore, we need to use this method insert two more operator.mul nodes for
|
||||||
|
the computation graph to compute the final result.
|
||||||
|
"""
|
||||||
|
node_kind = 'call_function'
|
||||||
|
node_target = operator.mul
|
||||||
|
node_args = (
|
||||||
|
input_proxy,
|
||||||
|
coefficent,
|
||||||
|
)
|
||||||
|
node_kwargs = {}
|
||||||
|
mul_proxy = self.tracer.create_proxy(node_kind, node_target, node_args, node_kwargs)
|
||||||
|
return mul_proxy
|
||||||
|
|
||||||
|
|
||||||
class LinearBasedBiasFunc(BiasAdditionFunc):
|
class LinearBasedBiasFunc(BiasAdditionFunc):
|
||||||
"""
|
"""
|
||||||
|
@ -88,4 +105,10 @@ class LinearBasedBiasFunc(BiasAdditionFunc):
|
||||||
|
|
||||||
func_to_func_dict = {
|
func_to_func_dict = {
|
||||||
torch.addmm: F.linear,
|
torch.addmm: F.linear,
|
||||||
|
torch.addbmm: torch.bmm,
|
||||||
|
}
|
||||||
|
|
||||||
|
method_to_func_dict = {
|
||||||
|
torch.Tensor.addmm: F.linear,
|
||||||
|
torch.Tensor.addbmm: torch.bmm,
|
||||||
}
|
}
|
||||||
|
|
|
@ -25,3 +25,4 @@ meta_patched_function = PatchRegistry(name='patched_functions_for_meta_execution
|
||||||
meta_patched_module = PatchRegistry(name='patched_modules_for_meta_execution')
|
meta_patched_module = PatchRegistry(name='patched_modules_for_meta_execution')
|
||||||
bias_addition_function = PatchRegistry(name='patched_function_for_bias_addition')
|
bias_addition_function = PatchRegistry(name='patched_function_for_bias_addition')
|
||||||
bias_addition_module = PatchRegistry(name='patched_module_for_bias_addition')
|
bias_addition_module = PatchRegistry(name='patched_module_for_bias_addition')
|
||||||
|
bias_addition_method = PatchRegistry(name='patched_method_for_bias_addition')
|
||||||
|
|
|
@ -20,8 +20,14 @@ from torch.fx.proxy import ParameterProxy, Proxy
|
||||||
|
|
||||||
from ..proxy import ColoProxy
|
from ..proxy import ColoProxy
|
||||||
from ._tracer_utils import compute_meta_data_for_functions_proxy, extract_meta, is_element_in_list
|
from ._tracer_utils import compute_meta_data_for_functions_proxy, extract_meta, is_element_in_list
|
||||||
from .bias_addition_patch import func_to_func_dict, module_to_func_dict
|
from .bias_addition_patch import func_to_func_dict, method_to_func_dict, module_to_func_dict
|
||||||
from .registry import bias_addition_function, bias_addition_module, meta_patched_function, meta_patched_module
|
from .registry import (
|
||||||
|
bias_addition_function,
|
||||||
|
bias_addition_method,
|
||||||
|
bias_addition_module,
|
||||||
|
meta_patched_function,
|
||||||
|
meta_patched_module,
|
||||||
|
)
|
||||||
|
|
||||||
__all__ = ['ColoTracer']
|
__all__ = ['ColoTracer']
|
||||||
|
|
||||||
|
@ -100,12 +106,14 @@ class ColoTracer(Tracer):
|
||||||
handle = bias_addition_function.get(target)(self, target, args, kwargs, function_to_substitute)
|
handle = bias_addition_function.get(target)(self, target, args, kwargs, function_to_substitute)
|
||||||
elif bias_addition_function.has(target.__name__):
|
elif bias_addition_function.has(target.__name__):
|
||||||
# use name for some builtin op like @ (matmul)
|
# use name for some builtin op like @ (matmul)
|
||||||
handle = bias_addition_function.get(target.__name__)(self, target, args, kwargs)
|
function_to_substitute = func_to_func_dict[target]
|
||||||
|
handle = bias_addition_function.get(target.__name__)(self, target, args, kwargs, function_to_substitute)
|
||||||
|
|
||||||
elif kind == "call_method":
|
elif kind == "call_method":
|
||||||
method = getattr(args_metas[0].__class__, target)
|
method = getattr(args_metas[0].__class__, target)
|
||||||
if bias_addition_function.has(method):
|
if bias_addition_method.has(method):
|
||||||
handle = bias_addition_function.get(method)(self, target, args, kwargs)
|
function_to_substitute = method_to_func_dict[method]
|
||||||
|
handle = bias_addition_method.get(method)(self, target, args, kwargs, function_to_substitute)
|
||||||
|
|
||||||
elif kind == "call_module":
|
elif kind == "call_module":
|
||||||
if not hasattr(self, "orig_forward"):
|
if not hasattr(self, "orig_forward"):
|
||||||
|
|
|
@ -5,7 +5,7 @@ import torch
|
||||||
import torch.multiprocessing as mp
|
import torch.multiprocessing as mp
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from colossalai.auto_parallel.tensor_shard.node_handler import AddBMMFunctionHandler
|
from colossalai.auto_parallel.tensor_shard.node_handler import BMMFunctionHandler
|
||||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
|
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
|
||||||
from colossalai.device.device_mesh import DeviceMesh
|
from colossalai.device.device_mesh import DeviceMesh
|
||||||
from colossalai.fx import ColoGraphModule, ColoTracer
|
from colossalai.fx import ColoGraphModule, ColoTracer
|
||||||
|
@ -19,20 +19,36 @@ from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import n
|
||||||
|
|
||||||
class AddBMMTensorMethodModule(nn.Module):
|
class AddBMMTensorMethodModule(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, using_kwargs):
|
||||||
|
super().__init__()
|
||||||
|
self.using_kwargs = using_kwargs
|
||||||
|
|
||||||
def forward(self, bias, x1, x2):
|
def forward(self, bias, x1, x2):
|
||||||
return bias.addbmm(x1, x2)
|
if self.using_kwargs:
|
||||||
|
output = bias.addbmm(x1, x2, alpha=2, beta=3)
|
||||||
|
else:
|
||||||
|
output = bias.addbmm(x1, x2)
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
class AddBMMTorchFunctionModule(nn.Module):
|
class AddBMMTorchFunctionModule(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, using_kwargs):
|
||||||
|
super().__init__()
|
||||||
|
self.using_kwargs = using_kwargs
|
||||||
|
|
||||||
def forward(self, bias, x1, x2):
|
def forward(self, bias, x1, x2):
|
||||||
return torch.addbmm(bias, x1, x2)
|
if self.using_kwargs:
|
||||||
|
output = torch.addbmm(bias, x1, x2, alpha=2, beta=3)
|
||||||
|
else:
|
||||||
|
output = torch.addbmm(bias, x1, x2)
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
def check_2d_device_mesh(rank, module, bias_shape, world_size, port):
|
def check_2d_device_mesh(rank, module, bias_shape, using_kwargs, world_size, port):
|
||||||
disable_existing_loggers()
|
disable_existing_loggers()
|
||||||
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||||
model = module().cuda()
|
model = module(using_kwargs).cuda()
|
||||||
physical_mesh_id = torch.arange(0, 4)
|
physical_mesh_id = torch.arange(0, 4)
|
||||||
mesh_shape = (2, 2)
|
mesh_shape = (2, 2)
|
||||||
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
|
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
|
||||||
|
@ -54,6 +70,14 @@ def check_2d_device_mesh(rank, module, bias_shape, world_size, port):
|
||||||
input_args=input_args,
|
input_args=input_args,
|
||||||
meta_arg_names=meta_arg_names)
|
meta_arg_names=meta_arg_names)
|
||||||
tracer = ColoTracer()
|
tracer = ColoTracer()
|
||||||
|
# graph():
|
||||||
|
# %bias : torch.Tensor [#users=1] = placeholder[target=bias]
|
||||||
|
# %x1 : torch.Tensor [#users=1] = placeholder[target=x1]
|
||||||
|
# %x2 : torch.Tensor [#users=1] = placeholder[target=x2]
|
||||||
|
# %bmm : [#users=1] = call_function[target=torch.bmm](args = (%x1, %x2), kwargs = {})
|
||||||
|
# %sum_1 : [#users=1] = call_function[target=torch.sum](args = (%bmm, 0), kwargs = {})
|
||||||
|
# %add : [#users=1] = call_function[target=operator.add](args = (%sum_1, %bias), kwargs = {})
|
||||||
|
# return add
|
||||||
graph = tracer.trace(model,
|
graph = tracer.trace(model,
|
||||||
meta_args={
|
meta_args={
|
||||||
'bias': torch.rand(*bias_shape).to('meta'),
|
'bias': torch.rand(*bias_shape).to('meta'),
|
||||||
|
@ -62,11 +86,11 @@ def check_2d_device_mesh(rank, module, bias_shape, world_size, port):
|
||||||
})
|
})
|
||||||
gm = ColoGraphModule(model, graph)
|
gm = ColoGraphModule(model, graph)
|
||||||
|
|
||||||
linear_mod_node = list(graph.nodes)[3]
|
bmm_mod_node = list(graph.nodes)[3]
|
||||||
strategies_vector = StrategiesVector(linear_mod_node)
|
strategies_vector = StrategiesVector(bmm_mod_node)
|
||||||
|
|
||||||
# build handler
|
# build handler
|
||||||
handler = AddBMMFunctionHandler(node=linear_mod_node, device_mesh=device_mesh, strategies_vector=strategies_vector)
|
handler = BMMFunctionHandler(node=bmm_mod_node, device_mesh=device_mesh, strategies_vector=strategies_vector)
|
||||||
|
|
||||||
# check operation data mapping
|
# check operation data mapping
|
||||||
mapping = handler.get_operation_data_mapping()
|
mapping = handler.get_operation_data_mapping()
|
||||||
|
@ -89,19 +113,15 @@ def check_2d_device_mesh(rank, module, bias_shape, world_size, port):
|
||||||
assert mapping['other'].type == OperationDataType.ARG
|
assert mapping['other'].type == OperationDataType.ARG
|
||||||
assert mapping['other'].logical_shape == torch.Size([4, 16, 8])
|
assert mapping['other'].logical_shape == torch.Size([4, 16, 8])
|
||||||
|
|
||||||
assert mapping['bias'].name == "bias"
|
assert mapping['output'].name == "bmm"
|
||||||
assert mapping['bias'].data.is_meta
|
|
||||||
assert mapping['bias'].data.shape == torch.Size(bias_shape)
|
|
||||||
assert mapping['bias'].type == OperationDataType.ARG
|
|
||||||
assert mapping['bias'].logical_shape == torch.Size([8, 8])
|
|
||||||
|
|
||||||
assert mapping['output'].name == "addbmm"
|
|
||||||
assert mapping['output'].data.is_meta
|
assert mapping['output'].data.is_meta
|
||||||
assert mapping['output'].data.shape == torch.Size([8, 8])
|
assert mapping['output'].data.shape == torch.Size([4, 8, 8])
|
||||||
assert mapping['output'].type == OperationDataType.OUTPUT
|
assert mapping['output'].type == OperationDataType.OUTPUT
|
||||||
|
|
||||||
strategies_vector = handler.register_strategy(compute_resharding_cost=False)
|
strategies_vector = handler.register_strategy(compute_resharding_cost=False)
|
||||||
strategy_name_list = [val.name for val in strategies_vector]
|
strategy_name_list = [val.name for val in strategies_vector]
|
||||||
|
for name in strategy_name_list:
|
||||||
|
print(name)
|
||||||
# one batch dim
|
# one batch dim
|
||||||
assert 'Sb0 = Sb0 x Sb0' not in strategy_name_list
|
assert 'Sb0 = Sb0 x Sb0' not in strategy_name_list
|
||||||
|
|
||||||
|
@ -123,23 +143,21 @@ def check_2d_device_mesh(rank, module, bias_shape, world_size, port):
|
||||||
for strategy in strategies_vector:
|
for strategy in strategies_vector:
|
||||||
input_sharding_spec = strategy.get_sharding_spec_by_name('x1')
|
input_sharding_spec = strategy.get_sharding_spec_by_name('x1')
|
||||||
other_sharding_spec = strategy.get_sharding_spec_by_name('x2')
|
other_sharding_spec = strategy.get_sharding_spec_by_name('x2')
|
||||||
bias_sharding_spec = strategy.get_sharding_spec_by_name('bias')
|
output_sharding_spec = strategy.get_sharding_spec_by_name('bmm')
|
||||||
output_sharding_spec = strategy.get_sharding_spec_by_name('addbmm')
|
|
||||||
|
|
||||||
# make sure the sharding matches across different operation data
|
# make sure the sharding matches across different operation data
|
||||||
assert input_sharding_spec.sharding_sequence[1] == output_sharding_spec.sharding_sequence[0]
|
assert input_sharding_spec.sharding_sequence[0] == output_sharding_spec.sharding_sequence[0]
|
||||||
assert other_sharding_spec.sharding_sequence[1] == input_sharding_spec.sharding_sequence[-1]
|
assert other_sharding_spec.sharding_sequence[1] == input_sharding_spec.sharding_sequence[-1]
|
||||||
assert other_sharding_spec.sharding_sequence[-1] == output_sharding_spec.sharding_sequence[-1]
|
assert other_sharding_spec.sharding_sequence[-1] == output_sharding_spec.sharding_sequence[-1]
|
||||||
assert bias_sharding_spec.sharding_sequence[-1] == output_sharding_spec.sharding_sequence[-1]
|
|
||||||
|
|
||||||
|
|
||||||
def check_1d_device_mesh(rank, module, bias_shape, world_size, port):
|
def check_1d_device_mesh(rank, module, bias_shape, using_kwargs, world_size, port):
|
||||||
disable_existing_loggers()
|
disable_existing_loggers()
|
||||||
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||||
physical_mesh_id = torch.arange(0, 4)
|
physical_mesh_id = torch.arange(0, 4)
|
||||||
mesh_shape = (1, 4)
|
mesh_shape = (1, 4)
|
||||||
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
|
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
|
||||||
model = module().cuda()
|
model = module(using_kwargs).cuda()
|
||||||
x1 = torch.rand(4, 8, 16).cuda()
|
x1 = torch.rand(4, 8, 16).cuda()
|
||||||
x2 = torch.rand(4, 16, 8).cuda()
|
x2 = torch.rand(4, 16, 8).cuda()
|
||||||
bias = torch.rand(bias_shape).cuda()
|
bias = torch.rand(bias_shape).cuda()
|
||||||
|
@ -159,6 +177,14 @@ def check_1d_device_mesh(rank, module, bias_shape, world_size, port):
|
||||||
meta_arg_names=meta_arg_names)
|
meta_arg_names=meta_arg_names)
|
||||||
|
|
||||||
tracer = ColoTracer()
|
tracer = ColoTracer()
|
||||||
|
# graph():
|
||||||
|
# %bias : torch.Tensor [#users=1] = placeholder[target=bias]
|
||||||
|
# %x1 : torch.Tensor [#users=1] = placeholder[target=x1]
|
||||||
|
# %x2 : torch.Tensor [#users=1] = placeholder[target=x2]
|
||||||
|
# %bmm : [#users=1] = call_function[target=torch.bmm](args = (%x1, %x2), kwargs = {})
|
||||||
|
# %sum_1 : [#users=1] = call_function[target=torch.sum](args = (%bmm, 0), kwargs = {})
|
||||||
|
# %add : [#users=1] = call_function[target=operator.add](args = (%sum_1, %bias), kwargs = {})
|
||||||
|
# return add
|
||||||
graph = tracer.trace(model,
|
graph = tracer.trace(model,
|
||||||
meta_args={
|
meta_args={
|
||||||
'bias': torch.rand(*bias_shape).to('meta'),
|
'bias': torch.rand(*bias_shape).to('meta'),
|
||||||
|
@ -166,11 +192,11 @@ def check_1d_device_mesh(rank, module, bias_shape, world_size, port):
|
||||||
'x2': torch.rand(4, 16, 8).to('meta')
|
'x2': torch.rand(4, 16, 8).to('meta')
|
||||||
})
|
})
|
||||||
gm = ColoGraphModule(model, graph)
|
gm = ColoGraphModule(model, graph)
|
||||||
linear_mod_node = list(graph.nodes)[3]
|
bmm_mod_node = list(graph.nodes)[3]
|
||||||
strategies_vector = StrategiesVector(linear_mod_node)
|
strategies_vector = StrategiesVector(bmm_mod_node)
|
||||||
|
|
||||||
# build handler
|
# build handler
|
||||||
handler = AddBMMFunctionHandler(node=linear_mod_node, device_mesh=device_mesh, strategies_vector=strategies_vector)
|
handler = BMMFunctionHandler(node=bmm_mod_node, device_mesh=device_mesh, strategies_vector=strategies_vector)
|
||||||
|
|
||||||
# check operation data mapping
|
# check operation data mapping
|
||||||
mapping = handler.get_operation_data_mapping()
|
mapping = handler.get_operation_data_mapping()
|
||||||
|
@ -193,15 +219,9 @@ def check_1d_device_mesh(rank, module, bias_shape, world_size, port):
|
||||||
assert mapping['other'].type == OperationDataType.ARG
|
assert mapping['other'].type == OperationDataType.ARG
|
||||||
assert mapping['other'].logical_shape == torch.Size([4, 16, 8])
|
assert mapping['other'].logical_shape == torch.Size([4, 16, 8])
|
||||||
|
|
||||||
assert mapping['bias'].name == "bias"
|
assert mapping['output'].name == "bmm"
|
||||||
assert mapping['bias'].data.is_meta
|
|
||||||
assert mapping['bias'].data.shape == torch.Size(bias_shape)
|
|
||||||
assert mapping['bias'].type == OperationDataType.ARG
|
|
||||||
assert mapping['bias'].logical_shape == torch.Size([8, 8])
|
|
||||||
|
|
||||||
assert mapping['output'].name == "addbmm"
|
|
||||||
assert mapping['output'].data.is_meta
|
assert mapping['output'].data.is_meta
|
||||||
assert mapping['output'].data.shape == torch.Size([8, 8])
|
assert mapping['output'].data.shape == torch.Size([4, 8, 8])
|
||||||
assert mapping['output'].type == OperationDataType.OUTPUT
|
assert mapping['output'].type == OperationDataType.OUTPUT
|
||||||
|
|
||||||
strategies_vector = handler.register_strategy(compute_resharding_cost=False)
|
strategies_vector = handler.register_strategy(compute_resharding_cost=False)
|
||||||
|
@ -213,14 +233,12 @@ def check_1d_device_mesh(rank, module, bias_shape, world_size, port):
|
||||||
for strategy in strategies_vector:
|
for strategy in strategies_vector:
|
||||||
input_sharding_spec = strategy.get_sharding_spec_by_name('x1')
|
input_sharding_spec = strategy.get_sharding_spec_by_name('x1')
|
||||||
other_sharding_spec = strategy.get_sharding_spec_by_name('x2')
|
other_sharding_spec = strategy.get_sharding_spec_by_name('x2')
|
||||||
bias_sharding_spec = strategy.get_sharding_spec_by_name('bias')
|
output_sharding_spec = strategy.get_sharding_spec_by_name('bmm')
|
||||||
output_sharding_spec = strategy.get_sharding_spec_by_name('addbmm')
|
|
||||||
|
|
||||||
# make sure the sharding matches across different operation data
|
# make sure the sharding matches across different operation data
|
||||||
assert input_sharding_spec.sharding_sequence[1] == output_sharding_spec.sharding_sequence[0]
|
assert input_sharding_spec.sharding_sequence[0] == output_sharding_spec.sharding_sequence[0]
|
||||||
assert other_sharding_spec.sharding_sequence[1] == input_sharding_spec.sharding_sequence[-1]
|
assert other_sharding_spec.sharding_sequence[1] == input_sharding_spec.sharding_sequence[-1]
|
||||||
assert other_sharding_spec.sharding_sequence[-1] == output_sharding_spec.sharding_sequence[-1]
|
assert other_sharding_spec.sharding_sequence[-1] == output_sharding_spec.sharding_sequence[-1]
|
||||||
assert bias_sharding_spec.sharding_sequence[-1] == output_sharding_spec.sharding_sequence[-1]
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip("skip due to bias cases not ready")
|
@pytest.mark.skip("skip due to bias cases not ready")
|
||||||
|
@ -228,13 +246,15 @@ def check_1d_device_mesh(rank, module, bias_shape, world_size, port):
|
||||||
@pytest.mark.dist
|
@pytest.mark.dist
|
||||||
@parameterize('module', [AddBMMTorchFunctionModule, AddBMMTensorMethodModule])
|
@parameterize('module', [AddBMMTorchFunctionModule, AddBMMTensorMethodModule])
|
||||||
@parameterize('bias_shape', [[8], [1, 8], [8, 8]])
|
@parameterize('bias_shape', [[8], [1, 8], [8, 8]])
|
||||||
|
@parameterize('using_kwargs', [True, False])
|
||||||
@rerun_if_address_is_in_use()
|
@rerun_if_address_is_in_use()
|
||||||
def test_2d_device_mesh(module, bias_shape):
|
def test_2d_device_mesh(module, bias_shape, using_kwargs):
|
||||||
world_size = 4
|
world_size = 4
|
||||||
run_func = partial(check_2d_device_mesh,
|
run_func = partial(check_2d_device_mesh,
|
||||||
module=module,
|
module=module,
|
||||||
bias_shape=bias_shape,
|
bias_shape=bias_shape,
|
||||||
world_size=world_size,
|
world_size=world_size,
|
||||||
|
using_kwargs=using_kwargs,
|
||||||
port=free_port())
|
port=free_port())
|
||||||
mp.spawn(run_func, nprocs=world_size)
|
mp.spawn(run_func, nprocs=world_size)
|
||||||
|
|
||||||
|
@ -244,12 +264,14 @@ def test_2d_device_mesh(module, bias_shape):
|
||||||
@pytest.mark.dist
|
@pytest.mark.dist
|
||||||
@parameterize('module', [AddBMMTorchFunctionModule, AddBMMTensorMethodModule])
|
@parameterize('module', [AddBMMTorchFunctionModule, AddBMMTensorMethodModule])
|
||||||
@parameterize('bias_shape', [[8], [1, 8], [8, 8]])
|
@parameterize('bias_shape', [[8], [1, 8], [8, 8]])
|
||||||
|
@parameterize('using_kwargs', [True, False])
|
||||||
@rerun_if_address_is_in_use()
|
@rerun_if_address_is_in_use()
|
||||||
def test_1d_device_mesh(module, bias_shape):
|
def test_1d_device_mesh(module, bias_shape, using_kwargs):
|
||||||
world_size = 4
|
world_size = 4
|
||||||
run_func = partial(check_1d_device_mesh,
|
run_func = partial(check_1d_device_mesh,
|
||||||
module=module,
|
module=module,
|
||||||
bias_shape=bias_shape,
|
bias_shape=bias_shape,
|
||||||
|
using_kwargs=using_kwargs,
|
||||||
world_size=world_size,
|
world_size=world_size,
|
||||||
port=free_port())
|
port=free_port())
|
||||||
mp.spawn(run_func, nprocs=world_size)
|
mp.spawn(run_func, nprocs=world_size)
|
||||||
|
|
Loading…
Reference in New Issue