- performance of fail2ban optimized

-- cache dnsToIp, ipToName to prevent long wait during retrieving of ip/name for wrong dns or lazy dns-system;
   -- instead of simple "sleep" used conditional wait "wait_for", that internal increases sleep interval up to sleeptime;
   -- ticket / banmanager / failmanager modules are performance optimized;
   -- api of filter (log files), jail, etc. rewritten and extended for performance purposes;
- performance of test cases optimized:
   -- added option "--fast" to decrease wait intervals, avoid passive waiting, and skip few very slow test cases;
- code review after partially cherry pick of branch 'ban-time-incr' (see gh-716)
   -- ticket module prepared to easy merge with newest version of 'ban-time-incr', now additionally holds banTime, banCount and json-data;
   -- executeCmd partially moved from action to new module utils, etc.
   -- python 2.6 compatibility;
- testExecuteTimeoutWithNastyChildren: test case repaired - wait for pid file inside bash, kill tree in any case (gh-1155);
- testSocket: test case repaired - wait for server thread starts a socket (listener)
pull/1346/head
sebres 2015-07-15 14:58:00 +02:00
parent 3540619a73
commit 59bf5013c0
29 changed files with 768 additions and 390 deletions

View File

@ -180,7 +180,6 @@ fail2ban/server/banmanager.py
fail2ban/server/database.py fail2ban/server/database.py
fail2ban/server/datedetector.py fail2ban/server/datedetector.py
fail2ban/server/datetemplate.py fail2ban/server/datetemplate.py
fail2ban/server/faildata.py
fail2ban/server/failmanager.py fail2ban/server/failmanager.py
fail2ban/server/failregex.py fail2ban/server/failregex.py
fail2ban/server/filter.py fail2ban/server/filter.py
@ -197,6 +196,7 @@ fail2ban/server/server.py
fail2ban/server/strptime.py fail2ban/server/strptime.py
fail2ban/server/ticket.py fail2ban/server/ticket.py
fail2ban/server/transmitter.py fail2ban/server/transmitter.py
fail2ban/server/utils.py
fail2ban/tests/__init__.py fail2ban/tests/__init__.py
fail2ban/tests/action_d/__init__.py fail2ban/tests/action_d/__init__.py
fail2ban/tests/action_d/test_badips.py fail2ban/tests/action_d/test_badips.py

View File

@ -58,6 +58,9 @@ def get_opt_parser():
Option('-n', "--no-network", action="store_true", Option('-n', "--no-network", action="store_true",
dest="no_network", dest="no_network",
help="Do not run tests that require the network"), help="Do not run tests that require the network"),
Option('-f', "--fast", action="store_true",
dest="fast",
help="Try to increase speed of the tests, decreasing of wait intervals, memory database"),
Option("-t", "--log-traceback", action='store_true', Option("-t", "--log-traceback", action='store_true',
help="Enrich log-messages with compressed tracebacks"), help="Enrich log-messages with compressed tracebacks"),
Option("--full-traceback", action='store_true', Option("--full-traceback", action='store_true',
@ -120,7 +123,7 @@ if not opts.log_level or opts.log_level != 'critical': # pragma: no cover
print("Fail2ban %s test suite. Python %s. Please wait..." \ print("Fail2ban %s test suite. Python %s. Please wait..." \
% (version, str(sys.version).replace('\n', ''))) % (version, str(sys.version).replace('\n', '')))
tests = gatherTests(regexps, opts.no_network) tests = gatherTests(regexps, opts)
# #
# Run the tests # Run the tests
# #

View File

@ -10,7 +10,6 @@ fail2ban.server package
fail2ban.server.database fail2ban.server.database
fail2ban.server.datedetector fail2ban.server.datedetector
fail2ban.server.datetemplate fail2ban.server.datetemplate
fail2ban.server.faildata
fail2ban.server.failmanager fail2ban.server.failmanager
fail2ban.server.failregex fail2ban.server.failregex
fail2ban.server.filter fail2ban.server.filter
@ -26,3 +25,4 @@ fail2ban.server package
fail2ban.server.strptime fail2ban.server.strptime
fail2ban.server.ticket fail2ban.server.ticket
fail2ban.server.transmitter fail2ban.server.transmitter
fail2ban.server.utils

View File

@ -1,7 +1,7 @@
fail2ban.server.faildata module fail2ban.server.utils module
=============================== ===============================
.. automodule:: fail2ban.server.faildata .. automodule:: fail2ban.server.utils
:members: :members:
:undoc-members: :undoc-members:
:show-inheritance: :show-inheritance:

View File

@ -32,6 +32,7 @@ import time
from abc import ABCMeta from abc import ABCMeta
from collections import MutableMapping from collections import MutableMapping
from .utils import Utils
from ..helpers import getLogger from ..helpers import getLogger
# Gets the instance of the logger. # Gets the instance of the logger.
@ -40,21 +41,6 @@ logSys = getLogger(__name__)
# Create a lock for running system commands # Create a lock for running system commands
_cmd_lock = threading.Lock() _cmd_lock = threading.Lock()
# Some hints on common abnormal exit codes
_RETCODE_HINTS = {
127: '"Command not found". Make sure that all commands in %(realCmd)r '
'are in the PATH of fail2ban-server process '
'(grep -a PATH= /proc/`pidof -x fail2ban-server`/environ). '
'You may want to start '
'"fail2ban-server -f" separately, initiate it with '
'"fail2ban-client reload" in another shell session and observe if '
'additional informative error messages appear in the terminals.'
}
# Dictionary to lookup signal name from number
signame = dict((num, name)
for name, num in signal.__dict__.iteritems() if name.startswith("SIG"))
class CallingMap(MutableMapping): class CallingMap(MutableMapping):
"""A Mapping type which returns the result of callable values. """A Mapping type which returns the result of callable values.
@ -561,61 +547,6 @@ class CommandAction(ActionBase):
_cmd_lock.acquire() _cmd_lock.acquire()
try: try:
retcode = None # to guarantee being defined upon early except return Utils.executeCmd(realCmd, timeout, shell=True, output=False)
stdout = tempfile.TemporaryFile(suffix=".stdout", prefix="fai2ban_")
stderr = tempfile.TemporaryFile(suffix=".stderr", prefix="fai2ban_")
popen = subprocess.Popen(
realCmd, stdout=stdout, stderr=stderr, shell=True,
preexec_fn=os.setsid # so that killpg does not kill our process
)
stime = time.time()
retcode = popen.poll()
while time.time() - stime <= timeout and retcode is None:
time.sleep(0.1)
retcode = popen.poll()
if retcode is None:
logSys.error("%s -- timed out after %i seconds." %
(realCmd, timeout))
pgid = os.getpgid(popen.pid)
os.killpg(pgid, signal.SIGTERM) # Terminate the process
time.sleep(0.1)
retcode = popen.poll()
if retcode is None: # Still going...
os.killpg(pgid, signal.SIGKILL) # Kill the process
time.sleep(0.1)
retcode = popen.poll()
except OSError as e:
logSys.error("%s -- failed with %s" % (realCmd, e))
finally: finally:
_cmd_lock.release() _cmd_lock.release()
std_level = retcode == 0 and logging.DEBUG or logging.ERROR
if std_level >= logSys.getEffectiveLevel():
stdout.seek(0); msg = stdout.read()
if msg != '':
logSys.log(std_level, "%s -- stdout: %r", realCmd, msg)
stderr.seek(0); msg = stderr.read()
if msg != '':
logSys.log(std_level, "%s -- stderr: %r", realCmd, msg)
stdout.close()
stderr.close()
if retcode == 0:
logSys.debug("%s -- returned successfully" % realCmd)
return True
elif retcode is None:
logSys.error("%s -- unable to kill PID %i" % (realCmd, popen.pid))
elif retcode < 0 or retcode > 128:
# dash would return negative while bash 128 + n
sigcode = -retcode if retcode < 0 else retcode - 128
logSys.error("%s -- killed with %s (return code: %s)" %
(realCmd, signame.get(sigcode, "signal %i" % sigcode), retcode))
else:
msg = _RETCODE_HINTS.get(retcode, None)
logSys.error("%s -- returned %i" % (realCmd, retcode))
if msg:
logSys.info("HINT on %i: %s"
% (retcode, msg % locals()))
return False

View File

@ -42,6 +42,7 @@ from .banmanager import BanManager
from .jailthread import JailThread from .jailthread import JailThread
from .action import ActionBase, CommandAction, CallingMap from .action import ActionBase, CommandAction, CallingMap
from .mytime import MyTime from .mytime import MyTime
from .utils import Utils
from ..helpers import getLogger from ..helpers import getLogger
# Gets the instance of the logger. # Gets the instance of the logger.
@ -225,14 +226,11 @@ class Actions(JailThread, Mapping):
self._jail.name, name, e, self._jail.name, name, e,
exc_info=logSys.getEffectiveLevel()<=logging.DEBUG) exc_info=logSys.getEffectiveLevel()<=logging.DEBUG)
while self.active: while self.active:
if not self.idle: if self.idle:
#logSys.debug(self._jail.name + ": action") time.sleep(self.sleeptime)
ret = self.__checkBan() continue
if not ret: if not Utils.wait_for(self.__checkBan, self.sleeptime):
self.__checkUnBan() self.__checkUnBan()
time.sleep(self.sleeptime)
else:
time.sleep(self.sleeptime)
self.__flushBan() self.__flushBan()
actions = self._actions.items() actions = self._actions.items()

View File

