client/server (bin) test cases introduced, ultimate closes #1121, closes #1139

small code review and fixing of some bugs during client-server communication process (in the test cases);
pull/1483/head
sebres 2016-02-11 17:57:23 +01:00
parent 5a053f4b74
commit afa1cdc3ae
10 changed files with 707 additions and 205 deletions

View File

@ -31,7 +31,7 @@ __author__ = "Fail2Ban Developers"
__copyright__ = "Copyright (c) 2004-2008 Cyril Jaquier, 2012-2014 Yaroslav Halchenko, 2014-2016 Serg G. Brester" __copyright__ = "Copyright (c) 2004-2008 Cyril Jaquier, 2012-2014 Yaroslav Halchenko, 2014-2016 Serg G. Brester"
__license__ = "GPL" __license__ = "GPL"
from fail2ban.client.fail2banclient import exec_command_line from fail2ban.client.fail2banclient import exec_command_line, sys
if __name__ == "__main__": if __name__ == "__main__":
exec_command_line() exec_command_line(sys.argv)

View File

@ -31,7 +31,7 @@ __author__ = "Fail2Ban Developers"
__copyright__ = "Copyright (c) 2004-2008 Cyril Jaquier, 2012-2014 Yaroslav Halchenko, 2014-2016 Serg G. Brester" __copyright__ = "Copyright (c) 2004-2008 Cyril Jaquier, 2012-2014 Yaroslav Halchenko, 2014-2016 Serg G. Brester"
__license__ = "GPL" __license__ = "GPL"
from fail2ban.client.fail2banserver import exec_command_line from fail2ban.client.fail2banserver import exec_command_line, sys
if __name__ == "__main__": if __name__ == "__main__":
exec_command_line() exec_command_line(sys.argv)

View File

@ -28,12 +28,20 @@ import socket
import sys import sys
import time import time
import threading
from threading import Thread from threading import Thread
from ..version import version from ..version import version
from .csocket import CSocket from .csocket import CSocket
from .beautifier import Beautifier from .beautifier import Beautifier
from .fail2bancmdline import Fail2banCmdLine, logSys, exit from .fail2bancmdline import Fail2banCmdLine, ExitException, logSys, exit, output
MAX_WAITTIME = 30
def _thread_name():
return threading.current_thread().__class__.__name__
## ##
# #
@ -51,13 +59,13 @@ class Fail2banClient(Fail2banCmdLine, Thread):
self._beautifier = None self._beautifier = None
def dispInteractive(self): def dispInteractive(self):
print "Fail2Ban v" + version + " reads log file that contains password failure report" output("Fail2Ban v" + version + " reads log file that contains password failure report")
print "and bans the corresponding IP addresses using firewall rules." output("and bans the corresponding IP addresses using firewall rules.")
print output("")
def __sigTERMhandler(self, signum, frame): def __sigTERMhandler(self, signum, frame):
# Print a new line because we probably come from wait # Print a new line because we probably come from wait
print output("")
logSys.warning("Caught signal %d. Exiting" % signum) logSys.warning("Caught signal %d. Exiting" % signum)
exit(-1) exit(-1)
@ -85,11 +93,11 @@ class Fail2banClient(Fail2banCmdLine, Thread):
if ret[0] == 0: if ret[0] == 0:
logSys.debug("OK : " + `ret[1]`) logSys.debug("OK : " + `ret[1]`)
if showRet or c[0] == 'echo': if showRet or c[0] == 'echo':
print beautifier.beautify(ret[1]) output(beautifier.beautify(ret[1]))
else: else:
logSys.error("NOK: " + `ret[1].args`) logSys.error("NOK: " + `ret[1].args`)
if showRet: if showRet:
print beautifier.beautifyError(ret[1]) output(beautifier.beautifyError(ret[1]))
streamRet = False streamRet = False
except socket.error: except socket.error:
if showRet or self._conf["verbose"] > 1: if showRet or self._conf["verbose"] > 1:
@ -182,10 +190,13 @@ class Fail2banClient(Fail2banCmdLine, Thread):
# Start server direct here in main thread (not fork): # Start server direct here in main thread (not fork):
self._server = Fail2banServer.startServerDirect(self._conf, False) self._server = Fail2banServer.startServerDirect(self._conf, False)
except ExitException:
pass
except Exception as e: except Exception as e:
print output("")
logSys.error("Exception while starting server foreground") logSys.error("Exception while starting server " + ("background" if background else "foreground"))
logSys.error(e) logSys.error(e)
return False
finally: finally:
self._alive = False self._alive = False
@ -229,18 +240,18 @@ class Fail2banClient(Fail2banCmdLine, Thread):
elif len(cmd) == 1 and cmd[0] == "restart": elif len(cmd) == 1 and cmd[0] == "restart":
if self._conf.get("interactive", False): if self._conf.get("interactive", False):
print(' ## stop ... ') output(' ## stop ... ')
self.__processCommand(['stop']) self.__processCommand(['stop'])
self.__waitOnServer(False) self.__waitOnServer(False)
# in interactive mode reset config, to make full-reload if there something changed: # in interactive mode reset config, to make full-reload if there something changed:
if self._conf.get("interactive", False): if self._conf.get("interactive", False):
print(' ## load configuration ... ') output(' ## load configuration ... ')
self.resetConf() self.resetConf()
ret = self.initCmdLine(self._argv) ret = self.initCmdLine(self._argv)
if ret is not None: if ret is not None:
return ret return ret
if self._conf.get("interactive", False): if self._conf.get("interactive", False):
print(' ## start ... ') output(' ## start ... ')
return self.__processCommand(['start']) return self.__processCommand(['start'])
elif len(cmd) >= 1 and cmd[0] == "reload": elif len(cmd) >= 1 and cmd[0] == "reload":
@ -283,7 +294,9 @@ class Fail2banClient(Fail2banCmdLine, Thread):
return False return False
return True return True
def __waitOnServer(self, alive=True, maxtime=30): def __waitOnServer(self, alive=True, maxtime=None):
if maxtime is None:
maxtime = MAX_WAITTIME
# Wait for the server to start (the server has 30 seconds to answer ping) # Wait for the server to start (the server has 30 seconds to answer ping)
starttime = time.time() starttime = time.time()
with VisualWait(self._conf["verbose"]) as vis: with VisualWait(self._conf["verbose"]) as vis:
@ -301,53 +314,59 @@ class Fail2banClient(Fail2banCmdLine, Thread):
def start(self, argv): def start(self, argv):
# Install signal handlers # Install signal handlers
signal.signal(signal.SIGTERM, self.__sigTERMhandler) _prev_signals = {}
signal.signal(signal.SIGINT, self.__sigTERMhandler) if _thread_name() == '_MainThread':
for s in (signal.SIGTERM, signal.SIGINT):
_prev_signals[s] = signal.getsignal(s)
signal.signal(s, self.__sigTERMhandler)
try:
# Command line options
if self._argv is None:
ret = self.initCmdLine(argv)
if ret is not None:
return ret
# Command line options # Commands
if self._argv is None: args = self._args
ret = self.initCmdLine(argv)
if ret is not None:
return ret
# Commands # Interactive mode
args = self._args if self._conf.get("interactive", False):
try:
# Interactive mode import readline
if self._conf.get("interactive", False): except ImportError:
try: logSys.error("Readline not available")
import readline return False
except ImportError: try:
logSys.error("Readline not available") ret = True
return False if len(args) > 0:
try: ret = self.__processCommand(args)
ret = True if ret:
if len(args) > 0: readline.parse_and_bind("tab: complete")
ret = self.__processCommand(args) self.dispInteractive()
if ret: while True:
readline.parse_and_bind("tab: complete") cmd = raw_input(self.PROMPT)
self.dispInteractive() if cmd == "exit" or cmd == "quit":
while True: # Exit
cmd = raw_input(self.PROMPT) return True
if cmd == "exit" or cmd == "quit": if cmd == "help":
# Exit self.dispUsage()
return True elif not cmd == "":
if cmd == "help": try:
self.dispUsage() self.__processCommand(shlex.split(cmd))
elif not cmd == "": except Exception, e:
try: logSys.error(e)
self.__processCommand(shlex.split(cmd)) except (EOFError, KeyboardInterrupt):
except Exception, e: output("")
logSys.error(e) return True
except (EOFError, KeyboardInterrupt): # Single command mode
print else:
return True if len(args) < 1:
# Single command mode self.dispUsage()
else: return False
if len(args) < 1: return self.__processCommand(args)
self.dispUsage() finally:
return False for s, sh in _prev_signals.iteritems():
return self.__processCommand(args) signal.signal(s, sh)
class ServerExecutionException(Exception): class ServerExecutionException(Exception):
@ -361,7 +380,8 @@ class ServerExecutionException(Exception):
class _VisualWait: class _VisualWait:
pos = 0 pos = 0
delta = 1 delta = 1
maxpos = 10 def __init__(self, maxpos=10):
self.maxpos = maxpos
def __enter__(self): def __enter__(self):
return self return self
def __exit__(self, *args): def __exit__(self, *args):
@ -390,14 +410,14 @@ class _NotVisualWait:
def heartbeat(self): def heartbeat(self):
pass pass
def VisualWait(verbose): def VisualWait(verbose, *args, **kwargs):
return _VisualWait() if verbose > 1 else _NotVisualWait() return _VisualWait(*args, **kwargs) if verbose > 1 else _NotVisualWait()
def exec_command_line(): # pragma: no cover - can't test main def exec_command_line(argv):
client = Fail2banClient() client = Fail2banClient()
# Exit with correct return value # Exit with correct return value
if client.start(sys.argv): if client.start(argv):
exit(0) exit(0)
else: else:
exit(-1) exit(-1)

