mirror of https://github.com/fail2ban/fail2ban
small code review and fixing of some bugs during client-server communication process (in the test cases);pull/1321/head
parent
4d696d69a0
commit
f120877756
|
@ -31,7 +31,7 @@ __author__ = "Fail2Ban Developers"
|
|||
__copyright__ = "Copyright (c) 2004-2008 Cyril Jaquier, 2012-2014 Yaroslav Halchenko, 2014-2016 Serg G. Brester"
|
||||
__license__ = "GPL"
|
||||
|
||||
from fail2ban.client.fail2banclient import exec_command_line
|
||||
from fail2ban.client.fail2banclient import exec_command_line, sys
|
||||
|
||||
if __name__ == "__main__":
|
||||
exec_command_line()
|
||||
exec_command_line(sys.argv)
|
||||
|
|
|
@ -31,7 +31,7 @@ __author__ = "Fail2Ban Developers"
|
|||
__copyright__ = "Copyright (c) 2004-2008 Cyril Jaquier, 2012-2014 Yaroslav Halchenko, 2014-2016 Serg G. Brester"
|
||||
__license__ = "GPL"
|
||||
|
||||
from fail2ban.client.fail2banserver import exec_command_line
|
||||
from fail2ban.client.fail2banserver import exec_command_line, sys
|
||||
|
||||
if __name__ == "__main__":
|
||||
exec_command_line()
|
||||
exec_command_line(sys.argv)
|
||||
|
|
|
@ -28,12 +28,20 @@ import socket
|
|||
import sys
|
||||
import time
|
||||
|
||||
import threading
|
||||
from threading import Thread
|
||||
|
||||
from ..version import version
|
||||
from .csocket import CSocket
|
||||
from .beautifier import Beautifier
|
||||
from .fail2bancmdline import Fail2banCmdLine, logSys, exit
|
||||
from .fail2bancmdline import Fail2banCmdLine, ExitException, logSys, exit, output
|
||||
|
||||
MAX_WAITTIME = 30
|
||||
|
||||
|
||||
def _thread_name():
|
||||
return threading.current_thread().__class__.__name__
|
||||
|
||||
|
||||
##
|
||||
#
|
||||
|
@ -51,13 +59,13 @@ class Fail2banClient(Fail2banCmdLine, Thread):
|
|||
self._beautifier = None
|
||||
|
||||
def dispInteractive(self):
|
||||
print "Fail2Ban v" + version + " reads log file that contains password failure report"
|
||||
print "and bans the corresponding IP addresses using firewall rules."
|
||||
print
|
||||
output("Fail2Ban v" + version + " reads log file that contains password failure report")
|
||||
output("and bans the corresponding IP addresses using firewall rules.")
|
||||
output("")
|
||||
|
||||
def __sigTERMhandler(self, signum, frame):
|
||||
# Print a new line because we probably come from wait
|
||||
print
|
||||
output("")
|
||||
logSys.warning("Caught signal %d. Exiting" % signum)
|
||||
exit(-1)
|
||||
|
||||
|
@ -85,11 +93,11 @@ class Fail2banClient(Fail2banCmdLine, Thread):
|
|||
if ret[0] == 0:
|
||||
logSys.debug("OK : " + `ret[1]`)
|
||||
if showRet or c[0] == 'echo':
|
||||
print beautifier.beautify(ret[1])
|
||||
output(beautifier.beautify(ret[1]))
|
||||
else:
|
||||
logSys.error("NOK: " + `ret[1].args`)
|
||||
if showRet:
|
||||
print beautifier.beautifyError(ret[1])
|
||||
output(beautifier.beautifyError(ret[1]))
|
||||
streamRet = False
|
||||
except socket.error:
|
||||
if showRet or self._conf["verbose"] > 1:
|
||||
|
@ -182,10 +190,13 @@ class Fail2banClient(Fail2banCmdLine, Thread):
|
|||
# Start server direct here in main thread (not fork):
|
||||
self._server = Fail2banServer.startServerDirect(self._conf, False)
|
||||
|
||||
except ExitException:
|
||||
pass
|
||||
except Exception as e:
|
||||
print
|
||||
logSys.error("Exception while starting server foreground")
|
||||
output("")
|
||||
logSys.error("Exception while starting server " + ("background" if background else "foreground"))
|
||||
logSys.error(e)
|
||||
return False
|
||||
finally:
|
||||
self._alive = False
|
||||
|
||||
|
@ -229,18 +240,18 @@ class Fail2banClient(Fail2banCmdLine, Thread):
|
|||
elif len(cmd) == 1 and cmd[0] == "restart":
|
||||
|
||||
if self._conf.get("interactive", False):
|
||||
print(' ## stop ... ')
|
||||
output(' ## stop ... ')
|
||||
self.__processCommand(['stop'])
|
||||
self.__waitOnServer(False)
|
||||
# in interactive mode reset config, to make full-reload if there something changed:
|
||||
if self._conf.get("interactive", False):
|
||||
print(' ## load configuration ... ')
|
||||
output(' ## load configuration ... ')
|
||||
self.resetConf()
|
||||
ret = self.initCmdLine(self._argv)
|
||||
if ret is not None:
|
||||
return ret
|
||||
if self._conf.get("interactive", False):
|
||||
print(' ## start ... ')
|
||||
output(' ## start ... ')
|
||||
return self.__processCommand(['start'])
|
||||
|
||||
elif len(cmd) >= 1 and cmd[0] == "reload":
|
||||
|
@ -283,7 +294,9 @@ class Fail2banClient(Fail2banCmdLine, Thread):
|
|||
return False
|
||||
return True
|
||||
|
||||
def __waitOnServer(self, alive=True, maxtime=30):
|
||||
def __waitOnServer(self, alive=True, maxtime=None):
|
||||
if maxtime is None:
|
||||
maxtime = MAX_WAITTIME
|
||||
# Wait for the server to start (the server has 30 seconds to answer ping)
|
||||
starttime = time.time()
|
||||
with VisualWait(self._conf["verbose"]) as vis:
|
||||
|
@ -301,53 +314,59 @@ class Fail2banClient(Fail2banCmdLine, Thread):
|
|||
|
||||
def start(self, argv):
|
||||
# Install signal handlers
|
||||
signal.signal(signal.SIGTERM, self.__sigTERMhandler)
|
||||
signal.signal(signal.SIGINT, self.__sigTERMhandler)
|
||||
_prev_signals = {}
|
||||
if _thread_name() == '_MainThread':
|
||||
for s in (signal.SIGTERM, signal.SIGINT):
|
||||
_prev_signals[s] = signal.getsignal(s)
|
||||
signal.signal(s, self.__sigTERMhandler)
|
||||
try:
|
||||
# Command line options
|
||||
if self._argv is None:
|
||||
ret = self.initCmdLine(argv)
|
||||
if ret is not None:
|
||||
return ret
|
||||
|
||||
# Command line options
|
||||
if self._argv is None:
|
||||
ret = self.initCmdLine(argv)
|
||||
if ret is not None:
|
||||
return ret
|
||||
# Commands
|
||||
args = self._args
|
||||
|
||||
# Commands
|
||||
args = self._args
|
||||
|
||||
# Interactive mode
|
||||
if self._conf.get("interactive", False):
|
||||
try:
|
||||
import readline
|
||||
except ImportError:
|
||||
logSys.error("Readline not available")
|
||||
return False
|
||||
try:
|
||||
ret = True
|
||||
if len(args) > 0:
|
||||
ret = self.__processCommand(args)
|
||||
if ret:
|
||||
readline.parse_and_bind("tab: complete")
|
||||
self.dispInteractive()
|
||||
while True:
|
||||
cmd = raw_input(self.PROMPT)
|
||||
if cmd == "exit" or cmd == "quit":
|
||||
# Exit
|
||||
return True
|
||||
if cmd == "help":
|
||||
self.dispUsage()
|
||||
elif not cmd == "":
|
||||
try:
|
||||
self.__processCommand(shlex.split(cmd))
|
||||
except Exception, e:
|
||||
logSys.error(e)
|
||||
except (EOFError, KeyboardInterrupt):
|
||||
print
|
||||
return True
|
||||
# Single command mode
|
||||
else:
|
||||
if len(args) < 1:
|
||||
self.dispUsage()
|
||||
return False
|
||||
return self.__processCommand(args)
|
||||
# Interactive mode
|
||||
if self._conf.get("interactive", False):
|
||||
try:
|
||||
import readline
|
||||
except ImportError:
|
||||
logSys.error("Readline not available")
|
||||
return False
|
||||
try:
|
||||
ret = True
|
||||
if len(args) > 0:
|
||||
ret = self.__processCommand(args)
|
||||
if ret:
|
||||
readline.parse_and_bind("tab: complete")
|
||||
self.dispInteractive()
|
||||
while True:
|
||||
cmd = raw_input(self.PROMPT)
|
||||
if cmd == "exit" or cmd == "quit":
|
||||
# Exit
|
||||
return True
|
||||
if cmd == "help":
|
||||
self.dispUsage()
|
||||
elif not cmd == "":
|
||||
try:
|
||||
self.__processCommand(shlex.split(cmd))
|
||||
except Exception, e:
|
||||
logSys.error(e)
|
||||
except (EOFError, KeyboardInterrupt):
|
||||
output("")
|
||||
return True
|
||||
# Single command mode
|
||||
else:
|
||||
if len(args) < 1:
|
||||
self.dispUsage()
|
||||
return False
|
||||
return self.__processCommand(args)
|
||||
finally:
|
||||
for s, sh in _prev_signals.iteritems():
|
||||
signal.signal(s, sh)
|
||||
|
||||
|
||||
class ServerExecutionException(Exception):
|
||||
|
@ -361,7 +380,8 @@ class ServerExecutionException(Exception):
|
|||
class _VisualWait:
|
||||
pos = 0
|
||||
delta = 1
|
||||
maxpos = 10
|
||||
def __init__(self, maxpos=10):
|
||||
self.maxpos = maxpos
|
||||
def __enter__(self):
|
||||
return self
|
||||
def __exit__(self, *args):
|
||||
|
@ -390,14 +410,14 @@ class _NotVisualWait:
|
|||
def heartbeat(self):
|
||||
pass
|
||||
|
||||
def VisualWait(verbose):
|
||||
return _VisualWait() if verbose > 1 else _NotVisualWait()
|
||||
def VisualWait(verbose, *args, **kwargs):
|
||||
return _VisualWait(*args, **kwargs) if verbose > 1 else _NotVisualWait()
|
||||
|
||||
|
||||
def exec_command_line(): # pragma: no cover - can't test main
|
||||
def exec_command_line(argv):
|
||||
client = Fail2banClient()
|
||||
# Exit with correct return value
|
||||
if client.start(sys.argv):
|
||||
if client.start(argv):
|
||||
exit(0)
|
||||
else:
|
||||
exit(-1)
|
||||
|
|
|
@ -33,7 +33,12 @@ from ..helpers import getLogger
|
|||
# Gets the instance of the logger.
|
||||
logSys = getLogger("fail2ban")
|
||||
|
||||
def output(s):
|
||||
print(s)
|
||||
|
||||
CONFIG_PARAMS = ("socket", "pidfile", "logtarget", "loglevel", "syslogsocket",)
|
||||
# Used to signal - we are in test cases (ex: prevents change logging params, log capturing, etc)
|
||||
PRODUCTION = True
|
||||
|
||||
|
||||
class Fail2banCmdLine():
|
||||
|
@ -71,50 +76,50 @@ class Fail2banCmdLine():
|
|||
self.__dict__[o] = obj.__dict__[o]
|
||||
|
||||
def dispVersion(self):
|
||||
print "Fail2Ban v" + version
|
||||
print
|
||||
print "Copyright (c) 2004-2008 Cyril Jaquier, 2008- Fail2Ban Contributors"
|
||||
print "Copyright of modifications held by their respective authors."
|
||||
print "Licensed under the GNU General Public License v2 (GPL)."
|
||||
print
|
||||
print "Written by Cyril Jaquier <cyril.jaquier@fail2ban.org>."
|
||||
print "Many contributions by Yaroslav O. Halchenko <debian@onerussian.com>."
|
||||
output("Fail2Ban v" + version)
|
||||
output("")
|
||||
output("Copyright (c) 2004-2008 Cyril Jaquier, 2008- Fail2Ban Contributors")
|
||||
output("Copyright of modifications held by their respective authors.")
|
||||
output("Licensed under the GNU General Public License v2 (GPL).")
|
||||
output("")
|
||||
output("Written by Cyril Jaquier <cyril.jaquier@fail2ban.org>.")
|
||||
output("Many contributions by Yaroslav O. Halchenko <debian@onerussian.com>.")
|
||||
|
||||
def dispUsage(self):
|
||||
""" Prints Fail2Ban command line options and exits
|
||||
"""
|
||||
caller = os.path.basename(self._argv[0])
|
||||
print "Usage: "+caller+" [OPTIONS]" + (" <COMMAND>" if not caller.endswith('server') else "")
|
||||
print
|
||||
print "Fail2Ban v" + version + " reads log file that contains password failure report"
|
||||
print "and bans the corresponding IP addresses using firewall rules."
|
||||
print
|
||||
print "Options:"
|
||||
print " -c <DIR> configuration directory"
|
||||
print " -s <FILE> socket path"
|
||||
print " -p <FILE> pidfile path"
|
||||
print " --loglevel <LEVEL> logging level"
|
||||
print " --logtarget <FILE>|STDOUT|STDERR|SYSLOG"
|
||||
print " --syslogsocket auto|<FILE>"
|
||||
print " -d dump configuration. For debugging"
|
||||
print " -i interactive mode"
|
||||
print " -v increase verbosity"
|
||||
print " -q decrease verbosity"
|
||||
print " -x force execution of the server (remove socket file)"
|
||||
print " -b start server in background (default)"
|
||||
print " -f start server in foreground"
|
||||
print " --async start server in async mode (for internal usage only, don't read configuration)"
|
||||
print " -h, --help display this help message"
|
||||
print " -V, --version print the version"
|
||||
output("Usage: "+caller+" [OPTIONS]" + (" <COMMAND>" if not caller.endswith('server') else ""))
|
||||
output("")
|
||||
output("Fail2Ban v" + version + " reads log file that contains password failure report")
|
||||
output("and bans the corresponding IP addresses using firewall rules.")
|
||||
output("")
|
||||
output("Options:")
|
||||
output(" -c <DIR> configuration directory")
|
||||
output(" -s <FILE> socket path")
|
||||
output(" -p <FILE> pidfile path")
|
||||
output(" --loglevel <LEVEL> logging level")
|
||||
output(" --logtarget <FILE>|STDOUT|STDERR|SYSLOG")
|
||||
output(" --syslogsocket auto|<FILE>")
|
||||
output(" -d dump configuration. For debugging")
|
||||
output(" -i interactive mode")
|
||||
output(" -v increase verbosity")
|
||||
output(" -q decrease verbosity")
|
||||
output(" -x force execution of the server (remove socket file)")
|
||||
output(" -b start server in background (default)")
|
||||
output(" -f start server in foreground")
|
||||
output(" --async start server in async mode (for internal usage only, don't read configuration)")
|
||||
output(" -h, --help display this help message")
|
||||
output(" -V, --version print the version")
|
||||
|
||||
if not caller.endswith('server'):
|
||||
print
|
||||
print "Command:"
|
||||
output("")
|
||||
output("Command:")
|
||||
# Prints the protocol
|
||||
printFormatted()
|
||||
|
||||
print
|
||||
print "Report bugs to https://github.com/fail2ban/fail2ban/issues"
|
||||
output("")
|
||||
output("Report bugs to https://github.com/fail2ban/fail2ban/issues")
|
||||
|
||||
def __getCmdLineOptions(self, optList):
|
||||
""" Gets the command line options
|
||||
|
@ -147,69 +152,78 @@ class Fail2banCmdLine():
|
|||
self._conf["background"] = False
|
||||
elif o in ["-h", "--help"]:
|
||||
self.dispUsage()
|
||||
exit(0)
|
||||
return True
|
||||
elif o in ["-V", "--version"]:
|
||||
self.dispVersion()
|
||||
exit(0)
|
||||
return True
|
||||
return None
|
||||
|
||||
def initCmdLine(self, argv):
|
||||
# First time?
|
||||
initial = (self._argv is None)
|
||||
|
||||
# Command line options
|
||||
self._argv = argv
|
||||
|
||||
# Reads the command line options.
|
||||
try:
|
||||
cmdOpts = 'hc:s:p:xfbdviqV'
|
||||
cmdLongOpts = ['loglevel=', 'logtarget=', 'syslogsocket=', 'async', 'help', 'version']
|
||||
optList, self._args = getopt.getopt(self._argv[1:], cmdOpts, cmdLongOpts)
|
||||
except getopt.GetoptError:
|
||||
self.dispUsage()
|
||||
exit(-1)
|
||||
# First time?
|
||||
initial = (self._argv is None)
|
||||
|
||||
self.__getCmdLineOptions(optList)
|
||||
# Command line options
|
||||
self._argv = argv
|
||||
logSys.info("Using start params %s", argv[1:])
|
||||
|
||||
if initial:
|
||||
verbose = self._conf["verbose"]
|
||||
if verbose <= 0:
|
||||
logSys.setLevel(logging.ERROR)
|
||||
elif verbose == 1:
|
||||
logSys.setLevel(logging.WARNING)
|
||||
elif verbose == 2:
|
||||
logSys.setLevel(logging.INFO)
|
||||
else:
|
||||
logSys.setLevel(logging.DEBUG)
|
||||
# Add the default logging handler to dump to stderr
|
||||
logout = logging.StreamHandler(sys.stderr)
|
||||
# set a format which is simpler for console use
|
||||
formatter = logging.Formatter('%(levelname)-6s %(message)s')
|
||||
# tell the handler to use this format
|
||||
logout.setFormatter(formatter)
|
||||
logSys.addHandler(logout)
|
||||
# Reads the command line options.
|
||||
try:
|
||||
cmdOpts = 'hc:s:p:xfbdviqV'
|
||||
cmdLongOpts = ['loglevel=', 'logtarget=', 'syslogsocket=', 'async', 'help', 'version']
|
||||
optList, self._args = getopt.getopt(self._argv[1:], cmdOpts, cmdLongOpts)
|
||||
except getopt.GetoptError:
|
||||
self.dispUsage()
|
||||
return False
|
||||
|
||||
# Set expected parameters (like socket, pidfile, etc) from configuration,
|
||||
# if those not yet specified, in which read configuration only if needed here:
|
||||
conf = None
|
||||
for o in CONFIG_PARAMS:
|
||||
if self._conf.get(o, None) is None:
|
||||
if not conf:
|
||||
self.configurator.readEarly()
|
||||
conf = self.configurator.getEarlyOptions()
|
||||
self._conf[o] = conf[o]
|
||||
ret = self.__getCmdLineOptions(optList)
|
||||
if ret is not None:
|
||||
return ret
|
||||
|
||||
logSys.info("Using socket file %s", self._conf["socket"])
|
||||
if initial and PRODUCTION: # pragma: no cover - can't test
|
||||
verbose = self._conf["verbose"]
|
||||
if verbose <= 0:
|
||||
logSys.setLevel(logging.ERROR)
|
||||
elif verbose == 1:
|
||||
logSys.setLevel(logging.WARNING)
|
||||
elif verbose == 2:
|
||||
logSys.setLevel(logging.INFO)
|
||||
else:
|
||||
logSys.setLevel(logging.DEBUG)
|
||||
# Add the default logging handler to dump to stderr
|
||||
logout = logging.StreamHandler(sys.stderr)
|
||||
# set a format which is simpler for console use
|
||||
formatter = logging.Formatter('%(levelname)-6s %(message)s')
|
||||
# tell the handler to use this format
|
||||
logout.setFormatter(formatter)
|
||||
logSys.addHandler(logout)
|
||||
|
||||
logSys.info("Using pid file %s, [%s] logging to %s",
|
||||
self._conf["pidfile"], self._conf["loglevel"], self._conf["logtarget"])
|
||||
# Set expected parameters (like socket, pidfile, etc) from configuration,
|
||||
# if those not yet specified, in which read configuration only if needed here:
|
||||
conf = None
|
||||
for o in CONFIG_PARAMS:
|
||||
if self._conf.get(o, None) is None:
|
||||
if not conf:
|
||||
self.configurator.readEarly()
|
||||
conf = self.configurator.getEarlyOptions()
|
||||
self._conf[o] = conf[o]
|
||||
|
||||
if self._conf.get("dump", False):
|
||||
ret, stream = self.readConfig()
|
||||
self.dumpConfig(stream)
|
||||
return ret
|
||||
logSys.info("Using socket file %s", self._conf["socket"])
|
||||
|
||||
# Nothing to do here, process in client/server
|
||||
return None
|
||||
logSys.info("Using pid file %s, [%s] logging to %s",
|
||||
self._conf["pidfile"], self._conf["loglevel"], self._conf["logtarget"])
|
||||
|
||||
if self._conf.get("dump", False):
|
||||
ret, stream = self.readConfig()
|
||||
self.dumpConfig(stream)
|
||||
return ret
|
||||
|
||||
# Nothing to do here, process in client/server
|
||||
return None
|
||||
except Exception as e:
|
||||
output("ERROR: %s" % (e,))
|
||||
#logSys.exception(e)
|
||||
return False
|
||||
|
||||
def readConfig(self, jail=None):
|
||||
# Read the configuration
|
||||
|
@ -242,4 +256,8 @@ class Fail2banCmdLine():
|
|||
sys.exit(code)
|
||||
|
||||
# global exit handler:
|
||||
exit = Fail2banCmdLine.exit
|
||||
exit = Fail2banCmdLine.exit
|
||||
|
||||
|
||||
class ExitException:
|
||||
pass
|
|
@ -29,7 +29,11 @@ from ..server.server import Server, ServerDaemonize
|
|||
from ..server.utils import Utils
|
||||
from .fail2bancmdline import Fail2banCmdLine, logSys, exit
|
||||
|
||||
MAX_WAITTIME = 30
|
||||
|
||||
SERVER = "fail2ban-server"
|
||||
|
||||
|
||||
##
|
||||
# \mainpage Fail2Ban
|
||||
#
|
||||
|
@ -72,8 +76,15 @@ class Fail2banServer(Fail2banCmdLine):
|
|||
|
||||
@staticmethod
|
||||
def startServerAsync(conf):
|
||||
# Forks the current process.
|
||||
pid = os.fork()
|
||||
# Directory of client (to try the first start from the same directory as client):
|
||||
startdir = sys.path[0]
|
||||
if startdir in ("", "."): # may be uresolved in test-cases, so get bin-directory:
|
||||
startdir = os.path.dirname(sys.argv[0])
|
||||
# Forks the current process, don't fork if async specified (ex: test cases)
|
||||
pid = 0
|
||||
frk = not conf["async"]
|
||||
if frk:
|
||||
pid = os.fork()
|
||||
if pid == 0:
|
||||
args = list()
|
||||
args.append(SERVER)
|
||||
|
@ -96,14 +107,20 @@ class Fail2banServer(Fail2banCmdLine):
|
|||
|
||||
try:
|
||||
# Use the current directory.
|
||||
exe = os.path.abspath(os.path.join(sys.path[0], SERVER))
|
||||
exe = os.path.abspath(os.path.join(startdir, SERVER))
|
||||
logSys.debug("Starting %r with args %r", exe, args)
|
||||
os.execv(exe, args)
|
||||
except OSError:
|
||||
if frk:
|
||||
os.execv(exe, args)
|
||||
else:
|
||||
os.spawnv(os.P_NOWAITO, exe, args)
|
||||
except OSError as e:
|
||||
try:
|
||||
# Use the PATH env.
|
||||
logSys.warning("Initial start attempt failed. Starting %r with the same args", SERVER)
|
||||
os.execvp(SERVER, args)
|
||||
logSys.warning("Initial start attempt failed (%s). Starting %r with the same args", e, SERVER)
|
||||
if frk:
|
||||
os.execvp(SERVER, args)
|
||||
else:
|
||||
os.spawnvp(os.P_NOWAITO, SERVER, args)
|
||||
except OSError:
|
||||
exit(-1)
|
||||
|
||||
|
@ -143,8 +160,8 @@ class Fail2banServer(Fail2banCmdLine):
|
|||
phase = dict()
|
||||
logSys.debug('Configure via async client thread')
|
||||
cli.configureServer(async=True, phase=phase)
|
||||
# wait up to 30 seconds, do not continue if configuration is not 100% valid:
|
||||
Utils.wait_for(lambda: phase.get('ready', None) is not None, 30)
|
||||
# wait up to MAX_WAITTIME, do not continue if configuration is not 100% valid:
|
||||
Utils.wait_for(lambda: phase.get('ready', None) is not None, MAX_WAITTIME)
|
||||
if not phase.get('start', False):
|
||||
return False
|
||||
|
||||
|
@ -158,7 +175,7 @@ class Fail2banServer(Fail2banCmdLine):
|
|||
|
||||
# wait for client answer "done":
|
||||
if not async and cli:
|
||||
Utils.wait_for(lambda: phase.get('done', None) is not None, 30)
|
||||
Utils.wait_for(lambda: phase.get('done', None) is not None, MAX_WAITTIME)
|
||||
if not phase.get('done', False):
|
||||
if server:
|
||||
server.quit()
|
||||
|
@ -179,9 +196,9 @@ class Fail2banServer(Fail2banCmdLine):
|
|||
logSys.error("Could not start %s", SERVER)
|
||||
exit(code)
|
||||
|
||||
def exec_command_line(): # pragma: no cover - can't test main
|
||||
def exec_command_line(argv):
|
||||
server = Fail2banServer()
|
||||
if server.start(sys.argv):
|
||||
if server.start(argv):
|
||||
exit(0)
|
||||
else:
|
||||
exit(-1)
|
||||
|
|
|
@ -26,6 +26,9 @@ __license__ = "GPL"
|
|||
|
||||
import textwrap
|
||||
|
||||
def output(s):
|
||||
print(s)
|
||||
|
||||
##
|
||||
# Describes the protocol used to communicate with the server.
|
||||
|
||||
|
@ -143,7 +146,7 @@ def printFormatted():
|
|||
firstHeading = False
|
||||
for m in protocol:
|
||||
if m[0] == '' and firstHeading:
|
||||
print
|
||||
output("")
|
||||
firstHeading = True
|
||||
first = True
|
||||
if len(m[0]) >= MARGIN:
|
||||
|
@ -154,7 +157,7 @@ def printFormatted():
|
|||
first = False
|
||||
else:
|
||||
line = ' ' * (INDENT + MARGIN) + n.strip()
|
||||
print line
|
||||
output(line)
|
||||
|
||||
|
||||
##
|
||||
|
@ -165,20 +168,20 @@ def printWiki():
|
|||
for m in protocol:
|
||||
if m[0] == '':
|
||||
if firstHeading:
|
||||
print "|}"
|
||||
output("|}")
|
||||
__printWikiHeader(m[1], m[2])
|
||||
firstHeading = True
|
||||
else:
|
||||
print "|-"
|
||||
print "| <span style=\"white-space:nowrap;\"><tt>" + m[0] + "</tt></span> || || " + m[1]
|
||||
print "|}"
|
||||
output("|-")
|
||||
output("| <span style=\"white-space:nowrap;\"><tt>" + m[0] + "</tt></span> || || " + m[1])
|
||||
output("|}")
|
||||
|
||||
|
||||
def __printWikiHeader(section, desc):
|
||||
print
|
||||
print "=== " + section + " ==="
|
||||
print
|
||||
print desc
|
||||
print
|
||||
print "{|"
|
||||
print "| '''Command''' || || '''Description'''"
|
||||
output("")
|
||||
output("=== " + section + " ===")
|
||||
output("")
|
||||
output(desc)
|
||||
output("")
|
||||
output("{|")
|
||||
output("| '''Command''' || || '''Description'''")
|
||||
|
|
|
@ -24,6 +24,7 @@ __author__ = "Cyril Jaquier"
|
|||
__copyright__ = "Copyright (c) 2004 Cyril Jaquier"
|
||||
__license__ = "GPL"
|
||||
|
||||
import threading
|
||||
from threading import Lock, RLock
|
||||
import logging
|
||||
import logging.handlers
|
||||
|
@ -42,6 +43,10 @@ from ..helpers import getLogger, excepthook
|
|||
# Gets the instance of the logger.
|
||||
logSys = getLogger(__name__)
|
||||
|
||||
DEF_SYSLOGSOCKET = "auto"
|
||||
DEF_LOGLEVEL = "INFO"
|
||||
DEF_LOGTARGET = "STDOUT"
|
||||
|
||||
try:
|
||||
from .database import Fail2BanDb
|
||||
except ImportError: # pragma: no cover
|
||||
|
@ -49,6 +54,10 @@ except ImportError: # pragma: no cover
|
|||
Fail2BanDb = None
|
||||
|
||||
|
||||
def _thread_name():
|
||||
return threading.current_thread().__class__.__name__
|
||||
|
||||
|
||||
class Server:
|
||||
|
||||
def __init__(self, daemon = False):
|
||||
|
@ -67,11 +76,7 @@ class Server:
|
|||
'FreeBSD': '/var/run/log',
|
||||
'Linux': '/dev/log',
|
||||
}
|
||||
# todo: remove that, if test cases are fixed
|
||||
self.setSyslogSocket("auto")
|
||||
# Set logging level
|
||||
self.setLogLevel("INFO")
|
||||
self.setLogTarget("STDOUT")
|
||||
self.__prev_signals = {}
|
||||
|
||||
def __sigTERMhandler(self, signum, frame):
|
||||
logSys.debug("Caught signal %d. Exiting" % signum)
|
||||
|
@ -93,9 +98,12 @@ class Server:
|
|||
raise ServerInitializationError("Could not create daemon")
|
||||
|
||||
# Set all logging parameters (or use default if not specified):
|
||||
self.setSyslogSocket(conf.get("syslogsocket", self.__syslogSocket))
|
||||
self.setLogLevel(conf.get("loglevel", self.__logLevel))
|
||||
self.setLogTarget(conf.get("logtarget", self.__logTarget))
|
||||
self.setSyslogSocket(conf.get("syslogsocket",
|
||||
self.__syslogSocket if self.__syslogSocket is not None else DEF_SYSLOGSOCKET))
|
||||
self.setLogLevel(conf.get("loglevel",
|
||||
self.__logLevel if self.__logLevel is not None else DEF_LOGLEVEL))
|
||||
self.setLogTarget(conf.get("logtarget",
|
||||
self.__logTarget if self.__logTarget is not None else DEF_LOGTARGET))
|
||||
|
||||
logSys.info("-"*50)
|
||||
logSys.info("Starting Fail2ban v%s", version.version)
|
||||
|
@ -104,10 +112,10 @@ class Server:
|
|||
logSys.info("Daemon started")
|
||||
|
||||
# Install signal handlers
|
||||
signal.signal(signal.SIGTERM, self.__sigTERMhandler)
|
||||
signal.signal(signal.SIGINT, self.__sigTERMhandler)
|
||||
signal.signal(signal.SIGUSR1, self.__sigUSR1handler)
|
||||
|
||||
if _thread_name() == '_MainThread':
|
||||
for s in (signal.SIGTERM, signal.SIGINT, signal.SIGUSR1):
|
||||
self.__prev_signals[s] = signal.getsignal(s)
|
||||
signal.signal(s, self.__sigTERMhandler if s != signal.SIGUSR1 else self.__sigUSR1handler)
|
||||
# Ensure unhandled exceptions are logged
|
||||
sys.excepthook = excepthook
|
||||
|
||||
|
@ -150,6 +158,10 @@ class Server:
|
|||
with self.__loggingLock:
|
||||
logging.shutdown()
|
||||
|
||||
# Restore default signal handlers:
|
||||
for s, sh in self.__prev_signals.iteritems():
|
||||
signal.signal(s, sh)
|
||||
|
||||
def addJail(self, name, backend):
|
||||
self.__jails.add(name, backend, self.__db)
|
||||
if self.__db is not None:
|
||||
|
@ -395,8 +407,13 @@ class Server:
|
|||
|
||||
def setLogTarget(self, target):
|
||||
with self.__loggingLock:
|
||||
# don't set new handlers if already the same
|
||||
# or if "INHERITED" (foreground worker of the test cases, to prevent stop logging):
|
||||
if self.__logTarget == target:
|
||||
return True
|
||||
if target == "INHERITED":
|
||||
self.__logTarget = target
|
||||
return True
|
||||
# set a format which is simpler for console use
|
||||
formatter = logging.Formatter("%(asctime)s %(name)-24s[%(process)d]: %(levelname)-7s %(message)s")
|
||||
if target == "SYSLOG":
|
||||
|
@ -539,7 +556,10 @@ class Server:
|
|||
# We need to set this in the parent process, so it gets inherited by the
|
||||
# child process, and this makes sure that it is effect even if the parent
|
||||
# terminates quickly.
|
||||
signal.signal(signal.SIGHUP, signal.SIG_IGN)
|
||||
if _thread_name() == '_MainThread':
|
||||
for s in (signal.SIGHUP,):
|
||||
self.__prev_signals[s] = signal.getsignal(s)
|
||||
signal.signal(s, signal.SIG_IGN)
|
||||
|
||||
try:
|
||||
# Fork a child process so the parent can exit. This will return control
|
||||
|
|
|
@ -125,6 +125,7 @@ class Utils():
|
|||
timeout_expr = lambda: time.time() - stime <= timeout
|
||||
else:
|
||||
timeout_expr = timeout
|
||||
popen = None
|
||||
try:
|
||||
popen = subprocess.Popen(
|
||||
realCmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=shell,
|
||||
|
@ -151,7 +152,10 @@ class Utils():
|
|||
if retcode is None and not Utils.pid_exists(pgid):
|
||||
retcode = signal.SIGKILL
|
||||
except OSError as e:
|
||||
logSys.error("%s -- failed with %s" % (realCmd, e))
|
||||
stderr = "%s -- failed with %s" % (realCmd, e)
|
||||
logSys.error(stderr)
|
||||
if not popen:
|
||||
return False if not output else (False, stdout, stderr, retcode)
|
||||
|
||||
std_level = retcode == 0 and logging.DEBUG or logging.ERROR
|
||||
# if we need output (to return or to log it):
|
||||
|
|
|
@ -0,0 +1,411 @@
|
|||
# emacs: -*- mode: python; py-indent-offset: 4; indent-tabs-mode: t -*-
|
||||
# vi: set ft=python sts=4 ts=4 sw=4 noet :
|
||||
|
||||
# This file is part of Fail2Ban.
|
||||
#
|
||||
# Fail2Ban is free software; you can redistribute it and/or modify
|
||||
# it under the terms of the GNU General Public License as published by
|
||||
# the Free Software Foundation; either version 2 of the License, or
|
||||
# (at your option) any later version.
|
||||
#
|
||||
# Fail2Ban is distributed in the hope that it will be useful,
|
||||
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
# GNU General Public License for more details.
|
||||
#
|
||||
# You should have received a copy of the GNU General Public License
|
||||
# along with Fail2Ban; if not, write to the Free Software
|
||||
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
|
||||
|
||||
# Fail2Ban developers
|
||||
|
||||
__author__ = "Serg Brester"
|
||||
__copyright__ = "Copyright (c) 2014- Serg G. Brester (sebres), 2008- Fail2Ban Contributors"
|
||||
__license__ = "GPL"
|
||||
|
||||
import fileinput
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
import unittest
|
||||
|
||||
from threading import Thread
|
||||
|
||||
from ..client import fail2banclient, fail2banserver, fail2bancmdline
|
||||
from ..client.fail2banclient import Fail2banClient, exec_command_line as _exec_client, VisualWait
|
||||
from ..client.fail2banserver import Fail2banServer, exec_command_line as _exec_server
|
||||
from .. import protocol
|
||||
from ..server import server
|
||||
from ..server.utils import Utils
|
||||
from .utils import LogCaptureTestCase, logSys, withtmpdir, shutil, logging
|
||||
|
||||
|
||||
STOCK_CONF_DIR = "config"
|
||||
STOCK = os.path.exists(os.path.join(STOCK_CONF_DIR,'fail2ban.conf'))
|
||||
TEST_CONF_DIR = os.path.join(os.path.dirname(__file__), "config")
|
||||
if STOCK:
|
||||
CONF_DIR = STOCK_CONF_DIR
|
||||
else:
|
||||
CONF_DIR = TEST_CONF_DIR
|
||||
|
||||
CLIENT = "fail2ban-client"
|
||||
SERVER = "fail2ban-server"
|
||||
BIN = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))), "bin")
|
||||
|
||||
MAX_WAITTIME = 10
|
||||
MAX_WAITTIME = unittest.F2B.maxWaitTime(MAX_WAITTIME)
|
||||
|
||||
##
|
||||
# Several wrappers and settings for proper testing:
|
||||
#
|
||||
|
||||
fail2banclient.MAX_WAITTIME = \
|
||||
fail2banserver.MAX_WAITTIME = MAX_WAITTIME
|
||||
|
||||
|
||||
fail2bancmdline.logSys = \
|
||||
fail2banclient.logSys = \
|
||||
fail2banserver.logSys = logSys
|
||||
|
||||
LOG_LEVEL = logSys.level
|
||||
|
||||
server.DEF_LOGTARGET = "/dev/null"
|
||||
|
||||
def _test_output(*args):
|
||||
logSys.info(args[0])
|
||||
fail2bancmdline.output = \
|
||||
fail2banclient.output = \
|
||||
fail2banserver.output = \
|
||||
protocol.output = _test_output
|
||||
|
||||
def _test_exit(code=0):
|
||||
logSys.debug("Exit with code %s", code)
|
||||
if code == 0:
|
||||
raise ExitException()
|
||||
else:
|
||||
raise FailExitException()
|
||||
fail2bancmdline.exit = \
|
||||
fail2banclient.exit = \
|
||||
fail2banserver.exit = _test_exit
|
||||
|
||||
INTERACT = []
|
||||
def _test_raw_input(*args):
|
||||
if len(INTERACT):
|
||||
#print('--- interact command: ', INTERACT[0])
|
||||
return INTERACT.pop(0)
|
||||
else:
|
||||
return "exit"
|
||||
fail2banclient.raw_input = _test_raw_input
|
||||
|
||||
# prevents change logging params, log capturing, etc:
|
||||
fail2bancmdline.PRODUCTION = False
|
||||
|
||||
|
||||
class ExitException(fail2bancmdline.ExitException):
|
||||
pass
|
||||
class FailExitException(fail2bancmdline.ExitException):
|
||||
pass
|
||||
|
||||
|
||||
def _out_file(fn): # pragma: no cover
|
||||
logSys.debug('---- ' + fn + ' ----')
|
||||
for line in fileinput.input(fn):
|
||||
line = line.rstrip('\n')
|
||||
logSys.debug(line)
|
||||
logSys.debug('-'*30)
|
||||
|
||||
def _start_params(tmp, use_stock=False, logtarget="/dev/null"):
|
||||
cfg = tmp+"/config"
|
||||
if use_stock and STOCK:
|
||||
# copy config:
|
||||
def ig_dirs(dir, files):
|
||||
return [f for f in files if not os.path.isfile(os.path.join(dir, f))]
|
||||
shutil.copytree(STOCK_CONF_DIR, cfg, ignore=ig_dirs)
|
||||
os.symlink(STOCK_CONF_DIR+"/action.d", cfg+"/action.d")
|
||||
os.symlink(STOCK_CONF_DIR+"/filter.d", cfg+"/filter.d")
|
||||
# replace fail2ban params (database with memory):
|
||||
r = re.compile(r'^dbfile\s*=')
|
||||
for line in fileinput.input(cfg+"/fail2ban.conf", inplace=True):
|
||||
line = line.rstrip('\n')
|
||||
if r.match(line):
|
||||
line = "dbfile = :memory:"
|
||||
print(line)
|
||||
# replace jail params (polling as backend to be fast in initialize):
|
||||
r = re.compile(r'^backend\s*=')
|
||||
for line in fileinput.input(cfg+"/jail.conf", inplace=True):
|
||||
line = line.rstrip('\n')
|
||||
if r.match(line):
|
||||
line = "backend = polling"
|
||||
print(line)
|
||||
else:
|
||||
# just empty config directory without anything (only fail2ban.conf/jail.conf):
|
||||
os.mkdir(cfg)
|
||||
f = open(cfg+"/fail2ban.conf", "wb")
|
||||
f.write('\n'.join((
|
||||
"[Definition]",
|
||||
"loglevel = INFO",
|
||||
"logtarget = " + logtarget,
|
||||
"syslogsocket = auto",
|
||||
"socket = "+tmp+"/f2b.sock",
|
||||
"pidfile = "+tmp+"/f2b.pid",
|
||||
"backend = polling",
|
||||
"dbfile = :memory:",
|
||||
"dbpurgeage = 1d",
|
||||
"",
|
||||
)))
|
||||
f.close()
|
||||
f = open(cfg+"/jail.conf", "wb")
|
||||
f.write('\n'.join((
|
||||
"[INCLUDES]", "",
|
||||
"[DEFAULT]", "",
|
||||
"",
|
||||
)))
|
||||
f.close()
|
||||
if LOG_LEVEL < logging.DEBUG: # if HEAVYDEBUG
|
||||
_out_file(cfg+"/fail2ban.conf")
|
||||
_out_file(cfg+"/jail.conf")
|
||||
# parameters:
|
||||
return ("-c", cfg,
|
||||
"--logtarget", logtarget, "--loglevel", "DEBUG", "--syslogsocket", "auto",
|
||||
"-s", tmp+"/f2b.sock", "-p", tmp+"/f2b.pid")
|
||||
|
||||
|
||||
class Fail2banClientTest(LogCaptureTestCase):
|
||||
|
||||
def setUp(self):
|
||||
"""Call before every test case."""
|
||||
LogCaptureTestCase.setUp(self)
|
||||
|
||||
def tearDown(self):
|
||||
"""Call after every test case."""
|
||||
LogCaptureTestCase.tearDown(self)
|
||||
|
||||
def testClientUsage(self):
|
||||
self.assertRaises(ExitException, _exec_client,
|
||||
(CLIENT, "-h",))
|
||||
self.assertLogged("Usage: " + CLIENT)
|
||||
self.assertLogged("Report bugs to ")
|
||||
|
||||
@withtmpdir
|
||||
def testClientStartBackgroundInside(self, tmp):
|
||||
# always add "--async" by start inside, should don't fork by async (not replace client with server, just start in new process)
|
||||
# (we can't fork the test cases process):
|
||||
startparams = _start_params(tmp, True)
|
||||
# start:
|
||||
self.assertRaises(ExitException, _exec_client,
|
||||
(CLIENT, "--async", "-b") + startparams + ("start",))
|
||||
self.assertLogged("Server ready")
|
||||
self.assertLogged("Exit with code 0")
|
||||
try:
|
||||
self.assertRaises(ExitException, _exec_client,
|
||||
(CLIENT,) + startparams + ("echo", "TEST-ECHO",))
|
||||
self.assertRaises(FailExitException, _exec_client,
|
||||
(CLIENT,) + startparams + ("~~unknown~cmd~failed~~",))
|
||||
finally:
|
||||
self.pruneLog()
|
||||
# stop:
|
||||
self.assertRaises(ExitException, _exec_client,
|
||||
(CLIENT,) + startparams + ("stop",))
|
||||
self.assertLogged("Shutdown successful")
|
||||
self.assertLogged("Exit with code 0")
|
||||
|
||||
@withtmpdir
|
||||
def testClientStartBackgroundCall(self, tmp):
|
||||
global INTERACT
|
||||
startparams = _start_params(tmp)
|
||||
# start (without async in new process):
|
||||
cmd = os.path.join(os.path.join(BIN), CLIENT)
|
||||
logSys.debug('Start %s ...', cmd)
|
||||
Utils.executeCmd((cmd,) + startparams + ("start",),
|
||||
timeout=MAX_WAITTIME, shell=False, output=False)
|
||||
self.pruneLog()
|
||||
try:
|
||||
# echo from client (inside):
|
||||
self.assertRaises(ExitException, _exec_client,
|
||||
(CLIENT,) + startparams + ("echo", "TEST-ECHO",))
|
||||
self.assertLogged("TEST-ECHO")
|
||||
self.assertLogged("Exit with code 0")
|
||||
self.pruneLog()
|
||||
# interactive client chat with started server:
|
||||
INTERACT += [
|
||||
"echo INTERACT-ECHO",
|
||||
"status",
|
||||
"exit"
|
||||
]
|
||||
self.assertRaises(ExitException, _exec_client,
|
||||
(CLIENT,) + startparams + ("-i",))
|
||||
self.assertLogged("INTERACT-ECHO")
|
||||
self.assertLogged("Status", "Number of jail:")
|
||||
self.assertLogged("Exit with code 0")
|
||||
self.pruneLog()
|
||||
# test reload and restart over interactive client:
|
||||
INTERACT += [
|
||||
"reload",
|
||||
"restart",
|
||||
"exit"
|
||||
]
|
||||
self.assertRaises(ExitException, _exec_client,
|
||||
(CLIENT,) + startparams + ("-i",))
|
||||
self.assertLogged("Reading config files:")
|
||||
self.assertLogged("Shutdown successful")
|
||||
self.assertLogged("Server ready")
|
||||
self.assertLogged("Exit with code 0")
|
||||
self.pruneLog()
|
||||
finally:
|
||||
self.pruneLog()
|
||||
# stop:
|
||||
self.assertRaises(ExitException, _exec_client,
|
||||
(CLIENT,) + startparams + ("stop",))
|
||||
self.assertLogged("Shutdown successful")
|
||||
self.assertLogged("Exit with code 0")
|
||||
|
||||
def _testClientStartForeground(self, tmp, startparams, phase):
|
||||
# start and wait to end (foreground):
|
||||
phase['start'] = True
|
||||
self.assertRaises(ExitException, _exec_client,
|
||||
(CLIENT, "-f") + startparams + ("start",))
|
||||
# end :
|
||||
phase['end'] = True
|
||||
|
||||
@withtmpdir
|
||||
def testClientStartForeground(self, tmp):
|
||||
# started directly here, so prevent overwrite test cases logger with "INHERITED"
|
||||
startparams = _start_params(tmp, logtarget="INHERITED")
|
||||
# because foreground block execution - start it in thread:
|
||||
phase = dict()
|
||||
Thread(name="_TestCaseWorker",
|
||||
target=Fail2banClientTest._testClientStartForeground, args=(self, tmp, startparams, phase)).start()
|
||||
try:
|
||||
# wait for start thread:
|
||||
Utils.wait_for(lambda: phase.get('start', None) is not None, MAX_WAITTIME)
|
||||
self.assertTrue(phase.get('start', None))
|
||||
# wait for server (socket):
|
||||
Utils.wait_for(lambda: os.path.exists(tmp+"/f2b.sock"), MAX_WAITTIME)
|
||||
self.assertLogged("Starting communication")
|
||||
self.assertRaises(ExitException, _exec_client,
|
||||
(CLIENT,) + startparams + ("ping",))
|
||||
self.assertRaises(FailExitException, _exec_client,
|
||||
(CLIENT,) + startparams + ("~~unknown~cmd~failed~~",))
|
||||
self.assertRaises(ExitException, _exec_client,
|
||||
(CLIENT,) + startparams + ("echo", "TEST-ECHO",))
|
||||
finally:
|
||||
self.pruneLog()
|
||||
# stop:
|
||||
self.assertRaises(ExitException, _exec_client,
|
||||
(CLIENT,) + startparams + ("stop",))
|
||||
# wait for end:
|
||||
Utils.wait_for(lambda: phase.get('end', None) is not None, MAX_WAITTIME)
|
||||
self.assertTrue(phase.get('end', None))
|
||||
self.assertLogged("Shutdown successful", "Exiting Fail2ban")
|
||||
self.assertLogged("Exit with code 0")
|
||||
|
||||
@withtmpdir
|
||||
def testClientFailStart(self, tmp):
|
||||
self.assertRaises(FailExitException, _exec_client,
|
||||
(CLIENT, "--async", "-c", tmp+"/miss", "start",))
|
||||
self.assertLogged("Base configuration directory " + tmp+"/miss" + " does not exist")
|
||||
|
||||
self.assertRaises(FailExitException, _exec_client,
|
||||
(CLIENT, "--async", "-c", CONF_DIR, "-s", tmp+"/miss/f2b.sock", "start",))
|
||||
self.assertLogged("There is no directory " + tmp+"/miss" + " to contain the socket file")
|
||||
|
||||
def testVisualWait(self):
|
||||
sleeptime = 0.035
|
||||
for verbose in (2, 0):
|
||||
cntr = 15
|
||||
with VisualWait(verbose, 5) as vis:
|
||||
while cntr:
|
||||
vis.heartbeat()
|
||||
if verbose and not unittest.F2B.fast:
|
||||
time.sleep(sleeptime)
|
||||
cntr -= 1
|
||||
|
||||
|
||||
class Fail2banServerTest(LogCaptureTestCase):
|
||||
|
||||
def setUp(self):
|
||||
"""Call before every test case."""
|
||||
LogCaptureTestCase.setUp(self)
|
||||
|
||||
def tearDown(self):
|
||||
"""Call after every test case."""
|
||||
LogCaptureTestCase.tearDown(self)
|
||||
|
||||
def testServerUsage(self):
|
||||
self.assertRaises(ExitException, _exec_server,
|
||||
(SERVER, "-h",))
|
||||
self.assertLogged("Usage: " + SERVER)
|
||||
self.assertLogged("Report bugs to ")
|
||||
|
||||
@withtmpdir
|
||||
def testServerStartBackground(self, tmp):
|
||||
# don't add "--async" by start, because if will fork current process by daemonize
|
||||
# (we can't fork the test cases process),
|
||||
# because server started internal communication in new thread use INHERITED as logtarget here:
|
||||
startparams = _start_params(tmp, logtarget="INHERITED")
|
||||
# start:
|
||||
self.assertRaises(ExitException, _exec_server,
|
||||
(SERVER, "-b") + startparams)
|
||||
self.assertLogged("Server ready")
|
||||
self.assertLogged("Exit with code 0")
|
||||
try:
|
||||
self.assertRaises(ExitException, _exec_server,
|
||||
(SERVER,) + startparams + ("echo", "TEST-ECHO",))
|
||||
self.assertRaises(FailExitException, _exec_server,
|
||||
(SERVER,) + startparams + ("~~unknown~cmd~failed~~",))
|
||||
finally:
|
||||
self.pruneLog()
|
||||
# stop:
|
||||
self.assertRaises(ExitException, _exec_server,
|
||||
(SERVER,) + startparams + ("stop",))
|
||||
self.assertLogged("Shutdown successful")
|
||||
self.assertLogged("Exit with code 0")
|
||||
|
||||
def _testServerStartForeground(self, tmp, startparams, phase):
|
||||
# start and wait to end (foreground):
|
||||
phase['start'] = True
|
||||
self.assertRaises(ExitException, _exec_server,
|
||||
(SERVER, "-f") + startparams + ("start",))
|
||||
# end :
|
||||
phase['end'] = True
|
||||
@withtmpdir
|
||||
def testServerStartForeground(self, tmp):
|
||||
# started directly here, so prevent overwrite test cases logger with "INHERITED"
|
||||
startparams = _start_params(tmp, logtarget="INHERITED")
|
||||
# because foreground block execution - start it in thread:
|
||||
phase = dict()
|
||||
Thread(name="_TestCaseWorker",
|
||||
target=Fail2banServerTest._testServerStartForeground, args=(self, tmp, startparams, phase)).start()
|
||||
try:
|
||||
# wait for start thread:
|
||||
Utils.wait_for(lambda: phase.get('start', None) is not None, MAX_WAITTIME)
|
||||
self.assertTrue(phase.get('start', None))
|
||||
# wait for server (socket):
|
||||
Utils.wait_for(lambda: os.path.exists(tmp+"/f2b.sock"), MAX_WAITTIME)
|
||||
self.assertLogged("Starting communication")
|
||||
self.assertRaises(ExitException, _exec_server,
|
||||
(SERVER,) + startparams + ("ping",))
|
||||
self.assertRaises(FailExitException, _exec_server,
|
||||
(SERVER,) + startparams + ("~~unknown~cmd~failed~~",))
|
||||
self.assertRaises(ExitException, _exec_server,
|
||||
(SERVER,) + startparams + ("echo", "TEST-ECHO",))
|
||||
finally:
|
||||
self.pruneLog()
|
||||
# stop:
|
||||
self.assertRaises(ExitException, _exec_server,
|
||||
(SERVER,) + startparams + ("stop",))
|
||||
# wait for end:
|
||||
Utils.wait_for(lambda: phase.get('end', None) is not None, MAX_WAITTIME)
|
||||
self.assertTrue(phase.get('end', None))
|
||||
self.assertLogged("Shutdown successful", "Exiting Fail2ban")
|
||||
self.assertLogged("Exit with code 0")
|
||||
|
||||
@withtmpdir
|
||||
def testServerFailStart(self, tmp):
|
||||
self.assertRaises(FailExitException, _exec_server,
|
||||
(SERVER, "-c", tmp+"/miss",))
|
||||
self.assertLogged("Base configuration directory " + tmp+"/miss" + " does not exist")
|
||||
|
||||
self.assertRaises(FailExitException, _exec_server,
|
||||
(SERVER, "-c", CONF_DIR, "-s", tmp+"/miss/f2b.sock",))
|
||||
self.assertLogged("There is no directory " + tmp+"/miss" + " to contain the socket file")
|
|
@ -23,19 +23,7 @@ __author__ = "Serg Brester"
|
|||
__copyright__ = "Copyright (c) 2015 Serg G. Brester (sebres), 2008- Fail2Ban Contributors"
|
||||
__license__ = "GPL"
|
||||
|
||||
from __builtin__ import open as fopen
|
||||
import unittest
|
||||
import getpass
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import tempfile
|
||||
import uuid
|
||||
|
||||
try:
|
||||
from systemd import journal
|
||||
except ImportError:
|
||||
journal = None
|
||||
|
||||
from ..client import fail2banregex
|
||||
from ..client.fail2banregex import Fail2banRegex, get_opt_parser, output
|
||||
|
|
|
@ -26,10 +26,14 @@ import logging
|
|||
import optparse
|
||||
import os
|
||||
import re
|
||||
import tempfile
|
||||
import shutil
|
||||
import sys
|
||||
import time
|
||||
import unittest
|
||||
|
||||
from StringIO import StringIO
|
||||
from functools import wraps
|
||||
|
||||
from ..helpers import getLogger
|
||||
from ..server.filter import DNSUtils
|
||||
|
@ -71,6 +75,17 @@ class F2B(optparse.Values):
|
|||
return wtime
|
||||
|
||||
|
||||
def withtmpdir(f):
|
||||
@wraps(f)
|
||||
def wrapper(self, *args, **kwargs):
|
||||
tmp = tempfile.mkdtemp(prefix="f2b-temp")
|
||||
try:
|
||||
return f(self, tmp, *args, **kwargs)
|
||||
finally:
|
||||
# clean up
|
||||
shutil.rmtree(tmp)
|
||||
return wrapper
|
||||
|
||||
def initTests(opts):
|
||||
unittest.F2B = F2B(opts)
|
||||
# --fast :
|
||||
|
@ -145,6 +160,7 @@ def gatherTests(regexps=None, opts=None):
|
|||
from . import misctestcase
|
||||
from . import databasetestcase
|
||||
from . import samplestestcase
|
||||
from . import fail2banclienttestcase
|
||||
from . import fail2banregextestcase
|
||||
|
||||
if not regexps: # pragma: no cover
|
||||
|
@ -223,6 +239,9 @@ def gatherTests(regexps=None, opts=None):
|
|||
# Filter Regex tests with sample logs
|
||||
tests.addTest(unittest.makeSuite(samplestestcase.FilterSamplesRegex))
|
||||
|
||||
# bin/fail2ban-client, bin/fail2ban-server
|
||||
tests.addTest(unittest.makeSuite(fail2banclienttestcase.Fail2banClientTest))
|
||||
tests.addTest(unittest.makeSuite(fail2banclienttestcase.Fail2banServerTest))
|
||||
# bin/fail2ban-regex
|
||||
tests.addTest(unittest.makeSuite(fail2banregextestcase.Fail2banRegexTest))
|
||||
|
||||
|
@ -293,8 +312,11 @@ class LogCaptureTestCase(unittest.TestCase):
|
|||
# Let's log everything into a string
|
||||
self._log = StringIO()
|
||||
logSys.handlers = [logging.StreamHandler(self._log)]
|
||||
if self._old_level <= logging.DEBUG:
|
||||
print("")
|
||||
if self._old_level < logging.DEBUG: # so if HEAVYDEBUG etc -- show them!
|
||||
logSys.handlers += self._old_handlers
|
||||
logSys.debug('--'*40)
|
||||
logSys.setLevel(getattr(logging, 'DEBUG'))
|
||||
|
||||
def tearDown(self):
|
||||
|
@ -340,6 +362,9 @@ class LogCaptureTestCase(unittest.TestCase):
|
|||
raise AssertionError("All of the %r were found present in the log: %r" % (s, logged))
|
||||
|
||||
|
||||
def pruneLog(self):
|
||||
self._log.truncate(0)
|
||||
|
||||
def getLog(self):
|
||||
return self._log.getvalue()
|
||||
|
||||
|
|
Loading…
Reference in New Issue