@ -247,12 +247,10 @@ class BanManager:
@staticmethod @staticmethod
def createBanTicket(ticket): def createBanTicket(ticket):
ip = ticket.getIP() # we should always use correct time to calculate correct end time (ban time is variable now,
#lastTime = ticket.getTime() # + possible double banning by restore from database and from log file)
lastTime = MyTime.time() # so use as lastTime always time from ticket.
banTicket = BanTicket(ip, lastTime, ticket.getMatches()) return BanTicket(ticket=ticket)
banTicket.setAttempt(ticket.getAttempt())
return banTicket
## ##
# Add a ban ticket. # Add a ban ticket.
@ -264,11 +262,25 @@ class BanManager:
def addBanTicket(self, ticket): def addBanTicket(self, ticket):
try: try:
self.__lock.acquire() self.__lock.acquire()
if not self._inBanList(ticket): # check already banned
for i in self.__banList:
if ticket.getIP() == i.getIP():
# if already permanent
btorg, torg = i.getBanTime(self.__banTime), i.getTime()
if btorg == -1:
return False
# if given time is less than already banned time
btnew, tnew = ticket.getBanTime(self.__banTime), ticket.getTime()
if btnew != -1 and tnew + btnew <= torg + btorg:
return False
# we have longest ban - set new (increment) ban time
i.setTime(tnew)
i.setBanTime(btnew)
return False
# not yet banned - add new
self.__banList.append(ticket) self.__banList.append(ticket)
self.__banTotal += 1 self.__banTotal += 1
return True return True
return False
finally: finally:
self.__lock.release() self.__lock.release()
@ -313,8 +325,7 @@ class BanManager:
return list() return list()
# Gets the list of ticket to remove. # Gets the list of ticket to remove.
unBanList = [ticket for ticket in self.__banList unBanList = [ticket for ticket in self.__banList if ticket.isTimedOut(time, self.__banTime)]
if ticket.getTime() < time - self.__banTime]
# Removes tickets. # Removes tickets.
self.__banList = [ticket for ticket in self.__banList self.__banList = [ticket for ticket in self.__banList

View File

@ -418,8 +418,7 @@ class Fail2BanDb(object):
cur.execute( cur.execute(
"INSERT INTO bans(jail, ip, timeofban, data) VALUES(?, ?, ?, ?)", "INSERT INTO bans(jail, ip, timeofban, data) VALUES(?, ?, ?, ?)",
(jail.name, ticket.getIP(), int(round(ticket.getTime())), (jail.name, ticket.getIP(), int(round(ticket.getTime())),
{"matches": ticket.getMatches(), ticket.getData()))
"failures": ticket.getAttempt()}))
@commitandrollback @commitandrollback
def delBan(self, cur, jail, ip): def delBan(self, cur, jail, ip):
@ -477,8 +476,8 @@ class Fail2BanDb(object):
tickets = [] tickets = []
for ip, timeofban, data in self._getBans(**kwargs): for ip, timeofban, data in self._getBans(**kwargs):
#TODO: Implement data parts once arbitrary match keys completed #TODO: Implement data parts once arbitrary match keys completed
tickets.append(FailTicket(ip, timeofban, data.get('matches'))) tickets.append(FailTicket(ip, timeofban))
tickets[-1].setAttempt(data.get('failures', 1)) tickets[-1].setData(data)
return tickets return tickets
def getBansMerged(self, ip=None, jail=None, bantime=None): def getBansMerged(self, ip=None, jail=None, bantime=None):
@ -520,6 +519,7 @@ class Fail2BanDb(object):
prev_banip = results[0][0] prev_banip = results[0][0]
matches = [] matches = []
failures = 0 failures = 0
tickdata = {}
for banip, timeofban, data in results: for banip, timeofban, data in results:
#TODO: Implement data parts once arbitrary match keys completed #TODO: Implement data parts once arbitrary match keys completed
if banip != prev_banip: if banip != prev_banip:
@ -530,11 +530,14 @@ class Fail2BanDb(object):
prev_banip = banip prev_banip = banip
matches = [] matches = []
failures = 0 failures = 0
matches.extend(data.get('matches', [])) tickdata = {}
matches.extend(data.get('matches', ()))
failures += data.get('failures', 1) failures += data.get('failures', 1)
tickdata.update(data.get('data', {}))
prev_timeofban = timeofban prev_timeofban = timeofban
ticket = FailTicket(banip, prev_timeofban, matches) ticket = FailTicket(banip, prev_timeofban, matches)
ticket.setAttempt(failures) ticket.setAttempt(failures)
ticket.setData(**tickdata)
tickets.append(ticket) tickets.append(ticket)
if cacheKey: if cacheKey:

View File

@ -1,71 +0,0 @@
# emacs: -*- mode: python; py-indent-offset: 4; indent-tabs-mode: t -*-
# vi: set ft=python sts=4 ts=4 sw=4 noet :
# This file is part of Fail2Ban.
#
# Fail2Ban is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 2 of the License, or
# (at your option) any later version.
#
# Fail2Ban is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with Fail2Ban; if not, write to the Free Software
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
# Author: Cyril Jaquier
#
__author__ = "Cyril Jaquier"
__copyright__ = "Copyright (c) 2004 Cyril Jaquier"
__license__ = "GPL"
from ..helpers import getLogger
# Gets the instance of the logger.
logSys = getLogger(__name__)
class FailData:
def __init__(self):
self.__retry = 0
self.__lastTime = 0
self.__lastReset = 0
self.__matches = []
def setRetry(self, value):
self.__retry = value
# keep only the last matches or reset entirely
# Explicit if/else for compatibility with Python 2.4
if value:
self.__matches = self.__matches[-min(len(self.__matches, value)):]
else:
self.__matches = []
def getRetry(self):
return self.__retry
def getMatches(self):
return self.__matches
def inc(self, matches=None):
self.__retry += 1
self.__matches += matches or []
def setLastTime(self, value):
if value > self.__lastTime:
self.__lastTime = value
def getLastTime(self):
return self.__lastTime
def getLastReset(self):
return self.__lastReset
def setLastReset(self, value):
self.__lastReset = value

View File

@ -27,7 +27,6 @@ __license__ = "GPL"
from threading import Lock from threading import Lock
import logging import logging
from .faildata import FailData
from .ticket import FailTicket from .ticket import FailTicket
from ..helpers import getLogger from ..helpers import getLogger
@ -86,26 +85,35 @@ class FailManager:
finally: finally:
self.__lock.release() self.__lock.release()
def addFailure(self, ticket): def addFailure(self, ticket, count=1):
attempts = 1
try: try:
self.__lock.acquire() self.__lock.acquire()
ip = ticket.getIP() ip = ticket.getIP()
unixTime = ticket.getTime()
matches = ticket.getMatches()
if ip in self.__failList: if ip in self.__failList:
fData = self.__failList[ip] fData = self.__failList[ip]
# if the same object:
if fData is ticket:
matches = None
else:
matches = ticket.getMatches()
unixTime = ticket.getTime()
if fData.getLastReset() < unixTime - self.__maxTime: if fData.getLastReset() < unixTime - self.__maxTime:
fData.setLastReset(unixTime) fData.setLastReset(unixTime)
fData.setRetry(0) fData.setRetry(0)
fData.inc(matches) fData.inc(matches, 1, count)
fData.setLastTime(unixTime) fData.setLastTime(unixTime)
else: else:
fData = FailData() # if already FailTicket - add it direct, otherwise create (using copy all ticket data):
fData.inc(matches) if isinstance(ticket, FailTicket):
fData.setLastReset(unixTime) fData = ticket;
fData.setLastTime(unixTime) else:
fData = FailTicket(ticket=ticket)
if count > ticket.getAttempt():
fData.setRetry(count)
self.__failList[ip] = fData self.__failList[ip] = fData
attempts = fData.getRetry()
self.__failTotal += 1 self.__failTotal += 1
if logSys.getEffectiveLevel() <= logging.DEBUG: if logSys.getEffectiveLevel() <= logging.DEBUG:
@ -118,6 +126,7 @@ class FailManager:
% (self.__failTotal, len(self.__failList), failures_summary)) % (self.__failTotal, len(self.__failList), failures_summary))
finally: finally:
self.__lock.release() self.__lock.release()
return attempts
def size(self): def size(self):
try: try:
@ -140,17 +149,14 @@ class FailManager:
if ip in self.__failList: if ip in self.__failList:
del self.__failList[ip] del self.__failList[ip]
def toBan(self): def toBan(self, ip=None):
try: try:
self.__lock.acquire() self.__lock.acquire()
for ip in self.__failList: for ip in ([ip] if ip != None and ip in self.__failList else self.__failList):
data = self.__failList[ip] data = self.__failList[ip]
if data.getRetry() >= self.__maxRetry: if data.getRetry() >= self.__maxRetry:
self.__delFailure(ip) del self.__failList[ip]
# Create a FailTicket from BanData return data
failTicket = FailTicket(ip, data.getLastTime(), data.getMatches())
failTicket.setAttempt(data.getRetry())
return failTicket
raise FailManagerEmpty raise FailManagerEmpty
finally: finally:
self.__lock.release() self.__lock.release()

View File