View File

@ -33,7 +33,12 @@ from ..helpers import getLogger
# Gets the instance of the logger. # Gets the instance of the logger.
logSys = getLogger("fail2ban") logSys = getLogger("fail2ban")
def output(s):
print(s)
CONFIG_PARAMS = ("socket", "pidfile", "logtarget", "loglevel", "syslogsocket",) CONFIG_PARAMS = ("socket", "pidfile", "logtarget", "loglevel", "syslogsocket",)
# Used to signal - we are in test cases (ex: prevents change logging params, log capturing, etc)
PRODUCTION = True
class Fail2banCmdLine(): class Fail2banCmdLine():
@ -71,50 +76,50 @@ class Fail2banCmdLine():
self.__dict__[o] = obj.__dict__[o] self.__dict__[o] = obj.__dict__[o]
def dispVersion(self): def dispVersion(self):
print "Fail2Ban v" + version output("Fail2Ban v" + version)
print output("")
print "Copyright (c) 2004-2008 Cyril Jaquier, 2008- Fail2Ban Contributors" output("Copyright (c) 2004-2008 Cyril Jaquier, 2008- Fail2Ban Contributors")
print "Copyright of modifications held by their respective authors." output("Copyright of modifications held by their respective authors.")
print "Licensed under the GNU General Public License v2 (GPL)." output("Licensed under the GNU General Public License v2 (GPL).")
print output("")
print "Written by Cyril Jaquier <cyril.jaquier@fail2ban.org>." output("Written by Cyril Jaquier <cyril.jaquier@fail2ban.org>.")
print "Many contributions by Yaroslav O. Halchenko <debian@onerussian.com>." output("Many contributions by Yaroslav O. Halchenko <debian@onerussian.com>.")
def dispUsage(self): def dispUsage(self):
""" Prints Fail2Ban command line options and exits """ Prints Fail2Ban command line options and exits
""" """
caller = os.path.basename(self._argv[0]) caller = os.path.basename(self._argv[0])
print "Usage: "+caller+" [OPTIONS]" + (" <COMMAND>" if not caller.endswith('server') else "") output("Usage: "+caller+" [OPTIONS]" + (" <COMMAND>" if not caller.endswith('server') else ""))
print output("")
print "Fail2Ban v" + version + " reads log file that contains password failure report" output("Fail2Ban v" + version + " reads log file that contains password failure report")
print "and bans the corresponding IP addresses using firewall rules." output("and bans the corresponding IP addresses using firewall rules.")
print output("")
print "Options:" output("Options:")
print " -c <DIR> configuration directory" output(" -c <DIR> configuration directory")
print " -s <FILE> socket path" output(" -s <FILE> socket path")
print " -p <FILE> pidfile path" output(" -p <FILE> pidfile path")
print " --loglevel <LEVEL> logging level" output(" --loglevel <LEVEL> logging level")
print " --logtarget <FILE>|STDOUT|STDERR|SYSLOG" output(" --logtarget <FILE>|STDOUT|STDERR|SYSLOG")
print " --syslogsocket auto|<FILE>" output(" --syslogsocket auto|<FILE>")
print " -d dump configuration. For debugging" output(" -d dump configuration. For debugging")
print " -i interactive mode" output(" -i interactive mode")
print " -v increase verbosity" output(" -v increase verbosity")
print " -q decrease verbosity" output(" -q decrease verbosity")
print " -x force execution of the server (remove socket file)" output(" -x force execution of the server (remove socket file)")
print " -b start server in background (default)" output(" -b start server in background (default)")
print " -f start server in foreground" output(" -f start server in foreground")
print " --async start server in async mode (for internal usage only, don't read configuration)" output(" --async start server in async mode (for internal usage only, don't read configuration)")
print " -h, --help display this help message" output(" -h, --help display this help message")
print " -V, --version print the version" output(" -V, --version print the version")
if not caller.endswith('server'): if not caller.endswith('server'):
print output("")
print "Command:" output("Command:")
# Prints the protocol # Prints the protocol
printFormatted() printFormatted()
print output("")
print "Report bugs to https://github.com/fail2ban/fail2ban/issues" output("Report bugs to https://github.com/fail2ban/fail2ban/issues")
def __getCmdLineOptions(self, optList): def __getCmdLineOptions(self, optList):
""" Gets the command line options """ Gets the command line options
@ -147,71 +152,80 @@ class Fail2banCmdLine():
self._conf["background"] = False self._conf["background"] = False
elif o in ["-h", "--help"]: elif o in ["-h", "--help"]:
self.dispUsage() self.dispUsage()
exit(0) return True
elif o in ["-V", "--version"]: elif o in ["-V", "--version"]:
self.dispVersion() self.dispVersion()
exit(0) return True
return None
def initCmdLine(self, argv): def initCmdLine(self, argv):
# First time?
initial = (self._argv is None)
# Command line options
self._argv = argv
# Reads the command line options.
try: try:
cmdOpts = 'hc:s:p:xfbdviqV' # First time?
cmdLongOpts = ['loglevel=', 'logtarget=', 'syslogsocket=', 'async', 'help', 'version'] initial = (self._argv is None)
optList, self._args = getopt.getopt(self._argv[1:], cmdOpts, cmdLongOpts)
except getopt.GetoptError:
self.dispUsage()
exit(-1)
self.__getCmdLineOptions(optList) # Command line options
self._argv = argv
logSys.info("Using start params %s", argv[1:])
if initial: # Reads the command line options.
verbose = self._conf["verbose"] try:
if verbose <= 0: cmdOpts = 'hc:s:p:xfbdviqV'
logSys.setLevel(logging.ERROR) cmdLongOpts = ['loglevel=', 'logtarget=', 'syslogsocket=', 'async', 'help', 'version']
elif verbose == 1: optList, self._args = getopt.getopt(self._argv[1:], cmdOpts, cmdLongOpts)
logSys.setLevel(logging.WARNING) except getopt.GetoptError:
elif verbose == 2: self.dispUsage()
logSys.setLevel(logging.INFO) return False
elif verbose == 3:
logSys.setLevel(logging.DEBUG)
else:
logSys.setLevel(logging.HEAVYDEBUG)
# Add the default logging handler to dump to stderr
logout = logging.StreamHandler(sys.stderr)
# set a format which is simpler for console use
formatter = logging.Formatter('%(levelname)-6s %(message)s')
# tell the handler to use this format
logout.setFormatter(formatter)
logSys.addHandler(logout)
# Set expected parameters (like socket, pidfile, etc) from configuration, ret = self.__getCmdLineOptions(optList)
# if those not yet specified, in which read configuration only if needed here: if ret is not None:
conf = None return ret
for o in CONFIG_PARAMS:
if self._conf.get(o, None) is None:
if not conf:
self.configurator.readEarly()
conf = self.configurator.getEarlyOptions()
self._conf[o] = conf[o]
logSys.info("Using socket file %s", self._conf["socket"]) if initial and PRODUCTION: # pragma: no cover - can't test
verbose = self._conf["verbose"]
if verbose <= 0:
logSys.setLevel(logging.ERROR)
elif verbose == 1:
logSys.setLevel(logging.WARNING)
elif verbose == 2:
logSys.setLevel(logging.INFO)
elif verbose == 3:
logSys.setLevel(logging.DEBUG)
else:
logSys.setLevel(logging.HEAVYDEBUG)
# Add the default logging handler to dump to stderr
logout = logging.StreamHandler(sys.stderr)
# set a format which is simpler for console use
formatter = logging.Formatter('%(levelname)-6s %(message)s')
# tell the handler to use this format
logout.setFormatter(formatter)
logSys.addHandler(logout)
logSys.info("Using pid file %s, [%s] logging to %s", # Set expected parameters (like socket, pidfile, etc) from configuration,
self._conf["pidfile"], self._conf["loglevel"], self._conf["logtarget"]) # if those not yet specified, in which read configuration only if needed here:
conf = None
for o in CONFIG_PARAMS:
if self._conf.get(o, None) is None:
if not conf:
self.configurator.readEarly()
conf = self.configurator.getEarlyOptions()
self._conf[o] = conf[o]
if self._conf.get("dump", False): logSys.info("Using socket file %s", self._conf["socket"])
ret, stream = self.readConfig()
self.dumpConfig(stream)
return ret
# Nothing to do here, process in client/server logSys.info("Using pid file %s, [%s] logging to %s",
return None self._conf["pidfile"], self._conf["loglevel"], self._conf["logtarget"])
if self._conf.get("dump", False):
ret, stream = self.readConfig()
self.dumpConfig(stream)
return ret
# Nothing to do here, process in client/server
return None
except Exception as e:
output("ERROR: %s" % (e,))
#logSys.exception(e)
return False
def readConfig(self, jail=None): def readConfig(self, jail=None):
# Read the configuration # Read the configuration
@ -244,4 +258,8 @@ class Fail2banCmdLine():
sys.exit(code) sys.exit(code)
# global exit handler: # global exit handler:
exit = Fail2banCmdLine.exit exit = Fail2banCmdLine.exit
class ExitException:
pass

