mirror of https://github.com/hpcaitech/ColossalAI
[autoparallel] support addbmm computation (#2102)
parent
d3d4630495
commit
0fecbb9e20
|
@ -1,2 +1,3 @@
|
|||
from .addbmm import Addbmm
|
||||
from .addmm import Addmm
|
||||
from .bias_addition_function import BiasAdditionFunc, LinearBasedBiasFunc, func_to_func_dict
|
||||
from .bias_addition_function import BiasAdditionFunc, LinearBasedBiasFunc, func_to_func_dict, method_to_func_dict
|
||||
|
|
|
@ -0,0 +1,75 @@
|
|||
import operator
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ...registry import bias_addition_function, bias_addition_method
|
||||
from .bias_addition_function import LinearBasedBiasFunc
|
||||
|
||||
|
||||
@bias_addition_method.register(torch.Tensor.addbmm)
|
||||
@bias_addition_function.register(torch.addbmm)
|
||||
class Addbmm(LinearBasedBiasFunc):
|
||||
|
||||
def extract_kwargs_from_origin_func(self):
|
||||
kwargs = {}
|
||||
if 'beta' in self.kwargs:
|
||||
kwargs['beta'] = self.kwargs['beta']
|
||||
if 'alpha' in self.kwargs:
|
||||
kwargs['alpha'] = self.kwargs['alpha']
|
||||
return kwargs
|
||||
|
||||
def create_non_bias_func_proxy(self, input_proxy, other_proxy):
|
||||
"""
|
||||
This method is used to create the non_bias_func proxy, the node created by this proxy will
|
||||
compute the main computation, such as convolution, with bias option banned.
|
||||
"""
|
||||
assert self.substitute_func == torch.bmm
|
||||
node_kind = 'call_function'
|
||||
node_target = self.substitute_func
|
||||
|
||||
node_args = (input_proxy, other_proxy)
|
||||
# torch.bmm does not have any kwargs
|
||||
node_kwargs = {}
|
||||
non_bias_func_proxy = self.tracer.create_proxy(node_kind, node_target, node_args, node_kwargs)
|
||||
return non_bias_func_proxy
|
||||
|
||||
def insert_sum_node(self, input_proxy, sum_dims=0):
|
||||
'''
|
||||
This method is used to sum the input_proxy through the sum_dims.
|
||||
'''
|
||||
node_kind = 'call_function'
|
||||
node_target = torch.sum
|
||||
node_args = (input_proxy, sum_dims)
|
||||
node_kwargs = {}
|
||||
sum_proxy = self.tracer.create_proxy(node_kind, node_target, node_args, node_kwargs)
|
||||
return sum_proxy
|
||||
|
||||
def generate(self):
|
||||
# The formula for addbmm is output = beta * input + alpha * (torch.bmm(b1, b2))
|
||||
|
||||
# doing the non-bias computation(temp_0 = torch.bmm(b1, b2))
|
||||
non_bias_linear_func_proxy = self.create_non_bias_func_proxy(self.args[1], self.args[2])
|
||||
|
||||
# doing sum on the batch dimension(temp_1 = torch.sum(temp_0, 0))
|
||||
sum_proxy = self.insert_sum_node(non_bias_linear_func_proxy)
|
||||
kwargs = self.extract_kwargs_from_origin_func()
|
||||
|
||||
if 'beta' in kwargs:
|
||||
beta = kwargs['beta']
|
||||
# doing the multiplication with beta if it exists(temp_2 = beta * input)
|
||||
beta_proxy = self.create_mul_node(self.args[0], beta)
|
||||
else:
|
||||
beta_proxy = self.args[0]
|
||||
|
||||
if 'alpha' in kwargs:
|
||||
alpha = kwargs['alpha']
|
||||
# doing the multiplication with alpha if it exists(temp_3 = alpha * temp_1)
|
||||
alpha_proxy = self.create_mul_node(alpha, sum_proxy)
|
||||
else:
|
||||
alpha_proxy = sum_proxy
|
||||
|
||||
# doing the addition(temp_4 = temp_2 + temp_3)
|
||||
bias_addition_proxy = self.create_bias_addition_proxy(alpha_proxy, beta_proxy)
|
||||
|
||||
return bias_addition_proxy
|
|
@ -3,10 +3,11 @@ import operator
|
|||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ...registry import bias_addition_function
|
||||
from ...registry import bias_addition_function, bias_addition_method
|
||||
from .bias_addition_function import LinearBasedBiasFunc
|
||||
|
||||
|
||||
@bias_addition_method.register(torch.Tensor.addmm)
|
||||
@bias_addition_function.register(torch.addmm)
|
||||
class Addmm(LinearBasedBiasFunc):
|
||||
|
||||
|
@ -18,23 +19,6 @@ class Addmm(LinearBasedBiasFunc):
|
|||
kwargs['alpha'] = self.kwargs['alpha']
|
||||
return kwargs
|
||||
|
||||
def coefficent_for_addmm(self, input_proxy, coefficent):
|
||||
"""
|
||||
This method is used to create a coefficent node for the numerical correctness.
|
||||
The formula for torch.addmm is out = beta * input + alpha * (m1 @ m2)
|
||||
Therefore, we need to use this method insert two more operator.mul nodes for
|
||||
the computation graph to compute the final result.
|
||||
"""
|
||||
node_kind = 'call_function'
|
||||
node_target = operator.mul
|
||||
node_args = (
|
||||
input_proxy,
|
||||
coefficent,
|
||||
)
|
||||
node_kwargs = {}
|
||||
mul_proxy = self.tracer.create_proxy(node_kind, node_target, node_args, node_kwargs)
|
||||
return mul_proxy
|
||||
|
||||
def transpose_other_operand_for_linear(self, other_proxy):
|
||||
'''
|
||||
This method is used to transpose the other operand for linear function.
|
||||
|
@ -61,13 +45,13 @@ class Addmm(LinearBasedBiasFunc):
|
|||
|
||||
if 'beta' in kwargs:
|
||||
beta = kwargs['beta']
|
||||
beta_proxy = self.coefficent_for_addmm(self.args[0], beta)
|
||||
beta_proxy = self.create_mul_node(self.args[0], beta)
|
||||
else:
|
||||
beta_proxy = self.args[0]
|
||||
|
||||
if 'alpha' in kwargs:
|
||||
alpha = kwargs['alpha']
|
||||
alpha_proxy = self.coefficent_for_addmm(alpha, non_bias_linear_func_proxy)
|
||||
alpha_proxy = self.create_mul_node(alpha, non_bias_linear_func_proxy)
|
||||
else:
|
||||
alpha_proxy = non_bias_linear_func_proxy
|
||||
|
||||
|
|
|
@ -52,6 +52,23 @@ class BiasAdditionFunc(ABC):
|
|||
"""
|
||||
pass
|
||||
|
||||
def create_mul_node(self, input_proxy, coefficent):
|
||||
"""
|
||||
This method is used to create a coefficent node for the numerical correctness.
|
||||
The formula for torch.addmm is out = beta * input + alpha * (m1 @ m2)
|
||||
Therefore, we need to use this method insert two more operator.mul nodes for
|
||||
the computation graph to compute the final result.
|
||||
"""
|
||||
node_kind = 'call_function'
|
||||
node_target = operator.mul
|
||||
node_args = (
|
||||
input_proxy,
|
||||
coefficent,
|
||||
)
|
||||
node_kwargs = {}
|
||||
mul_proxy = self.tracer.create_proxy(node_kind, node_target, node_args, node_kwargs)
|
||||
return mul_proxy
|
||||
|
||||
|
||||
class LinearBasedBiasFunc(BiasAdditionFunc):
|
||||
"""
|
||||
|
@ -88,4 +105,10 @@ class LinearBasedBiasFunc(BiasAdditionFunc):
|
|||
|
||||
func_to_func_dict = {
|
||||
torch.addmm: F.linear,
|
||||
torch.addbmm: torch.bmm,
|
||||
}
|
||||
|
||||
method_to_func_dict = {
|
||||
torch.Tensor.addmm: F.linear,
|
||||
torch.Tensor.addbmm: torch.bmm,
|
||||
}
|
||||
|
|
|
@ -25,3 +25,4 @@ meta_patched_function = PatchRegistry(name='patched_functions_for_meta_execution
|
|||
meta_patched_module = PatchRegistry(name='patched_modules_for_meta_execution')
|
||||
bias_addition_function = PatchRegistry(name='patched_function_for_bias_addition')
|
||||
bias_addition_module = PatchRegistry(name='patched_module_for_bias_addition')
|
||||
bias_addition_method = PatchRegistry(name='patched_method_for_bias_addition')
|
||||
|
|
|
@ -20,8 +20,14 @@ from torch.fx.proxy import ParameterProxy, Proxy
|
|||
|
||||
from ..proxy import ColoProxy
|
||||
from ._tracer_utils import compute_meta_data_for_functions_proxy, extract_meta, is_element_in_list
|
||||
from .bias_addition_patch import func_to_func_dict, module_to_func_dict
|
||||
from .registry import bias_addition_function, bias_addition_module, meta_patched_function, meta_patched_module
|
||||
from .bias_addition_patch import func_to_func_dict, method_to_func_dict, module_to_func_dict
|
||||
from .registry import (
|
||||
bias_addition_function,
|
||||
bias_addition_method,
|
||||
bias_addition_module,
|
||||
meta_patched_function,
|
||||
meta_patched_module,
|
||||
)
|
||||
|
||||
__all__ = ['ColoTracer']
|
||||
|
||||
|
@ -100,12 +106,14 @@ class ColoTracer(Tracer):
|
|||
handle = bias_addition_function.get(target)(self, target, args, kwargs, function_to_substitute)
|
||||
elif bias_addition_function.has(target.__name__):
|
||||
# use name for some builtin op like @ (matmul)
|
||||
handle = bias_addition_function.get(target.__name__)(self, target, args, kwargs)
|
||||
function_to_substitute = func_to_func_dict[target]
|
||||
handle = bias_addition_function.get(target.__name__)(self, target, args, kwargs, function_to_substitute)
|
||||
|
||||
elif kind == "call_method":
|
||||
method = getattr(args_metas[0].__class__, target)
|
||||
if bias_addition_function.has(method):
|
||||
handle = bias_addition_function.get(method)(self, target, args, kwargs)
|
||||
if bias_addition_method.has(method):
|
||||
function_to_substitute = method_to_func_dict[method]
|
||||
handle = bias_addition_method.get(method)(self, target, args, kwargs, function_to_substitute)
|
||||
|
||||
elif kind == "call_module":
|
||||
if not hasattr(self, "orig_forward"):
|
||||
|
|
|
@ -5,7 +5,7 @@ import torch
|
|||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
|
||||
from colossalai.auto_parallel.tensor_shard.node_handler import AddBMMFunctionHandler
|
||||
from colossalai.auto_parallel.tensor_shard.node_handler import BMMFunctionHandler
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.fx import ColoGraphModule, ColoTracer
|
||||
|
@ -19,20 +19,36 @@ from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import n
|
|||
|
||||
class AddBMMTensorMethodModule(nn.Module):
|
||||
|
||||
def __init__(self, using_kwargs):
|
||||
super().__init__()
|
||||
self.using_kwargs = using_kwargs
|
||||
|
||||
def forward(self, bias, x1, x2):
|
||||
return bias.addbmm(x1, x2)
|
||||
if self.using_kwargs:
|
||||
output = bias.addbmm(x1, x2, alpha=2, beta=3)
|
||||
else:
|
||||
output = bias.addbmm(x1, x2)
|
||||
return output
|
||||
|
||||
|
||||
class AddBMMTorchFunctionModule(nn.Module):
|
||||
|
||||
def __init__(self, using_kwargs):
|
||||
super().__init__()
|
||||
self.using_kwargs = using_kwargs
|
||||
|
||||
def forward(self, bias, x1, x2):
|
||||
return torch.addbmm(bias, x1, x2)
|
||||
if self.using_kwargs:
|
||||
output = torch.addbmm(bias, x1, x2, alpha=2, beta=3)
|
||||
else:
|
||||
output = torch.addbmm(bias, x1, x2)
|
||||
return output
|
||||
|
||||
|
||||
def check_2d_device_mesh(rank, module, bias_shape, world_size, port):
|
||||
def check_2d_device_mesh(rank, module, bias_shape, using_kwargs, world_size, port):
|
||||
disable_existing_loggers()
|
||||
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
model = module().cuda()
|
||||
model = module(using_kwargs).cuda()
|
||||
physical_mesh_id = torch.arange(0, 4)
|
||||
mesh_shape = (2, 2)
|
||||
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
|
||||
|
@ -54,6 +70,14 @@ def check_2d_device_mesh(rank, module, bias_shape, world_size, port):
|
|||
input_args=input_args,
|
||||
meta_arg_names=meta_arg_names)
|
||||
tracer = ColoTracer()
|
||||
# graph():
|
||||
# %bias : torch.Tensor [#users=1] = placeholder[target=bias]
|
||||
# %x1 : torch.Tensor [#users=1] = placeholder[target=x1]
|
||||
# %x2 : torch.Tensor [#users=1] = placeholder[target=x2]
|
||||
# %bmm : [#users=1] = call_function[target=torch.bmm](args = (%x1, %x2), kwargs = {})
|
||||
# %sum_1 : [#users=1] = call_function[target=torch.sum](args = (%bmm, 0), kwargs = {})
|
||||
# %add : [#users=1] = call_function[target=operator.add](args = (%sum_1, %bias), kwargs = {})
|
||||
# return add
|
||||
graph = tracer.trace(model,
|
||||
meta_args={
|
||||
'bias': torch.rand(*bias_shape).to('meta'),
|
||||
|
@ -62,11 +86,11 @@ def check_2d_device_mesh(rank, module, bias_shape, world_size, port):
|
|||
})
|
||||
gm = ColoGraphModule(model, graph)
|
||||
|
||||
linear_mod_node = list(graph.nodes)[3]
|
||||
strategies_vector = StrategiesVector(linear_mod_node)
|
||||
bmm_mod_node = list(graph.nodes)[3]
|
||||
strategies_vector = StrategiesVector(bmm_mod_node)
|
||||
|
||||
# build handler
|
||||
handler = AddBMMFunctionHandler(node=linear_mod_node, device_mesh=device_mesh, strategies_vector=strategies_vector)
|
||||
handler = BMMFunctionHandler(node=bmm_mod_node, device_mesh=device_mesh, strategies_vector=strategies_vector)
|
||||
|
||||
# check operation data mapping
|
||||
mapping = handler.get_operation_data_mapping()
|
||||
|
@ -89,19 +113,15 @@ def check_2d_device_mesh(rank, module, bias_shape, world_size, port):
|
|||
assert mapping['other'].type == OperationDataType.ARG
|
||||
assert mapping['other'].logical_shape == torch.Size([4, 16, 8])
|
||||
|
||||
assert mapping['bias'].name == "bias"
|
||||
assert mapping['bias'].data.is_meta
|
||||
assert mapping['bias'].data.shape == torch.Size(bias_shape)
|
||||
assert mapping['bias'].type == OperationDataType.ARG
|
||||
assert mapping['bias'].logical_shape == torch.Size([8, 8])
|
||||
|
||||
assert mapping['output'].name == "addbmm"
|
||||
assert mapping['output'].name == "bmm"
|
||||
assert mapping['output'].data.is_meta
|
||||
assert mapping['output'].data.shape == torch.Size([8, 8])
|
||||
assert mapping['output'].data.shape == torch.Size([4, 8, 8])
|
||||
assert mapping['output'].type == OperationDataType.OUTPUT
|
||||
|
||||
strategies_vector = handler.register_strategy(compute_resharding_cost=False)
|
||||
strategy_name_list = [val.name for val in strategies_vector]
|
||||
for name in strategy_name_list:
|
||||
print(name)
|
||||
# one batch dim
|
||||
assert 'Sb0 = Sb0 x Sb0' not in strategy_name_list
|
||||
|
||||
|
@ -123,23 +143,21 @@ def check_2d_device_mesh(rank, module, bias_shape, world_size, port):
|
|||
for strategy in strategies_vector:
|
||||
input_sharding_spec = strategy.get_sharding_spec_by_name('x1')
|
||||
other_sharding_spec = strategy.get_sharding_spec_by_name('x2')
|
||||
bias_sharding_spec = strategy.get_sharding_spec_by_name('bias')
|
||||
output_sharding_spec = strategy.get_sharding_spec_by_name('addbmm')
|
||||
output_sharding_spec = strategy.get_sharding_spec_by_name('bmm')
|
||||
|
||||
# make sure the sharding matches across different operation data
|
||||
assert input_sharding_spec.sharding_sequence[1] == output_sharding_spec.sharding_sequence[0]
|
||||
assert input_sharding_spec.sharding_sequence[0] == output_sharding_spec.sharding_sequence[0]
|
||||
assert other_sharding_spec.sharding_sequence[1] == input_sharding_spec.sharding_sequence[-1]
|
||||
assert other_sharding_spec.sharding_sequence[-1] == output_sharding_spec.sharding_sequence[-1]
|
||||
assert bias_sharding_spec.sharding_sequence[-1] == output_sharding_spec.sharding_sequence[-1]
|
||||
|
||||
|
||||
def check_1d_device_mesh(rank, module, bias_shape, world_size, port):
|
||||
def check_1d_device_mesh(rank, module, bias_shape, using_kwargs, world_size, port):
|
||||
disable_existing_loggers()
|
||||
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
physical_mesh_id = torch.arange(0, 4)
|
||||
mesh_shape = (1, 4)
|
||||
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
|
||||
model = module().cuda()
|
||||
model = module(using_kwargs).cuda()
|
||||
x1 = torch.rand(4, 8, 16).cuda()
|
||||
x2 = torch.rand(4, 16, 8).cuda()
|
||||
bias = torch.rand(bias_shape).cuda()
|
||||
|
@ -159,6 +177,14 @@ def check_1d_device_mesh(rank, module, bias_shape, world_size, port):
|
|||
meta_arg_names=meta_arg_names)
|
||||
|
||||
tracer = ColoTracer()
|
||||
# graph():
|
||||
# %bias : torch.Tensor [#users=1] = placeholder[target=bias]
|
||||
# %x1 : torch.Tensor [#users=1] = placeholder[target=x1]
|
||||
# %x2 : torch.Tensor [#users=1] = placeholder[target=x2]
|
||||
# %bmm : [#users=1] = call_function[target=torch.bmm](args = (%x1, %x2), kwargs = {})
|
||||
# %sum_1 : [#users=1] = call_function[target=torch.sum](args = (%bmm, 0), kwargs = {})
|
||||
# %add : [#users=1] = call_function[target=operator.add](args = (%sum_1, %bias), kwargs = {})
|
||||
# return add
|
||||
graph = tracer.trace(model,
|
||||
meta_args={
|
||||
'bias': torch.rand(*bias_shape).to('meta'),
|
||||
|
@ -166,11 +192,11 @@ def check_1d_device_mesh(rank, module, bias_shape, world_size, port):
|
|||
'x2': torch.rand(4, 16, 8).to('meta')
|
||||
})
|
||||
gm = ColoGraphModule(model, graph)
|
||||
linear_mod_node = list(graph.nodes)[3]
|
||||
strategies_vector = StrategiesVector(linear_mod_node)
|
||||
bmm_mod_node = list(graph.nodes)[3]
|
||||
strategies_vector = StrategiesVector(bmm_mod_node)
|
||||
|
||||
# build handler
|
||||
handler = AddBMMFunctionHandler(node=linear_mod_node, device_mesh=device_mesh, strategies_vector=strategies_vector)
|
||||
handler = BMMFunctionHandler(node=bmm_mod_node, device_mesh=device_mesh, strategies_vector=strategies_vector)
|
||||
|
||||
# check operation data mapping
|
||||
mapping = handler.get_operation_data_mapping()
|
||||
|
@ -193,15 +219,9 @@ def check_1d_device_mesh(rank, module, bias_shape, world_size, port):
|
|||
assert mapping['other'].type == OperationDataType.ARG
|
||||
assert mapping['other'].logical_shape == torch.Size([4, 16, 8])
|
||||
|
||||
assert mapping['bias'].name == "bias"
|
||||
assert mapping['bias'].data.is_meta
|
||||
assert mapping['bias'].data.shape == torch.Size(bias_shape)
|
||||
assert mapping['bias'].type == OperationDataType.ARG
|
||||
assert mapping['bias'].logical_shape == torch.Size([8, 8])
|
||||
|
||||
assert mapping['output'].name == "addbmm"
|
||||
assert mapping['output'].name == "bmm"
|
||||
assert mapping['output'].data.is_meta
|
||||
assert mapping['output'].data.shape == torch.Size([8, 8])
|
||||
assert mapping['output'].data.shape == torch.Size([4, 8, 8])
|
||||
assert mapping['output'].type == OperationDataType.OUTPUT
|
||||
|
||||
strategies_vector = handler.register_strategy(compute_resharding_cost=False)
|
||||
|
@ -213,14 +233,12 @@ def check_1d_device_mesh(rank, module, bias_shape, world_size, port):
|
|||
for strategy in strategies_vector:
|
||||
input_sharding_spec = strategy.get_sharding_spec_by_name('x1')
|
||||
other_sharding_spec = strategy.get_sharding_spec_by_name('x2')
|
||||
bias_sharding_spec = strategy.get_sharding_spec_by_name('bias')
|
||||
output_sharding_spec = strategy.get_sharding_spec_by_name('addbmm')
|
||||
output_sharding_spec = strategy.get_sharding_spec_by_name('bmm')
|
||||
|
||||
# make sure the sharding matches across different operation data
|
||||
assert input_sharding_spec.sharding_sequence[1] == output_sharding_spec.sharding_sequence[0]
|
||||
assert input_sharding_spec.sharding_sequence[0] == output_sharding_spec.sharding_sequence[0]
|
||||
assert other_sharding_spec.sharding_sequence[1] == input_sharding_spec.sharding_sequence[-1]
|
||||
assert other_sharding_spec.sharding_sequence[-1] == output_sharding_spec.sharding_sequence[-1]
|
||||
assert bias_sharding_spec.sharding_sequence[-1] == output_sharding_spec.sharding_sequence[-1]
|
||||
|
||||
|
||||
@pytest.mark.skip("skip due to bias cases not ready")
|
||||
|
@ -228,13 +246,15 @@ def check_1d_device_mesh(rank, module, bias_shape, world_size, port):
|
|||
@pytest.mark.dist
|
||||
@parameterize('module', [AddBMMTorchFunctionModule, AddBMMTensorMethodModule])
|
||||
@parameterize('bias_shape', [[8], [1, 8], [8, 8]])
|
||||
@parameterize('using_kwargs', [True, False])
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_2d_device_mesh(module, bias_shape):
|
||||
def test_2d_device_mesh(module, bias_shape, using_kwargs):
|
||||
world_size = 4
|
||||
run_func = partial(check_2d_device_mesh,
|
||||
module=module,
|
||||
bias_shape=bias_shape,
|
||||
world_size=world_size,
|
||||
using_kwargs=using_kwargs,
|
||||
port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
|
||||
|
@ -244,12 +264,14 @@ def test_2d_device_mesh(module, bias_shape):
|
|||
@pytest.mark.dist
|
||||
@parameterize('module', [AddBMMTorchFunctionModule, AddBMMTensorMethodModule])
|
||||
@parameterize('bias_shape', [[8], [1, 8], [8, 8]])
|
||||
@parameterize('using_kwargs', [True, False])
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_1d_device_mesh(module, bias_shape):
|
||||
def test_1d_device_mesh(module, bias_shape, using_kwargs):
|
||||
world_size = 4
|
||||
run_func = partial(check_1d_device_mesh,
|
||||
module=module,
|
||||
bias_shape=bias_shape,
|
||||
using_kwargs=using_kwargs,
|
||||
world_size=world_size,
|
||||
port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
|
|
Loading…
Reference in New Issue