diff --git a/fail2ban/client/fail2bancmdline.py b/fail2ban/client/fail2bancmdline.py index 9110a2b8..19f596a2 100644 --- a/fail2ban/client/fail2bancmdline.py +++ b/fail2ban/client/fail2bancmdline.py @@ -42,6 +42,7 @@ PRODUCTION = True MAX_WAITTIME = 30 + class Fail2banCmdLine(): def __init__(self): @@ -256,14 +257,23 @@ class Fail2banCmdLine(): output(c) return True + # + # _exit is made to ease mocking out of the behaviour in tests, + # since method is also exposed in API via globally bound variable @staticmethod - def exit(code=0): # pragma: no cover - can't test - logSys.debug("Exit with code %s", code) + def _exit(code=0): if hasattr(os, '_exit') and os._exit: os._exit(code) else: sys.exit(code) + @staticmethod + def exit(code=0): + logSys.debug("Exit with code %s", code) + # import pdb; pdb.set_trace() + Fail2banCmdLine._exit(code) + + # global exit handler: exit = Fail2banCmdLine.exit diff --git a/fail2ban/tests/fail2banclienttestcase.py b/fail2ban/tests/fail2banclienttestcase.py index e69d31e1..fa8a210c 100644 --- a/fail2ban/tests/fail2banclienttestcase.py +++ b/fail2ban/tests/fail2banclienttestcase.py @@ -36,6 +36,7 @@ from functools import wraps from threading import Thread from ..client import fail2banclient, fail2banserver, fail2bancmdline +from ..client.fail2bancmdline import Fail2banCmdLine from ..client.fail2banclient import exec_command_line as _exec_client, VisualWait from ..client.fail2banserver import Fail2banServer, exec_command_line as _exec_server from .. import protocol @@ -92,17 +93,6 @@ class FailExitException(fail2bancmdline.ExitException): pass -def _test_exit(code=0): - logSys.debug("Exit with code %s", code) - if code == 0: - raise ExitException() - else: - raise FailExitException() - -fail2bancmdline.exit = \ -fail2banclient.exit = \ -fail2banserver.exit = _test_exit - INTERACT = [] @@ -256,11 +246,22 @@ class Fail2banClientServerBase(LogCaptureTestCase): def setUp(self): """Call before every test case.""" LogCaptureTestCase.setUp(self) + Fail2banCmdLine._exit = staticmethod(self._test_exit) def tearDown(self): """Call after every test case.""" + Fail2banCmdLine._exit = self._orig_exit LogCaptureTestCase.tearDown(self) + _orig_exit = Fail2banCmdLine._exit + + @staticmethod + def _test_exit(code=0): + if code == 0: + raise ExitException() + else: + raise FailExitException() + def _wait_for_srv(self, tmp, ready=True, startparams=None): try: sock = pjoin(tmp, "f2b.sock")