View File

@ -29,7 +29,11 @@ from ..server.server import Server, ServerDaemonize
from ..server.utils import Utils from ..server.utils import Utils
from .fail2bancmdline import Fail2banCmdLine, logSys, exit from .fail2bancmdline import Fail2banCmdLine, logSys, exit
MAX_WAITTIME = 30
SERVER = "fail2ban-server" SERVER = "fail2ban-server"
## ##
# \mainpage Fail2Ban # \mainpage Fail2Ban
# #
@ -72,8 +76,15 @@ class Fail2banServer(Fail2banCmdLine):
@staticmethod @staticmethod
def startServerAsync(conf): def startServerAsync(conf):
# Forks the current process. # Directory of client (to try the first start from the same directory as client):
pid = os.fork() startdir = sys.path[0]
if startdir in ("", "."): # may be uresolved in test-cases, so get bin-directory:
startdir = os.path.dirname(sys.argv[0])
# Forks the current process, don't fork if async specified (ex: test cases)
pid = 0
frk = not conf["async"]
if frk:
pid = os.fork()
if pid == 0: if pid == 0:
args = list() args = list()
args.append(SERVER) args.append(SERVER)
@ -96,14 +107,20 @@ class Fail2banServer(Fail2banCmdLine):
try: try:
# Use the current directory. # Use the current directory.
exe = os.path.abspath(os.path.join(sys.path[0], SERVER)) exe = os.path.abspath(os.path.join(startdir, SERVER))
logSys.debug("Starting %r with args %r", exe, args) logSys.debug("Starting %r with args %r", exe, args)
os.execv(exe, args) if frk:
except OSError: os.execv(exe, args)
else:
os.spawnv(os.P_NOWAITO, exe, args)
except OSError as e:
try: try:
# Use the PATH env. # Use the PATH env.
logSys.warning("Initial start attempt failed. Starting %r with the same args", SERVER) logSys.warning("Initial start attempt failed (%s). Starting %r with the same args", e, SERVER)
os.execvp(SERVER, args) if frk:
os.execvp(SERVER, args)
else:
os.spawnvp(os.P_NOWAITO, SERVER, args)
except OSError: except OSError:
exit(-1) exit(-1)
@ -143,8 +160,8 @@ class Fail2banServer(Fail2banCmdLine):
phase = dict() phase = dict()
logSys.debug('Configure via async client thread') logSys.debug('Configure via async client thread')
cli.configureServer(async=True, phase=phase) cli.configureServer(async=True, phase=phase)
# wait up to 30 seconds, do not continue if configuration is not 100% valid: # wait up to MAX_WAITTIME, do not continue if configuration is not 100% valid:
Utils.wait_for(lambda: phase.get('ready', None) is not None, 30) Utils.wait_for(lambda: phase.get('ready', None) is not None, MAX_WAITTIME)
if not phase.get('start', False): if not phase.get('start', False):
return False return False
@ -158,7 +175,7 @@ class Fail2banServer(Fail2banCmdLine):
# wait for client answer "done": # wait for client answer "done":
if not async and cli: if not async and cli:
Utils.wait_for(lambda: phase.get('done', None) is not None, 30) Utils.wait_for(lambda: phase.get('done', None) is not None, MAX_WAITTIME)
if not phase.get('done', False): if not phase.get('done', False):
if server: if server:
server.quit() server.quit()
@ -179,9 +196,9 @@ class Fail2banServer(Fail2banCmdLine):
logSys.error("Could not start %s", SERVER) logSys.error("Could not start %s", SERVER)
exit(code) exit(code)
def exec_command_line(): # pragma: no cover - can't test main def exec_command_line(argv):
server = Fail2banServer() server = Fail2banServer()
if server.start(sys.argv): if server.start(argv):
exit(0) exit(0)
else: else:
exit(-1) exit(-1)