@ -22,6 +22,7 @@ __copyright__ = "Copyright (c) 2004 Cyril Jaquier, 2011-2013 Yaroslav Halchenko"
__license__ = "GPL" __license__ = "GPL"
import codecs import codecs
import datetime
import fcntl import fcntl
import locale import locale
import logging import logging
@ -316,13 +317,12 @@ class Filter(JailThread):
logSys.warning('Requested to manually ban an ignored IP %s. User knows best. Proceeding to ban it.' % ip) logSys.warning('Requested to manually ban an ignored IP %s. User knows best. Proceeding to ban it.' % ip)
unixTime = MyTime.time() unixTime = MyTime.time()
for i in xrange(self.failManager.getMaxRetry()): self.failManager.addFailure(FailTicket(ip, unixTime), self.failManager.getMaxRetry())
self.failManager.addFailure(FailTicket(ip, unixTime))
# Perform the banning of the IP now. # Perform the banning of the IP now.
try: # pragma: no branch - exception is the only way out try: # pragma: no branch - exception is the only way out
while True: while True:
ticket = self.failManager.toBan() ticket = self.failManager.toBan(ip)
self.jail.putFailTicket(ticket) self.jail.putFailTicket(ticket)
except FailManagerEmpty: except FailManagerEmpty:
self.failManager.cleanup(MyTime.time()) self.failManager.cleanup(MyTime.time())
@ -427,17 +427,19 @@ class Filter(JailThread):
ip = element[1] ip = element[1]
unixTime = element[2] unixTime = element[2]
lines = element[3] lines = element[3]
logSys.debug("Processing line with time:%s and ip:%s" logSys.debug("Processing line with time:%s and ip:%s",
% (unixTime, ip)) unixTime, ip)
if unixTime < MyTime.time() - self.getFindTime(): if unixTime < MyTime.time() - self.getFindTime():
logSys.debug("Ignore line since time %s < %s - %s" logSys.debug("Ignore line since time %s < %s - %s",
% (unixTime, MyTime.time(), self.getFindTime())) unixTime, MyTime.time(), self.getFindTime())
break break
if self.inIgnoreIPList(ip, log_ignore=True): if self.inIgnoreIPList(ip, log_ignore=True):
continue continue
logSys.info("[%s] Found %s" % (self.jail.name, ip)) logSys.info(
## print "D: Adding a ticket for %s" % ((ip, unixTime, [line]),) "[%s] Found %s - %s", self.jail.name, ip, datetime.datetime.fromtimestamp(unixTime).strftime("%Y-%m-%d %H:%M:%S")
self.failManager.addFailure(FailTicket(ip, unixTime, lines)) )
tick = FailTicket(ip, unixTime, lines)
self.failManager.addFailure(tick)
## ##
# Returns true if the line should be ignored. # Returns true if the line should be ignored.
@ -606,6 +608,14 @@ class FileFilter(Filter):
# to be overridden by backends # to be overridden by backends
pass pass
##
# Get the log file names
#
# @return log paths
def getLogPaths(self):
return self.__logs.keys()
## ##
# Get the log containers # Get the log containers
# #
@ -614,6 +624,14 @@ class FileFilter(Filter):
def getLogs(self): def getLogs(self):
return self.__logs.values() return self.__logs.values()
##
# Get the count of log containers
#
# @return count of log containers
def getLogCount(self):
return len(self.__logs)
## ##
# Check whether path is already monitored. # Check whether path is already monitored.
# #
@ -941,32 +959,50 @@ class JournalFilter(Filter): # pragma: systemd no cover
import socket import socket
import struct import struct
from .utils import Utils
class DNSUtils: class DNSUtils:
IP_CRE = re.compile("^(?:\d{1,3}\.){3}\d{1,3}$") IP_CRE = re.compile("^(?:\d{1,3}\.){3}\d{1,3}$")
# todo: make configurable the expired time and max count of cache entries:
CACHE_dnsToIp = Utils.Cache(maxCount=1000, maxTime=60*60)
CACHE_ipToName = Utils.Cache(maxCount=1000, maxTime=60*60)
@staticmethod @staticmethod
def dnsToIp(dns): def dnsToIp(dns):
""" Convert a DNS into an IP address using the Python socket module. """ Convert a DNS into an IP address using the Python socket module.
Thanks to Kevin Drapel. Thanks to Kevin Drapel.
""" """
# cache, also prevent long wait during retrieving of ip for wrong dns or lazy dns-system:
v = DNSUtils.CACHE_dnsToIp.get(dns)
if v is not None:
return v
# retrieve ip (todo: use AF_INET6 for IPv6) # retrieve ip (todo: use AF_INET6 for IPv6)
try: try:
return set([i[4][0] for i in socket.getaddrinfo(dns, None, socket.AF_INET, 0, socket.IPPROTO_TCP)]) v = set([i[4][0] for i in socket.getaddrinfo(dns, None, socket.AF_INET, 0, socket.IPPROTO_TCP)])
except socket.error, e: except socket.error, e:
logSys.warning("Unable to find a corresponding IP address for %s: %s" # todo: make configurable the expired time of cache entry:
% (dns, e)) logSys.warning("Unable to find a corresponding IP address for %s: %s", dns, e)
return list() v = list()
DNSUtils.CACHE_dnsToIp.set(dns, v)
return v
@staticmethod @staticmethod
def ipToName(ip): def ipToName(ip):
# cache, also prevent long wait during retrieving of name for wrong addresses, lazy dns:
v = DNSUtils.CACHE_ipToName.get(ip)
if v is not None:
return v
# retrieve name
try: try:
return socket.gethostbyaddr(ip)[0] v = socket.gethostbyaddr(ip)[0]
except socket.error, e: except socket.error, e:
logSys.debug("Unable to find a name for the IP %s: %s" % (ip, e)) logSys.debug("Unable to find a name for the IP %s: %s", ip, e)
return None v = None
DNSUtils.CACHE_ipToName.set(ip, v)
return v
@staticmethod @staticmethod
def searchIP(text): def searchIP(text):

View File

@ -31,6 +31,7 @@ import gamin
from .failmanager import FailManagerEmpty from .failmanager import FailManagerEmpty
from .filter import FileFilter from .filter import FileFilter
from .mytime import MyTime from .mytime import MyTime
from .utils import Utils
from ..helpers import getLogger from ..helpers import getLogger
# Gets the instance of the logger. # Gets the instance of the logger.
@ -102,6 +103,15 @@ class FilterGamin(FileFilter):
def _delLogPath(self, path): def _delLogPath(self, path):
self.monitor.stop_watch(path) self.monitor.stop_watch(path)
def _handleEvents(self):
ret = False
mon = self.monitor
while mon and mon.event_pending():
mon.handle_events()
mon = self.monitor
ret = True
return ret
## ##
# Main loop. # Main loop.
# #
@ -112,12 +122,10 @@ class FilterGamin(FileFilter):
def run(self): def run(self):
# Gamin needs a loop to collect and dispatch events # Gamin needs a loop to collect and dispatch events
while self.active: while self.active:
if not self.idle: if self.idle:
# We cannot block here because we want to be able to
# exit.
if self.monitor.event_pending():
self.monitor.handle_events()
time.sleep(self.sleeptime) time.sleep(self.sleeptime)
continue
Utils.wait_for(self._handleEvents, self.sleeptime)
logSys.debug(self.jail.name + ": filter terminated") logSys.debug(self.jail.name + ": filter terminated")
return True return True
@ -129,6 +137,6 @@ class FilterGamin(FileFilter):
# Desallocates the resources used by Gamin. # Desallocates the resources used by Gamin.
def __cleanup(self): def __cleanup(self):
for log in self.getLogs(): for filename in self.getLogPaths():
self.monitor.stop_watch(log.getFileName()) self.monitor.stop_watch(filename)
del self.monitor self.monitor = None

View File

