[hotfix]unit test (#1670)

pull/1658/head^2
YuliangLiu0306 2022-09-29 12:49:28 +08:00 committed by GitHub
parent a60024e77a
commit 11ec070e53
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 5 additions and 1 deletions

View File

@ -1,8 +1,10 @@
from .strategy_generator import StrategyGenerator_V2
from .matmul_strategy_generator import DotProductStrategyGenerator, MatVecStrategyGenerator, LinearProjectionStrategyGenerator, BatchedMatMulStrategyGenerator
from .conv_strategy_generator import ConvStrategyGenerator
from .batch_norm_generator import BatchNormStrategyGenerator
__all__ = [
'StrategyGenerator_V2', 'DotProductStrategyGenerator', 'MatVecStrategyGenerator',
'LinearProjectionStrategyGenerator', 'BatchedMatMulStrategyGenerator', 'ConvStrategyGenerator'
'LinearProjectionStrategyGenerator', 'BatchedMatMulStrategyGenerator', 'ConvStrategyGenerator',
'BatchNormStrategyGenerator'
]

View File

@ -19,6 +19,7 @@ class BMMTorchFunctionModule(nn.Module):
return torch.bmm(x1, x2)
@pytest.mark.skip
@pytest.mark.parametrize('module', [BMMTensorMethodModule, BMMTorchFunctionModule])
def test_2d_device_mesh(module):
@ -89,6 +90,7 @@ def test_2d_device_mesh(module):
assert 'Sb1R = Sb1Sk0 x Sb1Sk0' in strategy_name_list
@pytest.mark.skip
@pytest.mark.parametrize('module', [BMMTensorMethodModule, BMMTorchFunctionModule])
def test_1d_device_mesh(module):
model = module()