View File

@ -26,6 +26,9 @@ __license__ = "GPL"
import textwrap import textwrap
def output(s):
print(s)
## ##
# Describes the protocol used to communicate with the server. # Describes the protocol used to communicate with the server.
@ -143,7 +146,7 @@ def printFormatted():
firstHeading = False firstHeading = False
for m in protocol: for m in protocol:
if m[0] == '' and firstHeading: if m[0] == '' and firstHeading:
print output("")
firstHeading = True firstHeading = True
first = True first = True
if len(m[0]) >= MARGIN: if len(m[0]) >= MARGIN:
@ -154,7 +157,7 @@ def printFormatted():
first = False first = False
else: else:
line = ' ' * (INDENT + MARGIN) + n.strip() line = ' ' * (INDENT + MARGIN) + n.strip()
print line output(line)
## ##
@ -165,20 +168,20 @@ def printWiki():
for m in protocol: for m in protocol:
if m[0] == '': if m[0] == '':
if firstHeading: if firstHeading:
print "|}" output("|}")
__printWikiHeader(m[1], m[2]) __printWikiHeader(m[1], m[2])
firstHeading = True firstHeading = True
else: else:
print "|-" output("|-")
print "| <span style=\"white-space:nowrap;\"><tt>" + m[0] + "</tt></span> || || " + m[1] output("| <span style=\"white-space:nowrap;\"><tt>" + m[0] + "</tt></span> || || " + m[1])
print "|}" output("|}")
def __printWikiHeader(section, desc): def __printWikiHeader(section, desc):
print output("")
print "=== " + section + " ===" output("=== " + section + " ===")
print output("")
print desc output(desc)
print output("")
print "{|" output("{|")
print "| '''Command''' || || '''Description'''" output("| '''Command''' || || '''Description'''")