@ -31,6 +31,7 @@ from .failmanager import FailManagerEmpty
from .filter import FileFilter from .filter import FileFilter
from .mytime import MyTime from .mytime import MyTime
from ..helpers import getLogger from ..helpers import getLogger
from ..server.utils import Utils
# Gets the instance of the logger. # Gets the instance of the logger.
logSys = getLogger(__name__) logSys = getLogger(__name__)
@ -78,6 +79,15 @@ class FilterPoll(FileFilter):
del self.__prevStats[path] del self.__prevStats[path]
del self.__file404Cnt[path] del self.__file404Cnt[path]
##
# Get a modified log path at once
#
def getModified(self, modlst):
for filename in self.getLogPaths():
if self.isModified(filename):
modlst.append(filename)
return modlst
## ##
# Main loop. # Main loop.
# #
@ -89,12 +99,16 @@ class FilterPoll(FileFilter):
while self.active: while self.active:
if logSys.getEffectiveLevel() <= 6: if logSys.getEffectiveLevel() <= 6:
logSys.log(6, "Woke up idle=%s with %d files monitored", logSys.log(6, "Woke up idle=%s with %d files monitored",
self.idle, len(self.getLogs())) self.idle, self.getLogCount())
if not self.idle: if self.idle:
if not Utils.wait_for(lambda: not self.idle,
self.sleeptime * 100, self.sleeptime
):
continue
# Get file modification # Get file modification
for container in self.getLogs(): modlst = []
filename = container.getFileName() Utils.wait_for(lambda: self.getModified(modlst), self.sleeptime)
if self.isModified(filename): for filename in modlst:
# set start time as now - find time for first usage only (prevent performance bug with polling of big files) # set start time as now - find time for first usage only (prevent performance bug with polling of big files)
self.getFailures(filename, self.getFailures(filename,
(MyTime.time() - self.getFindTime()) if not self.__initial.get(filename) else None (MyTime.time() - self.getFindTime()) if not self.__initial.get(filename) else None
@ -111,9 +125,6 @@ class FilterPoll(FileFilter):
self.failManager.cleanup(MyTime.time()) self.failManager.cleanup(MyTime.time())
self.dateDetector.sortTemplate() self.dateDetector.sortTemplate()
self.__modified = False self.__modified = False
time.sleep(self.sleeptime)
else:
time.sleep(self.sleeptime)
logSys.debug( logSys.debug(
(self.jail is not None and self.jail.name or "jailless") + (self.jail is not None and self.jail.name or "jailless") +
" filter terminated") " filter terminated")
@ -129,7 +140,7 @@ class FilterPoll(FileFilter):
try: try:
logStats = os.stat(filename) logStats = os.stat(filename)
stats = logStats.st_mtime, logStats.st_ino, logStats.st_size stats = logStats.st_mtime, logStats.st_ino, logStats.st_size
pstats = self.__prevStats[filename] pstats = self.__prevStats.get(filename, ())
self.__file404Cnt[filename] = 0 self.__file404Cnt[filename] = 0
if logSys.getEffectiveLevel() <= 7: if logSys.getEffectiveLevel() <= 7:
# we do not want to waste time on strftime etc if not necessary # we do not want to waste time on strftime etc if not necessary
@ -139,7 +150,6 @@ class FilterPoll(FileFilter):
# os.system("stat %s | grep Modify" % filename) # os.system("stat %s | grep Modify" % filename)
if pstats == stats: if pstats == stats:
return False return False
else:
logSys.debug("%s has been modified", filename) logSys.debug("%s has been modified", filename)
self.__prevStats[filename] = stats self.__prevStats[filename] = stats
return True return True

View File

@ -28,6 +28,7 @@ import sys
from threading import Thread from threading import Thread
from abc import abstractmethod from abc import abstractmethod
from .utils import Utils
from ..helpers import excepthook from ..helpers import excepthook
@ -48,14 +49,14 @@ class JailThread(Thread):
The time the thread sleeps for in the loop. The time the thread sleeps for in the loop.
""" """
def __init__(self): def __init__(self, name=None):
super(JailThread, self).__init__() super(JailThread, self).__init__(name=name)
## Control the state of the thread. ## Control the state of the thread.
self.active = False self.active = False
## Control the idle state of the thread. ## Control the idle state of the thread.
self.idle = False self.idle = False
## The time the thread sleeps in the loop. ## The time the thread sleeps in the loop.
self.sleeptime = 1 self.sleeptime = Utils.DEFAULT_SLEEP_TIME
# excepthook workaround for threads, derived from: # excepthook workaround for threads, derived from:
# http://bugs.python.org/issue1230540#msg91244 # http://bugs.python.org/issue1230540#msg91244

View File

@ -211,8 +211,7 @@ class Server:
def getLogPath(self, name): def getLogPath(self, name):
filter_ = self.__jails[name].filter filter_ = self.__jails[name].filter
if isinstance(filter_, FileFilter): if isinstance(filter_, FileFilter):
return [m.getFileName() return filter_.getLogPaths()
for m in filter_.getLogs()]
else: # pragma: systemd no cover else: # pragma: systemd no cover
logSys.info("Jail %s is not a FileFilter instance" % name) logSys.info("Jail %s is not a FileFilter instance" % name)
return [] return []
@ -324,6 +323,15 @@ class Server:
def getBanTime(self, name): def getBanTime(self, name):
return self.__jails[name].actions.getBanTime() return self.__jails[name].actions.getBanTime()
def is_alive(self, jailnum=None):
if jailnum is not None and len(self.__jails) != jailnum:
return 0
for j in self.__jails:
j = self.__jails[j]
if not j.is_alive():
return 0
return 1
# Status # Status
def status(self): def status(self):
try: try:

View File

@ -24,7 +24,10 @@ __author__ = "Cyril Jaquier"
__copyright__ = "Copyright (c) 2004 Cyril Jaquier" __copyright__ = "Copyright (c) 2004 Cyril Jaquier"
__license__ = "GPL" __license__ = "GPL"
import sys
from ..helpers import getLogger from ..helpers import getLogger
from .mytime import MyTime
# Gets the instance of the logger. # Gets the instance of the logger.
logSys = getLogger(__name__) logSys = getLogger(__name__)
@ -32,7 +35,7 @@ logSys = getLogger(__name__)
class Ticket: class Ticket:
def __init__(self, ip, time, matches=None): def __init__(self, ip=None, time=None, matches=None, ticket=None):
"""Ticket constructor """Ticket constructor
@param ip the IP address @param ip the IP address
@ -41,14 +44,21 @@ class Ticket:
""" """
self.setIP(ip) self.setIP(ip)
self.__time = time self._flags = 0;
self.__attempt = 0 self._banCount = 0;
self.__file = None self._banTime = None;
self.__matches = matches or [] self._time = time if time is not None else MyTime.time()
self._data = {'matches': [], 'failures': 0}
if ticket:
# ticket available - copy whole information from ticket:
self.__dict__.update(i for i in ticket.__dict__.iteritems() if i[0] in self.__dict__)
else:
self._data['matches'] = matches or []
def __str__(self): def __str__(self):
return "%s: ip=%s time=%s #attempts=%d matches=%r" % \ return "%s: ip=%s time=%s #attempts=%d matches=%r" % \
(self.__class__.__name__.split('.')[-1], self.__ip, self.__time, self.__attempt, self.__matches) (self.__class__.__name__.split('.')[-1], self.__ip, self._time,
self._data['failures'], self._data.get('matches', []))
def __repr__(self): def __repr__(self):
return str(self) return str(self)
@ -56,9 +66,8 @@ class Ticket:
def __eq__(self, other): def __eq__(self, other):
try: try:
return self.__ip == other.__ip and \ return self.__ip == other.__ip and \
round(self.__time, 2) == round(other.__time, 2) and \ round(self._time, 2) == round(other._time, 2) and \
self.__attempt == other.__attempt and \ self._data == other._data
self.__matches == other.__matches
except AttributeError: except AttributeError:
return False return False
@ -72,24 +81,128 @@ class Ticket:
return self.__ip return self.__ip
def setTime(self, value): def setTime(self, value):
self.__time = value self._time = value
def getTime(self): def getTime(self):
return self.__time return self._time
def setBanTime(self, value):
self._banTime = value;
def getBanTime(self, defaultBT = None):
return (self._banTime if not self._banTime is None else defaultBT);
def setBanCount(self, value):
self._banCount = value;
def incrBanCount(self, value = 1):
self._banCount += value;
def getBanCount(self):
return self._banCount;
def isTimedOut(self, time, defaultBT = None):
bantime = (self._banTime if not self._banTime is None else defaultBT);
# permanent
if bantime == -1:
return False
# timed out
return (time > self._time + bantime)
def setAttempt(self, value): def setAttempt(self, value):
self.__attempt = value self._data['failures'] = value
def getAttempt(self): def getAttempt(self):
return self.__attempt return self._data['failures']
def setMatches(self, matches):
self._data['matches'] = matches or []
def getMatches(self): def getMatches(self):
return self.__matches return self._data.get('matches', [])
def setData(self, *args, **argv):
# if overwrite - set data and filter None values:
if len(args) == 1:
# todo: if support >= 2.7 only:
# self._data = {k:v for k,v in args[0].iteritems() if v is not None}
self._data = dict([(k,v) for k,v in args[0].iteritems() if v is not None])
# add k,v list or dict (merge):
elif len(args) == 2:
self._data.update((args,))
elif len(args) > 2:
self._data.update((k,v) for k,v in zip(*[iter(args)]*2))
if len(argv):
self._data.update(argv)
# filter (delete) None values:
# todo: if support >= 2.7 only:
# self._data = {k:v for k,v in self._data.iteritems() if v is not None}
self._data = dict([(k,v) for k,v in self._data.iteritems() if v is not None])
def getData(self, key=None, default=None):
# return whole data dict:
if key is None:
return self._data
# return default if not exists:
if not self._data:
return default
# return filtered by lambda/function:
if callable(key):
# todo: if support >= 2.7 only:
# return {k:v for k,v in self._data.iteritems() if key(k)}
return dict([(k,v) for k,v in self._data.iteritems() if key(k)])
# return filtered by keys:
if hasattr(key, '__iter__'):
# todo: if support >= 2.7 only:
# return {k:v for k,v in self._data.iteritems() if k in key}
return dict([(k,v) for k,v in self._data.iteritems() if k in key])
# return single value of data:
return self._data.get(key, default)
class FailTicket(Ticket): class FailTicket(Ticket):
pass
def __init__(self, ip=None, time=None, matches=None, ticket=None):
# this class variables:
self.__retry = 0
self.__lastReset = None
# create/copy using default ticket constructor:
Ticket.__init__(self, ip, time, matches, ticket)
# init:
if ticket is None:
self.__lastReset = time if time is not None else self.getTime()
if not self.__retry:
self.__retry = self._data['failures'];
def setRetry(self, value):
self.__retry = value
if not self._data['failures']:
self._data['failures'] = 1
if not value:
self._data['failures'] = 0
self._data['matches'] = []
def getRetry(self):
return max(self.__retry, self._data['failures'])
def inc(self, matches=None, attempt=1, count=1):
self.__retry += count
self._data['failures'] += attempt
if matches:
self._data['matches'] += matches
def setLastTime(self, value):
if value > self._time:
self._time = value
def getLastTime(self):
return self._time
def getLastReset(self):
return self.__lastReset
def setLastReset(self, value):
self.__lastReset = value
## ##
# Ban Ticket. # Ban Ticket.

View File

@ -95,7 +95,7 @@ class Transmitter:
return None return None
elif command[0] == "sleep": elif command[0] == "sleep":
value = command[1] value = command[1]
time.sleep(int(value)) time.sleep(float(value))
return None return None
elif command[0] == "flushlogs": elif command[0] == "flushlogs":
return self.__server.flushLogs() return self.__server.flushLogs()

242
fail2ban/server/utils.py Normal file
View File

@ -0,0 +1,242 @@
# emacs: -*- mode: python; py-indent-offset: 4; indent-tabs-mode: t -*-
# vi: set ft=python sts=4 ts=4 sw=4 noet :
# This file is part of Fail2Ban.
#
# Fail2Ban is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 2 of the License, or
# (at your option) any later version.
#
# Fail2Ban is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with Fail2Ban; if not, write to the Free Software
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
__author__ = "Serg G. Brester (sebres) and Fail2Ban Contributors"
__copyright__ = "Copyright (c) 2004 Cyril Jaquier, 2011-2012 Yaroslav Halchenko, 2012-2015 Serg G. Brester"
__license__ = "GPL"
import logging, os, fcntl, subprocess, time, signal
from ..helpers import getLogger
# Gets the instance of the logger.
logSys = getLogger(__name__)
# Some hints on common abnormal exit codes
_RETCODE_HINTS = {
127: '"Command not found". Make sure that all commands in %(realCmd)r '
'are in the PATH of fail2ban-server process '
'(grep -a PATH= /proc/`pidof -x fail2ban-server`/environ). '
'You may want to start '
'"fail2ban-server -f" separately, initiate it with '
'"fail2ban-client reload" in another shell session and observe if '
'additional informative error messages appear in the terminals.'
}
# Dictionary to lookup signal name from number
signame = dict((num, name)
for name, num in signal.__dict__.iteritems() if name.startswith("SIG"))
class Utils():
"""Utilities provide diverse static methods like executes OS shell commands, etc.
"""
DEFAULT_SLEEP_TIME = 0.1
DEFAULT_SLEEP_INTERVAL = 0.01
class Cache(dict):
def __init__(self, maxCount=1000, maxTime=60*60):
self.maxCount = maxCount
self.maxTime = maxTime
def get(self, k, defv=None):
v = dict.get(self, k)
if v:
if v[1] > time.time():
return v[0]
del self[k]
return defv
def set(self, k, v):
t = time.time()
# clean cache if max count reached:
if len(self) >= self.maxCount:
for (ck,cv) in self.items():
if cv[1] < t:
del self[ck]
# if still max count - remove any one:
if len(self) >= self.maxCount:
self.popitem()
self[k] = (v, t + self.maxTime)
@staticmethod
def setFBlockMode(fhandle, value):
flags = fcntl.fcntl(fhandle, fcntl.F_GETFL)
if not value:
flags |= os.O_NONBLOCK
else:
flags &= ~os.O_NONBLOCK
fcntl.fcntl(fhandle, fcntl.F_SETFL, flags)
return flags
@staticmethod
def executeCmd(realCmd, timeout=60, shell=True, output=False, tout_kill_tree=True):
"""Executes a command.
Parameters
----------
realCmd : str
The command to execute.
timeout : int
The time out in seconds for the command.
shell : bool
If shell is True (default), the specified command (may be a string) will be
executed through the shell.
output : bool
If output is True, the function returns tuple (success, stdoutdata, stderrdata, returncode)
Returns
-------
bool
True if the command succeeded.
Raises
------
OSError
If command fails to be executed.
RuntimeError
If command execution times out.
"""
stdout = stderr = None
retcode = None
if not callable(timeout):
stime = time.time()
timeout_expr = lambda: time.time() - stime <= timeout
else:
timeout_expr = timeout
try:
popen = subprocess.Popen(
realCmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=shell,
preexec_fn=os.setsid # so that killpg does not kill our process
)
retcode = popen.poll()
while retcode is None and timeout_expr():
time.sleep(Utils.DEFAULT_SLEEP_INTERVAL)
retcode = popen.poll()
if retcode is None:
logSys.error("%s -- timed out after %s seconds." %
(realCmd, timeout))
pgid = os.getpgid(popen.pid)
# if not tree - first try to terminate and then kill, otherwise - kill (-9) only:
os.killpg(pgid, signal.SIGTERM) # Terminate the process
time.sleep(Utils.DEFAULT_SLEEP_INTERVAL)
retcode = popen.poll()
#logSys.debug("%s -- terminated %s ", realCmd, retcode)
if retcode is None or tout_kill_tree: # Still going...
os.killpg(pgid, signal.SIGKILL) # Kill the process
time.sleep(Utils.DEFAULT_SLEEP_INTERVAL)
retcode = popen.poll()
#logSys.debug("%s -- killed %s ", realCmd, retcode)
if retcode is None and not Utils.pid_exists(pgid):
retcode = signal.SIGKILL
except OSError as e:
logSys.error("%s -- failed with %s" % (realCmd, e))
std_level = retcode == 0 and logging.DEBUG or logging.ERROR
# if we need output (to return or to log it):
if output or std_level >= logSys.getEffectiveLevel():
# if was timeouted (killed/terminated) - to prevent waiting, set std handles to non-blocking mode.
if popen.stdout:
try:
if retcode < 0:
Utils.setFBlockMode(popen.stdout, False)
stdout = popen.stdout.read()
except IOError as e:
logSys.error(" ... -- failed to read stdout %s", e)
if stdout is not None and stdout != '':
logSys.log(std_level, "%s -- stdout: %r", realCmd, stdout)
popen.stdout.close()
if popen.stderr:
try:
if retcode < 0:
Utils.setFBlockMode(popen.stderr, False)
stderr = popen.stderr.read()
except IOError as e:
logSys.error(" ... -- failed to read stderr %s", e)
if stderr is not None and stderr != '':
logSys.log(std_level, "%s -- stderr: %r", realCmd, stderr)
popen.stderr.close()
if retcode == 0:
logSys.debug("%s -- returned successfully", realCmd)
return True if not output else (True, stdout, stderr, retcode)
elif retcode is None:
logSys.error("%s -- unable to kill PID %i" % (realCmd, popen.pid))
elif retcode < 0 or retcode > 128:
# dash would return negative while bash 128 + n
sigcode = -retcode if retcode < 0 else retcode - 128
logSys.error("%s -- killed with %s (return code: %s)" %
(realCmd, signame.get(sigcode, "signal %i" % sigcode), retcode))
else:
msg = _RETCODE_HINTS.get(retcode, None)
logSys.error("%s -- returned %i" % (realCmd, retcode))
if msg:
logSys.info("HINT on %i: %s", retcode, msg % locals())
return False if not output else (False, stdout, stderr, retcode)
@staticmethod
def wait_for(cond, timeout, interval=None):
"""Wait until condition expression `cond` is True, up to `timeout` sec
"""
ini = 1
while True:
ret = cond()
if ret:
return ret
if ini:
ini = stm = 0
time0 = time.time() + timeout
if not interval:
interval = Utils.DEFAULT_SLEEP_INTERVAL
if time.time() > time0:
break
stm = min(stm + interval, Utils.DEFAULT_SLEEP_TIME)
time.sleep(stm)
return ret
# Solution from http://stackoverflow.com/questions/568271/how-to-check-if-there-exists-a-process-with-a-given-pid
# under cc by-sa 3.0
if os.name == 'posix':
@staticmethod
def pid_exists(pid):
"""Check whether pid exists in the current process table."""
import errno
if pid < 0:
return False
try:
os.kill(pid, 0)
except OSError as e:
return e.errno == errno.EPERM
else:
return True
else:
@staticmethod
def pid_exists(pid):
import ctypes
kernel32 = ctypes.windll.kernel32
SYNCHRONIZE = 0x100000
process = kernel32.OpenProcess(SYNCHRONIZE, 0, pid)
if process != 0:
kernel32.CloseHandle(process)
return True
else:
return False

View File

@ -29,6 +29,8 @@ if sys.version_info >= (2,7):
def setUp(self): def setUp(self):
"""Call before every test case.""" """Call before every test case."""
unittest.F2B.SkipIfNoNetwork()
self.jail = DummyJail() self.jail = DummyJail()
self.jail.actions.add("test") self.jail.actions.add("test")

View File

@ -46,6 +46,8 @@ class SMTPActionTest(unittest.TestCase):
def setUp(self): def setUp(self):
"""Call before every test case.""" """Call before every test case."""
unittest.F2B.SkipIfNoNetwork()
self.jail = DummyJail() self.jail = DummyJail()
pythonModule = os.path.join(CONFIG_DIR, "action.d", "smtp.py") pythonModule = os.path.join(CONFIG_DIR, "action.d", "smtp.py")
pythonModuleName = os.path.basename(pythonModule.rstrip(".py")) pythonModuleName = os.path.basename(pythonModule.rstrip(".py"))

View File

@ -30,6 +30,7 @@ import tempfile
from ..server.actions import Actions from ..server.actions import Actions
from ..server.ticket import FailTicket from ..server.ticket import FailTicket
from ..server.utils import Utils
from .dummyjail import DummyJail from .dummyjail import DummyJail
from .utils import LogCaptureTestCase from .utils import LogCaptureTestCase
@ -81,8 +82,7 @@ class ExecuteActions(LogCaptureTestCase):
self.defaultActions() self.defaultActions()
self.__actions.start() self.__actions.start()
with open(self.__tmpfilename) as f: with open(self.__tmpfilename) as f:
time.sleep(3) self.assertTrue( Utils.wait_for(lambda: (f.read() == "ip start 64\n"), 3) )
self.assertEqual(f.read(),"ip start 64\n")
self.__actions.stop() self.__actions.stop()
self.__actions.join() self.__actions.join()
@ -97,8 +97,7 @@ class ExecuteActions(LogCaptureTestCase):
self.assertLogged("TestAction initialised") self.assertLogged("TestAction initialised")
self.__actions.start() self.__actions.start()
time.sleep(3) self.assertTrue( Utils.wait_for(lambda: self._is_logged("TestAction action start"), 3) )
self.assertLogged("TestAction action start")
self.__actions.stop() self.__actions.stop()
self.__actions.join() self.__actions.join()
@ -135,8 +134,7 @@ class ExecuteActions(LogCaptureTestCase):
"action.d/action_errors.py"), "action.d/action_errors.py"),
{}) {})
self.__actions.start() self.__actions.start()
time.sleep(3) self.assertTrue( Utils.wait_for(lambda: self._is_logged("Failed to start"), 3) )
self.assertLogged("Failed to start")
self.__actions.stop() self.__actions.stop()
self.__actions.join() self.__actions.join()
self.assertLogged("Failed to stop") self.assertLogged("Failed to stop")

