client/server (bin) test cases introduced, ultimate closes #1121, closes #1139

small code review and fixing of some bugs during client-server communication process (in the test cases);
pull/1321/head
sebres 9 years ago
parent 4d696d69a0
commit f120877756

@ -31,7 +31,7 @@ __author__ = "Fail2Ban Developers"
__copyright__ = "Copyright (c) 2004-2008 Cyril Jaquier, 2012-2014 Yaroslav Halchenko, 2014-2016 Serg G. Brester"
__license__ = "GPL"
from fail2ban.client.fail2banclient import exec_command_line
from fail2ban.client.fail2banclient import exec_command_line, sys
if __name__ == "__main__":
exec_command_line()
exec_command_line(sys.argv)

@ -31,7 +31,7 @@ __author__ = "Fail2Ban Developers"
__copyright__ = "Copyright (c) 2004-2008 Cyril Jaquier, 2012-2014 Yaroslav Halchenko, 2014-2016 Serg G. Brester"
__license__ = "GPL"
from fail2ban.client.fail2banserver import exec_command_line
from fail2ban.client.fail2banserver import exec_command_line, sys
if __name__ == "__main__":
exec_command_line()
exec_command_line(sys.argv)

@ -28,12 +28,20 @@ import socket
import sys
import time
import threading
from threading import Thread
from ..version import version
from .csocket import CSocket
from .beautifier import Beautifier
from .fail2bancmdline import Fail2banCmdLine, logSys, exit
from .fail2bancmdline import Fail2banCmdLine, ExitException, logSys, exit, output
MAX_WAITTIME = 30
def _thread_name():
return threading.current_thread().__class__.__name__
##
#
@ -51,13 +59,13 @@ class Fail2banClient(Fail2banCmdLine, Thread):
self._beautifier = None
def dispInteractive(self):
print "Fail2Ban v" + version + " reads log file that contains password failure report"
print "and bans the corresponding IP addresses using firewall rules."
print
output("Fail2Ban v" + version + " reads log file that contains password failure report")
output("and bans the corresponding IP addresses using firewall rules.")
output("")
def __sigTERMhandler(self, signum, frame):
# Print a new line because we probably come from wait
print
output("")
logSys.warning("Caught signal %d. Exiting" % signum)
exit(-1)
@ -85,11 +93,11 @@ class Fail2banClient(Fail2banCmdLine, Thread):
if ret[0] == 0:
logSys.debug("OK : " + `ret[1]`)
if showRet or c[0] == 'echo':
print beautifier.beautify(ret[1])
output(beautifier.beautify(ret[1]))
else:
logSys.error("NOK: " + `ret[1].args`)
if showRet:
print beautifier.beautifyError(ret[1])
output(beautifier.beautifyError(ret[1]))
streamRet = False
except socket.error:
if showRet or self._conf["verbose"] > 1:
@ -182,10 +190,13 @@ class Fail2banClient(Fail2banCmdLine, Thread):
# Start server direct here in main thread (not fork):
self._server = Fail2banServer.startServerDirect(self._conf, False)
except ExitException:
pass
except Exception as e:
print
logSys.error("Exception while starting server foreground")
output("")
logSys.error("Exception while starting server " + ("background" if background else "foreground"))
logSys.error(e)
return False
finally:
self._alive = False
@ -229,18 +240,18 @@ class Fail2banClient(Fail2banCmdLine, Thread):
elif len(cmd) == 1 and cmd[0] == "restart":
if self._conf.get("interactive", False):
print(' ## stop ... ')
output(' ## stop ... ')
self.__processCommand(['stop'])
self.__waitOnServer(False)
# in interactive mode reset config, to make full-reload if there something changed:
if self._conf.get("interactive", False):
print(' ## load configuration ... ')
output(' ## load configuration ... ')
self.resetConf()
ret = self.initCmdLine(self._argv)
if ret is not None:
return ret
if self._conf.get("interactive", False):
print(' ## start ... ')
output(' ## start ... ')
return self.__processCommand(['start'])
elif len(cmd) >= 1 and cmd[0] == "reload":
@ -283,7 +294,9 @@ class Fail2banClient(Fail2banCmdLine, Thread):
return False
return True
def __waitOnServer(self, alive=True, maxtime=30):
def __waitOnServer(self, alive=True, maxtime=None):
if maxtime is None:
maxtime = MAX_WAITTIME
# Wait for the server to start (the server has 30 seconds to answer ping)
starttime = time.time()
with VisualWait(self._conf["verbose"]) as vis:
@ -301,53 +314,59 @@ class Fail2banClient(Fail2banCmdLine, Thread):
def start(self, argv):
# Install signal handlers
signal.signal(signal.SIGTERM, self.__sigTERMhandler)
signal.signal(signal.SIGINT, self.__sigTERMhandler)
# Command line options
if self._argv is None:
ret = self.initCmdLine(argv)
if ret is not None:
return ret
# Commands
args = self._args
# Interactive mode
if self._conf.get("interactive", False):
try:
import readline
except ImportError:
logSys.error("Readline not available")
return False
try:
ret = True
if len(args) > 0:
ret = self.__processCommand(args)
if ret:
readline.parse_and_bind("tab: complete")
self.dispInteractive()
while True:
cmd = raw_input(self.PROMPT)
if cmd == "exit" or cmd == "quit":
# Exit
return True
if cmd == "help":
self.dispUsage()
elif not cmd == "":
try:
self.__processCommand(shlex.split(cmd))
except Exception, e:
logSys.error(e)
except (EOFError, KeyboardInterrupt):
print
return True
# Single command mode
else:
if len(args) < 1:
self.dispUsage()
return False
return self.__processCommand(args)
_prev_signals = {}
if _thread_name() == '_MainThread':
for s in (signal.SIGTERM, signal.SIGINT):
_prev_signals[s] = signal.getsignal(s)
signal.signal(s, self.__sigTERMhandler)
try:
# Command line options
if self._argv is None:
ret = self.initCmdLine(argv)
if ret is not None:
return ret
# Commands
args = self._args
# Interactive mode
if self._conf.get("interactive", False):
try:
import readline
except ImportError:
logSys.error("Readline not available")
return False
try:
ret = True
if len(args) > 0:
ret = self.__processCommand(args)
if ret:
readline.parse_and_bind("tab: complete")
self.dispInteractive()
while True:
cmd = raw_input(self.PROMPT)
if cmd == "exit" or cmd == "quit":
# Exit
return True
if cmd == "help":
self.dispUsage()
elif not cmd == "":
try:
self.__processCommand(shlex.split(cmd))
except Exception, e:
logSys.error(e)
except (EOFError, KeyboardInterrupt):
output("")
return True
# Single command mode
else:
if len(args) < 1:
self.dispUsage()
return False
return self.__processCommand(args)
finally:
for s, sh in _prev_signals.iteritems():
signal.signal(s, sh)
class ServerExecutionException(Exception):
@ -361,7 +380,8 @@ class ServerExecutionException(Exception):
class _VisualWait:
pos = 0
delta = 1
maxpos = 10
def __init__(self, maxpos=10):
self.maxpos = maxpos
def __enter__(self):
return self
def __exit__(self, *args):
@ -390,14 +410,14 @@ class _NotVisualWait:
def heartbeat(self):
pass
def VisualWait(verbose):
return _VisualWait() if verbose > 1 else _NotVisualWait()
def VisualWait(verbose, *args, **kwargs):
return _VisualWait(*args, **kwargs) if verbose > 1 else _NotVisualWait()
def exec_command_line(): # pragma: no cover - can't test main
def exec_command_line(argv):
client = Fail2banClient()
# Exit with correct return value
if client.start(sys.argv):
if client.start(argv):
exit(0)
else:
exit(-1)