View File

@ -24,6 +24,7 @@ __author__ = "Cyril Jaquier"
__copyright__ = "Copyright (c) 2004 Cyril Jaquier" __copyright__ = "Copyright (c) 2004 Cyril Jaquier"
__license__ = "GPL" __license__ = "GPL"
import threading
from threading import Lock, RLock from threading import Lock, RLock
import logging import logging
import logging.handlers import logging.handlers
@ -42,6 +43,10 @@ from ..helpers import getLogger, excepthook
# Gets the instance of the logger. # Gets the instance of the logger.
logSys = getLogger(__name__) logSys = getLogger(__name__)
DEF_SYSLOGSOCKET = "auto"
DEF_LOGLEVEL = "INFO"
DEF_LOGTARGET = "STDOUT"
try: try:
from .database import Fail2BanDb from .database import Fail2BanDb
except ImportError: # pragma: no cover except ImportError: # pragma: no cover
@ -49,6 +54,10 @@ except ImportError: # pragma: no cover
Fail2BanDb = None Fail2BanDb = None
def _thread_name():
return threading.current_thread().__class__.__name__
class Server: class Server:
def __init__(self, daemon = False): def __init__(self, daemon = False):
@ -67,11 +76,7 @@ class Server:
'FreeBSD': '/var/run/log', 'FreeBSD': '/var/run/log',
'Linux': '/dev/log', 'Linux': '/dev/log',
} }
# todo: remove that, if test cases are fixed self.__prev_signals = {}
self.setSyslogSocket("auto")
# Set logging level
self.setLogLevel("INFO")
self.setLogTarget("STDOUT")
def __sigTERMhandler(self, signum, frame): def __sigTERMhandler(self, signum, frame):
logSys.debug("Caught signal %d. Exiting" % signum) logSys.debug("Caught signal %d. Exiting" % signum)
@ -93,9 +98,12 @@ class Server:
raise ServerInitializationError("Could not create daemon") raise ServerInitializationError("Could not create daemon")
# Set all logging parameters (or use default if not specified): # Set all logging parameters (or use default if not specified):
self.setSyslogSocket(conf.get("syslogsocket", self.__syslogSocket)) self.setSyslogSocket(conf.get("syslogsocket",
self.setLogLevel(conf.get("loglevel", self.__logLevel)) self.__syslogSocket if self.__syslogSocket is not None else DEF_SYSLOGSOCKET))
self.setLogTarget(conf.get("logtarget", self.__logTarget)) self.setLogLevel(conf.get("loglevel",
self.__logLevel if self.__logLevel is not None else DEF_LOGLEVEL))
self.setLogTarget(conf.get("logtarget",
self.__logTarget if self.__logTarget is not None else DEF_LOGTARGET))
logSys.info("-"*50) logSys.info("-"*50)
logSys.info("Starting Fail2ban v%s", version.version) logSys.info("Starting Fail2ban v%s", version.version)
@ -104,10 +112,10 @@ class Server:
logSys.info("Daemon started") logSys.info("Daemon started")
# Install signal handlers # Install signal handlers
signal.signal(signal.SIGTERM, self.__sigTERMhandler) if _thread_name() == '_MainThread':
signal.signal(signal.SIGINT, self.__sigTERMhandler) for s in (signal.SIGTERM, signal.SIGINT, signal.SIGUSR1):
signal.signal(signal.SIGUSR1, self.__sigUSR1handler) self.__prev_signals[s] = signal.getsignal(s)
signal.signal(s, self.__sigTERMhandler if s != signal.SIGUSR1 else self.__sigUSR1handler)
# Ensure unhandled exceptions are logged # Ensure unhandled exceptions are logged
sys.excepthook = excepthook sys.excepthook = excepthook
@ -150,6 +158,10 @@ class Server:
with self.__loggingLock: with self.__loggingLock:
logging.shutdown() logging.shutdown()
# Restore default signal handlers:
for s, sh in self.__prev_signals.iteritems():
signal.signal(s, sh)
def addJail(self, name, backend): def addJail(self, name, backend):
self.__jails.add(name, backend, self.__db) self.__jails.add(name, backend, self.__db)
if self.__db is not None: if self.__db is not None:
@ -405,8 +417,13 @@ class Server:
def setLogTarget(self, target): def setLogTarget(self, target):
with self.__loggingLock: with self.__loggingLock:
# don't set new handlers if already the same
# or if "INHERITED" (foreground worker of the test cases, to prevent stop logging):
if self.__logTarget == target: if self.__logTarget == target:
return True return True
if target == "INHERITED":
self.__logTarget = target
return True
# set a format which is simpler for console use # set a format which is simpler for console use
formatter = logging.Formatter("%(asctime)s %(name)-24s[%(process)d]: %(levelname)-7s %(message)s") formatter = logging.Formatter("%(asctime)s %(name)-24s[%(process)d]: %(levelname)-7s %(message)s")
if target == "SYSLOG": if target == "SYSLOG":
@ -549,7 +566,10 @@ class Server:
# We need to set this in the parent process, so it gets inherited by the # We need to set this in the parent process, so it gets inherited by the
# child process, and this makes sure that it is effect even if the parent # child process, and this makes sure that it is effect even if the parent
# terminates quickly. # terminates quickly.
signal.signal(signal.SIGHUP, signal.SIG_IGN) if _thread_name() == '_MainThread':
for s in (signal.SIGHUP,):
self.__prev_signals[s] = signal.getsignal(s)
signal.signal(s, signal.SIG_IGN)
try: try:
# Fork a child process so the parent can exit. This will return control # Fork a child process so the parent can exit. This will return control

View File

