diff --git a/bin/fail2ban-client b/bin/fail2ban-client index 19e76a98..f5ae7946 100755 --- a/bin/fail2ban-client +++ b/bin/fail2ban-client @@ -31,7 +31,7 @@ __author__ = "Fail2Ban Developers" __copyright__ = "Copyright (c) 2004-2008 Cyril Jaquier, 2012-2014 Yaroslav Halchenko, 2014-2016 Serg G. Brester" __license__ = "GPL" -from fail2ban.client.fail2banclient import exec_command_line +from fail2ban.client.fail2banclient import exec_command_line, sys if __name__ == "__main__": - exec_command_line() + exec_command_line(sys.argv) diff --git a/bin/fail2ban-server b/bin/fail2ban-server index 8e64d865..ffafabe2 100755 --- a/bin/fail2ban-server +++ b/bin/fail2ban-server @@ -31,7 +31,7 @@ __author__ = "Fail2Ban Developers" __copyright__ = "Copyright (c) 2004-2008 Cyril Jaquier, 2012-2014 Yaroslav Halchenko, 2014-2016 Serg G. Brester" __license__ = "GPL" -from fail2ban.client.fail2banserver import exec_command_line +from fail2ban.client.fail2banserver import exec_command_line, sys if __name__ == "__main__": - exec_command_line() + exec_command_line(sys.argv) diff --git a/fail2ban/client/fail2banclient.py b/fail2ban/client/fail2banclient.py index 7adbab95..736f8fd2 100755 --- a/fail2ban/client/fail2banclient.py +++ b/fail2ban/client/fail2banclient.py @@ -28,12 +28,20 @@ import socket import sys import time +import threading from threading import Thread from ..version import version from .csocket import CSocket from .beautifier import Beautifier -from .fail2bancmdline import Fail2banCmdLine, logSys, exit +from .fail2bancmdline import Fail2banCmdLine, ExitException, logSys, exit, output + +MAX_WAITTIME = 30 + + +def _thread_name(): + return threading.current_thread().__class__.__name__ + ## # @@ -51,13 +59,13 @@ class Fail2banClient(Fail2banCmdLine, Thread): self._beautifier = None def dispInteractive(self): - print "Fail2Ban v" + version + " reads log file that contains password failure report" - print "and bans the corresponding IP addresses using firewall rules." - print + output("Fail2Ban v" + version + " reads log file that contains password failure report") + output("and bans the corresponding IP addresses using firewall rules.") + output("") def __sigTERMhandler(self, signum, frame): # Print a new line because we probably come from wait - print + output("") logSys.warning("Caught signal %d. Exiting" % signum) exit(-1) @@ -85,11 +93,11 @@ class Fail2banClient(Fail2banCmdLine, Thread): if ret[0] == 0: logSys.debug("OK : " + `ret[1]`) if showRet or c[0] == 'echo': - print beautifier.beautify(ret[1]) + output(beautifier.beautify(ret[1])) else: logSys.error("NOK: " + `ret[1].args`) if showRet: - print beautifier.beautifyError(ret[1]) + output(beautifier.beautifyError(ret[1])) streamRet = False except socket.error: if showRet or self._conf["verbose"] > 1: @@ -182,10 +190,13 @@ class Fail2banClient(Fail2banCmdLine, Thread): # Start server direct here in main thread (not fork): self._server = Fail2banServer.startServerDirect(self._conf, False) + except ExitException: + pass except Exception as e: - print - logSys.error("Exception while starting server foreground") + output("") + logSys.error("Exception while starting server " + ("background" if background else "foreground")) logSys.error(e) + return False finally: self._alive = False @@ -229,18 +240,18 @@ class Fail2banClient(Fail2banCmdLine, Thread): elif len(cmd) == 1 and cmd[0] == "restart": if self._conf.get("interactive", False): - print(' ## stop ... ') + output(' ## stop ... ') self.__processCommand(['stop']) self.__waitOnServer(False) # in interactive mode reset config, to make full-reload if there something changed: if self._conf.get("interactive", False): - print(' ## load configuration ... ') + output(' ## load configuration ... ') self.resetConf() ret = self.initCmdLine(self._argv) if ret is not None: return ret if self._conf.get("interactive", False): - print(' ## start ... ') + output(' ## start ... ') return self.__processCommand(['start']) elif len(cmd) >= 1 and cmd[0] == "reload": @@ -283,7 +294,9 @@ class Fail2banClient(Fail2banCmdLine, Thread): return False return True - def __waitOnServer(self, alive=True, maxtime=30): + def __waitOnServer(self, alive=True, maxtime=None): + if maxtime is None: + maxtime = MAX_WAITTIME # Wait for the server to start (the server has 30 seconds to answer ping) starttime = time.time() with VisualWait(self._conf["verbose"]) as vis: @@ -301,53 +314,59 @@ class Fail2banClient(Fail2banCmdLine, Thread): def start(self, argv): # Install signal handlers - signal.signal(signal.SIGTERM, self.__sigTERMhandler) - signal.signal(signal.SIGINT, self.__sigTERMhandler) - - # Command line options - if self._argv is None: - ret = self.initCmdLine(argv) - if ret is not None: - return ret - - # Commands - args = self._args - - # Interactive mode - if self._conf.get("interactive", False): - try: - import readline - except ImportError: - logSys.error("Readline not available") - return False - try: - ret = True - if len(args) > 0: - ret = self.__processCommand(args) - if ret: - readline.parse_and_bind("tab: complete") - self.dispInteractive() - while True: - cmd = raw_input(self.PROMPT) - if cmd == "exit" or cmd == "quit": - # Exit - return True - if cmd == "help": - self.dispUsage() - elif not cmd == "": - try: - self.__processCommand(shlex.split(cmd)) - except Exception, e: - logSys.error(e) - except (EOFError, KeyboardInterrupt): - print - return True - # Single command mode - else: - if len(args) < 1: - self.dispUsage() - return False - return self.__processCommand(args) + _prev_signals = {} + if _thread_name() == '_MainThread': + for s in (signal.SIGTERM, signal.SIGINT): + _prev_signals[s] = signal.getsignal(s) + signal.signal(s, self.__sigTERMhandler) + try: + # Command line options + if self._argv is None: + ret = self.initCmdLine(argv) + if ret is not None: + return ret + + # Commands + args = self._args + + # Interactive mode + if self._conf.get("interactive", False): + try: + import readline + except ImportError: + logSys.error("Readline not available") + return False + try: + ret = True + if len(args) > 0: + ret = self.__processCommand(args) + if ret: + readline.parse_and_bind("tab: complete") + self.dispInteractive() + while True: + cmd = raw_input(self.PROMPT) + if cmd == "exit" or cmd == "quit": + # Exit + return True + if cmd == "help": + self.dispUsage() + elif not cmd == "": + try: + self.__processCommand(shlex.split(cmd)) + except Exception, e: + logSys.error(e) + except (EOFError, KeyboardInterrupt): + output("") + return True + # Single command mode + else: + if len(args) < 1: + self.dispUsage() + return False + return self.__processCommand(args) + finally: + for s, sh in _prev_signals.iteritems(): + signal.signal(s, sh) class ServerExecutionException(Exception): @@ -361,7 +380,8 @@ class ServerExecutionException(Exception): class _VisualWait: pos = 0 delta = 1 - maxpos = 10 + def __init__(self, maxpos=10): + self.maxpos = maxpos def __enter__(self): return self def __exit__(self, *args): @@ -390,14 +410,14 @@ class _NotVisualWait: def heartbeat(self): pass -def VisualWait(verbose): - return _VisualWait() if verbose > 1 else _NotVisualWait() +def VisualWait(verbose, *args, **kwargs): + return _VisualWait(*args, **kwargs) if verbose > 1 else _NotVisualWait() -def exec_command_line(): # pragma: no cover - can't test main +def exec_command_line(argv): client = Fail2banClient() # Exit with correct return value - if client.start(sys.argv): + if client.start(argv): exit(0) else: exit(-1) diff --git a/fail2ban/client/fail2bancmdline.py b/fail2ban/client/fail2bancmdline.py index e23f9b19..2ed6a499 100644 --- a/fail2ban/client/fail2bancmdline.py +++ b/fail2ban/client/fail2bancmdline.py @@ -33,7 +33,12 @@ from ..helpers import getLogger # Gets the instance of the logger. logSys = getLogger("fail2ban") +def output(s): + print(s) + CONFIG_PARAMS = ("socket", "pidfile", "logtarget", "loglevel", "syslogsocket",) +# Used to signal - we are in test cases (ex: prevents change logging params, log capturing, etc) +PRODUCTION = True class Fail2banCmdLine(): @@ -71,50 +76,50 @@ class Fail2banCmdLine(): self.__dict__[o] = obj.__dict__[o] def dispVersion(self): - print "Fail2Ban v" + version - print - print "Copyright (c) 2004-2008 Cyril Jaquier, 2008- Fail2Ban Contributors" - print "Copyright of modifications held by their respective authors." - print "Licensed under the GNU General Public License v2 (GPL)." - print - print "Written by Cyril Jaquier ." - print "Many contributions by Yaroslav O. Halchenko ." + output("Fail2Ban v" + version) + output("") + output("Copyright (c) 2004-2008 Cyril Jaquier, 2008- Fail2Ban Contributors") + output("Copyright of modifications held by their respective authors.") + output("Licensed under the GNU General Public License v2 (GPL).") + output("") + output("Written by Cyril Jaquier .") + output("Many contributions by Yaroslav O. Halchenko .") def dispUsage(self): """ Prints Fail2Ban command line options and exits """ caller = os.path.basename(self._argv[0]) - print "Usage: "+caller+" [OPTIONS]" + (" " if not caller.endswith('server') else "") - print - print "Fail2Ban v" + version + " reads log file that contains password failure report" - print "and bans the corresponding IP addresses using firewall rules." - print - print "Options:" - print " -c configuration directory" - print " -s socket path" - print " -p pidfile path" - print " --loglevel logging level" - print " --logtarget |STDOUT|STDERR|SYSLOG" - print " --syslogsocket auto|" - print " -d dump configuration. For debugging" - print " -i interactive mode" - print " -v increase verbosity" - print " -q decrease verbosity" - print " -x force execution of the server (remove socket file)" - print " -b start server in background (default)" - print " -f start server in foreground" - print " --async start server in async mode (for internal usage only, don't read configuration)" - print " -h, --help display this help message" - print " -V, --version print the version" + output("Usage: "+caller+" [OPTIONS]" + (" " if not caller.endswith('server') else "")) + output("") + output("Fail2Ban v" + version + " reads log file that contains password failure report") + output("and bans the corresponding IP addresses using firewall rules.") + output("") + output("Options:") + output(" -c configuration directory") + output(" -s socket path") + output(" -p pidfile path") + output(" --loglevel logging level") + output(" --logtarget |STDOUT|STDERR|SYSLOG") + output(" --syslogsocket auto|") + output(" -d dump configuration. For debugging") + output(" -i interactive mode") + output(" -v increase verbosity") + output(" -q decrease verbosity") + output(" -x force execution of the server (remove socket file)") + output(" -b start server in background (default)") + output(" -f start server in foreground") + output(" --async start server in async mode (for internal usage only, don't read configuration)") + output(" -h, --help display this help message") + output(" -V, --version print the version") if not caller.endswith('server'): - print - print "Command:" + output("") + output("Command:") # Prints the protocol printFormatted() - print - print "Report bugs to https://github.com/fail2ban/fail2ban/issues" + output("") + output("Report bugs to https://github.com/fail2ban/fail2ban/issues") def __getCmdLineOptions(self, optList): """ Gets the command line options @@ -147,71 +152,80 @@ class Fail2banCmdLine(): self._conf["background"] = False elif o in ["-h", "--help"]: self.dispUsage() - exit(0) + return True elif o in ["-V", "--version"]: self.dispVersion() - exit(0) + return True + return None def initCmdLine(self, argv): - # First time? - initial = (self._argv is None) + try: + # First time? + initial = (self._argv is None) - # Command line options - self._argv = argv + # Command line options + self._argv = argv + logSys.info("Using start params %s", argv[1:]) - # Reads the command line options. - try: - cmdOpts = 'hc:s:p:xfbdviqV' - cmdLongOpts = ['loglevel=', 'logtarget=', 'syslogsocket=', 'async', 'help', 'version'] - optList, self._args = getopt.getopt(self._argv[1:], cmdOpts, cmdLongOpts) - except getopt.GetoptError: - self.dispUsage() - exit(-1) - - self.__getCmdLineOptions(optList) - - if initial: - verbose = self._conf["verbose"] - if verbose <= 0: - logSys.setLevel(logging.ERROR) - elif verbose == 1: - logSys.setLevel(logging.WARNING) - elif verbose == 2: - logSys.setLevel(logging.INFO) - elif verbose == 3: - logSys.setLevel(logging.DEBUG) - else: - logSys.setLevel(logging.HEAVYDEBUG) - # Add the default logging handler to dump to stderr - logout = logging.StreamHandler(sys.stderr) - # set a format which is simpler for console use - formatter = logging.Formatter('%(levelname)-6s %(message)s') - # tell the handler to use this format - logout.setFormatter(formatter) - logSys.addHandler(logout) - - # Set expected parameters (like socket, pidfile, etc) from configuration, - # if those not yet specified, in which read configuration only if needed here: - conf = None - for o in CONFIG_PARAMS: - if self._conf.get(o, None) is None: - if not conf: - self.configurator.readEarly() - conf = self.configurator.getEarlyOptions() - self._conf[o] = conf[o] - - logSys.info("Using socket file %s", self._conf["socket"]) - - logSys.info("Using pid file %s, [%s] logging to %s", - self._conf["pidfile"], self._conf["loglevel"], self._conf["logtarget"]) - - if self._conf.get("dump", False): - ret, stream = self.readConfig() - self.dumpConfig(stream) - return ret - - # Nothing to do here, process in client/server - return None + # Reads the command line options. + try: + cmdOpts = 'hc:s:p:xfbdviqV' + cmdLongOpts = ['loglevel=', 'logtarget=', 'syslogsocket=', 'async', 'help', 'version'] + optList, self._args = getopt.getopt(self._argv[1:], cmdOpts, cmdLongOpts) + except getopt.GetoptError: + self.dispUsage() + return False + + ret = self.__getCmdLineOptions(optList) + if ret is not None: + return ret + + if initial and PRODUCTION: # pragma: no cover - can't test + verbose = self._conf["verbose"] + if verbose <= 0: + logSys.setLevel(logging.ERROR) + elif verbose == 1: + logSys.setLevel(logging.WARNING) + elif verbose == 2: + logSys.setLevel(logging.INFO) + elif verbose == 3: + logSys.setLevel(logging.DEBUG) + else: + logSys.setLevel(logging.HEAVYDEBUG) + # Add the default logging handler to dump to stderr + logout = logging.StreamHandler(sys.stderr) + # set a format which is simpler for console use + formatter = logging.Formatter('%(levelname)-6s %(message)s') + # tell the handler to use this format + logout.setFormatter(formatter) + logSys.addHandler(logout) + + # Set expected parameters (like socket, pidfile, etc) from configuration, + # if those not yet specified, in which read configuration only if needed here: + conf = None + for o in CONFIG_PARAMS: + if self._conf.get(o, None) is None: + if not conf: + self.configurator.readEarly() + conf = self.configurator.getEarlyOptions() + self._conf[o] = conf[o] + + logSys.info("Using socket file %s", self._conf["socket"]) + + logSys.info("Using pid file %s, [%s] logging to %s", + self._conf["pidfile"], self._conf["loglevel"], self._conf["logtarget"]) + + if self._conf.get("dump", False): + ret, stream = self.readConfig() + self.dumpConfig(stream) + return ret + + # Nothing to do here, process in client/server + return None + except Exception as e: + output("ERROR: %s" % (e,)) + #logSys.exception(e) + return False def readConfig(self, jail=None): # Read the configuration @@ -244,4 +258,8 @@ class Fail2banCmdLine(): sys.exit(code) # global exit handler: -exit = Fail2banCmdLine.exit \ No newline at end of file +exit = Fail2banCmdLine.exit + + +class ExitException: + pass \ No newline at end of file diff --git a/fail2ban/client/fail2banserver.py b/fail2ban/client/fail2banserver.py index 6c1dd694..da8e57b8 100644 --- a/fail2ban/client/fail2banserver.py +++ b/fail2ban/client/fail2banserver.py @@ -29,7 +29,11 @@ from ..server.server import Server, ServerDaemonize from ..server.utils import Utils from .fail2bancmdline import Fail2banCmdLine, logSys, exit +MAX_WAITTIME = 30 + SERVER = "fail2ban-server" + + ## # \mainpage Fail2Ban # @@ -72,8 +76,15 @@ class Fail2banServer(Fail2banCmdLine): @staticmethod def startServerAsync(conf): - # Forks the current process. - pid = os.fork() + # Directory of client (to try the first start from the same directory as client): + startdir = sys.path[0] + if startdir in ("", "."): # may be uresolved in test-cases, so get bin-directory: + startdir = os.path.dirname(sys.argv[0]) + # Forks the current process, don't fork if async specified (ex: test cases) + pid = 0 + frk = not conf["async"] + if frk: + pid = os.fork() if pid == 0: args = list() args.append(SERVER) @@ -96,14 +107,20 @@ class Fail2banServer(Fail2banCmdLine): try: # Use the current directory. - exe = os.path.abspath(os.path.join(sys.path[0], SERVER)) + exe = os.path.abspath(os.path.join(startdir, SERVER)) logSys.debug("Starting %r with args %r", exe, args) - os.execv(exe, args) - except OSError: + if frk: + os.execv(exe, args) + else: + os.spawnv(os.P_NOWAITO, exe, args) + except OSError as e: try: # Use the PATH env. - logSys.warning("Initial start attempt failed. Starting %r with the same args", SERVER) - os.execvp(SERVER, args) + logSys.warning("Initial start attempt failed (%s). Starting %r with the same args", e, SERVER) + if frk: + os.execvp(SERVER, args) + else: + os.spawnvp(os.P_NOWAITO, SERVER, args) except OSError: exit(-1) @@ -143,8 +160,8 @@ class Fail2banServer(Fail2banCmdLine): phase = dict() logSys.debug('Configure via async client thread') cli.configureServer(async=True, phase=phase) - # wait up to 30 seconds, do not continue if configuration is not 100% valid: - Utils.wait_for(lambda: phase.get('ready', None) is not None, 30) + # wait up to MAX_WAITTIME, do not continue if configuration is not 100% valid: + Utils.wait_for(lambda: phase.get('ready', None) is not None, MAX_WAITTIME) if not phase.get('start', False): return False @@ -158,7 +175,7 @@ class Fail2banServer(Fail2banCmdLine): # wait for client answer "done": if not async and cli: - Utils.wait_for(lambda: phase.get('done', None) is not None, 30) + Utils.wait_for(lambda: phase.get('done', None) is not None, MAX_WAITTIME) if not phase.get('done', False): if server: server.quit() @@ -179,9 +196,9 @@ class Fail2banServer(Fail2banCmdLine): logSys.error("Could not start %s", SERVER) exit(code) -def exec_command_line(): # pragma: no cover - can't test main +def exec_command_line(argv): server = Fail2banServer() - if server.start(sys.argv): + if server.start(argv): exit(0) else: exit(-1) diff --git a/fail2ban/protocol.py b/fail2ban/protocol.py index 857d5fa6..648666a1 100644 --- a/fail2ban/protocol.py +++ b/fail2ban/protocol.py @@ -26,6 +26,9 @@ __license__ = "GPL" import textwrap +def output(s): + print(s) + ## # Describes the protocol used to communicate with the server. @@ -143,7 +146,7 @@ def printFormatted(): firstHeading = False for m in protocol: if m[0] == '' and firstHeading: - print + output("") firstHeading = True first = True if len(m[0]) >= MARGIN: @@ -154,7 +157,7 @@ def printFormatted(): first = False else: line = ' ' * (INDENT + MARGIN) + n.strip() - print line + output(line) ## @@ -165,20 +168,20 @@ def printWiki(): for m in protocol: if m[0] == '': if firstHeading: - print "|}" + output("|}") __printWikiHeader(m[1], m[2]) firstHeading = True else: - print "|-" - print "| " + m[0] + " || || " + m[1] - print "|}" + output("|-") + output("| " + m[0] + " || || " + m[1]) + output("|}") def __printWikiHeader(section, desc): - print - print "=== " + section + " ===" - print - print desc - print - print "{|" - print "| '''Command''' || || '''Description'''" + output("") + output("=== " + section + " ===") + output("") + output(desc) + output("") + output("{|") + output("| '''Command''' || || '''Description'''") diff --git a/fail2ban/server/server.py b/fail2ban/server/server.py index d7f212c8..d4786f29 100644 --- a/fail2ban/server/server.py +++ b/fail2ban/server/server.py @@ -24,6 +24,7 @@ __author__ = "Cyril Jaquier" __copyright__ = "Copyright (c) 2004 Cyril Jaquier" __license__ = "GPL" +import threading from threading import Lock, RLock import logging import logging.handlers @@ -42,6 +43,10 @@ from ..helpers import getLogger, excepthook # Gets the instance of the logger. logSys = getLogger(__name__) +DEF_SYSLOGSOCKET = "auto" +DEF_LOGLEVEL = "INFO" +DEF_LOGTARGET = "STDOUT" + try: from .database import Fail2BanDb except ImportError: # pragma: no cover @@ -49,6 +54,10 @@ except ImportError: # pragma: no cover Fail2BanDb = None +def _thread_name(): + return threading.current_thread().__class__.__name__ + + class Server: def __init__(self, daemon = False): @@ -67,11 +76,7 @@ class Server: 'FreeBSD': '/var/run/log', 'Linux': '/dev/log', } - # todo: remove that, if test cases are fixed - self.setSyslogSocket("auto") - # Set logging level - self.setLogLevel("INFO") - self.setLogTarget("STDOUT") + self.__prev_signals = {} def __sigTERMhandler(self, signum, frame): logSys.debug("Caught signal %d. Exiting" % signum) @@ -93,9 +98,12 @@ class Server: raise ServerInitializationError("Could not create daemon") # Set all logging parameters (or use default if not specified): - self.setSyslogSocket(conf.get("syslogsocket", self.__syslogSocket)) - self.setLogLevel(conf.get("loglevel", self.__logLevel)) - self.setLogTarget(conf.get("logtarget", self.__logTarget)) + self.setSyslogSocket(conf.get("syslogsocket", + self.__syslogSocket if self.__syslogSocket is not None else DEF_SYSLOGSOCKET)) + self.setLogLevel(conf.get("loglevel", + self.__logLevel if self.__logLevel is not None else DEF_LOGLEVEL)) + self.setLogTarget(conf.get("logtarget", + self.__logTarget if self.__logTarget is not None else DEF_LOGTARGET)) logSys.info("-"*50) logSys.info("Starting Fail2ban v%s", version.version) @@ -104,10 +112,10 @@ class Server: logSys.info("Daemon started") # Install signal handlers - signal.signal(signal.SIGTERM, self.__sigTERMhandler) - signal.signal(signal.SIGINT, self.__sigTERMhandler) - signal.signal(signal.SIGUSR1, self.__sigUSR1handler) - + if _thread_name() == '_MainThread': + for s in (signal.SIGTERM, signal.SIGINT, signal.SIGUSR1): + self.__prev_signals[s] = signal.getsignal(s) + signal.signal(s, self.__sigTERMhandler if s != signal.SIGUSR1 else self.__sigUSR1handler) # Ensure unhandled exceptions are logged sys.excepthook = excepthook @@ -150,6 +158,10 @@ class Server: with self.__loggingLock: logging.shutdown() + # Restore default signal handlers: + for s, sh in self.__prev_signals.iteritems(): + signal.signal(s, sh) + def addJail(self, name, backend): self.__jails.add(name, backend, self.__db) if self.__db is not None: @@ -405,8 +417,13 @@ class Server: def setLogTarget(self, target): with self.__loggingLock: + # don't set new handlers if already the same + # or if "INHERITED" (foreground worker of the test cases, to prevent stop logging): if self.__logTarget == target: return True + if target == "INHERITED": + self.__logTarget = target + return True # set a format which is simpler for console use formatter = logging.Formatter("%(asctime)s %(name)-24s[%(process)d]: %(levelname)-7s %(message)s") if target == "SYSLOG": @@ -549,7 +566,10 @@ class Server: # We need to set this in the parent process, so it gets inherited by the # child process, and this makes sure that it is effect even if the parent # terminates quickly. - signal.signal(signal.SIGHUP, signal.SIG_IGN) + if _thread_name() == '_MainThread': + for s in (signal.SIGHUP,): + self.__prev_signals[s] = signal.getsignal(s) + signal.signal(s, signal.SIG_IGN) try: # Fork a child process so the parent can exit. This will return control diff --git a/fail2ban/tests/fail2banclienttestcase.py b/fail2ban/tests/fail2banclienttestcase.py new file mode 100644 index 00000000..7db48fe8 --- /dev/null +++ b/fail2ban/tests/fail2banclienttestcase.py @@ -0,0 +1,411 @@ +# emacs: -*- mode: python; py-indent-offset: 4; indent-tabs-mode: t -*- +# vi: set ft=python sts=4 ts=4 sw=4 noet : + +# This file is part of Fail2Ban. +# +# Fail2Ban is free software; you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation; either version 2 of the License, or +# (at your option) any later version. +# +# Fail2Ban is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with Fail2Ban; if not, write to the Free Software +# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA. + +# Fail2Ban developers + +__author__ = "Serg Brester" +__copyright__ = "Copyright (c) 2014- Serg G. Brester (sebres), 2008- Fail2Ban Contributors" +__license__ = "GPL" + +import fileinput +import os +import re +import time +import unittest + +from threading import Thread + +from ..client import fail2banclient, fail2banserver, fail2bancmdline +from ..client.fail2banclient import Fail2banClient, exec_command_line as _exec_client, VisualWait +from ..client.fail2banserver import Fail2banServer, exec_command_line as _exec_server +from .. import protocol +from ..server import server +from ..server.utils import Utils +from .utils import LogCaptureTestCase, logSys, withtmpdir, shutil, logging + + +STOCK_CONF_DIR = "config" +STOCK = os.path.exists(os.path.join(STOCK_CONF_DIR,'fail2ban.conf')) +TEST_CONF_DIR = os.path.join(os.path.dirname(__file__), "config") +if STOCK: + CONF_DIR = STOCK_CONF_DIR +else: + CONF_DIR = TEST_CONF_DIR + +CLIENT = "fail2ban-client" +SERVER = "fail2ban-server" +BIN = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))), "bin") + +MAX_WAITTIME = 10 +MAX_WAITTIME = unittest.F2B.maxWaitTime(MAX_WAITTIME) + +## +# Several wrappers and settings for proper testing: +# + +fail2banclient.MAX_WAITTIME = \ +fail2banserver.MAX_WAITTIME = MAX_WAITTIME + + +fail2bancmdline.logSys = \ +fail2banclient.logSys = \ +fail2banserver.logSys = logSys + +LOG_LEVEL = logSys.level + +server.DEF_LOGTARGET = "/dev/null" + +def _test_output(*args): + logSys.info(args[0]) +fail2bancmdline.output = \ +fail2banclient.output = \ +fail2banserver.output = \ +protocol.output = _test_output + +def _test_exit(code=0): + logSys.debug("Exit with code %s", code) + if code == 0: + raise ExitException() + else: + raise FailExitException() +fail2bancmdline.exit = \ +fail2banclient.exit = \ +fail2banserver.exit = _test_exit + +INTERACT = [] +def _test_raw_input(*args): + if len(INTERACT): + #print('--- interact command: ', INTERACT[0]) + return INTERACT.pop(0) + else: + return "exit" +fail2banclient.raw_input = _test_raw_input + +# prevents change logging params, log capturing, etc: +fail2bancmdline.PRODUCTION = False + + +class ExitException(fail2bancmdline.ExitException): + pass +class FailExitException(fail2bancmdline.ExitException): + pass + + +def _out_file(fn): # pragma: no cover + logSys.debug('---- ' + fn + ' ----') + for line in fileinput.input(fn): + line = line.rstrip('\n') + logSys.debug(line) + logSys.debug('-'*30) + +def _start_params(tmp, use_stock=False, logtarget="/dev/null"): + cfg = tmp+"/config" + if use_stock and STOCK: + # copy config: + def ig_dirs(dir, files): + return [f for f in files if not os.path.isfile(os.path.join(dir, f))] + shutil.copytree(STOCK_CONF_DIR, cfg, ignore=ig_dirs) + os.symlink(STOCK_CONF_DIR+"/action.d", cfg+"/action.d") + os.symlink(STOCK_CONF_DIR+"/filter.d", cfg+"/filter.d") + # replace fail2ban params (database with memory): + r = re.compile(r'^dbfile\s*=') + for line in fileinput.input(cfg+"/fail2ban.conf", inplace=True): + line = line.rstrip('\n') + if r.match(line): + line = "dbfile = :memory:" + print(line) + # replace jail params (polling as backend to be fast in initialize): + r = re.compile(r'^backend\s*=') + for line in fileinput.input(cfg+"/jail.conf", inplace=True): + line = line.rstrip('\n') + if r.match(line): + line = "backend = polling" + print(line) + else: + # just empty config directory without anything (only fail2ban.conf/jail.conf): + os.mkdir(cfg) + f = open(cfg+"/fail2ban.conf", "wb") + f.write('\n'.join(( + "[Definition]", + "loglevel = INFO", + "logtarget = " + logtarget, + "syslogsocket = auto", + "socket = "+tmp+"/f2b.sock", + "pidfile = "+tmp+"/f2b.pid", + "backend = polling", + "dbfile = :memory:", + "dbpurgeage = 1d", + "", + ))) + f.close() + f = open(cfg+"/jail.conf", "wb") + f.write('\n'.join(( + "[INCLUDES]", "", + "[DEFAULT]", "", + "", + ))) + f.close() + if LOG_LEVEL < logging.DEBUG: # if HEAVYDEBUG + _out_file(cfg+"/fail2ban.conf") + _out_file(cfg+"/jail.conf") + # parameters: + return ("-c", cfg, + "--logtarget", logtarget, "--loglevel", "DEBUG", "--syslogsocket", "auto", + "-s", tmp+"/f2b.sock", "-p", tmp+"/f2b.pid") + + +class Fail2banClientTest(LogCaptureTestCase): + + def setUp(self): + """Call before every test case.""" + LogCaptureTestCase.setUp(self) + + def tearDown(self): + """Call after every test case.""" + LogCaptureTestCase.tearDown(self) + + def testClientUsage(self): + self.assertRaises(ExitException, _exec_client, + (CLIENT, "-h",)) + self.assertLogged("Usage: " + CLIENT) + self.assertLogged("Report bugs to ") + + @withtmpdir + def testClientStartBackgroundInside(self, tmp): + # always add "--async" by start inside, should don't fork by async (not replace client with server, just start in new process) + # (we can't fork the test cases process): + startparams = _start_params(tmp, True) + # start: + self.assertRaises(ExitException, _exec_client, + (CLIENT, "--async", "-b") + startparams + ("start",)) + self.assertLogged("Server ready") + self.assertLogged("Exit with code 0") + try: + self.assertRaises(ExitException, _exec_client, + (CLIENT,) + startparams + ("echo", "TEST-ECHO",)) + self.assertRaises(FailExitException, _exec_client, + (CLIENT,) + startparams + ("~~unknown~cmd~failed~~",)) + finally: + self.pruneLog() + # stop: + self.assertRaises(ExitException, _exec_client, + (CLIENT,) + startparams + ("stop",)) + self.assertLogged("Shutdown successful") + self.assertLogged("Exit with code 0") + + @withtmpdir + def testClientStartBackgroundCall(self, tmp): + global INTERACT + startparams = _start_params(tmp) + # start (without async in new process): + cmd = os.path.join(os.path.join(BIN), CLIENT) + logSys.debug('Start %s ...', cmd) + Utils.executeCmd((cmd,) + startparams + ("start",), + timeout=MAX_WAITTIME, shell=False, output=False) + self.pruneLog() + try: + # echo from client (inside): + self.assertRaises(ExitException, _exec_client, + (CLIENT,) + startparams + ("echo", "TEST-ECHO",)) + self.assertLogged("TEST-ECHO") + self.assertLogged("Exit with code 0") + self.pruneLog() + # interactive client chat with started server: + INTERACT += [ + "echo INTERACT-ECHO", + "status", + "exit" + ] + self.assertRaises(ExitException, _exec_client, + (CLIENT,) + startparams + ("-i",)) + self.assertLogged("INTERACT-ECHO") + self.assertLogged("Status", "Number of jail:") + self.assertLogged("Exit with code 0") + self.pruneLog() + # test reload and restart over interactive client: + INTERACT += [ + "reload", + "restart", + "exit" + ] + self.assertRaises(ExitException, _exec_client, + (CLIENT,) + startparams + ("-i",)) + self.assertLogged("Reading config files:") + self.assertLogged("Shutdown successful") + self.assertLogged("Server ready") + self.assertLogged("Exit with code 0") + self.pruneLog() + finally: + self.pruneLog() + # stop: + self.assertRaises(ExitException, _exec_client, + (CLIENT,) + startparams + ("stop",)) + self.assertLogged("Shutdown successful") + self.assertLogged("Exit with code 0") + + def _testClientStartForeground(self, tmp, startparams, phase): + # start and wait to end (foreground): + phase['start'] = True + self.assertRaises(ExitException, _exec_client, + (CLIENT, "-f") + startparams + ("start",)) + # end : + phase['end'] = True + + @withtmpdir + def testClientStartForeground(self, tmp): + # started directly here, so prevent overwrite test cases logger with "INHERITED" + startparams = _start_params(tmp, logtarget="INHERITED") + # because foreground block execution - start it in thread: + phase = dict() + Thread(name="_TestCaseWorker", + target=Fail2banClientTest._testClientStartForeground, args=(self, tmp, startparams, phase)).start() + try: + # wait for start thread: + Utils.wait_for(lambda: phase.get('start', None) is not None, MAX_WAITTIME) + self.assertTrue(phase.get('start', None)) + # wait for server (socket): + Utils.wait_for(lambda: os.path.exists(tmp+"/f2b.sock"), MAX_WAITTIME) + self.assertLogged("Starting communication") + self.assertRaises(ExitException, _exec_client, + (CLIENT,) + startparams + ("ping",)) + self.assertRaises(FailExitException, _exec_client, + (CLIENT,) + startparams + ("~~unknown~cmd~failed~~",)) + self.assertRaises(ExitException, _exec_client, + (CLIENT,) + startparams + ("echo", "TEST-ECHO",)) + finally: + self.pruneLog() + # stop: + self.assertRaises(ExitException, _exec_client, + (CLIENT,) + startparams + ("stop",)) + # wait for end: + Utils.wait_for(lambda: phase.get('end', None) is not None, MAX_WAITTIME) + self.assertTrue(phase.get('end', None)) + self.assertLogged("Shutdown successful", "Exiting Fail2ban") + self.assertLogged("Exit with code 0") + + @withtmpdir + def testClientFailStart(self, tmp): + self.assertRaises(FailExitException, _exec_client, + (CLIENT, "--async", "-c", tmp+"/miss", "start",)) + self.assertLogged("Base configuration directory " + tmp+"/miss" + " does not exist") + + self.assertRaises(FailExitException, _exec_client, + (CLIENT, "--async", "-c", CONF_DIR, "-s", tmp+"/miss/f2b.sock", "start",)) + self.assertLogged("There is no directory " + tmp+"/miss" + " to contain the socket file") + + def testVisualWait(self): + sleeptime = 0.035 + for verbose in (2, 0): + cntr = 15 + with VisualWait(verbose, 5) as vis: + while cntr: + vis.heartbeat() + if verbose and not unittest.F2B.fast: + time.sleep(sleeptime) + cntr -= 1 + + +class Fail2banServerTest(LogCaptureTestCase): + + def setUp(self): + """Call before every test case.""" + LogCaptureTestCase.setUp(self) + + def tearDown(self): + """Call after every test case.""" + LogCaptureTestCase.tearDown(self) + + def testServerUsage(self): + self.assertRaises(ExitException, _exec_server, + (SERVER, "-h",)) + self.assertLogged("Usage: " + SERVER) + self.assertLogged("Report bugs to ") + + @withtmpdir + def testServerStartBackground(self, tmp): + # don't add "--async" by start, because if will fork current process by daemonize + # (we can't fork the test cases process), + # because server started internal communication in new thread use INHERITED as logtarget here: + startparams = _start_params(tmp, logtarget="INHERITED") + # start: + self.assertRaises(ExitException, _exec_server, + (SERVER, "-b") + startparams) + self.assertLogged("Server ready") + self.assertLogged("Exit with code 0") + try: + self.assertRaises(ExitException, _exec_server, + (SERVER,) + startparams + ("echo", "TEST-ECHO",)) + self.assertRaises(FailExitException, _exec_server, + (SERVER,) + startparams + ("~~unknown~cmd~failed~~",)) + finally: + self.pruneLog() + # stop: + self.assertRaises(ExitException, _exec_server, + (SERVER,) + startparams + ("stop",)) + self.assertLogged("Shutdown successful") + self.assertLogged("Exit with code 0") + + def _testServerStartForeground(self, tmp, startparams, phase): + # start and wait to end (foreground): + phase['start'] = True + self.assertRaises(ExitException, _exec_server, + (SERVER, "-f") + startparams + ("start",)) + # end : + phase['end'] = True + @withtmpdir + def testServerStartForeground(self, tmp): + # started directly here, so prevent overwrite test cases logger with "INHERITED" + startparams = _start_params(tmp, logtarget="INHERITED") + # because foreground block execution - start it in thread: + phase = dict() + Thread(name="_TestCaseWorker", + target=Fail2banServerTest._testServerStartForeground, args=(self, tmp, startparams, phase)).start() + try: + # wait for start thread: + Utils.wait_for(lambda: phase.get('start', None) is not None, MAX_WAITTIME) + self.assertTrue(phase.get('start', None)) + # wait for server (socket): + Utils.wait_for(lambda: os.path.exists(tmp+"/f2b.sock"), MAX_WAITTIME) + self.assertLogged("Starting communication") + self.assertRaises(ExitException, _exec_server, + (SERVER,) + startparams + ("ping",)) + self.assertRaises(FailExitException, _exec_server, + (SERVER,) + startparams + ("~~unknown~cmd~failed~~",)) + self.assertRaises(ExitException, _exec_server, + (SERVER,) + startparams + ("echo", "TEST-ECHO",)) + finally: + self.pruneLog() + # stop: + self.assertRaises(ExitException, _exec_server, + (SERVER,) + startparams + ("stop",)) + # wait for end: + Utils.wait_for(lambda: phase.get('end', None) is not None, MAX_WAITTIME) + self.assertTrue(phase.get('end', None)) + self.assertLogged("Shutdown successful", "Exiting Fail2ban") + self.assertLogged("Exit with code 0") + + @withtmpdir + def testServerFailStart(self, tmp): + self.assertRaises(FailExitException, _exec_server, + (SERVER, "-c", tmp+"/miss",)) + self.assertLogged("Base configuration directory " + tmp+"/miss" + " does not exist") + + self.assertRaises(FailExitException, _exec_server, + (SERVER, "-c", CONF_DIR, "-s", tmp+"/miss/f2b.sock",)) + self.assertLogged("There is no directory " + tmp+"/miss" + " to contain the socket file") diff --git a/fail2ban/tests/fail2banregextestcase.py b/fail2ban/tests/fail2banregextestcase.py index 49d6a3a6..d9f4081f 100644 --- a/fail2ban/tests/fail2banregextestcase.py +++ b/fail2ban/tests/fail2banregextestcase.py @@ -23,19 +23,7 @@ __author__ = "Serg Brester" __copyright__ = "Copyright (c) 2015 Serg G. Brester (sebres), 2008- Fail2Ban Contributors" __license__ = "GPL" -from __builtin__ import open as fopen -import unittest -import getpass import os -import sys -import time -import tempfile -import uuid - -try: - from systemd import journal -except ImportError: - journal = None from ..client import fail2banregex from ..client.fail2banregex import Fail2banRegex, get_opt_parser, output diff --git a/fail2ban/tests/utils.py b/fail2ban/tests/utils.py index a0036979..e373aa5f 100644 --- a/fail2ban/tests/utils.py +++ b/fail2ban/tests/utils.py @@ -26,10 +26,14 @@ import logging import optparse import os import re +import tempfile +import shutil import sys import time import unittest + from StringIO import StringIO +from functools import wraps from ..helpers import getLogger from ..server.ipdns import DNSUtils @@ -71,6 +75,17 @@ class F2B(optparse.Values): return wtime +def withtmpdir(f): + @wraps(f) + def wrapper(self, *args, **kwargs): + tmp = tempfile.mkdtemp(prefix="f2b-temp") + try: + return f(self, tmp, *args, **kwargs) + finally: + # clean up + shutil.rmtree(tmp) + return wrapper + def initTests(opts): unittest.F2B = F2B(opts) # --fast : @@ -156,6 +171,7 @@ def gatherTests(regexps=None, opts=None): from . import misctestcase from . import databasetestcase from . import samplestestcase + from . import fail2banclienttestcase from . import fail2banregextestcase if not regexps: # pragma: no cover @@ -239,6 +255,9 @@ def gatherTests(regexps=None, opts=None): # Filter Regex tests with sample logs tests.addTest(unittest.makeSuite(samplestestcase.FilterSamplesRegex)) + # bin/fail2ban-client, bin/fail2ban-server + tests.addTest(unittest.makeSuite(fail2banclienttestcase.Fail2banClientTest)) + tests.addTest(unittest.makeSuite(fail2banclienttestcase.Fail2banServerTest)) # bin/fail2ban-regex tests.addTest(unittest.makeSuite(fail2banregextestcase.Fail2banRegexTest)) @@ -321,8 +340,11 @@ class LogCaptureTestCase(unittest.TestCase): # Let's log everything into a string self._log = StringIO() logSys.handlers = [logging.StreamHandler(self._log)] + if self._old_level <= logging.DEBUG: + print("") if self._old_level < logging.DEBUG: # so if HEAVYDEBUG etc -- show them! logSys.handlers += self._old_handlers + logSys.debug('--'*40) logSys.setLevel(getattr(logging, 'DEBUG')) def tearDown(self): @@ -386,6 +408,9 @@ class LogCaptureTestCase(unittest.TestCase): def pruneLog(self): self._log.truncate(0) + def pruneLog(self): + self._log.truncate(0) + def getLog(self): return self._log.getvalue()