mirror of https://github.com/fail2ban/fail2ban
RF: revertably mock out exit call while testing new client/servers
parent
1417cc99ef
commit
fcda7c9ac7
|
@ -42,6 +42,7 @@ PRODUCTION = True
|
||||||
|
|
||||||
MAX_WAITTIME = 30
|
MAX_WAITTIME = 30
|
||||||
|
|
||||||
|
|
||||||
class Fail2banCmdLine():
|
class Fail2banCmdLine():
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
@ -256,14 +257,23 @@ class Fail2banCmdLine():
|
||||||
output(c)
|
output(c)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
#
|
||||||
|
# _exit is made to ease mocking out of the behaviour in tests,
|
||||||
|
# since method is also exposed in API via globally bound variable
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def exit(code=0): # pragma: no cover - can't test
|
def _exit(code=0):
|
||||||
logSys.debug("Exit with code %s", code)
|
|
||||||
if hasattr(os, '_exit') and os._exit:
|
if hasattr(os, '_exit') and os._exit:
|
||||||
os._exit(code)
|
os._exit(code)
|
||||||
else:
|
else:
|
||||||
sys.exit(code)
|
sys.exit(code)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def exit(code=0):
|
||||||
|
logSys.debug("Exit with code %s", code)
|
||||||
|
# import pdb; pdb.set_trace()
|
||||||
|
Fail2banCmdLine._exit(code)
|
||||||
|
|
||||||
|
|
||||||
# global exit handler:
|
# global exit handler:
|
||||||
exit = Fail2banCmdLine.exit
|
exit = Fail2banCmdLine.exit
|
||||||
|
|
||||||
|
|
|
@ -36,6 +36,7 @@ from functools import wraps
|
||||||
from threading import Thread
|
from threading import Thread
|
||||||
|
|
||||||
from ..client import fail2banclient, fail2banserver, fail2bancmdline
|
from ..client import fail2banclient, fail2banserver, fail2bancmdline
|
||||||
|
from ..client.fail2bancmdline import Fail2banCmdLine
|
||||||
from ..client.fail2banclient import exec_command_line as _exec_client, VisualWait
|
from ..client.fail2banclient import exec_command_line as _exec_client, VisualWait
|
||||||
from ..client.fail2banserver import Fail2banServer, exec_command_line as _exec_server
|
from ..client.fail2banserver import Fail2banServer, exec_command_line as _exec_server
|
||||||
from .. import protocol
|
from .. import protocol
|
||||||
|
@ -92,17 +93,6 @@ class FailExitException(fail2bancmdline.ExitException):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
def _test_exit(code=0):
|
|
||||||
logSys.debug("Exit with code %s", code)
|
|
||||||
if code == 0:
|
|
||||||
raise ExitException()
|
|
||||||
else:
|
|
||||||
raise FailExitException()
|
|
||||||
|
|
||||||
fail2bancmdline.exit = \
|
|
||||||
fail2banclient.exit = \
|
|
||||||
fail2banserver.exit = _test_exit
|
|
||||||
|
|
||||||
INTERACT = []
|
INTERACT = []
|
||||||
|
|
||||||
|
|
||||||
|
@ -256,11 +246,22 @@ class Fail2banClientServerBase(LogCaptureTestCase):
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
"""Call before every test case."""
|
"""Call before every test case."""
|
||||||
LogCaptureTestCase.setUp(self)
|
LogCaptureTestCase.setUp(self)
|
||||||
|
Fail2banCmdLine._exit = staticmethod(self._test_exit)
|
||||||
|
|
||||||
def tearDown(self):
|
def tearDown(self):
|
||||||
"""Call after every test case."""
|
"""Call after every test case."""
|
||||||
|
Fail2banCmdLine._exit = self._orig_exit
|
||||||
LogCaptureTestCase.tearDown(self)
|
LogCaptureTestCase.tearDown(self)
|
||||||
|
|
||||||
|
_orig_exit = Fail2banCmdLine._exit
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _test_exit(code=0):
|
||||||
|
if code == 0:
|
||||||
|
raise ExitException()
|
||||||
|
else:
|
||||||
|
raise FailExitException()
|
||||||
|
|
||||||
def _wait_for_srv(self, tmp, ready=True, startparams=None):
|
def _wait_for_srv(self, tmp, ready=True, startparams=None):
|
||||||
try:
|
try:
|
||||||
sock = pjoin(tmp, "f2b.sock")
|
sock = pjoin(tmp, "f2b.sock")
|
||||||
|
|
Loading…
Reference in New Issue