@ -0,0 +1,411 @@
# emacs: -*- mode: python; py-indent-offset: 4; indent-tabs-mode: t -*-
# vi: set ft=python sts=4 ts=4 sw=4 noet :
# This file is part of Fail2Ban.
#
# Fail2Ban is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 2 of the License, or
# (at your option) any later version.
#
# Fail2Ban is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with Fail2Ban; if not, write to the Free Software
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
# Fail2Ban developers
__author__ = "Serg Brester"
__copyright__ = "Copyright (c) 2014- Serg G. Brester (sebres), 2008- Fail2Ban Contributors"
__license__ = "GPL"
import fileinput
import os
import re
import time
import unittest
from threading import Thread
from ..client import fail2banclient, fail2banserver, fail2bancmdline
from ..client.fail2banclient import Fail2banClient, exec_command_line as _exec_client, VisualWait
from ..client.fail2banserver import Fail2banServer, exec_command_line as _exec_server
from .. import protocol
from ..server import server
from ..server.utils import Utils
from .utils import LogCaptureTestCase, logSys, withtmpdir, shutil, logging
STOCK_CONF_DIR = "config"
STOCK = os.path.exists(os.path.join(STOCK_CONF_DIR,'fail2ban.conf'))
TEST_CONF_DIR = os.path.join(os.path.dirname(__file__), "config")
if STOCK:
CONF_DIR = STOCK_CONF_DIR
else:
CONF_DIR = TEST_CONF_DIR
CLIENT = "fail2ban-client"
SERVER = "fail2ban-server"
BIN = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))), "bin")
MAX_WAITTIME = 10
MAX_WAITTIME = unittest.F2B.maxWaitTime(MAX_WAITTIME)
##
# Several wrappers and settings for proper testing:
#
fail2banclient.MAX_WAITTIME = \
fail2banserver.MAX_WAITTIME = MAX_WAITTIME
fail2bancmdline.logSys = \
fail2banclient.logSys = \
fail2banserver.logSys = logSys
LOG_LEVEL = logSys.level
server.DEF_LOGTARGET = "/dev/null"
def _test_output(*args):
logSys.info(args[0])
fail2bancmdline.output = \
fail2banclient.output = \
fail2banserver.output = \
protocol.output = _test_output
def _test_exit(code=0):
logSys.debug("Exit with code %s", code)
if code == 0:
raise ExitException()
else:
raise FailExitException()
fail2bancmdline.exit = \
fail2banclient.exit = \
fail2banserver.exit = _test_exit
INTERACT = []
def _test_raw_input(*args):
if len(INTERACT):
#print('--- interact command: ', INTERACT[0])
return INTERACT.pop(0)
else:
return "exit"
fail2banclient.raw_input = _test_raw_input
# prevents change logging params, log capturing, etc:
fail2bancmdline.PRODUCTION = False
class ExitException(fail2bancmdline.ExitException):
pass
class FailExitException(fail2bancmdline.ExitException):
pass
def _out_file(fn): # pragma: no cover
logSys.debug('---- ' + fn + ' ----')
for line in fileinput.input(fn):
line = line.rstrip('\n')
logSys.debug(line)
logSys.debug('-'*30)
def _start_params(tmp, use_stock=False, logtarget="/dev/null"):
cfg = tmp+"/config"
if use_stock and STOCK:
# copy config:
def ig_dirs(dir, files):
return [f for f in files if not os.path.isfile(os.path.join(dir, f))]
shutil.copytree(STOCK_CONF_DIR, cfg, ignore=ig_dirs)
os.symlink(STOCK_CONF_DIR+"/action.d", cfg+"/action.d")
os.symlink(STOCK_CONF_DIR+"/filter.d", cfg+"/filter.d")
# replace fail2ban params (database with memory):
r = re.compile(r'^dbfile\s*=')
for line in fileinput.input(cfg+"/fail2ban.conf", inplace=True):
line = line.rstrip('\n')
if r.match(line):
line = "dbfile = :memory:"
print(line)
# replace jail params (polling as backend to be fast in initialize):
r = re.compile(r'^backend\s*=')
for line in fileinput.input(cfg+"/jail.conf", inplace=True):
line = line.rstrip('\n')
if r.match(line):
line = "backend = polling"
print(line)
else:
# just empty config directory without anything (only fail2ban.conf/jail.conf):
os.mkdir(cfg)
f = open(cfg+"/fail2ban.conf", "wb")
f.write('\n'.join((
"[Definition]",
"loglevel = INFO",
"logtarget = " + logtarget,
"syslogsocket = auto",
"socket = "+tmp+"/f2b.sock",
"pidfile = "+tmp+"/f2b.pid",
"backend = polling",
"dbfile = :memory:",
"dbpurgeage = 1d",
"",
)))
f.close()
f = open(cfg+"/jail.conf", "wb")
f.write('\n'.join((
"[INCLUDES]", "",
"[DEFAULT]", "",
"",
)))
f.close()
if LOG_LEVEL < logging.DEBUG: # if HEAVYDEBUG
_out_file(cfg+"/fail2ban.conf")
_out_file(cfg+"/jail.conf")
# parameters:
return ("-c", cfg,
"--logtarget", logtarget, "--loglevel", "DEBUG", "--syslogsocket", "auto",
"-s", tmp+"/f2b.sock", "-p", tmp+"/f2b.pid")
class Fail2banClientTest(LogCaptureTestCase):
def setUp(self):
"""Call before every test case."""
LogCaptureTestCase.setUp(self)
def tearDown(self):
"""Call after every test case."""
LogCaptureTestCase.tearDown(self)
def testClientUsage(self):
self.assertRaises(ExitException, _exec_client,
(CLIENT, "-h",))
self.assertLogged("Usage: " + CLIENT)
self.assertLogged("Report bugs to ")
@withtmpdir
def testClientStartBackgroundInside(self, tmp):
# always add "--async" by start inside, should don't fork by async (not replace client with server, just start in new process)
# (we can't fork the test cases process):
startparams = _start_params(tmp, True)
# start:
self.assertRaises(ExitException, _exec_client,
(CLIENT, "--async", "-b") + startparams + ("start",))
self.assertLogged("Server ready")
self.assertLogged("Exit with code 0")
try:
self.assertRaises(ExitException, _exec_client,
(CLIENT,) + startparams + ("echo", "TEST-ECHO",))
self.assertRaises(FailExitException, _exec_client,
(CLIENT,) + startparams + ("~~unknown~cmd~failed~~",))
finally:
self.pruneLog()
# stop:
self.assertRaises(ExitException, _exec_client,
(CLIENT,) + startparams + ("stop",))
self.assertLogged("Shutdown successful")
self.assertLogged("Exit with code 0")
@withtmpdir
def testClientStartBackgroundCall(self, tmp):
global INTERACT
startparams = _start_params(tmp)
# start (without async in new process):
cmd = os.path.join(os.path.join(BIN), CLIENT)
logSys.debug('Start %s ...', cmd)
Utils.executeCmd((cmd,) + startparams + ("start",),
timeout=MAX_WAITTIME, shell=False, output=False)
self.pruneLog()
try:
# echo from client (inside):
self.assertRaises(ExitException, _exec_client,
(CLIENT,) + startparams + ("echo", "TEST-ECHO",))
self.assertLogged("TEST-ECHO")
self.assertLogged("Exit with code 0")
self.pruneLog()
# interactive client chat with started server:
INTERACT += [
"echo INTERACT-ECHO",
"status",
"exit"
]
self.assertRaises(ExitException, _exec_client,
(CLIENT,) + startparams + ("-i",))
self.assertLogged("INTERACT-ECHO")
self.assertLogged("Status", "Number of jail:")
self.assertLogged("Exit with code 0")
self.pruneLog()
# test reload and restart over interactive client:
INTERACT += [
"reload",
"restart",
"exit"
]
self.assertRaises(ExitException, _exec_client,
(CLIENT,) + startparams + ("-i",))
self.assertLogged("Reading config files:")
self.assertLogged("Shutdown successful")
self.assertLogged("Server ready")
self.assertLogged("Exit with code 0")
self.pruneLog()
finally:
self.pruneLog()
# stop:
self.assertRaises(ExitException, _exec_client,
(CLIENT,) + startparams + ("stop",))
self.assertLogged("Shutdown successful")
self.assertLogged("Exit with code 0")
def _testClientStartForeground(self, tmp, startparams, phase):
# start and wait to end (foreground):
phase['start'] = True
self.assertRaises(ExitException, _exec_client,
(CLIENT, "-f") + startparams + ("start",))
# end :
phase['end'] = True
@withtmpdir
def testClientStartForeground(self, tmp):
# started directly here, so prevent overwrite test cases logger with "INHERITED"
startparams = _start_params(tmp, logtarget="INHERITED")
# because foreground block execution - start it in thread:
phase = dict()
Thread(name="_TestCaseWorker",
target=Fail2banClientTest._testClientStartForeground, args=(self, tmp, startparams, phase)).start()
try:
# wait for start thread:
Utils.wait_for(lambda: phase.get('start', None) is not None, MAX_WAITTIME)
self.assertTrue(phase.get('start', None))
# wait for server (socket):
Utils.wait_for(lambda: os.path.exists(tmp+"/f2b.sock"), MAX_WAITTIME)
self.assertLogged("Starting communication")
self.assertRaises(ExitException, _exec_client,
(CLIENT,) + startparams + ("ping",))
self.assertRaises(FailExitException, _exec_client,
(CLIENT,) + startparams + ("~~unknown~cmd~failed~~",))
self.assertRaises(ExitException, _exec_client,
(CLIENT,) + startparams + ("echo", "TEST-ECHO",))
finally:
self.pruneLog()
# stop:
self.assertRaises(ExitException, _exec_client,
(CLIENT,) + startparams + ("stop",))
# wait for end:
Utils.wait_for(lambda: phase.get('end', None) is not None, MAX_WAITTIME)
self.assertTrue(phase.get('end', None))
self.assertLogged("Shutdown successful", "Exiting Fail2ban")
self.assertLogged("Exit with code 0")
@withtmpdir
def testClientFailStart(self, tmp):
self.assertRaises(FailExitException, _exec_client,
(CLIENT, "--async", "-c", tmp+"/miss", "start",))
self.assertLogged("Base configuration directory " + tmp+"/miss" + " does not exist")
self.assertRaises(FailExitException, _exec_client,
(CLIENT, "--async", "-c", CONF_DIR, "-s", tmp+"/miss/f2b.sock", "start",))
self.assertLogged("There is no directory " + tmp+"/miss" + " to contain the socket file")
def testVisualWait(self):
sleeptime = 0.035
for verbose in (2, 0):
cntr = 15
with VisualWait(verbose, 5) as vis:
while cntr:
vis.heartbeat()
if verbose and not unittest.F2B.fast:
time.sleep(sleeptime)
cntr -= 1
class Fail2banServerTest(LogCaptureTestCase):
def setUp(self):
"""Call before every test case."""
LogCaptureTestCase.setUp(self)
def tearDown(self):
"""Call after every test case."""
LogCaptureTestCase.tearDown(self)
def testServerUsage(self):
self.assertRaises(ExitException, _exec_server,
(SERVER, "-h",))
self.assertLogged("Usage: " + SERVER)
self.assertLogged("Report bugs to ")
@withtmpdir
def testServerStartBackground(self, tmp):
# don't add "--async" by start, because if will fork current process by daemonize
# (we can't fork the test cases process),
# because server started internal communication in new thread use INHERITED as logtarget here:
startparams = _start_params(tmp, logtarget="INHERITED")
# start:
self.assertRaises(ExitException, _exec_server,
(SERVER, "-b") + startparams)
self.assertLogged("Server ready")
self.assertLogged("Exit with code 0")
try:
self.assertRaises(ExitException, _exec_server,
(SERVER,) + startparams + ("echo", "TEST-ECHO",))
self.assertRaises(FailExitException, _exec_server,
(SERVER,) + startparams + ("~~unknown~cmd~failed~~",))
finally:
self.pruneLog()
# stop:
self.assertRaises(ExitException, _exec_server,
(SERVER,) + startparams + ("stop",))
self.assertLogged("Shutdown successful")
self.assertLogged("Exit with code 0")
def _testServerStartForeground(self, tmp, startparams, phase):
# start and wait to end (foreground):
phase['start'] = True
self.assertRaises(ExitException, _exec_server,
(SERVER, "-f") + startparams + ("start",))
# end :
phase['end'] = True
@withtmpdir
def testServerStartForeground(self, tmp):
# started directly here, so prevent overwrite test cases logger with "INHERITED"
startparams = _start_params(tmp, logtarget="INHERITED")
# because foreground block execution - start it in thread:
phase = dict()
Thread(name="_TestCaseWorker",
target=Fail2banServerTest._testServerStartForeground, args=(self, tmp, startparams, phase)).start()
try:
# wait for start thread:
Utils.wait_for(lambda: phase.get('start', None) is not None, MAX_WAITTIME)
self.assertTrue(phase.get('start', None))
# wait for server (socket):
Utils.wait_for(lambda: os.path.exists(tmp+"/f2b.sock"), MAX_WAITTIME)
self.assertLogged("Starting communication")
self.assertRaises(ExitException, _exec_server,
(SERVER,) + startparams + ("ping",))
self.assertRaises(FailExitException, _exec_server,
(SERVER,) + startparams + ("~~unknown~cmd~failed~~",))
self.assertRaises(ExitException, _exec_server,
(SERVER,) + startparams + ("echo", "TEST-ECHO",))
finally:
self.pruneLog()
# stop:
self.assertRaises(ExitException, _exec_server,
(SERVER,) + startparams + ("stop",))
# wait for end:
Utils.wait_for(lambda: phase.get('end', None) is not None, MAX_WAITTIME)
self.assertTrue(phase.get('end', None))
self.assertLogged("Shutdown successful", "Exiting Fail2ban")
self.assertLogged("Exit with code 0")
@withtmpdir
def testServerFailStart(self, tmp):
self.assertRaises(FailExitException, _exec_server,
(SERVER, "-c", tmp+"/miss",))
self.assertLogged("Base configuration directory " + tmp+"/miss" + " does not exist")
self.assertRaises(FailExitException, _exec_server,
(SERVER, "-c", CONF_DIR, "-s", tmp+"/miss/f2b.sock",))
self.assertLogged("There is no directory " + tmp+"/miss" + " to contain the socket file")

