mirror of https://github.com/hpcaitech/ColossalAI
[autoparallel] add numerical test for handlers (#1769)
parent
b0f7c8bde8
commit
a4d1f59c78
|
@ -1,11 +1,20 @@
|
|||
from functools import partial
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
|
||||
from colossalai.auto_parallel.tensor_shard.node_handler import AddBMMFunctionHandler
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.fx import ColoGraphModule, ColoTracer
|
||||
from colossalai.testing import parameterize
|
||||
from colossalai.initialize import launch
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use
|
||||
from colossalai.testing.pytest_wrapper import run_on_environment_flag
|
||||
from colossalai.utils import free_port
|
||||
from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy
|
||||
|
||||
|
||||
class AddBMMTensorMethodModule(nn.Module):
|
||||
|
@ -20,11 +29,30 @@ class AddBMMTorchFunctionModule(nn.Module):
|
|||
return torch.addbmm(bias, x1, x2)
|
||||
|
||||
|
||||
@parameterize('module', [AddBMMTorchFunctionModule, AddBMMTensorMethodModule])
|
||||
@parameterize('bias_shape', [[8], [1, 8], [8, 8]])
|
||||
def test_2d_device_mesh(module, bias_shape):
|
||||
|
||||
model = module()
|
||||
def check_2d_device_mesh(rank, module, bias_shape, world_size, port):
|
||||
disable_existing_loggers()
|
||||
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
model = module().cuda()
|
||||
physical_mesh_id = torch.arange(0, 4)
|
||||
mesh_shape = (2, 2)
|
||||
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
|
||||
x1 = torch.rand(4, 8, 16).cuda()
|
||||
x2 = torch.rand(4, 16, 8).cuda()
|
||||
bias = torch.rand(bias_shape).cuda()
|
||||
# the index of addbmm node in computation graph
|
||||
node_index = 3
|
||||
# strategy number of addbmm node on 2d device mesh
|
||||
strategy_number = 7
|
||||
# construct input args
|
||||
input_args = [bias, x1, x2]
|
||||
# construct meta arg names
|
||||
meta_arg_names = ['bias', 'x1', 'x2']
|
||||
numerical_test_for_node_strategy(model=model,
|
||||
device_mesh=device_mesh,
|
||||
node_index=node_index,
|
||||
strategy_number=strategy_number,
|
||||
input_args=input_args,
|
||||
meta_arg_names=meta_arg_names)
|
||||
tracer = ColoTracer()
|
||||
graph = tracer.trace(model,
|
||||
meta_args={
|
||||
|
@ -32,12 +60,8 @@ def test_2d_device_mesh(module, bias_shape):
|
|||
"x1": torch.rand(4, 8, 16).to('meta'),
|
||||
'x2': torch.rand(4, 16, 8).to('meta')
|
||||
})
|
||||
print(graph)
|
||||
gm = ColoGraphModule(model, graph)
|
||||
physical_mesh_id = torch.arange(0, 4)
|
||||
|
||||
mesh_shape = (2, 2)
|
||||
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
|
||||
linear_mod_node = list(graph.nodes)[3]
|
||||
strategies_vector = StrategiesVector(linear_mod_node)
|
||||
|
||||
|
@ -78,7 +102,6 @@ def test_2d_device_mesh(module, bias_shape):
|
|||
|
||||
strategies_vector = handler.register_strategy(compute_resharding_cost=False)
|
||||
strategy_name_list = [val.name for val in strategies_vector]
|
||||
|
||||
# one batch dim
|
||||
assert 'Sb0 = Sb0 x Sb0' not in strategy_name_list
|
||||
|
||||
|
@ -110,10 +133,31 @@ def test_2d_device_mesh(module, bias_shape):
|
|||
assert bias_sharding_spec.sharding_sequence[-1] == output_sharding_spec.sharding_sequence[-1]
|
||||
|
||||
|
||||
@parameterize('module', [AddBMMTorchFunctionModule, AddBMMTensorMethodModule])
|
||||
@parameterize('bias_shape', [[8], [1, 8], [8, 8]])
|
||||
def test_1d_device_mesh(module, bias_shape):
|
||||
model = module()
|
||||
def check_1d_device_mesh(rank, module, bias_shape, world_size, port):
|
||||
disable_existing_loggers()
|
||||
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
physical_mesh_id = torch.arange(0, 4)
|
||||
mesh_shape = (1, 4)
|
||||
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
|
||||
model = module().cuda()
|
||||
x1 = torch.rand(4, 8, 16).cuda()
|
||||
x2 = torch.rand(4, 16, 8).cuda()
|
||||
bias = torch.rand(bias_shape).cuda()
|
||||
# the index of addbmm node in computation graph
|
||||
node_index = 3
|
||||
# strategy number of addbmm node on 2d device mesh
|
||||
strategy_number = 1
|
||||
# construct input args
|
||||
input_args = [bias, x1, x2]
|
||||
# construct meta arg names
|
||||
meta_arg_names = ['bias', 'x1', 'x2']
|
||||
numerical_test_for_node_strategy(model=model,
|
||||
device_mesh=device_mesh,
|
||||
node_index=node_index,
|
||||
strategy_number=strategy_number,
|
||||
input_args=input_args,
|
||||
meta_arg_names=meta_arg_names)
|
||||
|
||||
tracer = ColoTracer()
|
||||
graph = tracer.trace(model,
|
||||
meta_args={
|
||||
|
@ -121,12 +165,7 @@ def test_1d_device_mesh(module, bias_shape):
|
|||
"x1": torch.rand(4, 8, 16).to('meta'),
|
||||
'x2': torch.rand(4, 16, 8).to('meta')
|
||||
})
|
||||
print(graph)
|
||||
gm = ColoGraphModule(model, graph)
|
||||
physical_mesh_id = torch.arange(0, 4)
|
||||
|
||||
mesh_shape = (1, 4)
|
||||
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
|
||||
linear_mod_node = list(graph.nodes)[3]
|
||||
strategies_vector = StrategiesVector(linear_mod_node)
|
||||
|
||||
|
@ -184,6 +223,38 @@ def test_1d_device_mesh(module, bias_shape):
|
|||
assert bias_sharding_spec.sharding_sequence[-1] == output_sharding_spec.sharding_sequence[-1]
|
||||
|
||||
|
||||
@pytest.mark.skip("skip due to bias cases not ready")
|
||||
@run_on_environment_flag(name='AUTO_PARALLEL')
|
||||
@pytest.mark.dist
|
||||
@parameterize('module', [AddBMMTorchFunctionModule, AddBMMTensorMethodModule])
|
||||
@parameterize('bias_shape', [[8], [1, 8], [8, 8]])
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_2d_device_mesh(module, bias_shape):
|
||||
world_size = 4
|
||||
run_func = partial(check_2d_device_mesh,
|
||||
module=module,
|
||||
bias_shape=bias_shape,
|
||||
world_size=world_size,
|
||||
port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
|
||||
|
||||
@pytest.mark.skip("skip due to bias cases not ready")
|
||||
@run_on_environment_flag(name='AUTO_PARALLEL')
|
||||
@pytest.mark.dist
|
||||
@parameterize('module', [AddBMMTorchFunctionModule, AddBMMTensorMethodModule])
|
||||
@parameterize('bias_shape', [[8], [1, 8], [8, 8]])
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_1d_device_mesh(module, bias_shape):
|
||||
world_size = 4
|
||||
run_func = partial(check_1d_device_mesh,
|
||||
module=module,
|
||||
bias_shape=bias_shape,
|
||||
world_size=world_size,
|
||||
port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_1d_device_mesh()
|
||||
# test_2d_device_mesh()
|
||||
test_2d_device_mesh()
|
||||
|
|
|
@ -1,18 +1,43 @@
|
|||
from functools import partial
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
|
||||
from colossalai.auto_parallel.tensor_shard.node_handler.batch_norm_handler import \
|
||||
BatchNormModuleHandler
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (OperationData, OperationDataType, StrategiesVector)
|
||||
from colossalai.auto_parallel.tensor_shard.node_handler.batch_norm_handler import BatchNormModuleHandler
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.fx import ColoGraphModule, ColoTracer
|
||||
from colossalai.fx.tracer.meta_patch.patched_module import linear
|
||||
import pytest
|
||||
from colossalai.initialize import launch
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use
|
||||
from colossalai.testing.pytest_wrapper import run_on_environment_flag
|
||||
from colossalai.utils import free_port
|
||||
from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy
|
||||
|
||||
|
||||
@pytest.mark.skip("skip due to passes not ready")
|
||||
def test_bn_module_handler():
|
||||
model = nn.Sequential(nn.BatchNorm2d(16).to('meta'))
|
||||
def check_bn_module_handler(rank, world_size, port):
|
||||
disable_existing_loggers()
|
||||
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
model = nn.Sequential(nn.BatchNorm2d(16)).cuda()
|
||||
|
||||
physical_mesh_id = torch.arange(0, 4)
|
||||
|
||||
mesh_shape = (2, 2)
|
||||
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
|
||||
input = torch.rand(4, 16, 64, 64).cuda()
|
||||
# the index of bn node in computation graph
|
||||
node_index = 1
|
||||
# the total number of bn strategies without sync bn mode
|
||||
# TODO: add sync bn stategies after related passes ready
|
||||
strategy_number = 4
|
||||
numerical_test_for_node_strategy(model=model,
|
||||
device_mesh=device_mesh,
|
||||
node_index=node_index,
|
||||
strategy_number=strategy_number,
|
||||
input_args=[input],
|
||||
meta_arg_names=['input'])
|
||||
tracer = ColoTracer()
|
||||
# graph():
|
||||
# %input_1 : torch.Tensor [#users=1] = placeholder[target=input]
|
||||
|
@ -20,10 +45,6 @@ def test_bn_module_handler():
|
|||
# return _0
|
||||
graph = tracer.trace(model, meta_args={"input": torch.rand(4, 16, 64, 64).to('meta')})
|
||||
gm = ColoGraphModule(model, graph)
|
||||
physical_mesh_id = torch.arange(0, 4)
|
||||
|
||||
mesh_shape = (2, 2)
|
||||
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
|
||||
bn_mod_node = list(graph.nodes)[1]
|
||||
strategies_vector = StrategiesVector(bn_mod_node)
|
||||
|
||||
|
@ -40,25 +61,21 @@ def test_bn_module_handler():
|
|||
assert op_data.data is not None
|
||||
|
||||
assert mapping['input'].name == "input_1"
|
||||
assert mapping['input'].data.is_meta
|
||||
assert mapping['input'].data.shape == torch.Size([4, 16, 64, 64])
|
||||
assert mapping['input'].type == OperationDataType.ARG
|
||||
assert mapping['input'].logical_shape == torch.Size([4, 16, 64, 64])
|
||||
|
||||
assert mapping['other'].name == "weight"
|
||||
assert mapping['other'].data.is_meta
|
||||
assert mapping['other'].data.shape == torch.Size([16])
|
||||
assert mapping['other'].type == OperationDataType.PARAM
|
||||
assert mapping['other'].logical_shape == torch.Size([16])
|
||||
|
||||
assert mapping['bias'].name == "bias"
|
||||
assert mapping['bias'].data.is_meta
|
||||
assert mapping['bias'].data.shape == torch.Size([16])
|
||||
assert mapping['bias'].type == OperationDataType.PARAM
|
||||
assert mapping['bias'].logical_shape == torch.Size([16])
|
||||
|
||||
assert mapping['output'].name == "_0"
|
||||
assert mapping['output'].data.is_meta
|
||||
assert mapping['output'].data.shape == torch.Size([4, 16, 64, 64])
|
||||
assert mapping['output'].type == OperationDataType.OUTPUT
|
||||
|
||||
|
@ -75,16 +92,27 @@ def test_bn_module_handler():
|
|||
# RS01 = RS01 x S01
|
||||
assert 'RS01 = RS01 x S01' in strategy_name_list
|
||||
|
||||
# temporarily skip the sync bn test
|
||||
# TODO: test sync bn after the implicit runtime pass completed
|
||||
# SR = SR x R WITH SYNC_BN
|
||||
assert 'S0R = S0R x R WITH SYNC_BN' in strategy_name_list
|
||||
assert 'S1R = S1R x R WITH SYNC_BN' in strategy_name_list
|
||||
# assert 'S0R = S0R x R WITH SYNC_BN' in strategy_name_list
|
||||
# assert 'S1R = S1R x R WITH SYNC_BN' in strategy_name_list
|
||||
|
||||
# SS = SS x S WITH SYNC_BN
|
||||
assert 'S0S1 = S0S1 x S1 WITH SYNC_BN' in strategy_name_list
|
||||
assert 'S1S0 = S1S0 x S0 WITH SYNC_BN' in strategy_name_list
|
||||
# assert 'S0S1 = S0S1 x S1 WITH SYNC_BN' in strategy_name_list
|
||||
# assert 'S1S0 = S1S0 x S0 WITH SYNC_BN' in strategy_name_list
|
||||
|
||||
# S01R = S01R x R WITH SYNC_BN
|
||||
assert 'S01R = S01R x R WITH SYNC_BN' in strategy_name_list
|
||||
# assert 'S01R = S01R x R WITH SYNC_BN' in strategy_name_list
|
||||
|
||||
|
||||
@run_on_environment_flag(name='AUTO_PARALLEL')
|
||||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_bn_module_handler():
|
||||
world_size = 4
|
||||
run_func = partial(check_bn_module_handler, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
@ -1,16 +1,25 @@
|
|||
from functools import partial
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
|
||||
from colossalai.auto_parallel.tensor_shard.node_handler import BinaryElementwiseHandler
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.fx import ColoGraphModule, ColoTracer
|
||||
from colossalai.testing import parameterize
|
||||
from colossalai.initialize import launch
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use
|
||||
from colossalai.testing.pytest_wrapper import run_on_environment_flag
|
||||
from colossalai.utils import free_port
|
||||
from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy
|
||||
|
||||
|
||||
@parameterize('op', [torch.add])
|
||||
@parameterize('other_dim', [1, 2])
|
||||
def test_binary_elementwise_handler_with_tensor(op, other_dim):
|
||||
def check_binary_elementwise_handler_with_tensor(rank, op, other_dim, world_size, port):
|
||||
disable_existing_loggers()
|
||||
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
|
||||
class BinaryElementwiseOpModel(nn.Module):
|
||||
|
||||
|
@ -22,16 +31,32 @@ def test_binary_elementwise_handler_with_tensor(op, other_dim):
|
|||
out = self.op(x1, x2)
|
||||
return out
|
||||
|
||||
model = BinaryElementwiseOpModel(op)
|
||||
tracer = ColoTracer()
|
||||
|
||||
meta_args = {'x1': torch.rand(4, 4).to('meta'), 'x2': torch.rand([4] * other_dim).to('meta')}
|
||||
graph = tracer.trace(model, meta_args=meta_args)
|
||||
print(graph)
|
||||
gm = ColoGraphModule(model, graph)
|
||||
model = BinaryElementwiseOpModel(op).cuda()
|
||||
physical_mesh_id = torch.arange(0, 4)
|
||||
mesh_shape = (2, 2)
|
||||
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
|
||||
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
|
||||
x1 = torch.rand(4, 4).cuda()
|
||||
x2 = torch.rand([4] * other_dim).cuda()
|
||||
# the index of binary-elementwise node in computation graph
|
||||
node_index = 2
|
||||
# strategy number of binary-elementwise node
|
||||
strategy_number = 9
|
||||
# construct input args
|
||||
input_args = [x1, x2]
|
||||
# construct meta arg names
|
||||
meta_arg_names = ['x1', 'x2']
|
||||
numerical_test_for_node_strategy(model=model,
|
||||
device_mesh=device_mesh,
|
||||
node_index=node_index,
|
||||
strategy_number=strategy_number,
|
||||
input_args=input_args,
|
||||
meta_arg_names=meta_arg_names)
|
||||
|
||||
tracer = ColoTracer()
|
||||
meta_args = {'x1': torch.rand(4, 4).to('meta'), 'x2': torch.rand([4] * other_dim).to('meta')}
|
||||
graph = tracer.trace(model, meta_args=meta_args)
|
||||
gm = ColoGraphModule(model, graph)
|
||||
|
||||
op_node = list(graph.nodes)[2]
|
||||
strategies_vector = StrategiesVector(op_node)
|
||||
|
||||
|
@ -97,9 +122,9 @@ def test_binary_elementwise_handler_with_tensor(op, other_dim):
|
|||
assert input_sharding_spec.sharding_sequence[-1] == other_sharding_spec.sharding_sequence[-1]
|
||||
|
||||
|
||||
@parameterize('op', [torch.add])
|
||||
@parameterize('other', [1, 2])
|
||||
def test_binary_elementwise_handler_with_int(op, other):
|
||||
def check_binary_elementwise_handler_with_int(rank, op, other_dim, world_size, port):
|
||||
disable_existing_loggers()
|
||||
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
|
||||
class BinaryElementwiseOpModel(nn.Module):
|
||||
|
||||
|
@ -112,16 +137,30 @@ def test_binary_elementwise_handler_with_int(op, other):
|
|||
out = self.op(x1, self.const)
|
||||
return out
|
||||
|
||||
model = BinaryElementwiseOpModel(op, other)
|
||||
tracer = ColoTracer()
|
||||
|
||||
meta_args = {'x1': torch.rand(4, 4).to('meta')}
|
||||
graph = tracer.trace(model, meta_args=meta_args)
|
||||
print(graph)
|
||||
gm = ColoGraphModule(model, graph)
|
||||
physical_mesh_id = torch.arange(0, 4)
|
||||
mesh_shape = (2, 2)
|
||||
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
|
||||
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
|
||||
model = BinaryElementwiseOpModel(op, other_dim).cuda()
|
||||
x1 = torch.rand(4, 4).cuda()
|
||||
# the index of binary-elementwise node in computation graph
|
||||
node_index = 1
|
||||
# strategy number of binary-elementwise node
|
||||
strategy_number = 9
|
||||
# construct input args
|
||||
input_args = [x1]
|
||||
# construct meta arg names
|
||||
meta_arg_names = ['x1']
|
||||
numerical_test_for_node_strategy(model=model,
|
||||
device_mesh=device_mesh,
|
||||
node_index=node_index,
|
||||
strategy_number=strategy_number,
|
||||
input_args=input_args,
|
||||
meta_arg_names=meta_arg_names)
|
||||
tracer = ColoTracer()
|
||||
meta_args = {'x1': torch.rand(4, 4).to('meta')}
|
||||
graph = tracer.trace(model, meta_args=meta_args)
|
||||
gm = ColoGraphModule(model, graph)
|
||||
|
||||
op_node = list(graph.nodes)[1]
|
||||
strategies_vector = StrategiesVector(op_node)
|
||||
|
||||
|
@ -168,6 +207,26 @@ def test_binary_elementwise_handler_with_int(op, other):
|
|||
assert input_sharding_spec.sharding_sequence == output_sharding_spec.sharding_sequence
|
||||
|
||||
|
||||
@parameterize('op', [torch.add])
|
||||
@parameterize('other_dim', [1, 2])
|
||||
@run_on_environment_flag(name='AUTO_PARALLEL')
|
||||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_binary_elementwise_handler(op, other_dim):
|
||||
world_size = 4
|
||||
run_func_tensor = partial(check_binary_elementwise_handler_with_tensor,
|
||||
op=op,
|
||||
other_dim=other_dim,
|
||||
world_size=world_size,
|
||||
port=free_port())
|
||||
mp.spawn(run_func_tensor, nprocs=world_size)
|
||||
run_func_int = partial(check_binary_elementwise_handler_with_int,
|
||||
op=op,
|
||||
other_dim=other_dim,
|
||||
world_size=world_size,
|
||||
port=free_port())
|
||||
mp.spawn(run_func_int, nprocs=world_size)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_binary_elementwise_handler_with_tensor()
|
||||
test_binary_elementwise_handler_with_int()
|
||||
test_binary_elementwise_handler()
|
||||
|
|
|
@ -1,12 +1,20 @@
|
|||
from functools import partial
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
|
||||
from colossalai.auto_parallel.tensor_shard.node_handler import BMMFunctionHandler
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.fx import ColoGraphModule, ColoTracer
|
||||
from colossalai.testing import parameterize
|
||||
from colossalai.initialize import launch
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use
|
||||
from colossalai.testing.pytest_wrapper import run_on_environment_flag
|
||||
from colossalai.utils import free_port
|
||||
from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy
|
||||
|
||||
|
||||
class BMMTensorMethodModule(nn.Module):
|
||||
|
@ -21,22 +29,37 @@ class BMMTorchFunctionModule(nn.Module):
|
|||
return torch.bmm(x1, x2)
|
||||
|
||||
|
||||
@parameterize('module', [BMMTensorMethodModule, BMMTorchFunctionModule])
|
||||
def test_2d_device_mesh(module):
|
||||
|
||||
model = module()
|
||||
def check_2d_device_mesh(rank, module, world_size, port):
|
||||
disable_existing_loggers()
|
||||
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
model = module().cuda()
|
||||
physical_mesh_id = torch.arange(0, 4)
|
||||
mesh_shape = (2, 2)
|
||||
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
|
||||
x1 = torch.rand(4, 8, 16).cuda()
|
||||
x2 = torch.rand(4, 16, 8).cuda()
|
||||
# the index of bmm node in computation graph
|
||||
node_index = 2
|
||||
# strategy number of bmm node on 2d device mesh
|
||||
strategy_number = 7
|
||||
# construct input args
|
||||
input_args = [x1, x2]
|
||||
# construct meta arg names
|
||||
meta_arg_names = ['x1', 'x2']
|
||||
numerical_test_for_node_strategy(model=model,
|
||||
device_mesh=device_mesh,
|
||||
node_index=node_index,
|
||||
strategy_number=strategy_number,
|
||||
input_args=input_args,
|
||||
meta_arg_names=meta_arg_names)
|
||||
tracer = ColoTracer()
|
||||
graph = tracer.trace(model,
|
||||
meta_args={
|
||||
"x1": torch.rand(4, 8, 16).to('meta'),
|
||||
'x2': torch.rand(4, 16, 8).to('meta')
|
||||
})
|
||||
print(graph)
|
||||
gm = ColoGraphModule(model, graph)
|
||||
physical_mesh_id = torch.arange(0, 4)
|
||||
|
||||
mesh_shape = (2, 2)
|
||||
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
|
||||
linear_mod_node = list(graph.nodes)[2]
|
||||
strategies_vector = StrategiesVector(linear_mod_node)
|
||||
|
||||
|
@ -96,27 +119,41 @@ def test_2d_device_mesh(module):
|
|||
output_sharding_spec = strategy.get_sharding_spec_by_name('bmm')
|
||||
|
||||
# make sure the sharding matches across different operation data
|
||||
print(input_sharding_spec.sharding_sequence, output_sharding_spec.sharding_sequence)
|
||||
assert input_sharding_spec.sharding_sequence[:-1] == output_sharding_spec.sharding_sequence[:-1]
|
||||
assert other_sharding_spec.sharding_sequence[1] == input_sharding_spec.sharding_sequence[-1]
|
||||
assert other_sharding_spec.sharding_sequence[-1] == output_sharding_spec.sharding_sequence[-1]
|
||||
|
||||
|
||||
@parameterize('module', [BMMTensorMethodModule, BMMTorchFunctionModule])
|
||||
def test_1d_device_mesh(module):
|
||||
model = module()
|
||||
def check_1d_device_mesh(rank, module, world_size, port):
|
||||
disable_existing_loggers()
|
||||
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
model = module().cuda()
|
||||
physical_mesh_id = torch.arange(0, 4)
|
||||
mesh_shape = (1, 4)
|
||||
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
|
||||
x1 = torch.rand(4, 8, 16).cuda()
|
||||
x2 = torch.rand(4, 16, 8).cuda()
|
||||
# the index of bmm node in computation graph
|
||||
node_index = 2
|
||||
# strategy number of bmm node on 1d device mesh
|
||||
strategy_number = 1
|
||||
# construct input args
|
||||
input_args = [x1, x2]
|
||||
# construct meta arg names
|
||||
meta_arg_names = ['x1', 'x2']
|
||||
numerical_test_for_node_strategy(model=model,
|
||||
device_mesh=device_mesh,
|
||||
node_index=node_index,
|
||||
strategy_number=strategy_number,
|
||||
input_args=input_args,
|
||||
meta_arg_names=meta_arg_names)
|
||||
tracer = ColoTracer()
|
||||
graph = tracer.trace(model,
|
||||
meta_args={
|
||||
"x1": torch.rand(4, 8, 16).to('meta'),
|
||||
'x2': torch.rand(4, 16, 8).to('meta')
|
||||
})
|
||||
print(graph)
|
||||
gm = ColoGraphModule(model, graph)
|
||||
physical_mesh_id = torch.arange(0, 4)
|
||||
|
||||
mesh_shape = (1, 4)
|
||||
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
|
||||
linear_mod_node = list(graph.nodes)[2]
|
||||
strategies_vector = StrategiesVector(linear_mod_node)
|
||||
|
||||
|
@ -166,6 +203,17 @@ def test_1d_device_mesh(module):
|
|||
assert other_sharding_spec.sharding_sequence[-1] == output_sharding_spec.sharding_sequence[-1]
|
||||
|
||||
|
||||
@parameterize('module', [BMMTensorMethodModule, BMMTorchFunctionModule])
|
||||
@run_on_environment_flag(name='AUTO_PARALLEL')
|
||||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_bmm_handler(module):
|
||||
world_size = 4
|
||||
run_func_2d = partial(check_2d_device_mesh, module=module, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func_2d, nprocs=world_size)
|
||||
run_func_1d = partial(check_1d_device_mesh, module=module, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func_1d, nprocs=world_size)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_1d_device_mesh()
|
||||
test_2d_device_mesh()
|
||||
test_bmm_handler()
|
||||
|
|
|
@ -31,11 +31,16 @@ def check_conv_module_handler(rank, bias, world_size, port):
|
|||
mesh_shape = (2, 2)
|
||||
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
|
||||
|
||||
# index of conv node in this graph
|
||||
# index of conv node in computation graph
|
||||
node_index = 1
|
||||
# total number of conv strategies
|
||||
strategy_number = 16
|
||||
numerical_test_for_node_strategy(model, device_mesh, node_index, strategy_number, [input], ['input'])
|
||||
numerical_test_for_node_strategy(model=model,
|
||||
device_mesh=device_mesh,
|
||||
node_index=node_index,
|
||||
strategy_number=strategy_number,
|
||||
input_args=[input],
|
||||
meta_arg_names=['input'])
|
||||
tracer = ColoTracer()
|
||||
graph = tracer.trace(model, meta_args={"input": torch.rand(4, 4, 64, 64).to('meta')})
|
||||
gm = ColoGraphModule(model, graph)
|
||||
|
@ -165,8 +170,13 @@ def check_conv_function_handler(rank, bias, world_size, port):
|
|||
bias_tensor = torch.rand(16).cuda()
|
||||
input_kwargs['bias'] = bias_tensor
|
||||
node_index += 1
|
||||
numerical_test_for_node_strategy(model, device_mesh, node_index, strategy_number, input_args, meta_arg_names,
|
||||
input_kwargs)
|
||||
numerical_test_for_node_strategy(model=model,
|
||||
device_mesh=device_mesh,
|
||||
node_index=node_index,
|
||||
strategy_number=strategy_number,
|
||||
input_args=input_args,
|
||||
meta_arg_names=meta_arg_names,
|
||||
input_kwargs=input_kwargs)
|
||||
|
||||
tracer = ColoTracer()
|
||||
# graph():
|
||||
|
@ -280,21 +290,27 @@ def check_conv_function_handler(rank, bias, world_size, port):
|
|||
assert bias_sharding_spec.sharding_sequence[-1] == output_sharding_spec.sharding_sequence[1]
|
||||
|
||||
|
||||
@pytest.mark.skip("some cases need to be fixed")
|
||||
@run_on_environment_flag(name='AUTO_PARALLEL')
|
||||
@pytest.mark.dist
|
||||
@parameterize('bias', [True, False])
|
||||
# We temporarily ban the bias option before doing bias add
|
||||
# before all reduce communication may encounter correctness issue.
|
||||
# @parameterize('bias', [True, False])
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_conv_module_handler(bias):
|
||||
def test_conv_module_handler(bias=False):
|
||||
world_size = 4
|
||||
run_func = partial(check_conv_module_handler, bias=bias, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
|
||||
|
||||
@pytest.mark.skip("some cases need to be fixed")
|
||||
@run_on_environment_flag(name='AUTO_PARALLEL')
|
||||
@pytest.mark.dist
|
||||
@parameterize('bias', [True, False])
|
||||
# We temporarily ban the bias option before doing bias add
|
||||
# before all reduce communication may encounter correctness issue.
|
||||
# @parameterize('bias', [True, False])
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_conv_function_handler(bias):
|
||||
def test_conv_function_handler(bias=False):
|
||||
world_size = 4
|
||||
run_func = partial(check_conv_function_handler, bias=bias, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
|
|
|
@ -1,16 +1,45 @@
|
|||
from functools import partial
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
|
||||
from colossalai.auto_parallel.tensor_shard.node_handler.layer_norm_handler import \
|
||||
LayerNormModuleHandler
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (OperationData, OperationDataType, StrategiesVector)
|
||||
from colossalai.auto_parallel.tensor_shard.node_handler.layer_norm_handler import LayerNormModuleHandler
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.fx import ColoGraphModule, ColoTracer
|
||||
from colossalai.fx.tracer.meta_patch.patched_module import linear
|
||||
from colossalai.initialize import launch
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use
|
||||
from colossalai.testing.pytest_wrapper import run_on_environment_flag
|
||||
from colossalai.utils import free_port
|
||||
from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy
|
||||
|
||||
|
||||
def test_ln_module_handler():
|
||||
model = nn.Sequential(nn.LayerNorm(16).to('meta'))
|
||||
def check_ln_module_handler(rank, world_size, port):
|
||||
disable_existing_loggers()
|
||||
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
model = nn.Sequential(nn.LayerNorm(16)).cuda()
|
||||
physical_mesh_id = torch.arange(0, 4)
|
||||
mesh_shape = (2, 2)
|
||||
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
|
||||
input = torch.rand(4, 16).cuda()
|
||||
# the index of bn node in computation graph
|
||||
node_index = 1
|
||||
# the total number of ln strategies
|
||||
strategy_number = 4
|
||||
# construct input args
|
||||
input_args = [input]
|
||||
# construct meta arg names
|
||||
meta_arg_names = ['input']
|
||||
numerical_test_for_node_strategy(model=model,
|
||||
device_mesh=device_mesh,
|
||||
node_index=node_index,
|
||||
strategy_number=strategy_number,
|
||||
input_args=input_args,
|
||||
meta_arg_names=meta_arg_names)
|
||||
tracer = ColoTracer()
|
||||
# graph():
|
||||
# %input_1 : torch.Tensor [#users=1] = placeholder[target=input]
|
||||
|
@ -18,10 +47,7 @@ def test_ln_module_handler():
|
|||
# return _0
|
||||
graph = tracer.trace(model, meta_args={"input": torch.rand(4, 16).to('meta')})
|
||||
gm = ColoGraphModule(model, graph)
|
||||
physical_mesh_id = torch.arange(0, 4)
|
||||
|
||||
mesh_shape = (2, 2)
|
||||
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
|
||||
ln_mod_node = list(graph.nodes)[1]
|
||||
strategies_vector = StrategiesVector(ln_mod_node)
|
||||
|
||||
|
@ -38,25 +64,21 @@ def test_ln_module_handler():
|
|||
assert op_data.data is not None
|
||||
|
||||
assert mapping['input'].name == "input_1"
|
||||
assert mapping['input'].data.is_meta
|
||||
assert mapping['input'].data.shape == torch.Size([4, 16])
|
||||
assert mapping['input'].type == OperationDataType.ARG
|
||||
assert mapping['input'].logical_shape == torch.Size([4, 16])
|
||||
|
||||
assert mapping['other'].name == "weight"
|
||||
assert mapping['other'].data.is_meta
|
||||
assert mapping['other'].data.shape == torch.Size([16])
|
||||
assert mapping['other'].type == OperationDataType.PARAM
|
||||
assert mapping['other'].logical_shape == torch.Size([16])
|
||||
|
||||
assert mapping['bias'].name == "bias"
|
||||
assert mapping['bias'].data.is_meta
|
||||
assert mapping['bias'].data.shape == torch.Size([16])
|
||||
assert mapping['bias'].type == OperationDataType.PARAM
|
||||
assert mapping['bias'].logical_shape == torch.Size([16])
|
||||
|
||||
assert mapping['output'].name == "_0"
|
||||
assert mapping['output'].data.is_meta
|
||||
assert mapping['output'].data.shape == torch.Size([4, 16])
|
||||
assert mapping['output'].type == OperationDataType.OUTPUT
|
||||
|
||||
|
@ -74,5 +96,14 @@ def test_ln_module_handler():
|
|||
assert '[S01, R] = [S01, R] x [R]' in strategy_name_list
|
||||
|
||||
|
||||
@run_on_environment_flag(name='AUTO_PARALLEL')
|
||||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_ln_module_handler():
|
||||
world_size = 4
|
||||
run_func = partial(check_ln_module_handler, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_ln_module_handler()
|
||||
|
|
|
@ -1,4 +1,10 @@
|
|||
from faulthandler import disable
|
||||
from functools import partial
|
||||
from xml.dom import WrongDocumentErr
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
from typing_extensions import Self
|
||||
|
||||
|
@ -11,22 +17,42 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
|
|||
)
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.fx import ColoGraphModule, ColoTracer
|
||||
from colossalai.initialize import launch
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use
|
||||
from colossalai.testing.pytest_wrapper import run_on_environment_flag
|
||||
from colossalai.testing.utils import parameterize
|
||||
from colossalai.utils import free_port
|
||||
from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy
|
||||
|
||||
|
||||
@parameterize('bias', [True, False])
|
||||
def test_linear_module_handler(bias):
|
||||
model = nn.Sequential(nn.Linear(16, 32, bias=bias).to('meta'))
|
||||
def check_linear_module_handler(rank, bias, world_size, port):
|
||||
disable_existing_loggers()
|
||||
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
model = nn.Sequential(nn.Linear(16, 32, bias=bias)).cuda()
|
||||
physical_mesh_id = torch.arange(0, 4)
|
||||
mesh_shape = (2, 2)
|
||||
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
|
||||
input = torch.rand(2, 2, 4, 16).cuda()
|
||||
# the index of linear node in computation graph
|
||||
node_index = 1
|
||||
# strategy number of linear node
|
||||
strategy_number = 10
|
||||
# construct input args
|
||||
input_args = [input]
|
||||
# construct meta arg names
|
||||
meta_arg_names = ['input']
|
||||
numerical_test_for_node_strategy(model=model,
|
||||
device_mesh=device_mesh,
|
||||
node_index=node_index,
|
||||
strategy_number=strategy_number,
|
||||
input_args=input_args,
|
||||
meta_arg_names=meta_arg_names)
|
||||
|
||||
tracer = ColoTracer()
|
||||
graph = tracer.trace(model, meta_args={"input": torch.rand(2, 2, 4, 16).to('meta')})
|
||||
gm = ColoGraphModule(model, graph)
|
||||
physical_mesh_id = torch.arange(0, 4)
|
||||
|
||||
print(graph)
|
||||
mesh_shape = (2, 2)
|
||||
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
|
||||
linear_mod_node = list(graph.nodes)[1]
|
||||
strategies_vector = StrategiesVector(linear_mod_node)
|
||||
|
||||
|
@ -43,26 +69,22 @@ def test_linear_module_handler(bias):
|
|||
assert op_data.data is not None
|
||||
|
||||
assert mapping['input'].name == "input_1"
|
||||
assert mapping['input'].data.is_meta
|
||||
assert mapping['input'].data.shape == torch.Size([2, 2, 4, 16])
|
||||
assert mapping['input'].type == OperationDataType.ARG
|
||||
assert mapping['input'].logical_shape == torch.Size([16, 16])
|
||||
|
||||
assert mapping['other'].name == "weight"
|
||||
assert mapping['other'].data.is_meta
|
||||
assert mapping['other'].data.shape == torch.Size([32, 16])
|
||||
assert mapping['other'].type == OperationDataType.PARAM
|
||||
assert mapping['other'].logical_shape == torch.Size([16, 32])
|
||||
|
||||
if bias:
|
||||
assert mapping['bias'].name == "bias"
|
||||
assert mapping['bias'].data.is_meta
|
||||
assert mapping['bias'].data.shape == torch.Size([32])
|
||||
assert mapping['bias'].type == OperationDataType.PARAM
|
||||
assert mapping['bias'].logical_shape == torch.Size([32])
|
||||
|
||||
assert mapping['output'].name == "_0"
|
||||
assert mapping['output'].data.is_meta
|
||||
assert mapping['output'].data.shape == torch.Size([2, 2, 4, 32])
|
||||
assert mapping['output'].type == OperationDataType.OUTPUT
|
||||
assert mapping['output'].logical_shape == torch.Size([16, 32])
|
||||
|
@ -110,19 +132,49 @@ def test_linear_module_handler(bias):
|
|||
assert bias_sharding_spec.sharding_sequence[-1] == output_sharding_spec.sharding_sequence[-1]
|
||||
|
||||
|
||||
@run_on_environment_flag(name='AUTO_PARALLEL')
|
||||
@parameterize('bias', [True, False])
|
||||
def test_linear_function_handler(bias):
|
||||
model = nn.Linear(16, 32, bias=bias).to('meta')
|
||||
tracer = ColoTracer()
|
||||
graph = tracer.trace(model, meta_args={"input": torch.rand(2, 2, 4, 16).to('meta')})
|
||||
gm = ColoGraphModule(model, graph)
|
||||
|
||||
class LinearModel(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, input, others, bias=None):
|
||||
x = nn.functional.linear(input, others, bias=bias)
|
||||
return x
|
||||
|
||||
|
||||
def check_linear_function_handler(rank, bias, world_size, port):
|
||||
disable_existing_loggers()
|
||||
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
model = LinearModel().cuda()
|
||||
physical_mesh_id = torch.arange(0, 4)
|
||||
|
||||
print(graph)
|
||||
mesh_shape = (2, 2)
|
||||
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
|
||||
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
|
||||
|
||||
input = torch.rand(2, 2, 4, 16).cuda()
|
||||
other = torch.rand(32, 16).cuda()
|
||||
# the index of linear node in computation graph
|
||||
node_index = 2
|
||||
# strategy number of linear node
|
||||
strategy_number = 10
|
||||
# construct input args
|
||||
input_args = [input, other]
|
||||
# construct meta arg names
|
||||
meta_arg_names = ['input', 'others']
|
||||
numerical_test_for_node_strategy(model=model,
|
||||
device_mesh=device_mesh,
|
||||
node_index=node_index,
|
||||
strategy_number=strategy_number,
|
||||
input_args=input_args,
|
||||
meta_arg_names=meta_arg_names)
|
||||
|
||||
tracer = ColoTracer()
|
||||
graph = tracer.trace(model,
|
||||
meta_args={
|
||||
"input": torch.rand(2, 2, 4, 16).to('meta'),
|
||||
'others': torch.rand(32, 16).to('meta')
|
||||
})
|
||||
gm = ColoGraphModule(model, graph)
|
||||
if bias:
|
||||
linear_func_node = list(graph.nodes)[3]
|
||||
else:
|
||||
|
@ -136,26 +188,22 @@ def test_linear_function_handler(bias):
|
|||
mapping = handler.get_operation_data_mapping()
|
||||
|
||||
assert mapping['input'].name == "input_1"
|
||||
assert mapping['input'].data.is_meta
|
||||
assert mapping['input'].data.shape == torch.Size([2, 2, 4, 16])
|
||||
assert mapping['input'].type == OperationDataType.ARG
|
||||
assert mapping['input'].logical_shape == torch.Size([16, 16])
|
||||
|
||||
assert mapping['other'].name == "weight"
|
||||
assert mapping['other'].data.is_meta
|
||||
assert mapping['other'].name == "others"
|
||||
assert mapping['other'].data.shape == torch.Size([32, 16])
|
||||
assert mapping['other'].type == OperationDataType.PARAM
|
||||
assert mapping['other'].type == OperationDataType.ARG
|
||||
assert mapping['other'].logical_shape == torch.Size([16, 32])
|
||||
|
||||
if bias:
|
||||
assert mapping['bias'].name == "bias"
|
||||
assert mapping['bias'].data.is_meta
|
||||
assert mapping['bias'].data.shape == torch.Size([32])
|
||||
assert mapping['bias'].type == OperationDataType.PARAM
|
||||
assert mapping['bias'].type == OperationDataType.ARG
|
||||
assert mapping['other'].logical_shape == torch.Size([16, 32])
|
||||
|
||||
assert mapping['output'].name == "linear"
|
||||
assert mapping['output'].data.is_meta
|
||||
assert mapping['output'].data.shape == torch.Size([2, 2, 4, 32])
|
||||
assert mapping['output'].type == OperationDataType.OUTPUT
|
||||
|
||||
|
@ -187,7 +235,7 @@ def test_linear_function_handler(bias):
|
|||
for strategy in strategies_vector:
|
||||
strategy: ShardingStrategy
|
||||
input_sharding_spec = strategy.get_sharding_spec_by_name('input_1')
|
||||
weight_sharding_spec = strategy.get_sharding_spec_by_name('weight')
|
||||
weight_sharding_spec = strategy.get_sharding_spec_by_name('others')
|
||||
output_sharding_spec = strategy.get_sharding_spec_by_name('linear')
|
||||
|
||||
if bias:
|
||||
|
@ -202,6 +250,17 @@ def test_linear_function_handler(bias):
|
|||
assert bias_sharding_spec.sharding_sequence[-1] == output_sharding_spec.sharding_sequence[-1]
|
||||
|
||||
|
||||
# @parameterize('bias', [True, False])
|
||||
@run_on_environment_flag(name='AUTO_PARALLEL')
|
||||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_linear_handler(bias=False):
|
||||
world_size = 4
|
||||
run_func_module = partial(check_linear_module_handler, bias=bias, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func_module, nprocs=world_size)
|
||||
run_func_function = partial(check_linear_function_handler, bias=bias, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func_function, nprocs=world_size)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_linear_module_handler()
|
||||
test_linear_function_handler()
|
||||
test_linear_handler()
|
||||
|
|
|
@ -10,7 +10,7 @@ from colossalai.auto_parallel.tensor_shard.solver import SolverOptions, Strategi
|
|||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.fx.tracer.tracer import ColoTracer
|
||||
from colossalai.tensor.shape_consistency import to_global
|
||||
from colossalai.testing.comparison import assert_close
|
||||
from colossalai.testing.comparison import assert_close, assert_close_loose
|
||||
|
||||
|
||||
def _build_model_to_compare(model: torch.nn.Module, input_args: List[torch.Tensor],
|
||||
|
@ -31,7 +31,6 @@ def _build_model_to_compare(model: torch.nn.Module, input_args: List[torch.Tenso
|
|||
arg_to_compare = copy.deepcopy(input_tensor)
|
||||
arg_to_compare.requires_grad = True
|
||||
wrapper(arg_to_compare, arg_index)
|
||||
# arg_to_compare.register_hook(hook_fn)
|
||||
args_to_compare.append(arg_to_compare)
|
||||
|
||||
for name, input_kwarg in input_kwargs.items():
|
||||
|
@ -68,8 +67,6 @@ def numerical_test_for_node_strategy(model: torch.nn.Module,
|
|||
model_to_shard, args_to_shard, kwargs_to_shard = _build_model_to_compare(model, input_args, input_kwargs,
|
||||
grad_to_shard_dict)
|
||||
|
||||
zero_tensor = torch.Tensor(0).cuda()
|
||||
|
||||
tracer = ColoTracer()
|
||||
input_sample = {}
|
||||
for input_arg, meta_arg_name in zip(input_args, meta_arg_names):
|
||||
|
@ -98,10 +95,8 @@ def numerical_test_for_node_strategy(model: torch.nn.Module,
|
|||
origin_node_sharding_spec_dict=origin_spec_dict,
|
||||
comm_actions_dict=comm_actions_dict,
|
||||
**kwargs_to_shard)
|
||||
# except:
|
||||
# print(gm)
|
||||
output_to_compare = model_to_compare(*args_to_compare, **kwargs_to_compare)
|
||||
assert_close((output - output_to_compare).sum(), zero_tensor)
|
||||
assert_close_helper(output, output_to_compare, strategy_index=strategy_index, type='forward output')
|
||||
|
||||
# backward result compare
|
||||
loss = output.sum()
|
||||
|
@ -111,7 +106,7 @@ def numerical_test_for_node_strategy(model: torch.nn.Module,
|
|||
for key in grad_to_shard_dict.keys():
|
||||
grad_to_shard = grad_to_shard_dict[key]
|
||||
grad_to_compare = grad_to_compare_dict[key]
|
||||
assert_close((grad_to_shard - grad_to_compare).sum(), zero_tensor)
|
||||
assert_close_helper(grad_to_shard, grad_to_compare, strategy_index=strategy_index, type='input grad')
|
||||
|
||||
# extract the strategy used in this iter
|
||||
strategy_in_use = target_node.strategies_vector[strategy_index]
|
||||
|
@ -123,4 +118,20 @@ def numerical_test_for_node_strategy(model: torch.nn.Module,
|
|||
grad_sharded = param_to_shard_dict[name].grad
|
||||
grad_to_compare = param_to_compare_dict[name].grad
|
||||
global_grad = to_global(grad_sharded, param_sharding_spec)
|
||||
assert_close((global_grad - grad_to_compare).sum(), zero_tensor)
|
||||
assert_close_helper(global_grad, grad_to_compare, strategy_index=strategy_index, type='param grad')
|
||||
|
||||
|
||||
def assert_close_helper(first: torch.Tensor,
|
||||
second: torch.Tensor,
|
||||
rtol: float = 1e-2,
|
||||
atol: float = 1e-2,
|
||||
strategy_index: int = -1,
|
||||
type: str = 'not defined'):
|
||||
"""
|
||||
This method is used to check whether the average difference between two tensors is as close as expected.
|
||||
"""
|
||||
# average_diff_tensor = ((first - second)/(second+0.1)).sum()/second.numel()
|
||||
try:
|
||||
assert_close(first, second, rtol=rtol, atol=atol)
|
||||
except:
|
||||
print(f'strategy index {strategy_index} encounter assert_close error on {type}')
|
||||
|
|
Loading…
Reference in New Issue