mirror of https://github.com/fail2ban/fail2ban
several bug fixed: fork in client-server test cases prohibited, all worker threads daemonized (to prevent hanging on exit).
parent
afa1cdc3ae
commit
2fcb6358ff
|
@ -34,7 +34,7 @@ from threading import Thread
|
|||
from ..version import version
|
||||
from .csocket import CSocket
|
||||
from .beautifier import Beautifier
|
||||
from .fail2bancmdline import Fail2banCmdLine, ExitException, logSys, exit, output
|
||||
from .fail2bancmdline import Fail2banCmdLine, ExitException, PRODUCTION, logSys, exit, output
|
||||
|
||||
MAX_WAITTIME = 30
|
||||
|
||||
|
@ -108,8 +108,13 @@ class Fail2banClient(Fail2banCmdLine, Thread):
|
|||
logSys.error(e)
|
||||
return False
|
||||
finally:
|
||||
# prevent errors by close during shutdown (on exit command):
|
||||
if client:
|
||||
try :
|
||||
client.close()
|
||||
except Exception as e:
|
||||
if showRet or self._conf["verbose"] > 1:
|
||||
logSys.debug(e)
|
||||
if showRet or c[0] == 'echo':
|
||||
sys.stdout.flush()
|
||||
return streamRet
|
||||
|
@ -184,7 +189,9 @@ class Fail2banClient(Fail2banCmdLine, Thread):
|
|||
return False
|
||||
else:
|
||||
# In foreground mode we should make server/client communication in different threads:
|
||||
Thread(target=Fail2banClient.__processStartStreamAfterWait, args=(self, stream, False)).start()
|
||||
th = Thread(target=Fail2banClient.__processStartStreamAfterWait, args=(self, stream, False))
|
||||
th.daemon = True
|
||||
th.start()
|
||||
# Mark current (main) thread as daemon:
|
||||
self.setDaemon(True)
|
||||
# Start server direct here in main thread (not fork):
|
||||
|
@ -197,8 +204,6 @@ class Fail2banClient(Fail2banCmdLine, Thread):
|
|||
logSys.error("Exception while starting server " + ("background" if background else "foreground"))
|
||||
logSys.error(e)
|
||||
return False
|
||||
finally:
|
||||
self._alive = False
|
||||
|
||||
return True
|
||||
|
||||
|
@ -206,7 +211,9 @@ class Fail2banClient(Fail2banCmdLine, Thread):
|
|||
def configureServer(self, async=True, phase=None):
|
||||
# if asynchron start this operation in the new thread:
|
||||
if async:
|
||||
return Thread(target=Fail2banClient.configureServer, args=(self, False, phase)).start()
|
||||
th = Thread(target=Fail2banClient.configureServer, args=(self, False, phase))
|
||||
th.daemon = True
|
||||
return th.start()
|
||||
# prepare: read config, check configuration is valid, etc.:
|
||||
if phase is not None:
|
||||
phase['start'] = True
|
||||
|
@ -290,7 +297,6 @@ class Fail2banClient(Fail2banCmdLine, Thread):
|
|||
"server, adding the -x option will do it")
|
||||
if self._server:
|
||||
self._server.quit()
|
||||
exit(-1)
|
||||
return False
|
||||
return True
|
||||
|
||||
|
@ -299,10 +305,12 @@ class Fail2banClient(Fail2banCmdLine, Thread):
|
|||
maxtime = MAX_WAITTIME
|
||||
# Wait for the server to start (the server has 30 seconds to answer ping)
|
||||
starttime = time.time()
|
||||
logSys.debug("__waitOnServer: %r", (alive, maxtime))
|
||||
with VisualWait(self._conf["verbose"]) as vis:
|
||||
while self._alive and not self.__ping() == alive or (
|
||||
while self._alive and (
|
||||
not self.__ping() == alive or (
|
||||
not alive and os.path.exists(self._conf["socket"])
|
||||
):
|
||||
)):
|
||||
now = time.time()
|
||||
# Wonderful visual :)
|
||||
if now > starttime + 1:
|
||||
|
@ -365,6 +373,7 @@ class Fail2banClient(Fail2banCmdLine, Thread):
|
|||
return False
|
||||
return self.__processCommand(args)
|
||||
finally:
|
||||
self._alive = False
|
||||
for s, sh in _prev_signals.iteritems():
|
||||
signal.signal(s, sh)
|
||||
|
||||
|
|
|
@ -25,15 +25,14 @@ import os
|
|||
import sys
|
||||
|
||||
from ..version import version
|
||||
from ..server.server import Server, ServerDaemonize
|
||||
from ..server.server import Server
|
||||
from ..server.utils import Utils
|
||||
from .fail2bancmdline import Fail2banCmdLine, logSys, exit
|
||||
from .fail2bancmdline import Fail2banCmdLine, logSys, PRODUCTION, exit
|
||||
|
||||
MAX_WAITTIME = 30
|
||||
|
||||
SERVER = "fail2ban-server"
|
||||
|
||||
|
||||
##
|
||||
# \mainpage Fail2Ban
|
||||
#
|
||||
|
@ -51,6 +50,7 @@ class Fail2banServer(Fail2banCmdLine):
|
|||
|
||||
@staticmethod
|
||||
def startServerDirect(conf, daemon=True):
|
||||
logSys.debug("-- direct starting of server in %s, deamon: %s", os.getpid(), daemon)
|
||||
server = None
|
||||
try:
|
||||
# Start it in foreground (current thread, not new process),
|
||||
|
@ -59,8 +59,6 @@ class Fail2banServer(Fail2banCmdLine):
|
|||
server.start(conf["socket"],
|
||||
conf["pidfile"], conf["force"],
|
||||
conf=conf)
|
||||
except ServerDaemonize:
|
||||
pass
|
||||
except Exception, e:
|
||||
logSys.exception(e)
|
||||
if server:
|
||||
|
@ -82,9 +80,10 @@ class Fail2banServer(Fail2banCmdLine):
|
|||
startdir = os.path.dirname(sys.argv[0])
|
||||
# Forks the current process, don't fork if async specified (ex: test cases)
|
||||
pid = 0
|
||||
frk = not conf["async"]
|
||||
frk = not conf["async"] and PRODUCTION
|
||||
if frk:
|
||||
pid = os.fork()
|
||||
logSys.debug("-- async starting of server in %s, fork: %s - %s", os.getpid(), frk, pid)
|
||||
if pid == 0:
|
||||
args = list()
|
||||
args.append(SERVER)
|
||||
|
|
|
@ -51,6 +51,8 @@ class JailThread(Thread):
|
|||
|
||||
def __init__(self, name=None):
|
||||
super(JailThread, self).__init__(name=name)
|
||||
## Should going with main thread also:
|
||||
self.daemon = True
|
||||
## Control the state of the thread.
|
||||
self.active = False
|
||||
## Control the idle state of the thread.
|
||||
|
|
|
@ -351,6 +351,9 @@ class Server:
|
|||
def getBanTime(self, name):
|
||||
return self.__jails[name].actions.getBanTime()
|
||||
|
||||
def isStarted(self):
|
||||
self.__asyncServer.isActive()
|
||||
|
||||
def isAlive(self, jailnum=None):
|
||||
if jailnum is not None and len(self.__jails) != jailnum:
|
||||
return 0
|
||||
|
@ -643,6 +646,3 @@ class Server:
|
|||
|
||||
class ServerInitializationError(Exception):
|
||||
pass
|
||||
|
||||
class ServerDaemonize(Exception):
|
||||
pass
|
|
@ -65,6 +65,7 @@ class SMTPActionTest(unittest.TestCase):
|
|||
self._active = True
|
||||
self._loop_thread = threading.Thread(
|
||||
target=asyncserver.loop, kwargs={'active': lambda: self._active})
|
||||
self._loop_thread.daemon = True
|
||||
self._loop_thread.start()
|
||||
|
||||
def tearDown(self):
|
||||
|
|
|
@ -98,7 +98,9 @@ def _test_raw_input(*args):
|
|||
fail2banclient.raw_input = _test_raw_input
|
||||
|
||||
# prevents change logging params, log capturing, etc:
|
||||
fail2bancmdline.PRODUCTION = False
|
||||
fail2bancmdline.PRODUCTION = \
|
||||
fail2banclient.PRODUCTION = \
|
||||
fail2banserver.PRODUCTION = False
|
||||
|
||||
|
||||
class ExitException(fail2bancmdline.ExitException):
|
||||
|
@ -117,9 +119,9 @@ def _out_file(fn): # pragma: no cover
|
|||
def _start_params(tmp, use_stock=False, logtarget="/dev/null"):
|
||||
cfg = tmp+"/config"
|
||||
if use_stock and STOCK:
|
||||
# copy config:
|
||||
# copy config (sub-directories as alias):
|
||||
def ig_dirs(dir, files):
|
||||
return [f for f in files if not os.path.isfile(os.path.join(dir, f))]
|
||||
return [f for f in files if os.path.isdir(os.path.join(dir, f))]
|
||||
shutil.copytree(STOCK_CONF_DIR, cfg, ignore=ig_dirs)
|
||||
os.symlink(STOCK_CONF_DIR+"/action.d", cfg+"/action.d")
|
||||
os.symlink(STOCK_CONF_DIR+"/filter.d", cfg+"/filter.d")
|
||||
|
@ -169,6 +171,47 @@ def _start_params(tmp, use_stock=False, logtarget="/dev/null"):
|
|||
"--logtarget", logtarget, "--loglevel", "DEBUG", "--syslogsocket", "auto",
|
||||
"-s", tmp+"/f2b.sock", "-p", tmp+"/f2b.pid")
|
||||
|
||||
def _kill_srv(pidfile): # pragma: no cover
|
||||
def _pid_exists(pid):
|
||||
try:
|
||||
os.kill(pid, 0)
|
||||
return True
|
||||
except OSError:
|
||||
return False
|
||||
logSys.debug("-- cleanup: %r", (pidfile, os.path.isdir(pidfile)))
|
||||
if os.path.isdir(pidfile):
|
||||
piddir = pidfile
|
||||
pidfile = piddir + "/f2b.pid"
|
||||
if not os.path.isfile(pidfile):
|
||||
pidfile = piddir + "/fail2ban.pid"
|
||||
if not os.path.isfile(pidfile):
|
||||
logSys.debug("--- cleanup: no pidfile for %r", piddir)
|
||||
return True
|
||||
f = pid = None
|
||||
try:
|
||||
logSys.debug("--- cleanup pidfile: %r", pidfile)
|
||||
f = open(pidfile)
|
||||
pid = f.read().split()[1]
|
||||
pid = int(pid)
|
||||
logSys.debug("--- cleanup pid: %r", pid)
|
||||
if pid <= 0:
|
||||
raise ValueError('pid %s of %s is invalid' % (pid, pidfile))
|
||||
if not _pid_exists(pid):
|
||||
return True
|
||||
## try to preper stop (have signal handler):
|
||||
os.kill(pid, signal.SIGTERM)
|
||||
## check still exists after small timeout:
|
||||
if not Utils.wait_for(lambda: not _pid_exists(pid), MAX_WAITTIME / 3):
|
||||
## try to kill hereafter:
|
||||
os.kill(pid, signal.SIGKILL)
|
||||
return not _pid_exists(pid)
|
||||
except Exception as e:
|
||||
sysLog.debug(e)
|
||||
finally:
|
||||
if f is not None:
|
||||
f.close()
|
||||
return True
|
||||
|
||||
|
||||
class Fail2banClientTest(LogCaptureTestCase):
|
||||
|
||||
|
@ -188,6 +231,7 @@ class Fail2banClientTest(LogCaptureTestCase):
|
|||
|
||||
@withtmpdir
|
||||
def testClientStartBackgroundInside(self, tmp):
|
||||
try:
|
||||
# always add "--async" by start inside, should don't fork by async (not replace client with server, just start in new process)
|
||||
# (we can't fork the test cases process):
|
||||
startparams = _start_params(tmp, True)
|
||||
|
@ -208,9 +252,12 @@ class Fail2banClientTest(LogCaptureTestCase):
|
|||
(CLIENT,) + startparams + ("stop",))
|
||||
self.assertLogged("Shutdown successful")
|
||||
self.assertLogged("Exit with code 0")
|
||||
finally:
|
||||
_kill_srv(tmp)
|
||||
|
||||
@withtmpdir
|
||||
def testClientStartBackgroundCall(self, tmp):
|
||||
try:
|
||||
global INTERACT
|
||||
startparams = _start_params(tmp)
|
||||
# start (without async in new process):
|
||||
|
@ -258,23 +305,31 @@ class Fail2banClientTest(LogCaptureTestCase):
|
|||
(CLIENT,) + startparams + ("stop",))
|
||||
self.assertLogged("Shutdown successful")
|
||||
self.assertLogged("Exit with code 0")
|
||||
finally:
|
||||
_kill_srv(tmp)
|
||||
|
||||
def _testClientStartForeground(self, tmp, startparams, phase):
|
||||
# start and wait to end (foreground):
|
||||
logSys.debug("-- start of test worker")
|
||||
phase['start'] = True
|
||||
self.assertRaises(ExitException, _exec_client,
|
||||
(CLIENT, "-f") + startparams + ("start",))
|
||||
# end :
|
||||
phase['end'] = True
|
||||
logSys.debug("-- end of test worker")
|
||||
|
||||
@withtmpdir
|
||||
def testClientStartForeground(self, tmp):
|
||||
th = None
|
||||
try:
|
||||
# started directly here, so prevent overwrite test cases logger with "INHERITED"
|
||||
startparams = _start_params(tmp, logtarget="INHERITED")
|
||||
# because foreground block execution - start it in thread:
|
||||
phase = dict()
|
||||
Thread(name="_TestCaseWorker",
|
||||
target=Fail2banClientTest._testClientStartForeground, args=(self, tmp, startparams, phase)).start()
|
||||
th = Thread(name="_TestCaseWorker",
|
||||
target=Fail2banClientTest._testClientStartForeground, args=(self, tmp, startparams, phase))
|
||||
th.daemon = True
|
||||
th.start()
|
||||
try:
|
||||
# wait for start thread:
|
||||
Utils.wait_for(lambda: phase.get('start', None) is not None, MAX_WAITTIME)
|
||||
|
@ -297,10 +352,14 @@ class Fail2banClientTest(LogCaptureTestCase):
|
|||
Utils.wait_for(lambda: phase.get('end', None) is not None, MAX_WAITTIME)
|
||||
self.assertTrue(phase.get('end', None))
|
||||
self.assertLogged("Shutdown successful", "Exiting Fail2ban")
|
||||
self.assertLogged("Exit with code 0")
|
||||
finally:
|
||||
_kill_srv(tmp)
|
||||
if th:
|
||||
th.join()
|
||||
|
||||
@withtmpdir
|
||||
def testClientFailStart(self, tmp):
|
||||
try:
|
||||
self.assertRaises(FailExitException, _exec_client,
|
||||
(CLIENT, "--async", "-c", tmp+"/miss", "start",))
|
||||
self.assertLogged("Base configuration directory " + tmp+"/miss" + " does not exist")
|
||||
|
@ -308,6 +367,8 @@ class Fail2banClientTest(LogCaptureTestCase):
|
|||
self.assertRaises(FailExitException, _exec_client,
|
||||
(CLIENT, "--async", "-c", CONF_DIR, "-s", tmp+"/miss/f2b.sock", "start",))
|
||||
self.assertLogged("There is no directory " + tmp+"/miss" + " to contain the socket file")
|
||||
finally:
|
||||
_kill_srv(tmp)
|
||||
|
||||
def testVisualWait(self):
|
||||
sleeptime = 0.035
|
||||
|
@ -339,6 +400,7 @@ class Fail2banServerTest(LogCaptureTestCase):
|
|||
|
||||
@withtmpdir
|
||||
def testServerStartBackground(self, tmp):
|
||||
try:
|
||||
# don't add "--async" by start, because if will fork current process by daemonize
|
||||
# (we can't fork the test cases process),
|
||||
# because server started internal communication in new thread use INHERITED as logtarget here:
|
||||
|
@ -360,22 +422,31 @@ class Fail2banServerTest(LogCaptureTestCase):
|
|||
(SERVER,) + startparams + ("stop",))
|
||||
self.assertLogged("Shutdown successful")
|
||||
self.assertLogged("Exit with code 0")
|
||||
finally:
|
||||
_kill_srv(tmp)
|
||||
|
||||
def _testServerStartForeground(self, tmp, startparams, phase):
|
||||
# start and wait to end (foreground):
|
||||
logSys.debug("-- start of test worker")
|
||||
phase['start'] = True
|
||||
self.assertRaises(ExitException, _exec_server,
|
||||
(SERVER, "-f") + startparams + ("start",))
|
||||
# end :
|
||||
phase['end'] = True
|
||||
logSys.debug("-- end of test worker")
|
||||
|
||||
@withtmpdir
|
||||
def testServerStartForeground(self, tmp):
|
||||
th = None
|
||||
try:
|
||||
# started directly here, so prevent overwrite test cases logger with "INHERITED"
|
||||
startparams = _start_params(tmp, logtarget="INHERITED")
|
||||
# because foreground block execution - start it in thread:
|
||||
phase = dict()
|
||||
Thread(name="_TestCaseWorker",
|
||||
target=Fail2banServerTest._testServerStartForeground, args=(self, tmp, startparams, phase)).start()
|
||||
th = Thread(name="_TestCaseWorker",
|
||||
target=Fail2banServerTest._testServerStartForeground, args=(self, tmp, startparams, phase))
|
||||
th.daemon = True
|
||||
th.start()
|
||||
try:
|
||||
# wait for start thread:
|
||||
Utils.wait_for(lambda: phase.get('start', None) is not None, MAX_WAITTIME)
|
||||
|
@ -398,10 +469,14 @@ class Fail2banServerTest(LogCaptureTestCase):
|
|||
Utils.wait_for(lambda: phase.get('end', None) is not None, MAX_WAITTIME)
|
||||
self.assertTrue(phase.get('end', None))
|
||||
self.assertLogged("Shutdown successful", "Exiting Fail2ban")
|
||||
self.assertLogged("Exit with code 0")
|
||||
finally:
|
||||
_kill_srv(tmp)
|
||||
if th:
|
||||
th.join()
|
||||
|
||||
@withtmpdir
|
||||
def testServerFailStart(self, tmp):
|
||||
try:
|
||||
self.assertRaises(FailExitException, _exec_server,
|
||||
(SERVER, "-c", tmp+"/miss",))
|
||||
self.assertLogged("Base configuration directory " + tmp+"/miss" + " does not exist")
|
||||
|
@ -409,3 +484,5 @@ class Fail2banServerTest(LogCaptureTestCase):
|
|||
self.assertRaises(FailExitException, _exec_server,
|
||||
(SERVER, "-c", CONF_DIR, "-s", tmp+"/miss/f2b.sock",))
|
||||
self.assertLogged("There is no directory " + tmp+"/miss" + " to contain the socket file")
|
||||
finally:
|
||||
_kill_srv(tmp)
|
||||
|
|
|
@ -340,9 +340,8 @@ class LogCaptureTestCase(unittest.TestCase):
|
|||
# Let's log everything into a string
|
||||
self._log = StringIO()
|
||||
logSys.handlers = [logging.StreamHandler(self._log)]
|
||||
if self._old_level <= logging.DEBUG:
|
||||
if self._old_level <= logging.DEBUG: # so if DEBUG etc -- show them (and log it in travis)!
|
||||
print("")
|
||||
if self._old_level < logging.DEBUG: # so if HEAVYDEBUG etc -- show them!
|
||||
logSys.handlers += self._old_handlers
|
||||
logSys.debug('--'*40)
|
||||
logSys.setLevel(getattr(logging, 'DEBUG'))
|
||||
|
|
Loading…
Reference in New Issue