View File

@ -23,19 +23,7 @@ __author__ = "Serg Brester"
__copyright__ = "Copyright (c) 2015 Serg G. Brester (sebres), 2008- Fail2Ban Contributors" __copyright__ = "Copyright (c) 2015 Serg G. Brester (sebres), 2008- Fail2Ban Contributors"
__license__ = "GPL" __license__ = "GPL"
from __builtin__ import open as fopen
import unittest
import getpass
import os import os
import sys
import time
import tempfile
import uuid
try:
from systemd import journal
except ImportError:
journal = None
from ..client import fail2banregex from ..client import fail2banregex
from ..client.fail2banregex import Fail2banRegex, get_opt_parser, output from ..client.fail2banregex import Fail2banRegex, get_opt_parser, output

View File

@ -26,10 +26,14 @@ import logging
import optparse import optparse
import os import os
import re import re
import tempfile
import shutil
import sys import sys
import time import time
import unittest import unittest
from StringIO import StringIO from StringIO import StringIO
from functools import wraps
from ..helpers import getLogger from ..helpers import getLogger
from ..server.ipdns import DNSUtils from ..server.ipdns import DNSUtils
@ -71,6 +75,17 @@ class F2B(optparse.Values):
return wtime return wtime
def withtmpdir(f):
@wraps(f)
def wrapper(self, *args, **kwargs):
tmp = tempfile.mkdtemp(prefix="f2b-temp")
try:
return f(self, tmp, *args, **kwargs)
finally:
# clean up
shutil.rmtree(tmp)
return wrapper
def initTests(opts): def initTests(opts):
unittest.F2B = F2B(opts) unittest.F2B = F2B(opts)
# --fast : # --fast :
@ -156,6 +171,7 @@ def gatherTests(regexps=None, opts=None):
from . import misctestcase from . import misctestcase
from . import databasetestcase from . import databasetestcase
from . import samplestestcase from . import samplestestcase
from . import fail2banclienttestcase
from . import fail2banregextestcase from . import fail2banregextestcase
if not regexps: # pragma: no cover if not regexps: # pragma: no cover
@ -239,6 +255,9 @@ def gatherTests(regexps=None, opts=None):
# Filter Regex tests with sample logs # Filter Regex tests with sample logs
tests.addTest(unittest.makeSuite(samplestestcase.FilterSamplesRegex)) tests.addTest(unittest.makeSuite(samplestestcase.FilterSamplesRegex))
# bin/fail2ban-client, bin/fail2ban-server
tests.addTest(unittest.makeSuite(fail2banclienttestcase.Fail2banClientTest))
tests.addTest(unittest.makeSuite(fail2banclienttestcase.Fail2banServerTest))
# bin/fail2ban-regex # bin/fail2ban-regex
tests.addTest(unittest.makeSuite(fail2banregextestcase.Fail2banRegexTest)) tests.addTest(unittest.makeSuite(fail2banregextestcase.Fail2banRegexTest))
@ -321,8 +340,11 @@ class LogCaptureTestCase(unittest.TestCase):
# Let's log everything into a string # Let's log everything into a string
self._log = StringIO() self._log = StringIO()
logSys.handlers = [logging.StreamHandler(self._log)] logSys.handlers = [logging.StreamHandler(self._log)]
if self._old_level <= logging.DEBUG:
print("")
if self._old_level < logging.DEBUG: # so if HEAVYDEBUG etc -- show them! if self._old_level < logging.DEBUG: # so if HEAVYDEBUG etc -- show them!
logSys.handlers += self._old_handlers logSys.handlers += self._old_handlers
logSys.debug('--'*40)
logSys.setLevel(getattr(logging, 'DEBUG')) logSys.setLevel(getattr(logging, 'DEBUG'))
def tearDown(self): def tearDown(self):
@ -386,6 +408,9 @@ class LogCaptureTestCase(unittest.TestCase):
def pruneLog(self): def pruneLog(self):
self._log.truncate(0) self._log.truncate(0)
def pruneLog(self):
self._log.truncate(0)
def getLog(self): def getLog(self):
return self._log.getvalue() return self._log.getvalue()