View File

@ -25,10 +25,12 @@ __copyright__ = "Copyright (c) 2004 Cyril Jaquier"
__license__ = "GPL" __license__ = "GPL"
import os import os
import time
import tempfile import tempfile
import time
import unittest
from ..server.action import CommandAction, CallingMap from ..server.action import CommandAction, CallingMap
from ..server.utils import Utils
from .utils import LogCaptureTestCase from .utils import LogCaptureTestCase
from .utils import pid_exists from .utils import pid_exists
@ -194,16 +196,17 @@ class CommandActionTest(LogCaptureTestCase):
self.assertLogged('HINT on 127: "Command not found"') self.assertLogged('HINT on 127: "Command not found"')
def testExecuteTimeout(self): def testExecuteTimeout(self):
unittest.F2B.SkipIfFast()
stime = time.time() stime = time.time()
# Should take a minute # Should take a minute
self.assertFalse(CommandAction.executeCmd('sleep 60', timeout=2)) self.assertFalse(CommandAction.executeCmd('sleep 30', timeout=1))
# give a test still 1 second, because system could be too busy # give a test still 1 second, because system could be too busy
self.assertTrue(time.time() >= stime + 2 and time.time() <= stime + 3) self.assertTrue(time.time() >= stime + 1 and time.time() <= stime + 2)
self.assertLogged( self.assertLogged(
'sleep 60 -- timed out after 2 seconds', 'sleep 30 -- timed out after 1 seconds',
'sleep 60 -- timed out after 3 seconds' 'sleep 30 -- timed out after 2 seconds'
) )
self.assertLogged('sleep 60 -- killed with SIGTERM') self.assertLogged('sleep 30 -- killed with SIGTERM')
def testExecuteTimeoutWithNastyChildren(self): def testExecuteTimeoutWithNastyChildren(self):
# temporary file for a nasty kid shell script # temporary file for a nasty kid shell script
@ -215,29 +218,53 @@ class CommandActionTest(LogCaptureTestCase):
echo "$$" > %s.pid echo "$$" > %s.pid
echo "my pid $$ . sleeping lo-o-o-ong" echo "my pid $$ . sleeping lo-o-o-ong"
sleep 10000 sleep 30
""" % tmpFilename) """ % tmpFilename)
stime = 0
# timeout as long as pid-file was not created, but max 5 seconds
def getnasty_tout():
return (
getnastypid() is None
and time.time() - stime <= 5
)
def getnastypid(): def getnastypid():
cpid = None
if os.path.isfile(tmpFilename + '.pid'):
with open(tmpFilename + '.pid') as f: with open(tmpFilename + '.pid') as f:
return int(f.read()) try:
cpid = int(f.read())
except ValueError:
pass
return cpid
# First test if can kill the bastard # First test if can kill the bastard
stime = time.time()
self.assertFalse(CommandAction.executeCmd( self.assertFalse(CommandAction.executeCmd(
'bash %s' % tmpFilename, timeout=.1)) 'bash %s' % tmpFilename, timeout=getnasty_tout))
# Wait up to 3 seconds, the child got killed
cpid = getnastypid()
# Verify that the process itself got killed # Verify that the process itself got killed
self.assertFalse(pid_exists(getnastypid())) # process should have been killed self.assertTrue(Utils.wait_for(lambda: not pid_exists(cpid), 3)) # process should have been killed
self.assertLogged('my pid ')
self.assertLogged('timed out') self.assertLogged('timed out')
self.assertLogged('killed with SIGTERM') self.assertLogged('killed with SIGTERM',
'killed with SIGKILL')
os.unlink(tmpFilename + '.pid')
# A bit evolved case even though, previous test already tests killing children processes # A bit evolved case even though, previous test already tests killing children processes
stime = time.time()
self.assertFalse(CommandAction.executeCmd( self.assertFalse(CommandAction.executeCmd(
'out=`bash %s`; echo ALRIGHT' % tmpFilename, timeout=.2)) 'out=`bash %s`; echo ALRIGHT' % tmpFilename, timeout=getnasty_tout))
# Wait up to 3 seconds, the child got killed
cpid = getnastypid()
# Verify that the process itself got killed # Verify that the process itself got killed
self.assertFalse(pid_exists(getnastypid())) self.assertTrue(Utils.wait_for(lambda: not pid_exists(cpid), 3))
self.assertLogged('my pid ')
self.assertLogged('timed out') self.assertLogged('timed out')
self.assertLogged('killed with SIGTERM') self.assertLogged('killed with SIGTERM',
'killed with SIGKILL')
os.unlink(tmpFilename) os.unlink(tmpFilename)
os.unlink(tmpFilename + '.pid') os.unlink(tmpFilename + '.pid')