@ -33,7 +33,12 @@ from ..helpers import getLogger
# Gets the instance of the logger.
logSys = getLogger("fail2ban")
def output(s):
print(s)
CONFIG_PARAMS = ("socket", "pidfile", "logtarget", "loglevel", "syslogsocket",)
# Used to signal - we are in test cases (ex: prevents change logging params, log capturing, etc)
PRODUCTION = True
class Fail2banCmdLine():
@ -71,50 +76,50 @@ class Fail2banCmdLine():
self.__dict__[o] = obj.__dict__[o]
def dispVersion(self):
print "Fail2Ban v" + version
print
print "Copyright (c) 2004-2008 Cyril Jaquier, 2008- Fail2Ban Contributors"
print "Copyright of modifications held by their respective authors."
print "Licensed under the GNU General Public License v2 (GPL)."
print
print "Written by Cyril Jaquier <cyril.jaquier@fail2ban.org>."
print "Many contributions by Yaroslav O. Halchenko <debian@onerussian.com>."
output("Fail2Ban v" + version)
output("")
output("Copyright (c) 2004-2008 Cyril Jaquier, 2008- Fail2Ban Contributors")
output("Copyright of modifications held by their respective authors.")
output("Licensed under the GNU General Public License v2 (GPL).")
output("")
output("Written by Cyril Jaquier <cyril.jaquier@fail2ban.org>.")
output("Many contributions by Yaroslav O. Halchenko <debian@onerussian.com>.")
def dispUsage(self):
""" Prints Fail2Ban command line options and exits
"""
caller = os.path.basename(self._argv[0])
print "Usage: "+caller+" [OPTIONS]" + (" <COMMAND>" if not caller.endswith('server') else "")
print
print "Fail2Ban v" + version + " reads log file that contains password failure report"
print "and bans the corresponding IP addresses using firewall rules."
print
print "Options:"
print " -c <DIR> configuration directory"
print " -s <FILE> socket path"
print " -p <FILE> pidfile path"
print " --loglevel <LEVEL> logging level"
print " --logtarget <FILE>|STDOUT|STDERR|SYSLOG"
print " --syslogsocket auto|<FILE>"
print " -d dump configuration. For debugging"
print " -i interactive mode"
print " -v increase verbosity"
print " -q decrease verbosity"
print " -x force execution of the server (remove socket file)"
print " -b start server in background (default)"
print " -f start server in foreground"
print " --async start server in async mode (for internal usage only, don't read configuration)"
print " -h, --help display this help message"
print " -V, --version print the version"
output("Usage: "+caller+" [OPTIONS]" + (" <COMMAND>" if not caller.endswith('server') else ""))
output("")
output("Fail2Ban v" + version + " reads log file that contains password failure report")
output("and bans the corresponding IP addresses using firewall rules.")
output("")
output("Options:")
output(" -c <DIR> configuration directory")
output(" -s <FILE> socket path")
output(" -p <FILE> pidfile path")
output(" --loglevel <LEVEL> logging level")
output(" --logtarget <FILE>|STDOUT|STDERR|SYSLOG")
output(" --syslogsocket auto|<FILE>")
output(" -d dump configuration. For debugging")
output(" -i interactive mode")
output(" -v increase verbosity")
output(" -q decrease verbosity")
output(" -x force execution of the server (remove socket file)")
output(" -b start server in background (default)")
output(" -f start server in foreground")
output(" --async start server in async mode (for internal usage only, don't read configuration)")
output(" -h, --help display this help message")
output(" -V, --version print the version")
if not caller.endswith('server'):
print
print "Command:"
output("")
output("Command:")
# Prints the protocol
printFormatted()
print
print "Report bugs to https://github.com/fail2ban/fail2ban/issues"
output("")
output("Report bugs to https://github.com/fail2ban/fail2ban/issues")
def __getCmdLineOptions(self, optList):
""" Gets the command line options
@ -147,69 +152,78 @@ class Fail2banCmdLine():
self._conf["background"] = False
elif o in ["-h", "--help"]:
self.dispUsage()
exit(0)
return True
elif o in ["-V", "--version"]:
self.dispVersion()
exit(0)
return True
return None
def initCmdLine(self, argv):
# First time?
initial = (self._argv is None)
try:
# First time?
initial = (self._argv is None)
# Command line options
self._argv = argv
# Command line options
self._argv = argv
logSys.info("Using start params %s", argv[1:])
# Reads the command line options.
try:
cmdOpts = 'hc:s:p:xfbdviqV'
cmdLongOpts = ['loglevel=', 'logtarget=', 'syslogsocket=', 'async', 'help', 'version']
optList, self._args = getopt.getopt(self._argv[1:], cmdOpts, cmdLongOpts)
except getopt.GetoptError:
self.dispUsage()
exit(-1)
self.__getCmdLineOptions(optList)
if initial:
verbose = self._conf["verbose"]
if verbose <= 0:
logSys.setLevel(logging.ERROR)
elif verbose == 1:
logSys.setLevel(logging.WARNING)
elif verbose == 2:
logSys.setLevel(logging.INFO)
else:
logSys.setLevel(logging.DEBUG)
# Add the default logging handler to dump to stderr
logout = logging.StreamHandler(sys.stderr)
# set a format which is simpler for console use
formatter = logging.Formatter('%(levelname)-6s %(message)s')
# tell the handler to use this format
logout.setFormatter(formatter)
logSys.addHandler(logout)
# Set expected parameters (like socket, pidfile, etc) from configuration,
# if those not yet specified, in which read configuration only if needed here:
conf = None
for o in CONFIG_PARAMS:
if self._conf.get(o, None) is None:
if not conf:
self.configurator.readEarly()
conf = self.configurator.getEarlyOptions()
self._conf[o] = conf[o]
logSys.info("Using socket file %s", self._conf["socket"])
logSys.info("Using pid file %s, [%s] logging to %s",
self._conf["pidfile"], self._conf["loglevel"], self._conf["logtarget"])
if self._conf.get("dump", False):
ret, stream = self.readConfig()
self.dumpConfig(stream)
return ret
# Nothing to do here, process in client/server
return None
# Reads the command line options.
try:
cmdOpts = 'hc:s:p:xfbdviqV'
cmdLongOpts = ['loglevel=', 'logtarget=', 'syslogsocket=', 'async', 'help', 'version']
optList, self._args = getopt.getopt(self._argv[1:], cmdOpts, cmdLongOpts)
except getopt.GetoptError:
self.dispUsage()
return False
ret = self.__getCmdLineOptions(optList)
if ret is not None:
return ret
if initial and PRODUCTION: # pragma: no cover - can't test
verbose = self._conf["verbose"]
if verbose <= 0:
logSys.setLevel(logging.ERROR)
elif verbose == 1:
logSys.setLevel(logging.WARNING)
elif verbose == 2:
logSys.setLevel(logging.INFO)
else:
logSys.setLevel(logging.DEBUG)
# Add the default logging handler to dump to stderr
logout = logging.StreamHandler(sys.stderr)
# set a format which is simpler for console use
formatter = logging.Formatter('%(levelname)-6s %(message)s')
# tell the handler to use this format
logout.setFormatter(formatter)
logSys.addHandler(logout)
# Set expected parameters (like socket, pidfile, etc) from configuration,
# if those not yet specified, in which read configuration only if needed here:
conf = None
for o in CONFIG_PARAMS:
if self._conf.get(o, None) is None:
if not conf:
self.configurator.readEarly()
conf = self.configurator.getEarlyOptions()
self._conf[o] = conf[o]
logSys.info("Using socket file %s", self._conf["socket"])
logSys.info("Using pid file %s, [%s] logging to %s",
self._conf["pidfile"], self._conf["loglevel"], self._conf["logtarget"])
if self._conf.get("dump", False):
ret, stream = self.readConfig()
self.dumpConfig(stream)
return ret
# Nothing to do here, process in client/server
return None
except Exception as e:
output("ERROR: %s" % (e,))
#logSys.exception(e)
return False
def readConfig(self, jail=None):
# Read the configuration
@ -242,4 +256,8 @@ class Fail2banCmdLine():
sys.exit(code)
# global exit handler:
exit = Fail2banCmdLine.exit
exit = Fail2banCmdLine.exit
class ExitException:
pass

@ -29,7 +29,11 @@ from ..server.server import Server, ServerDaemonize
from ..server.utils import Utils
from .fail2bancmdline import Fail2banCmdLine, logSys, exit
MAX_WAITTIME = 30
SERVER = "fail2ban-server"
##
# \mainpage Fail2Ban
#
@ -72,8 +76,15 @@ class Fail2banServer(Fail2banCmdLine):
@staticmethod
def startServerAsync(conf):
# Forks the current process.
pid = os.fork()
# Directory of client (to try the first start from the same directory as client):
startdir = sys.path[0]
if startdir in ("", "."): # may be uresolved in test-cases, so get bin-directory:
startdir = os.path.dirname(sys.argv[0])
# Forks the current process, don't fork if async specified (ex: test cases)
pid = 0
frk = not conf["async"]
if frk:
pid = os.fork()
if pid == 0:
args = list()
args.append(SERVER)
@ -96,14 +107,20 @@ class Fail2banServer(Fail2banCmdLine):
try:
# Use the current directory.
exe = os.path.abspath(os.path.join(sys.path[0], SERVER))
exe = os.path.abspath(os.path.join(startdir, SERVER))
logSys.debug("Starting %r with args %r", exe, args)
os.execv(exe, args)
except OSError:
if frk:
os.execv(exe, args)
else:
os.spawnv(os.P_NOWAITO, exe, args)
except OSError as e:
try:
# Use the PATH env.
logSys.warning("Initial start attempt failed. Starting %r with the same args", SERVER)
os.execvp(SERVER, args)
logSys.warning("Initial start attempt failed (%s). Starting %r with the same args", e, SERVER)
if frk:
os.execvp(SERVER, args)
else:
os.spawnvp(os.P_NOWAITO, SERVER, args)
except OSError:
exit(-1)
@ -143,8 +160,8 @@ class Fail2banServer(Fail2banCmdLine):
phase = dict()
logSys.debug('Configure via async client thread')
cli.configureServer(async=True, phase=phase)
# wait up to 30 seconds, do not continue if configuration is not 100% valid:
Utils.wait_for(lambda: phase.get('ready', None) is not None, 30)
# wait up to MAX_WAITTIME, do not continue if configuration is not 100% valid:
Utils.wait_for(lambda: phase.get('ready', None) is not None, MAX_WAITTIME)
if not phase.get('start', False):
return False
@ -158,7 +175,7 @@ class Fail2banServer(Fail2banCmdLine):
# wait for client answer "done":
if not async and cli:
Utils.wait_for(lambda: phase.get('done', None) is not None, 30)
Utils.wait_for(lambda: phase.get('done', None) is not None, MAX_WAITTIME)
if not phase.get('done', False):
if server:
server.quit()
@ -179,9 +196,9 @@ class Fail2banServer(Fail2banCmdLine):
logSys.error("Could not start %s", SERVER)
exit(code)
def exec_command_line(): # pragma: no cover - can't test main
def exec_command_line(argv):
server = Fail2banServer()
if server.start(sys.argv):
if server.start(argv):
exit(0)
else:
exit(-1)

@ -26,6 +26,9 @@ __license__ = "GPL"
import textwrap
def output(s):
print(s)
##
# Describes the protocol used to communicate with the server.
@ -143,7 +146,7 @@ def printFormatted():
firstHeading = False
for m in protocol:
if m[0] == '' and firstHeading:
print
output("")
firstHeading = True
first = True
if len(m[0]) >= MARGIN:
@ -154,7 +157,7 @@ def printFormatted():
first = False
else:
line = ' ' * (INDENT + MARGIN) + n.strip()
print line
output(line)
##
@ -165,20 +168,20 @@ def printWiki():
for m in protocol:
if m[0] == '':
if firstHeading:
print "|}"
output("|}")
__printWikiHeader(m[1], m[2])
firstHeading = True
else:
print "|-"
print "| <span style=\"white-space:nowrap;\"><tt>" + m[0] + "</tt></span> || || " + m[1]
print "|}"
output("|-")
output("| <span style=\"white-space:nowrap;\"><tt>" + m[0] + "</tt></span> || || " + m[1])
output("|}")
def __printWikiHeader(section, desc):
print
print "=== " + section + " ==="
print
print desc
print
print "{|"
print "| '''Command''' || || '''Description'''"
output("")
output("=== " + section + " ===")
output("")
output(desc)
output("")
output("{|")
output("| '''Command''' || || '''Description'''")

@ -24,6 +24,7 @@ __author__ = "Cyril Jaquier"
__copyright__ = "Copyright (c) 2004 Cyril Jaquier"
__license__ = "GPL"
import threading
from threading import Lock, RLock
import logging
import logging.handlers
@ -42,6 +43,10 @@ from ..helpers import getLogger, excepthook
# Gets the instance of the logger.
logSys = getLogger(__name__)
DEF_SYSLOGSOCKET = "auto"
DEF_LOGLEVEL = "INFO"
DEF_LOGTARGET = "STDOUT"
try:
from .database import Fail2BanDb
except ImportError: # pragma: no cover
@ -49,6 +54,10 @@ except ImportError: # pragma: no cover
Fail2BanDb = None
def _thread_name():
return threading.current_thread().__class__.__name__
class Server:
def __init__(self, daemon = False):
@ -67,11 +76,7 @@ class Server:
'FreeBSD': '/var/run/log',
'Linux': '/dev/log',
}
# todo: remove that, if test cases are fixed
self.setSyslogSocket("auto")
# Set logging level
self.setLogLevel("INFO")
self.setLogTarget("STDOUT")
self.__prev_signals = {}
def __sigTERMhandler(self, signum, frame):
logSys.debug("Caught signal %d. Exiting" % signum)
@ -93,9 +98,12 @@ class Server:
raise ServerInitializationError("Could not create daemon")
# Set all logging parameters (or use default if not specified):
self.setSyslogSocket(conf.get("syslogsocket", self.__syslogSocket))
self.setLogLevel(conf.get("loglevel", self.__logLevel))
self.setLogTarget(conf.get("logtarget", self.__logTarget))
self.setSyslogSocket(conf.get("syslogsocket",
self.__syslogSocket if self.__syslogSocket is not None else DEF_SYSLOGSOCKET))
self.setLogLevel(conf.get("loglevel",
self.__logLevel if self.__logLevel is not None else DEF_LOGLEVEL))
self.setLogTarget(conf.get("logtarget",
self.__logTarget if self.__logTarget is not None else DEF_LOGTARGET))
logSys.info("-"*50)
logSys.info("Starting Fail2ban v%s", version.version)
@ -104,10 +112,10 @@ class Server:
logSys.info("Daemon started")
# Install signal handlers
signal.signal(signal.SIGTERM, self.__sigTERMhandler)
signal.signal(signal.SIGINT, self.__sigTERMhandler)
signal.signal(signal.SIGUSR1, self.__sigUSR1handler)
if _thread_name() == '_MainThread':
for s in (signal.SIGTERM, signal.SIGINT, signal.SIGUSR1):
self.__prev_signals[s] = signal.getsignal(s)
signal.signal(s, self.__sigTERMhandler if s != signal.SIGUSR1 else self.__sigUSR1handler)
# Ensure unhandled exceptions are logged
sys.excepthook = excepthook
@ -150,6 +158,10 @@ class Server:
with self.__loggingLock:
logging.shutdown()
# Restore default signal handlers:
for s, sh in self.__prev_signals.iteritems():
signal.signal(s, sh)
def addJail(self, name, backend):
self.__jails.add(name, backend, self.__db)
if self.__db is not None:
@ -395,8 +407,13 @@ class Server:
def setLogTarget(self, target):
with self.__loggingLock:
# don't set new handlers if already the same
# or if "INHERITED" (foreground worker of the test cases, to prevent stop logging):
if self.__logTarget == target:
return True
if target == "INHERITED":
self.__logTarget = target
return True
# set a format which is simpler for console use
formatter = logging.Formatter("%(asctime)s %(name)-24s[%(process)d]: %(levelname)-7s %(message)s")
if target == "SYSLOG":
@ -539,7 +556,10 @@ class Server:
# We need to set this in the parent process, so it gets inherited by the
# child process, and this makes sure that it is effect even if the parent
# terminates quickly.
signal.signal(signal.SIGHUP, signal.SIG_IGN)
if _thread_name() == '_MainThread':
for s in (signal.SIGHUP,):
self.__prev_signals[s] = signal.getsignal(s)
signal.signal(s, signal.SIG_IGN)
try:
# Fork a child process so the parent can exit. This will return control

@ -125,6 +125,7 @@ class Utils():
timeout_expr = lambda: time.time() - stime <= timeout
else:
timeout_expr = timeout
popen = None
try:
popen = subprocess.Popen(
realCmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=shell,
@ -151,7 +152,10 @@ class Utils():
if retcode is None and not Utils.pid_exists(pgid):
retcode = signal.SIGKILL
except OSError as e:
logSys.error("%s -- failed with %s" % (realCmd, e))
stderr = "%s -- failed with %s" % (realCmd, e)
logSys.error(stderr)
if not popen:
return False if not output else (False, stdout, stderr, retcode)
std_level = retcode == 0 and logging.DEBUG or logging.ERROR
# if we need output (to return or to log it):

@ -0,0 +1,411 @@
# emacs: -*- mode: python; py-indent-offset: 4; indent-tabs-mode: t -*-
# vi: set ft=python sts=4 ts=4 sw=4 noet :
# This file is part of Fail2Ban.
#
# Fail2Ban is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 2 of the License, or
# (at your option) any later version.
#
# Fail2Ban is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with Fail2Ban; if not, write to the Free Software
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
# Fail2Ban developers
__author__ = "Serg Brester"
__copyright__ = "Copyright (c) 2014- Serg G. Brester (sebres), 2008- Fail2Ban Contributors"
__license__ = "GPL"
import fileinput
import os
import re
import time
import unittest
from threading import Thread
from ..client import fail2banclient, fail2banserver, fail2bancmdline
from ..client.fail2banclient import Fail2banClient, exec_command_line as _exec_client, VisualWait
from ..client.fail2banserver import Fail2banServer, exec_command_line as _exec_server
from .. import protocol
from ..server import server
from ..server.utils import Utils
from .utils import LogCaptureTestCase, logSys, withtmpdir, shutil, logging
STOCK_CONF_DIR = "config"
STOCK = os.path.exists(os.path.join(STOCK_CONF_DIR,'fail2ban.conf'))
TEST_CONF_DIR = os.path.join(os.path.dirname(__file__), "config")
if STOCK:
CONF_DIR = STOCK_CONF_DIR
else:
CONF_DIR = TEST_CONF_DIR
CLIENT = "fail2ban-client"
SERVER = "fail2ban-server"
BIN = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))), "bin")
MAX_WAITTIME = 10
MAX_WAITTIME = unittest.F2B.maxWaitTime(MAX_WAITTIME)
##
# Several wrappers and settings for proper testing:
#
fail2banclient.MAX_WAITTIME = \
fail2banserver.MAX_WAITTIME = MAX_WAITTIME
fail2bancmdline.logSys = \
fail2banclient.logSys = \
fail2banserver.logSys = logSys
LOG_LEVEL = logSys.level
server.DEF_LOGTARGET = "/dev/null"
def _test_output(*args):
logSys.info(args[0])
fail2bancmdline.output = \
fail2banclient.output = \
fail2banserver.output = \
protocol.output = _test_output
def _test_exit(code=0):
logSys.debug("Exit with code %s", code)
if code == 0:
raise ExitException()
else:
raise FailExitException()
fail2bancmdline.exit = \
fail2banclient.exit = \
fail2banserver.exit = _test_exit
INTERACT = []
def _test_raw_input(*args):
if len(INTERACT):
#print('--- interact command: ', INTERACT[0])
return INTERACT.pop(0)
else:
return "exit"
fail2banclient.raw_input = _test_raw_input
# prevents change logging params, log capturing, etc:
fail2bancmdline.PRODUCTION = False
class ExitException(fail2bancmdline.ExitException):
pass
class FailExitException(fail2bancmdline.ExitException):
pass
def _out_file(fn): # pragma: no cover
logSys.debug('---- ' + fn + ' ----')
for line in fileinput.input(fn):
line = line.rstrip('\n')
logSys.debug(line)
logSys.debug('-'*30)
def _start_params(tmp, use_stock=False, logtarget="/dev/null"):
cfg = tmp+"/config"
if use_stock and STOCK:
# copy config:
def ig_dirs(dir, files):
return [f for f in files if not os.path.isfile(os.path.join(dir, f))]
shutil.copytree(STOCK_CONF_DIR, cfg, ignore=ig_dirs)
os.symlink(STOCK_CONF_DIR+"/action.d", cfg+"/action.d")
os.symlink(STOCK_CONF_DIR+"/filter.d", cfg+"/filter.d")
# replace fail2ban params (database with memory):
r = re.compile(r'^dbfile\s*=')
for line in fileinput.input(cfg+"/fail2ban.conf", inplace=True):
line = line.rstrip('\n')
if r.match(line):
line = "dbfile = :memory:"
print(line)
# replace jail params (polling as backend to be fast in initialize):
r = re.compile(r'^backend\s*=')
for line in fileinput.input(cfg+"/jail.conf", inplace=True):
line = line.rstrip('\n')
if r.match(line):
line = "backend = polling"
print(line)
else:
# just empty config directory without anything (only fail2ban.conf/jail.conf):
os.mkdir(cfg)
f = open(cfg+"/fail2ban.conf", "wb")
f.write('\n'.join((
"[Definition]",
"loglevel = INFO",
"logtarget = " + logtarget,
"syslogsocket = auto",
"socket = "+tmp+"/f2b.sock",
"pidfile = "+tmp+"/f2b.pid",
"backend = polling",
"dbfile = :memory:",
"dbpurgeage = 1d",
"",
)))
f.close()
f = open(cfg+"/jail.conf", "wb")
f.write('\n'.join((
"[INCLUDES]", "",
"[DEFAULT]", "",
"",
)))
f.close()
if LOG_LEVEL < logging.DEBUG: # if HEAVYDEBUG
_out_file(cfg+"/fail2ban.conf")
_out_file(cfg+"/jail.conf")
# parameters:
return ("-c", cfg,
"--logtarget", logtarget, "--loglevel", "DEBUG", "--syslogsocket", "auto",
"-s", tmp+"/f2b.sock", "-p", tmp+"/f2b.pid")
class Fail2banClientTest(LogCaptureTestCase):
def setUp(self):
"""Call before every test case."""
LogCaptureTestCase.setUp(self)
def tearDown(self):
"""Call after every test case."""
LogCaptureTestCase.tearDown(self)
def testClientUsage(self):
self.assertRaises(ExitException, _exec_client,
(CLIENT, "-h",))
self.assertLogged("Usage: " + CLIENT)
self.assertLogged("Report bugs to ")
@withtmpdir
def testClientStartBackgroundInside(self, tmp):
# always add "--async" by start inside, should don't fork by async (not replace client with server, just start in new process)
# (we can't fork the test cases process):
startparams = _start_params(tmp, True)
# start:
self.assertRaises(ExitException, _exec_client,
(CLIENT, "--async", "-b") + startparams + ("start",))
self.assertLogged("Server ready")
self.assertLogged("Exit with code 0")
try:
self.assertRaises(ExitException, _exec_client,
(CLIENT,) + startparams + ("echo", "TEST-ECHO",))
self.assertRaises(FailExitException, _exec_client,
(CLIENT,) + startparams + ("~~unknown~cmd~failed~~",))
finally:
self.pruneLog()
# stop:
self.assertRaises(ExitException, _exec_client,
(CLIENT,) + startparams + ("stop",))
self.assertLogged("Shutdown successful")
self.assertLogged("Exit with code 0")
@withtmpdir
def testClientStartBackgroundCall(self, tmp):
global INTERACT
startparams = _start_params(tmp)
# start (without async in new process):
cmd = os.path.join(os.path.join(BIN), CLIENT)
logSys.debug('Start %s ...', cmd)
Utils.executeCmd((cmd,) + startparams + ("start",),
timeout=MAX_WAITTIME, shell=False, output=False)
self.pruneLog()
try:
# echo from client (inside):
self.assertRaises(ExitException, _exec_client,
(CLIENT,) + startparams + ("echo", "TEST-ECHO",))
self.assertLogged("TEST-ECHO")
self.assertLogged("Exit with code 0")
self.pruneLog()
# interactive client chat with started server:
INTERACT += [
"echo INTERACT-ECHO",
"status",
"exit"
]
self.assertRaises(ExitException, _exec_client,
(CLIENT,) + startparams + ("-i",))
self.assertLogged("INTERACT-ECHO")
self.assertLogged("Status", "Number of jail:")
self.assertLogged("Exit with code 0")
self.pruneLog()
# test reload and restart over interactive client:
INTERACT += [
"reload",
"restart",
"exit"
]
self.assertRaises(ExitException, _exec_client,
(CLIENT,) + startparams + ("-i",))
self.assertLogged("Reading config files:")
self.assertLogged("Shutdown successful")
self.assertLogged("Server ready")
self.assertLogged("Exit with code 0")
self.pruneLog()
finally:
self.pruneLog()
# stop:
self.assertRaises(ExitException, _exec_client,
(CLIENT,) + startparams + ("stop",))
self.assertLogged("Shutdown successful")
self.assertLogged("Exit with code 0")
def _testClientStartForeground(self, tmp, startparams, phase):
# start and wait to end (foreground):
phase['start'] = True
self.assertRaises(ExitException, _exec_client,
(CLIENT, "-f") + startparams + ("start",))
# end :
phase['end'] = True
@withtmpdir
def testClientStartForeground(self, tmp):
# started directly here, so prevent overwrite test cases logger with "INHERITED"
startparams = _start_params(tmp, logtarget="INHERITED")
# because foreground block execution - start it in thread:
phase = dict()
Thread(name="_TestCaseWorker",
target=Fail2banClientTest._testClientStartForeground, args=(self, tmp, startparams, phase)).start()
try:
# wait for start thread:
Utils.wait_for(lambda: phase.get('start', None) is not None, MAX_WAITTIME)
self.assertTrue(phase.get('start', None))
# wait for server (socket):
Utils.wait_for(lambda: os.path.exists(tmp+"/f2b.sock"), MAX_WAITTIME)
self.assertLogged("Starting communication")
self.assertRaises(ExitException, _exec_client,
(CLIENT,) + startparams + ("ping",))
self.assertRaises(FailExitException, _exec_client,
(CLIENT,) + startparams + ("~~unknown~cmd~failed~~",))
self.assertRaises(ExitException, _exec_client,
(CLIENT,) + startparams + ("echo", "TEST-ECHO",))
finally:
self.pruneLog()
# stop:
self.assertRaises(ExitException, _exec_client,
(CLIENT,) + startparams + ("stop",))
# wait for end:
Utils.wait_for(lambda: phase.get('end', None) is not None, MAX_WAITTIME)
self.assertTrue(phase.get('end', None))
self.assertLogged("Shutdown successful", "Exiting Fail2ban")
self.assertLogged("Exit with code 0")
@withtmpdir
def testClientFailStart(self, tmp):
self.assertRaises(FailExitException, _exec_client,
(CLIENT, "--async", "-c", tmp+"/miss", "start",))
self.assertLogged("Base configuration directory " + tmp+"/miss" + " does not exist")
self.assertRaises(FailExitException, _exec_client,
(CLIENT, "--async", "-c", CONF_DIR, "-s", tmp+"/miss/f2b.sock", "start",))
self.assertLogged("There is no directory " + tmp+"/miss" + " to contain the socket file")
def testVisualWait(self):
sleeptime = 0.035
for verbose in (2, 0):
cntr = 15
with VisualWait(verbose, 5) as vis:
while cntr:
vis.heartbeat()
if verbose and not unittest.F2B.fast:
time.sleep(sleeptime)
cntr -= 1
class Fail2banServerTest(LogCaptureTestCase):
def setUp(self):
"""Call before every test case."""
LogCaptureTestCase.setUp(self)
def tearDown(self):
"""Call after every test case."""
LogCaptureTestCase.tearDown(self)
def testServerUsage(self):
self.assertRaises(ExitException, _exec_server,
(SERVER, "-h",))
self.assertLogged("Usage: " + SERVER)
self.assertLogged("Report bugs to ")
@withtmpdir
def testServerStartBackground(self, tmp):
# don't add "--async" by start, because if will fork current process by daemonize
# (we can't fork the test cases process),
# because server started internal communication in new thread use INHERITED as logtarget here:
startparams = _start_params(tmp, logtarget="INHERITED")
# start:
self.assertRaises(ExitException, _exec_server,
(SERVER, "-b") + startparams)
self.assertLogged("Server ready")
self.assertLogged("Exit with code 0")
try:
self.assertRaises(ExitException, _exec_server,
(SERVER,) + startparams + ("echo", "TEST-ECHO",))
self.assertRaises(FailExitException, _exec_server,
(SERVER,) + startparams + ("~~unknown~cmd~failed~~",))
finally:
self.pruneLog()
# stop:
self.assertRaises(ExitException, _exec_server,
(SERVER,) + startparams + ("stop",))
self.assertLogged("Shutdown successful")
self.assertLogged("Exit with code 0")
def _testServerStartForeground(self, tmp, startparams, phase):
# start and wait to end (foreground):
phase['start'] = True
self.assertRaises(ExitException, _exec_server,
(SERVER, "-f") + startparams + ("start",))
# end :
phase['end'] = True
@withtmpdir
def testServerStartForeground(self, tmp):
# started directly here, so prevent overwrite test cases logger with "INHERITED"
startparams = _start_params(tmp, logtarget="INHERITED")
# because foreground block execution - start it in thread:
phase = dict()
Thread(name="_TestCaseWorker",
target=Fail2banServerTest._testServerStartForeground, args=(self, tmp, startparams, phase)).start()
try:
# wait for start thread:
Utils.wait_for(lambda: phase.get('start', None) is not None, MAX_WAITTIME)
self.assertTrue(phase.get('start', None))
# wait for server (socket):
Utils.wait_for(lambda: os.path.exists(tmp+"/f2b.sock"), MAX_WAITTIME)
self.assertLogged("Starting communication")
self.assertRaises(ExitException, _exec_server,
(SERVER,) + startparams + ("ping",))
self.assertRaises(FailExitException, _exec_server,
(SERVER,) + startparams + ("~~unknown~cmd~failed~~",))
self.assertRaises(ExitException, _exec_server,
(SERVER,) + startparams + ("echo", "TEST-ECHO",))
finally:
self.pruneLog()
# stop:
self.assertRaises(ExitException, _exec_server,
(SERVER,) + startparams + ("stop",))
# wait for end:
Utils.wait_for(lambda: phase.get('end', None) is not None, MAX_WAITTIME)
self.assertTrue(phase.get('end', None))
self.assertLogged("Shutdown successful", "Exiting Fail2ban")
self.assertLogged("Exit with code 0")
@withtmpdir
def testServerFailStart(self, tmp):
self.assertRaises(FailExitException, _exec_server,
(SERVER, "-c", tmp+"/miss",))
self.assertLogged("Base configuration directory " + tmp+"/miss" + " does not exist")
self.assertRaises(FailExitException, _exec_server,
(SERVER, "-c", CONF_DIR, "-s", tmp+"/miss/f2b.sock",))
self.assertLogged("There is no directory " + tmp+"/miss" + " to contain the socket file")

@ -23,19 +23,7 @@ __author__ = "Serg Brester"
__copyright__ = "Copyright (c) 2015 Serg G. Brester (sebres), 2008- Fail2Ban Contributors"
__license__ = "GPL"
from __builtin__ import open as fopen
import unittest
import getpass
import os
import sys
import time
import tempfile
import uuid
try:
from systemd import journal
except ImportError:
journal = None
from ..client import fail2banregex
from ..client.fail2banregex import Fail2banRegex, get_opt_parser, output

@ -26,10 +26,14 @@ import logging
import optparse
import os
import re
import tempfile
import shutil
import sys
import time
import unittest
from StringIO import StringIO
from functools import wraps
from ..helpers import getLogger
from ..server.filter import DNSUtils
@ -71,6 +75,17 @@ class F2B(optparse.Values):
return wtime
def withtmpdir(f):
@wraps(f)
def wrapper(self, *args, **kwargs):
tmp = tempfile.mkdtemp(prefix="f2b-temp")
try:
return f(self, tmp, *args, **kwargs)
finally:
# clean up
shutil.rmtree(tmp)
return wrapper
def initTests(opts):
unittest.F2B = F2B(opts)
# --fast :
@ -145,6 +160,7 @@ def gatherTests(regexps=None, opts=None):
from . import misctestcase
from . import databasetestcase
from . import samplestestcase
from . import fail2banclienttestcase
from . import fail2banregextestcase
if not regexps: # pragma: no cover
@ -223,6 +239,9 @@ def gatherTests(regexps=None, opts=None):
# Filter Regex tests with sample logs
tests.addTest(unittest.makeSuite(samplestestcase.FilterSamplesRegex))
# bin/fail2ban-client, bin/fail2ban-server
tests.addTest(unittest.makeSuite(fail2banclienttestcase.Fail2banClientTest))
tests.addTest(unittest.makeSuite(fail2banclienttestcase.Fail2banServerTest))
# bin/fail2ban-regex
tests.addTest(unittest.makeSuite(fail2banregextestcase.Fail2banRegexTest))
@ -293,8 +312,11 @@ class LogCaptureTestCase(unittest.TestCase):
# Let's log everything into a string
self._log = StringIO()
logSys.handlers = [logging.StreamHandler(self._log)]
if self._old_level <= logging.DEBUG:
print("")
if self._old_level < logging.DEBUG: # so if HEAVYDEBUG etc -- show them!
logSys.handlers += self._old_handlers
logSys.debug('--'*40)
logSys.setLevel(getattr(logging, 'DEBUG'))
def tearDown(self):
@ -340,6 +362,9 @@ class LogCaptureTestCase(unittest.TestCase):
raise AssertionError("All of the %r were found present in the log: %r" % (s, logged))
def pruneLog(self):
self._log.truncate(0)
def getLog(self):
return self._log.getvalue()

Loading…
Cancel
Save