diff --git a/colossalai/testing/pytest_wrapper.py b/colossalai/testing/pytest_wrapper.py index b264b0090..6a80e1dcc 100644 --- a/colossalai/testing/pytest_wrapper.py +++ b/colossalai/testing/pytest_wrapper.py @@ -1,10 +1,9 @@ """ This file will not be automatically imported by `colossalai.testing` -as this file has a dependency on `pytest`. Therefore, you need to +as this file has a dependency on `pytest`. Therefore, you need to explicitly import this file `from colossalai.testing.pytest_wrapper import `.from """ -import pytest import os @@ -30,6 +29,12 @@ def run_on_environment_flag(name: str): pytest test_for_something.py """ + try: + import pytest + except ImportError: + raise ImportError( + 'This function requires `pytest` to be installed, please do `pip install pytest` and try again.') + assert isinstance(name, str) flag = os.environ.get(name.upper(), '0')