View File

@ -60,6 +60,7 @@ class AddFailure(unittest.TestCase):
class StatusExtendedCymruInfo(unittest.TestCase): class StatusExtendedCymruInfo(unittest.TestCase):
def setUp(self): def setUp(self):
"""Call before every test case.""" """Call before every test case."""
unittest.F2B.SkipIfNoNetwork()
self.__ban_ip = "93.184.216.34" self.__ban_ip = "93.184.216.34"
self.__asn = "15133" self.__asn = "15133"
self.__country = "EU" self.__country = "EU"

View File

@ -35,7 +35,12 @@ from ..server.ticket import FailTicket
from ..server.actions import Actions from ..server.actions import Actions
from .dummyjail import DummyJail from .dummyjail import DummyJail
try: try:
from ..server.database import Fail2BanDb from ..server.database import Fail2BanDb as Fail2BanDb
# because of tests performance use memory instead of file:
def TestFail2BanDb(filename):
if unittest.F2B.fast:
return Fail2BanDb(':memory:')
return Fail2BanDb(filename)
except ImportError: except ImportError:
Fail2BanDb = None Fail2BanDb = None
from .utils import LogCaptureTestCase from .utils import LogCaptureTestCase
@ -55,7 +60,7 @@ class DatabaseTest(LogCaptureTestCase):
elif Fail2BanDb is None: elif Fail2BanDb is None:
return return
_, self.dbFilename = tempfile.mkstemp(".db", "fail2ban_") _, self.dbFilename = tempfile.mkstemp(".db", "fail2ban_")
self.db = Fail2BanDb(self.dbFilename) self.db = TestFail2BanDb(self.dbFilename)
def tearDown(self): def tearDown(self):
"""Call after every test case.""" """Call after every test case."""
@ -66,7 +71,7 @@ class DatabaseTest(LogCaptureTestCase):
os.remove(self.dbFilename) os.remove(self.dbFilename)
def testGetFilename(self): def testGetFilename(self):
if Fail2BanDb is None: # pragma: no cover if Fail2BanDb is None or self.db.filename == ':memory:': # pragma: no cover
return return
self.assertEqual(self.dbFilename, self.db.filename) self.assertEqual(self.dbFilename, self.db.filename)
@ -88,7 +93,7 @@ class DatabaseTest(LogCaptureTestCase):
"/this/path/should/not/exist") "/this/path/should/not/exist")
def testCreateAndReconnect(self): def testCreateAndReconnect(self):
if Fail2BanDb is None: # pragma: no cover if Fail2BanDb is None or self.db.filename == ':memory:': # pragma: no cover
return return
self.testAddJail() self.testAddJail()
# Reconnect... # Reconnect...

View File

@ -39,28 +39,27 @@ class DummyJail(Jail, object):
self.__actions = Actions(self) self.__actions = Actions(self)
def __len__(self): def __len__(self):
try: with self.lock:
self.lock.acquire()
return len(self.queue) return len(self.queue)
finally:
self.lock.release() def isEmpty(self):
with self.lock:
return not self.queue
def isFilled(self):
with self.lock:
return bool(self.queue)
def putFailTicket(self, ticket): def putFailTicket(self, ticket):
try: with self.lock:
self.lock.acquire()
self.queue.append(ticket) self.queue.append(ticket)
finally:
self.lock.release()
def getFailTicket(self): def getFailTicket(self):
try: with self.lock:
self.lock.acquire()
try: try:
return self.queue.pop() return self.queue.pop()
except IndexError: except IndexError:
return False return False
finally:
self.lock.release()
@property @property
def name(self): def name(self):

View File

@ -41,6 +41,7 @@ from ..server.filterpoll import FilterPoll
from ..server.filter import Filter, FileFilter, FileContainer, DNSUtils from ..server.filter import Filter, FileFilter, FileContainer, DNSUtils
from ..server.failmanager import FailManagerEmpty from ..server.failmanager import FailManagerEmpty
from ..server.mytime import MyTime from ..server.mytime import MyTime
from ..server.utils import Utils
from .utils import setUpMyTime, tearDownMyTime, mtimesleep, LogCaptureTestCase from .utils import setUpMyTime, tearDownMyTime, mtimesleep, LogCaptureTestCase
from .dummyjail import DummyJail from .dummyjail import DummyJail
@ -162,7 +163,7 @@ def _copy_lines_between_files(in_, fout, n=None, skip=0, mode='a', terminal_line
# Opened earlier, therefore must close it # Opened earlier, therefore must close it
fin.close() fin.close()
# to give other threads possibly some time to crunch # to give other threads possibly some time to crunch
time.sleep(0.1) time.sleep(Utils.DEFAULT_SLEEP_INTERVAL)
return fout return fout
@ -299,6 +300,11 @@ class IgnoreIP(LogCaptureTestCase):
class IgnoreIPDNS(IgnoreIP): class IgnoreIPDNS(IgnoreIP):
def setUp(self):
"""Call before every test case."""
unittest.F2B.SkipIfNoNetwork()
IgnoreIP.setUp(self)
def testIgnoreIPDNSOK(self): def testIgnoreIPDNSOK(self):
self.filter.addIgnoreIP("www.epfl.ch") self.filter.addIgnoreIP("www.epfl.ch")
self.assertTrue(self.filter.inIgnoreIPList("128.178.50.12")) self.assertTrue(self.filter.inIgnoreIPList("128.178.50.12"))
@ -425,16 +431,11 @@ class LogFileMonitor(LogCaptureTestCase):
def isModified(self, delay=2.): def isModified(self, delay=2.):
"""Wait up to `delay` sec to assure that it was modified or not """Wait up to `delay` sec to assure that it was modified or not
""" """
time0 = time.time() return Utils.wait_for(lambda: self.filter.isModified(self.name), delay)
while time.time() < time0 + delay:
if self.filter.isModified(self.name):
return True
time.sleep(0.1)
return False
def notModified(self): def notModified(self):
# shorter wait time for not modified status # shorter wait time for not modified status
return not self.isModified(0.4) return not self.isModified(4*Utils.DEFAULT_SLEEP_TIME)
def testUnaccessibleLogFile(self): def testUnaccessibleLogFile(self):
os.chmod(self.name, 0) os.chmod(self.name, 0)
@ -583,26 +584,21 @@ def get_monitor_failures_testcase(Filter_):
#time.sleep(0.2) # Give FS time to ack the removal #time.sleep(0.2) # Give FS time to ack the removal
pass pass
def isFilled(self, delay=2.): def isFilled(self, delay=1.):
"""Wait up to `delay` sec to assure that it was modified or not """Wait up to `delay` sec to assure that it was modified or not
""" """
time0 = time.time() return Utils.wait_for(lambda: self.jail.isFilled(), delay)
while time.time() < time0 + delay:
if len(self.jail):
return True
time.sleep(0.1)
return False
def _sleep_4_poll(self): def _sleep_4_poll(self):
# Since FilterPoll relies on time stamps and some # Since FilterPoll relies on time stamps and some
# actions might be happening too fast in the tests, # actions might be happening too fast in the tests,
# sleep a bit to guarantee reliable time stamps # sleep a bit to guarantee reliable time stamps
if isinstance(self.filter, FilterPoll): if isinstance(self.filter, FilterPoll):
mtimesleep() Utils.wait_for(lambda: self.filter.is_alive(), 4*Utils.DEFAULT_SLEEP_TIME)
def isEmpty(self, delay=0.4): def isEmpty(self, delay=4*Utils.DEFAULT_SLEEP_TIME):
# shorter wait time for not modified status # shorter wait time for not modified status
return not self.isFilled(delay) return Utils.wait_for(lambda: self.jail.isEmpty(), delay)
def assert_correct_last_attempt(self, failures, count=None): def assert_correct_last_attempt(self, failures, count=None):
self.assertTrue(self.isFilled(20)) # give Filter a chance to react self.assertTrue(self.isFilled(20)) # give Filter a chance to react
@ -657,10 +653,11 @@ def get_monitor_failures_testcase(Filter_):
self.file = _copy_lines_between_files(GetFailures.FILENAME_01, self.name, self.file = _copy_lines_between_files(GetFailures.FILENAME_01, self.name,
n=14, mode='w') n=14, mode='w')
# Poll might need more time # Poll might need more time
self.assertTrue(self.isEmpty(4 + int(isinstance(self.filter, FilterPoll))*2), self.assertTrue(self.isEmpty(min(4, 100 * Utils.DEFAULT_SLEEP_TIME)),
"Queue must be empty but it is not: %s." "Queue must be empty but it is not: %s."
% (', '.join([str(x) for x in self.jail.queue]))) % (', '.join([str(x) for x in self.jail.queue])))
self.assertRaises(FailManagerEmpty, self.filter.failManager.toBan) self.assertRaises(FailManagerEmpty, self.filter.failManager.toBan)
Utils.wait_for(lambda: self.filter.failManager.getFailTotal() == 2, 50 * Utils.DEFAULT_SLEEP_TIME)
self.assertEqual(self.filter.failManager.getFailTotal(), 2) self.assertEqual(self.filter.failManager.getFailTotal(), 2)
# move aside, but leaving the handle still open... # move aside, but leaving the handle still open...
@ -685,7 +682,7 @@ def get_monitor_failures_testcase(Filter_):
if interim_kill: if interim_kill:
_killfile(None, self.name) _killfile(None, self.name)
time.sleep(0.2) # let them know time.sleep(Utils.DEFAULT_SLEEP_TIME) # let them know
# now create a new one to override old one # now create a new one to override old one
_copy_lines_between_files(GetFailures.FILENAME_01, self.name + '.new', _copy_lines_between_files(GetFailures.FILENAME_01, self.name + '.new',
@ -732,7 +729,7 @@ def get_monitor_failures_testcase(Filter_):
_copy_lines_between_files(GetFailures.FILENAME_01, self.file, n=100) _copy_lines_between_files(GetFailures.FILENAME_01, self.file, n=100)
# so we should get no more failures detected # so we should get no more failures detected
self.assertTrue(self.isEmpty(2)) self.assertTrue(self.isEmpty(200 * Utils.DEFAULT_SLEEP_TIME))
# but then if we add it back again # but then if we add it back again
self.filter.addLogPath(self.name) self.filter.addLogPath(self.name)
@ -789,19 +786,14 @@ def get_monitor_failures_journal_testcase(Filter_): # pragma: systemd no cover
return "MonitorJournalFailures%s(%s)" \ return "MonitorJournalFailures%s(%s)" \
% (Filter_, hasattr(self, 'name') and self.name or 'tempfile') % (Filter_, hasattr(self, 'name') and self.name or 'tempfile')
def isFilled(self, delay=2.): def isFilled(self, delay=1.):
"""Wait up to `delay` sec to assure that it was modified or not """Wait up to `delay` sec to assure that it was modified or not
""" """
time0 = time.time() return Utils.wait_for(lambda: self.jail.isFilled(), delay)
while time.time() < time0 + delay:
if len(self.jail):
return True
time.sleep(0.1)
return False
def isEmpty(self, delay=0.4): def isEmpty(self, delay=4*Utils.DEFAULT_SLEEP_TIME):
# shorter wait time for not modified status # shorter wait time for not modified status
return not self.isFilled(delay) return Utils.wait_for(lambda: self.jail.isEmpty(), delay)
def assert_correct_ban(self, test_ip, test_attempts): def assert_correct_ban(self, test_ip, test_attempts):
self.assertTrue(self.isFilled(10)) # give Filter a chance to react self.assertTrue(self.isFilled(10)) # give Filter a chance to react
@ -860,7 +852,7 @@ def get_monitor_failures_journal_testcase(Filter_): # pragma: systemd no cover
_copy_lines_to_journal( _copy_lines_to_journal(
self.test_file, self.journal_fields, n=5, skip=5) self.test_file, self.journal_fields, n=5, skip=5)
# so we should get no more failures detected # so we should get no more failures detected
self.assertTrue(self.isEmpty(2)) self.assertTrue(self.isEmpty(200 * Utils.DEFAULT_SLEEP_TIME))
# but then if we add it back again # but then if we add it back again
self.filter.addJournalMatch([ self.filter.addJournalMatch([
@ -905,6 +897,16 @@ class GetFailures(LogCaptureTestCase):
tearDownMyTime() tearDownMyTime()
LogCaptureTestCase.tearDown(self) LogCaptureTestCase.tearDown(self)
def testFilterAPI(self):
self.assertEqual(self.filter.getLogs(), [])
self.assertEqual(self.filter.getLogCount(), 0)
self.filter.addLogPath(GetFailures.FILENAME_01, tail=True)
self.assertEqual(self.filter.getLogCount(), 1)
self.assertEqual(self.filter.getLogPaths(), [GetFailures.FILENAME_01])
self.filter.addLogPath(GetFailures.FILENAME_02, tail=True)
self.assertEqual(self.filter.getLogCount(), 2)
self.assertEqual(sorted(self.filter.getLogPaths()), sorted([GetFailures.FILENAME_01, GetFailures.FILENAME_02]))
def testTail(self): def testTail(self):
# There must be no containters registered, otherwise [-1] indexing would be wrong # There must be no containters registered, otherwise [-1] indexing would be wrong
self.assertEqual(self.filter.getLogs(), []) self.assertEqual(self.filter.getLogs(), [])
@ -1025,6 +1027,7 @@ class GetFailures(LogCaptureTestCase):
_killfile(fout, fname) _killfile(fout, fname)
def testGetFailuresUseDNS(self): def testGetFailuresUseDNS(self):
unittest.F2B.SkipIfNoNetwork()
# We should still catch failures with usedns = no ;-) # We should still catch failures with usedns = no ;-)
output_yes = ('93.184.216.34', 2, 1124013539.0, output_yes = ('93.184.216.34', 2, 1124013539.0,
[u'Aug 14 11:54:59 i60p295 sshd[12365]: Failed publickey for roehl from example.com port 51332 ssh2', [u'Aug 14 11:54:59 i60p295 sshd[12365]: Failed publickey for roehl from example.com port 51332 ssh2',
@ -1126,6 +1129,10 @@ class GetFailures(LogCaptureTestCase):
class DNSUtilsTests(unittest.TestCase): class DNSUtilsTests(unittest.TestCase):
def setUp(self):
"""Call before every test case."""
unittest.F2B.SkipIfNoNetwork()
def testUseDns(self): def testUseDns(self):
res = DNSUtils.textToIp('www.example.com', 'no') res = DNSUtils.textToIp('www.example.com', 'no')
self.assertEqual(res, []) self.assertEqual(res, [])
@ -1151,6 +1158,7 @@ class DNSUtilsTests(unittest.TestCase):
def testIpToName(self): def testIpToName(self):
res = DNSUtils.ipToName('8.8.4.4') res = DNSUtils.ipToName('8.8.4.4')
self.assertEqual(res, 'google-public-dns-b.google.com') self.assertEqual(res, 'google-public-dns-b.google.com')
unittest.F2B.SkipIfNoNetwork()
# invalid ip (TEST-NET-1 according to RFC 5737) # invalid ip (TEST-NET-1 according to RFC 5737)
res = DNSUtils.ipToName('192.0.2.0') res = DNSUtils.ipToName('192.0.2.0')
self.assertEqual(res, None) self.assertEqual(res, None)

View File

@ -36,6 +36,7 @@ from ..server.failregex import Regex, FailRegex, RegexException
from ..server.server import Server from ..server.server import Server
from ..server.jail import Jail from ..server.jail import Jail
from ..server.jailthread import JailThread from ..server.jailthread import JailThread
from ..server.utils import Utils
from .utils import LogCaptureTestCase from .utils import LogCaptureTestCase
from ..helpers import getLogger from ..helpers import getLogger
from .. import version from .. import version
@ -74,14 +75,14 @@ class TransmitterBase(unittest.TestCase):
"""Call after every test case.""" """Call after every test case."""
self.server.quit() self.server.quit()
def setGetTest(self, cmd, inValue, outValue=None, outCode=0, jail=None, repr_=False): def setGetTest(self, cmd, inValue, outValue=(None,), outCode=0, jail=None, repr_=False):
setCmd = ["set", cmd, inValue] setCmd = ["set", cmd, inValue]
getCmd = ["get", cmd] getCmd = ["get", cmd]
if jail is not None: if jail is not None:
setCmd.insert(1, jail) setCmd.insert(1, jail)
getCmd.insert(1, jail) getCmd.insert(1, jail)
if outValue is None: if outValue == (None,):
outValue = inValue outValue = inValue
def v(x): def v(x):
@ -161,15 +162,21 @@ class Transmitter(TransmitterBase):
self.assertEqual(self.transm.proceed(["version"]), (0, version.version)) self.assertEqual(self.transm.proceed(["version"]), (0, version.version))
def testSleep(self): def testSleep(self):
if not unittest.F2B.fast:
t0 = time.time() t0 = time.time()
self.assertEqual(self.transm.proceed(["sleep", "1"]), (0, None)) self.assertEqual(self.transm.proceed(["sleep", "0.1"]), (0, None))
t1 = time.time() t1 = time.time()
# Approx 1 second delay but not faster # Approx 0.1 second delay but not faster
dt = t1 - t0 dt = t1 - t0
self.assertTrue(0.99 < dt < 1.1, msg="Sleep was %g sec" % dt) self.assertTrue(0.09 < dt < 0.2, msg="Sleep was %g sec" % dt)
else: # pragma: no cover
self.assertEqual(self.transm.proceed(["sleep", "0.0001"]), (0, None))
def testDatabase(self): def testDatabase(self):
if not unittest.F2B.fast:
tmp, tmpFilename = tempfile.mkstemp(".db", "fail2ban_") tmp, tmpFilename = tempfile.mkstemp(".db", "fail2ban_")
else: # pragma: no cover
tmpFilename = ':memory:'
# Jails present, can't change database # Jails present, can't change database
self.setGetTestNOK("dbfile", tmpFilename) self.setGetTestNOK("dbfile", tmpFilename)
self.server.delJail(self.jailName) self.server.delJail(self.jailName)
@ -201,6 +208,7 @@ class Transmitter(TransmitterBase):
self.assertEqual(self.transm.proceed( self.assertEqual(self.transm.proceed(
["set", "dbfile", "None"]), ["set", "dbfile", "None"]),
(0, None)) (0, None))
if not unittest.F2B.fast:
os.close(tmp) os.close(tmp)
os.unlink(tmpFilename) os.unlink(tmpFilename)
@ -225,7 +233,11 @@ class Transmitter(TransmitterBase):
def testStartStopJail(self): def testStartStopJail(self):
self.assertEqual( self.assertEqual(
self.transm.proceed(["start", self.jailName]), (0, None)) self.transm.proceed(["start", self.jailName]), (0, None))
time.sleep(1) time.sleep(Utils.DEFAULT_SLEEP_TIME)
# wait until not started (3 seconds as long as any RuntimeError, ex.: RuntimeError('cannot join thread before it is started',)):
self.assertTrue( Utils.wait_for(
lambda: self.server.is_alive(1) and not isinstance(self.transm.proceed(["status", self.jailName]), RuntimeError),
3) )
self.assertEqual( self.assertEqual(
self.transm.proceed(["stop", self.jailName]), (0, None)) self.transm.proceed(["stop", self.jailName]), (0, None))
self.assertTrue(self.jailName not in self.server._Server__jails) self.assertTrue(self.jailName not in self.server._Server__jails)
@ -239,9 +251,12 @@ class Transmitter(TransmitterBase):
# yoh: workaround for gh-146. I still think that there is some # yoh: workaround for gh-146. I still think that there is some
# race condition and missing locking somewhere, but for now # race condition and missing locking somewhere, but for now
# giving it a small delay reliably helps to proceed with tests # giving it a small delay reliably helps to proceed with tests
time.sleep(0.1) time.sleep(Utils.DEFAULT_SLEEP_TIME)
self.assertTrue( Utils.wait_for(
lambda: self.server.is_alive(2) and not isinstance(self.transm.proceed(["status", self.jailName]), RuntimeError),
3) )
self.assertEqual(self.transm.proceed(["stop", "all"]), (0, None)) self.assertEqual(self.transm.proceed(["stop", "all"]), (0, None))
time.sleep(1) self.assertTrue( Utils.wait_for( lambda: not len(self.server._Server__jails), 3) )
self.assertTrue(self.jailName not in self.server._Server__jails) self.assertTrue(self.jailName not in self.server._Server__jails)
self.assertTrue("TestJail2" not in self.server._Server__jails) self.assertTrue("TestJail2" not in self.server._Server__jails)
@ -297,11 +312,11 @@ class Transmitter(TransmitterBase):
self.assertEqual( self.assertEqual(
self.transm.proceed(["set", self.jailName, "banip", "127.0.0.1"]), self.transm.proceed(["set", self.jailName, "banip", "127.0.0.1"]),
(0, "127.0.0.1")) (0, "127.0.0.1"))
time.sleep(1) # Give chance to ban time.sleep(Utils.DEFAULT_SLEEP_TIME) # Give chance to ban
self.assertEqual( self.assertEqual(
self.transm.proceed(["set", self.jailName, "banip", "Badger"]), self.transm.proceed(["set", self.jailName, "banip", "Badger"]),
(0, "Badger")) #NOTE: Is IP address validated? Is DNS Lookup done? (0, "Badger")) #NOTE: Is IP address validated? Is DNS Lookup done?
time.sleep(1) # Give chance to ban time.sleep(Utils.DEFAULT_SLEEP_TIME) # Give chance to ban
# Unban IP # Unban IP
self.assertEqual( self.assertEqual(
self.transm.proceed( self.transm.proceed(

View File

@ -33,6 +33,7 @@ import unittest
from .. import protocol from .. import protocol
from ..server.asyncserver import AsyncServer, AsyncServerException from ..server.asyncserver import AsyncServer, AsyncServerException
from ..server.utils import Utils
from ..client.csocket import CSocket from ..client.csocket import CSocket
@ -54,14 +55,20 @@ class Socket(unittest.TestCase):
"""Test transmitter proceed method which just returns first arg""" """Test transmitter proceed method which just returns first arg"""
return message return message
def _serverSocket(self):
try:
return CSocket(self.sock_name)
except Exception as e:
return None
def testSocket(self): def testSocket(self):
serverThread = threading.Thread( serverThread = threading.Thread(
target=self.server.start, args=(self.sock_name, False)) target=self.server.start, args=(self.sock_name, False))
serverThread.daemon = True serverThread.daemon = True
serverThread.start() serverThread.start()
time.sleep(1) time.sleep(Utils.DEFAULT_SLEEP_TIME)
client = CSocket(self.sock_name) client = Utils.wait_for(self._serverSocket, 2)
testMessage = ["A", "test", "message"] testMessage = ["A", "test", "message"]
self.assertEqual(client.send(testMessage), testMessage) self.assertEqual(client.send(testMessage), testMessage)
@ -71,7 +78,7 @@ class Socket(unittest.TestCase):
client.close() client.close()
self.server.stop() self.server.stop()
serverThread.join(1) serverThread.join(Utils.DEFAULT_SLEEP_TIME)
self.assertFalse(os.path.exists(self.sock_name)) self.assertFalse(os.path.exists(self.sock_name))
def testSocketForce(self): def testSocketForce(self):
@ -85,10 +92,10 @@ class Socket(unittest.TestCase):
target=self.server.start, args=(self.sock_name, True)) target=self.server.start, args=(self.sock_name, True))
serverThread.daemon = True serverThread.daemon = True
serverThread.start() serverThread.start()
time.sleep(1) time.sleep(Utils.DEFAULT_SLEEP_TIME)
self.server.stop() self.server.stop()
serverThread.join(1) serverThread.join(Utils.DEFAULT_SLEEP_TIME)
self.assertFalse(os.path.exists(self.sock_name)) self.assertFalse(os.path.exists(self.sock_name))

View File

@ -30,8 +30,10 @@ import time
import unittest import unittest
from StringIO import StringIO from StringIO import StringIO
from ..server.mytime import MyTime
from ..helpers import getLogger from ..helpers import getLogger
from ..server.filter import DNSUtils
from ..server.mytime import MyTime
from ..server.utils import Utils
logSys = getLogger(__name__) logSys = getLogger(__name__)
@ -45,6 +47,43 @@ if not CONFIG_DIR:
CONFIG_DIR = '/etc/fail2ban' CONFIG_DIR = '/etc/fail2ban'
class F2B():
def __init__(self, fast=False, no_network=False):
self.fast=fast
self.no_network=no_network
def SkipIfFast(self):
pass
def SkipIfNoNetwork(self):
pass
def initTests(opts):
if opts: # pragma: no cover
unittest.F2B = F2B(opts.fast, opts.no_network)
else:
unittest.F2B = F2B()
# --fast :
if unittest.F2B.fast: # pragma: no cover
# prevent long sleeping during test cases...
Utils.DEFAULT_SLEEP_TIME = 0.0025
Utils.DEFAULT_SLEEP_INTERVAL = 0.0005
def F2B_SkipIfFast():
raise unittest.SkipTest('Skip test because of "--fast"')
unittest.F2B.SkipIfFast = F2B_SkipIfFast
else:
# sleep intervals are large - use replacement for sleep to check time to sleep:
_org_sleep = time.sleep
def _new_sleep(v):
if (v > Utils.DEFAULT_SLEEP_TIME):
raise ValueError('[BAD-CODE] To long sleep interval: %s, try to use conditional Utils.wait_for instead' % v)
_org_sleep(min(v, Utils.DEFAULT_SLEEP_TIME))
time.sleep = _new_sleep
# --no-network :
if unittest.F2B.no_network: # pragma: no cover
def F2B_SkipIfNoNetwork():
raise unittest.SkipTest('Skip test because of "--no-network"')
unittest.F2B.SkipIfNoNetwork = F2B_SkipIfNoNetwork
def mtimesleep(): def mtimesleep():
# no sleep now should be necessary since polling tracks now not only # no sleep now should be necessary since polling tracks now not only
# mtime but also ino and size # mtime but also ino and size
@ -70,7 +109,8 @@ def tearDownMyTime():
MyTime.myTime = None MyTime.myTime = None
def gatherTests(regexps=None, no_network=False): def gatherTests(regexps=None, opts=None):
initTests(opts)
# Import all the test cases here instead of a module level to # Import all the test cases here instead of a module level to
# avoid circular imports # avoid circular imports
from . import banmanagertestcase from . import banmanagertestcase
@ -143,7 +183,7 @@ def gatherTests(regexps=None, no_network=False):
tests.addTest(unittest.makeSuite(filtertestcase.LogFile)) tests.addTest(unittest.makeSuite(filtertestcase.LogFile))
tests.addTest(unittest.makeSuite(filtertestcase.LogFileMonitor)) tests.addTest(unittest.makeSuite(filtertestcase.LogFileMonitor))
tests.addTest(unittest.makeSuite(filtertestcase.LogFileFilterPoll)) tests.addTest(unittest.makeSuite(filtertestcase.LogFileFilterPoll))
if not no_network: # each test case class self will check no network, and skip it (we see it in log)
tests.addTest(unittest.makeSuite(filtertestcase.IgnoreIPDNS)) tests.addTest(unittest.makeSuite(filtertestcase.IgnoreIPDNS))
tests.addTest(unittest.makeSuite(filtertestcase.GetFailures)) tests.addTest(unittest.makeSuite(filtertestcase.GetFailures))
tests.addTest(unittest.makeSuite(filtertestcase.DNSUtilsTests)) tests.addTest(unittest.makeSuite(filtertestcase.DNSUtilsTests))
@ -165,9 +205,6 @@ def gatherTests(regexps=None, no_network=False):
for file_ in os.listdir( for file_ in os.listdir(
os.path.abspath(os.path.dirname(action_d.__file__))): os.path.abspath(os.path.dirname(action_d.__file__))):
if file_.startswith("test_") and file_.endswith(".py"): if file_.startswith("test_") and file_.endswith(".py"):
if no_network and file_ in ['test_badips.py','test_smtp.py']: #pragma: no cover
# Test required network
continue
tests.addTest(testloader.loadTestsFromName( tests.addTest(testloader.loadTestsFromName(
"%s.%s" % (action_d.__name__, os.path.splitext(file_)[0]))) "%s.%s" % (action_d.__name__, os.path.splitext(file_)[0])))
@ -182,6 +219,9 @@ def gatherTests(regexps=None, no_network=False):
# yoh: Since I do not know better way for parametric tests # yoh: Since I do not know better way for parametric tests
# with good old unittest # with good old unittest
try: try:
# because gamin can be very slow on some platforms (and can produce many failures
# with fast sleep interval) - skip it by fast run:
unittest.F2B.SkipIfFast()
from ..server.filtergamin import FilterGamin from ..server.filtergamin import FilterGamin
filters.append(FilterGamin) filters.append(FilterGamin)
except Exception, e: # pragma: no cover except Exception, e: # pragma: no cover
@ -276,29 +316,4 @@ class LogCaptureTestCase(unittest.TestCase):
def printLog(self): def printLog(self):
print(self._log.getvalue()) print(self._log.getvalue())
# Solution from http://stackoverflow.com/questions/568271/how-to-check-if-there-exists-a-process-with-a-given-pid pid_exists = Utils.pid_exists
# under cc by-sa 3.0
if os.name == 'posix':
def pid_exists(pid):
"""Check whether pid exists in the current process table."""
import errno
if pid < 0:
return False
try:
os.kill(pid, 0)
except OSError as e:
return e.errno == errno.EPERM
else:
return True
else:
def pid_exists(pid):
import ctypes
kernel32 = ctypes.windll.kernel32
SYNCHRONIZE = 0x100000
process = kernel32.OpenProcess(SYNCHRONIZE, 0, pid)
if process != 0:
kernel32.CloseHandle(process)
return True
else:
return False