- performance of fail2ban optimized

-- cache dnsToIp, ipToName to prevent long wait during retrieving of ip/name for wrong dns or lazy dns-system;
   -- instead of simple "sleep" used conditional wait "wait_for", that internal increases sleep interval up to sleeptime;
   -- ticket / banmanager / failmanager modules are performance optimized;
- performance of test cases optimized:
   -- added option "--fast" to decrease wait intervals, avoid passive waiting, and skip few very slow test cases;
- code review after partially cherry pick of branch 'ban-time-incr' (see gh-716)
   -- ticket module prepared to easy merge with newest version of 'ban-time-incr', now additionally holds banTime, banCount and json-data;
   -- executeCmd partially moved from action to new module utils, etc.
   -- python 2.6 compatibility;
- testExecuteTimeoutWithNastyChildren: test case repaired - wait for pid file inside bash, kill tree in any case (gh-1155);
- testSocket: test case repaired - wait for server thread starts a socket (listener)
f2b-perfom-prepare-716
sebres 2015-07-15 14:58:00 +02:00
parent d7c4df5acd
commit 4157adf5fe
29 changed files with 737 additions and 385 deletions

View File

@ -180,7 +180,6 @@ fail2ban/server/banmanager.py
fail2ban/server/database.py
fail2ban/server/datedetector.py
fail2ban/server/datetemplate.py
fail2ban/server/faildata.py
fail2ban/server/failmanager.py
fail2ban/server/failregex.py
fail2ban/server/filter.py
@ -197,6 +196,7 @@ fail2ban/server/server.py
fail2ban/server/strptime.py
fail2ban/server/ticket.py
fail2ban/server/transmitter.py
fail2ban/server/utils.py
fail2ban/tests/__init__.py
fail2ban/tests/action_d/__init__.py
fail2ban/tests/action_d/test_badips.py

View File

@ -58,6 +58,9 @@ def get_opt_parser():
Option('-n', "--no-network", action="store_true",
dest="no_network",
help="Do not run tests that require the network"),
Option('-f', "--fast", action="store_true",
dest="fast",
help="Try to increase speed of the tests, decreasing of wait intervals, memory database"),
Option("-t", "--log-traceback", action='store_true',
help="Enrich log-messages with compressed tracebacks"),
Option("--full-traceback", action='store_true',
@ -120,7 +123,7 @@ if not opts.log_level or opts.log_level != 'critical': # pragma: no cover
print("Fail2ban %s test suite. Python %s. Please wait..." \
% (version, str(sys.version).replace('\n', '')))
tests = gatherTests(regexps, opts.no_network)
tests = gatherTests(regexps, opts)
#
# Run the tests
#

View File

@ -10,7 +10,6 @@ fail2ban.server package
fail2ban.server.database
fail2ban.server.datedetector
fail2ban.server.datetemplate
fail2ban.server.faildata
fail2ban.server.failmanager
fail2ban.server.failregex
fail2ban.server.filter
@ -26,3 +25,4 @@ fail2ban.server package
fail2ban.server.strptime
fail2ban.server.ticket
fail2ban.server.transmitter
fail2ban.server.utils

View File

@ -1,7 +1,7 @@
fail2ban.server.faildata module
fail2ban.server.utils module
===============================
.. automodule:: fail2ban.server.faildata
.. automodule:: fail2ban.server.utils
:members:
:undoc-members:
:show-inheritance:

View File

@ -32,6 +32,7 @@ import time
from abc import ABCMeta
from collections import MutableMapping
from .utils import Utils
from ..helpers import getLogger
# Gets the instance of the logger.
@ -40,21 +41,6 @@ logSys = getLogger(__name__)
# Create a lock for running system commands
_cmd_lock = threading.Lock()
# Some hints on common abnormal exit codes
_RETCODE_HINTS = {
127: '"Command not found". Make sure that all commands in %(realCmd)r '
'are in the PATH of fail2ban-server process '
'(grep -a PATH= /proc/`pidof -x fail2ban-server`/environ). '
'You may want to start '
'"fail2ban-server -f" separately, initiate it with '
'"fail2ban-client reload" in another shell session and observe if '
'additional informative error messages appear in the terminals.'
}
# Dictionary to lookup signal name from number
signame = dict((num, name)
for name, num in signal.__dict__.iteritems() if name.startswith("SIG"))
class CallingMap(MutableMapping):
"""A Mapping type which returns the result of callable values.
@ -561,61 +547,6 @@ class CommandAction(ActionBase):
_cmd_lock.acquire()
try:
retcode = None # to guarantee being defined upon early except
stdout = tempfile.TemporaryFile(suffix=".stdout", prefix="fai2ban_")
stderr = tempfile.TemporaryFile(suffix=".stderr", prefix="fai2ban_")
popen = subprocess.Popen(
realCmd, stdout=stdout, stderr=stderr, shell=True,
preexec_fn=os.setsid # so that killpg does not kill our process
)
stime = time.time()
retcode = popen.poll()
while time.time() - stime <= timeout and retcode is None:
time.sleep(0.1)
retcode = popen.poll()
if retcode is None:
logSys.error("%s -- timed out after %i seconds." %
(realCmd, timeout))
pgid = os.getpgid(popen.pid)
os.killpg(pgid, signal.SIGTERM) # Terminate the process
time.sleep(0.1)
retcode = popen.poll()
if retcode is None: # Still going...
os.killpg(pgid, signal.SIGKILL) # Kill the process
time.sleep(0.1)
retcode = popen.poll()
except OSError as e:
logSys.error("%s -- failed with %s" % (realCmd, e))
return Utils.executeCmd(realCmd, timeout, shell=True, output=False)
finally:
_cmd_lock.release()
std_level = retcode == 0 and logging.DEBUG or logging.ERROR
if std_level >= logSys.getEffectiveLevel():
stdout.seek(0); msg = stdout.read()
if msg != '':
logSys.log(std_level, "%s -- stdout: %r", realCmd, msg)
stderr.seek(0); msg = stderr.read()
if msg != '':
logSys.log(std_level, "%s -- stderr: %r", realCmd, msg)
stdout.close()
stderr.close()
if retcode == 0:
logSys.debug("%s -- returned successfully" % realCmd)
return True
elif retcode is None:
logSys.error("%s -- unable to kill PID %i" % (realCmd, popen.pid))
elif retcode < 0 or retcode > 128:
# dash would return negative while bash 128 + n
sigcode = -retcode if retcode < 0 else retcode - 128
logSys.error("%s -- killed with %s (return code: %s)" %
(realCmd, signame.get(sigcode, "signal %i" % sigcode), retcode))
else:
msg = _RETCODE_HINTS.get(retcode, None)
logSys.error("%s -- returned %i" % (realCmd, retcode))
if msg:
logSys.info("HINT on %i: %s"
% (retcode, msg % locals()))
return False

View File

@ -42,6 +42,7 @@ from .banmanager import BanManager
from .jailthread import JailThread
from .action import ActionBase, CommandAction, CallingMap
from .mytime import MyTime
from .utils import Utils
from ..helpers import getLogger
# Gets the instance of the logger.
@ -225,14 +226,11 @@ class Actions(JailThread, Mapping):
self._jail.name, name, e,
exc_info=logSys.getEffectiveLevel()<=logging.DEBUG)
while self.active:
if not self.idle:
#logSys.debug(self._jail.name + ": action")
ret = self.__checkBan()
if not ret:
self.__checkUnBan()
time.sleep(self.sleeptime)
else:
if self.idle:
time.sleep(self.sleeptime)
continue
if not Utils.wait_for(self.__checkBan, self.sleeptime):
self.__checkUnBan()
self.__flushBan()
actions = self._actions.items()

View File

@ -247,12 +247,10 @@ class BanManager:
@staticmethod
def createBanTicket(ticket):
ip = ticket.getIP()
#lastTime = ticket.getTime()
lastTime = MyTime.time()
banTicket = BanTicket(ip, lastTime, ticket.getMatches())
banTicket.setAttempt(ticket.getAttempt())
return banTicket
# we should always use correct time to calculate correct end time (ban time is variable now,
# + possible double banning by restore from database and from log file)
# so use as lastTime always time from ticket.
return BanTicket(ticket=ticket)
##
# Add a ban ticket.
@ -264,11 +262,25 @@ class BanManager:
def addBanTicket(self, ticket):
try:
self.__lock.acquire()
if not self._inBanList(ticket):
self.__banList.append(ticket)
self.__banTotal += 1
return True
return False
# check already banned
for i in self.__banList:
if ticket.getIP() == i.getIP():
# if already permanent
btorg, torg = i.getBanTime(self.__banTime), i.getTime()
if btorg == -1:
return False
# if given time is less than already banned time
btnew, tnew = ticket.getBanTime(self.__banTime), ticket.getTime()
if btnew != -1 and tnew + btnew <= torg + btorg:
return False
# we have longest ban - set new (increment) ban time
i.setTime(tnew)
i.setBanTime(btnew)
return False
# not yet banned - add new
self.__banList.append(ticket)
self.__banTotal += 1
return True
finally:
self.__lock.release()
@ -313,8 +325,7 @@ class BanManager:
return list()
# Gets the list of ticket to remove.
unBanList = [ticket for ticket in self.__banList
if ticket.getTime() < time - self.__banTime]
unBanList = [ticket for ticket in self.__banList if ticket.isTimedOut(time, self.__banTime)]
# Removes tickets.
self.__banList = [ticket for ticket in self.__banList

View File

@ -418,8 +418,7 @@ class Fail2BanDb(object):
cur.execute(
"INSERT INTO bans(jail, ip, timeofban, data) VALUES(?, ?, ?, ?)",
(jail.name, ticket.getIP(), int(round(ticket.getTime())),
{"matches": ticket.getMatches(),
"failures": ticket.getAttempt()}))
ticket.getData()))
@commitandrollback
def delBan(self, cur, jail, ip):
@ -477,8 +476,8 @@ class Fail2BanDb(object):
tickets = []
for ip, timeofban, data in self._getBans(**kwargs):
#TODO: Implement data parts once arbitrary match keys completed
tickets.append(FailTicket(ip, timeofban, data.get('matches')))
tickets[-1].setAttempt(data.get('failures', 1))
tickets.append(FailTicket(ip, timeofban))
tickets[-1].setData(data)
return tickets
def getBansMerged(self, ip=None, jail=None, bantime=None):
@ -520,6 +519,7 @@ class Fail2BanDb(object):
prev_banip = results[0][0]
matches = []
failures = 0
tickdata = {}
for banip, timeofban, data in results:
#TODO: Implement data parts once arbitrary match keys completed
if banip != prev_banip:
@ -530,11 +530,14 @@ class Fail2BanDb(object):
prev_banip = banip
matches = []
failures = 0
matches.extend(data.get('matches', []))
tickdata = {}
matches.extend(data.get('matches', ()))
failures += data.get('failures', 1)
tickdata.update(data.get('data', {}))
prev_timeofban = timeofban
ticket = FailTicket(banip, prev_timeofban, matches)
ticket.setAttempt(failures)
ticket.setData(**tickdata)
tickets.append(ticket)
if cacheKey:

View File

@ -1,71 +0,0 @@
# emacs: -*- mode: python; py-indent-offset: 4; indent-tabs-mode: t -*-
# vi: set ft=python sts=4 ts=4 sw=4 noet :
# This file is part of Fail2Ban.
#
# Fail2Ban is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 2 of the License, or
# (at your option) any later version.
#
# Fail2Ban is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with Fail2Ban; if not, write to the Free Software
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
# Author: Cyril Jaquier
#
__author__ = "Cyril Jaquier"
__copyright__ = "Copyright (c) 2004 Cyril Jaquier"
__license__ = "GPL"
from ..helpers import getLogger
# Gets the instance of the logger.
logSys = getLogger(__name__)
class FailData:
def __init__(self):
self.__retry = 0
self.__lastTime = 0
self.__lastReset = 0
self.__matches = []
def setRetry(self, value):
self.__retry = value
# keep only the last matches or reset entirely
# Explicit if/else for compatibility with Python 2.4
if value:
self.__matches = self.__matches[-min(len(self.__matches, value)):]
else:
self.__matches = []
def getRetry(self):
return self.__retry
def getMatches(self):
return self.__matches
def inc(self, matches=None):
self.__retry += 1
self.__matches += matches or []
def setLastTime(self, value):
if value > self.__lastTime:
self.__lastTime = value
def getLastTime(self):
return self.__lastTime
def getLastReset(self):
return self.__lastReset
def setLastReset(self, value):
self.__lastReset = value

View File

@ -27,7 +27,6 @@ __license__ = "GPL"
from threading import Lock
import logging
from .faildata import FailData
from .ticket import FailTicket
from ..helpers import getLogger
@ -86,26 +85,35 @@ class FailManager:
finally:
self.__lock.release()
def addFailure(self, ticket):
def addFailure(self, ticket, count=1):
attempts = 1
try:
self.__lock.acquire()
ip = ticket.getIP()
unixTime = ticket.getTime()
matches = ticket.getMatches()
if ip in self.__failList:
fData = self.__failList[ip]
# if the same object:
if fData is ticket:
matches = None
else:
matches = ticket.getMatches()
unixTime = ticket.getTime()
if fData.getLastReset() < unixTime - self.__maxTime:
fData.setLastReset(unixTime)
fData.setRetry(0)
fData.inc(matches)
fData.inc(matches, 1, count)
fData.setLastTime(unixTime)
else:
fData = FailData()
fData.inc(matches)
fData.setLastReset(unixTime)
fData.setLastTime(unixTime)
# if already FailTicket - add it direct, otherwise create (using copy all ticket data):
if isinstance(ticket, FailTicket):
fData = ticket;
else:
fData = FailTicket(ticket=ticket)
if count > ticket.getAttempt():
fData.setRetry(count)
self.__failList[ip] = fData
attempts = fData.getRetry()
self.__failTotal += 1
if logSys.getEffectiveLevel() <= logging.DEBUG:
@ -118,6 +126,7 @@ class FailManager:
% (self.__failTotal, len(self.__failList), failures_summary))
finally:
self.__lock.release()
return attempts
def size(self):
try:
@ -140,17 +149,14 @@ class FailManager:
if ip in self.__failList:
del self.__failList[ip]
def toBan(self):
def toBan(self, ip=None):
try:
self.__lock.acquire()
for ip in self.__failList:
for ip in ([ip] if ip != None and ip in self.__failList else self.__failList):
data = self.__failList[ip]
if data.getRetry() >= self.__maxRetry:
self.__delFailure(ip)
# Create a FailTicket from BanData
failTicket = FailTicket(ip, data.getLastTime(), data.getMatches())
failTicket.setAttempt(data.getRetry())
return failTicket
del self.__failList[ip]
return data
raise FailManagerEmpty
finally:
self.__lock.release()

View File

@ -22,6 +22,7 @@ __copyright__ = "Copyright (c) 2004 Cyril Jaquier, 2011-2013 Yaroslav Halchenko"
__license__ = "GPL"
import codecs
import datetime
import fcntl
import locale
import logging
@ -316,13 +317,12 @@ class Filter(JailThread):
logSys.warning('Requested to manually ban an ignored IP %s. User knows best. Proceeding to ban it.' % ip)
unixTime = MyTime.time()
for i in xrange(self.failManager.getMaxRetry()):
self.failManager.addFailure(FailTicket(ip, unixTime))
self.failManager.addFailure(FailTicket(ip, unixTime), self.failManager.getMaxRetry())
# Perform the banning of the IP now.
try: # pragma: no branch - exception is the only way out
while True:
ticket = self.failManager.toBan()
ticket = self.failManager.toBan(ip)
self.jail.putFailTicket(ticket)
except FailManagerEmpty:
self.failManager.cleanup(MyTime.time())
@ -427,17 +427,19 @@ class Filter(JailThread):
ip = element[1]
unixTime = element[2]
lines = element[3]
logSys.debug("Processing line with time:%s and ip:%s"
% (unixTime, ip))
logSys.debug("Processing line with time:%s and ip:%s",
unixTime, ip)
if unixTime < MyTime.time() - self.getFindTime():
logSys.debug("Ignore line since time %s < %s - %s"
% (unixTime, MyTime.time(), self.getFindTime()))
logSys.debug("Ignore line since time %s < %s - %s",
unixTime, MyTime.time(), self.getFindTime())
break
if self.inIgnoreIPList(ip, log_ignore=True):
continue
logSys.info("[%s] Found %s" % (self.jail.name, ip))
## print "D: Adding a ticket for %s" % ((ip, unixTime, [line]),)
self.failManager.addFailure(FailTicket(ip, unixTime, lines))
logSys.info(
"[%s] Found %s - %s", self.jail.name, ip, datetime.datetime.fromtimestamp(unixTime).strftime("%Y-%m-%d %H:%M:%S")
)
tick = FailTicket(ip, unixTime, lines)
self.failManager.addFailure(tick)
##
# Returns true if the line should be ignored.
@ -941,32 +943,50 @@ class JournalFilter(Filter): # pragma: systemd no cover
import socket
import struct
from .utils import Utils
class DNSUtils:
IP_CRE = re.compile("^(?:\d{1,3}\.){3}\d{1,3}$")
# todo: make configurable the expired time and max count of cache entries:
CACHE_dnsToIp = Utils.Cache(maxCount=1000, maxTime=60*60)
CACHE_ipToName = Utils.Cache(maxCount=1000, maxTime=60*60)
@staticmethod
def dnsToIp(dns):
""" Convert a DNS into an IP address using the Python socket module.
Thanks to Kevin Drapel.
"""
# cache, also prevent long wait during retrieving of ip for wrong dns or lazy dns-system:
v = DNSUtils.CACHE_dnsToIp.get(dns)
if v is not None:
return v
# retrieve ip (todo: use AF_INET6 for IPv6)
try:
return set([i[4][0] for i in socket.getaddrinfo(dns, None, socket.AF_INET, 0, socket.IPPROTO_TCP)])
v = set([i[4][0] for i in socket.getaddrinfo(dns, None, socket.AF_INET, 0, socket.IPPROTO_TCP)])
except socket.error, e:
logSys.warning("Unable to find a corresponding IP address for %s: %s"
% (dns, e))
return list()
# todo: make configurable the expired time of cache entry:
logSys.warning("Unable to find a corresponding IP address for %s: %s", dns, e)
v = list()
DNSUtils.CACHE_dnsToIp.set(dns, v)
return v
@staticmethod
def ipToName(ip):
# cache, also prevent long wait during retrieving of name for wrong addresses, lazy dns:
v = DNSUtils.CACHE_ipToName.get(ip)
if v is not None:
return v
# retrieve name
try:
return socket.gethostbyaddr(ip)[0]
v = socket.gethostbyaddr(ip)[0]
except socket.error, e:
logSys.debug("Unable to find a name for the IP %s: %s" % (ip, e))
return None
logSys.debug("Unable to find a name for the IP %s: %s", ip, e)
v = None
DNSUtils.CACHE_ipToName.set(ip, v)
return v
@staticmethod
def searchIP(text):

View File

@ -31,6 +31,7 @@ import gamin
from .failmanager import FailManagerEmpty
from .filter import FileFilter
from .mytime import MyTime
from .utils import Utils
from ..helpers import getLogger
# Gets the instance of the logger.
@ -102,6 +103,15 @@ class FilterGamin(FileFilter):
def _delLogPath(self, path):
self.monitor.stop_watch(path)
def _handleEvents(self):
ret = False
mon = self.monitor
while mon and mon.event_pending():
mon.handle_events()
mon = self.monitor
ret = True
return ret
##
# Main loop.
#
@ -112,12 +122,10 @@ class FilterGamin(FileFilter):
def run(self):
# Gamin needs a loop to collect and dispatch events
while self.active:
if not self.idle:
# We cannot block here because we want to be able to
# exit.
if self.monitor.event_pending():
self.monitor.handle_events()
time.sleep(self.sleeptime)
if self.idle:
time.sleep(self.sleeptime)
continue
Utils.wait_for(self._handleEvents, self.sleeptime)
logSys.debug(self.jail.name + ": filter terminated")
return True
@ -131,4 +139,4 @@ class FilterGamin(FileFilter):
def __cleanup(self):
for path in self.getLogPath():
self.monitor.stop_watch(path.getFileName())
del self.monitor
self.monitor = None

View File

@ -31,6 +31,7 @@ from .failmanager import FailManagerEmpty
from .filter import FileFilter
from .mytime import MyTime
from ..helpers import getLogger
from ..server.utils import Utils
# Gets the instance of the logger.
logSys = getLogger(__name__)
@ -78,6 +79,16 @@ class FilterPoll(FileFilter):
del self.__prevStats[path]
del self.__file404Cnt[path]
##
# Get a modified log path at once
#
def getModified(self, modlst):
for container in self.getLogPath():
filename = container.getFileName()
if self.isModified(filename):
modlst.append(filename)
return modlst
##
# Main loop.
#
@ -90,30 +101,29 @@ class FilterPoll(FileFilter):
if logSys.getEffectiveLevel() <= 6:
logSys.log(6, "Woke up idle=%s with %d files monitored",
self.idle, len(self.getLogPath()))
if not self.idle:
# Get file modification
for container in self.getLogPath():
filename = container.getFileName()
if self.isModified(filename):
# set start time as now - find time for first usage only (prevent performance bug with polling of big files)
self.getFailures(filename,
(MyTime.time() - self.getFindTime()) if not self.__initial.get(filename) else None
)
self.__initial[filename] = True
self.__modified = True
if self.idle:
time.sleep(self.sleeptime)
continue
# Get file modification
modlst = []
Utils.wait_for(lambda: self.getModified(modlst), self.sleeptime)
for filename in modlst:
# set start time as now - find time for first usage only (prevent performance bug with polling of big files)
self.getFailures(filename,
(MyTime.time() - self.getFindTime()) if not self.__initial.get(filename) else None
)
self.__initial[filename] = True
self.__modified = True
if self.__modified:
try:
while True:
ticket = self.failManager.toBan()
self.jail.putFailTicket(ticket)
except FailManagerEmpty:
self.failManager.cleanup(MyTime.time())
self.dateDetector.sortTemplate()
self.__modified = False
time.sleep(self.sleeptime)
else:
time.sleep(self.sleeptime)
if self.__modified:
try:
while True:
ticket = self.failManager.toBan()
self.jail.putFailTicket(ticket)
except FailManagerEmpty:
self.failManager.cleanup(MyTime.time())
self.dateDetector.sortTemplate()
self.__modified = False
logSys.debug(
(self.jail is not None and self.jail.name or "jailless") +
" filter terminated")
@ -129,7 +139,7 @@ class FilterPoll(FileFilter):
try:
logStats = os.stat(filename)
stats = logStats.st_mtime, logStats.st_ino, logStats.st_size
pstats = self.__prevStats[filename]
pstats = self.__prevStats.get(filename, ())
self.__file404Cnt[filename] = 0
if logSys.getEffectiveLevel() <= 7:
# we do not want to waste time on strftime etc if not necessary
@ -139,10 +149,9 @@ class FilterPoll(FileFilter):
# os.system("stat %s | grep Modify" % filename)
if pstats == stats:
return False
else:
logSys.debug("%s has been modified", filename)
self.__prevStats[filename] = stats
return True
logSys.debug("%s has been modified", filename)
self.__prevStats[filename] = stats
return True
except OSError, e:
logSys.error("Unable to get stat on %s because of: %s"
% (filename, e))

View File

@ -28,6 +28,7 @@ import sys
from threading import Thread
from abc import abstractmethod
from .utils import Utils
from ..helpers import excepthook
@ -48,14 +49,14 @@ class JailThread(Thread):
The time the thread sleeps for in the loop.
"""
def __init__(self):
super(JailThread, self).__init__()
def __init__(self, name=None):
super(JailThread, self).__init__(name=name)
## Control the state of the thread.
self.active = False
## Control the idle state of the thread.
self.idle = False
## The time the thread sleeps in the loop.
self.sleeptime = 1
self.sleeptime = Utils.DEFAULT_SLEEP_TIME
# excepthook workaround for threads, derived from:
# http://bugs.python.org/issue1230540#msg91244

View File

@ -324,6 +324,15 @@ class Server:
def getBanTime(self, name):
return self.__jails[name].actions.getBanTime()
def is_alive(self, jailnum=None):
if jailnum is not None and len(self.__jails) != jailnum:
return 0
for j in self.__jails:
j = self.__jails[j]
if not j.is_alive():
return 0
return 1
# Status
def status(self):
try:

View File

@ -24,7 +24,10 @@ __author__ = "Cyril Jaquier"
__copyright__ = "Copyright (c) 2004 Cyril Jaquier"
__license__ = "GPL"
import sys
from ..helpers import getLogger
from .mytime import MyTime
# Gets the instance of the logger.
logSys = getLogger(__name__)
@ -32,7 +35,7 @@ logSys = getLogger(__name__)
class Ticket:
def __init__(self, ip, time, matches=None):
def __init__(self, ip=None, time=None, matches=None, ticket=None):
"""Ticket constructor
@param ip the IP address
@ -41,14 +44,21 @@ class Ticket:
"""
self.setIP(ip)
self.__time = time
self.__attempt = 0
self.__file = None
self.__matches = matches or []
self._flags = 0;
self._banCount = 0;
self._banTime = None;
self._time = time if time is not None else MyTime.time()
self._data = {'matches': [], 'failures': 0}
if ticket:
# ticket available - copy whole information from ticket:
self.__dict__.update(i for i in ticket.__dict__.iteritems() if i[0] in self.__dict__)
else:
self._data['matches'] = matches or []
def __str__(self):
return "%s: ip=%s time=%s #attempts=%d matches=%r" % \
(self.__class__.__name__.split('.')[-1], self.__ip, self.__time, self.__attempt, self.__matches)
(self.__class__.__name__.split('.')[-1], self.__ip, self._time,
self._data['failures'], self._data.get('matches', []))
def __repr__(self):
return str(self)
@ -56,9 +66,8 @@ class Ticket:
def __eq__(self, other):
try:
return self.__ip == other.__ip and \
round(self.__time, 2) == round(other.__time, 2) and \
self.__attempt == other.__attempt and \
self.__matches == other.__matches
round(self._time, 2) == round(other._time, 2) and \
self._data == other._data
except AttributeError:
return False
@ -72,24 +81,128 @@ class Ticket:
return self.__ip
def setTime(self, value):
self.__time = value
self._time = value
def getTime(self):
return self.__time
return self._time
def setBanTime(self, value):
self._banTime = value;
def getBanTime(self, defaultBT = None):
return (self._banTime if not self._banTime is None else defaultBT);
def setBanCount(self, value):
self._banCount = value;
def incrBanCount(self, value = 1):
self._banCount += value;
def getBanCount(self):
return self._banCount;
def isTimedOut(self, time, defaultBT = None):
bantime = (self._banTime if not self._banTime is None else defaultBT);
# permanent
if bantime == -1:
return False
# timed out
return (time > self._time + bantime)
def setAttempt(self, value):
self.__attempt = value
self._data['failures'] = value
def getAttempt(self):
return self.__attempt
return self._data['failures']
def setMatches(self, matches):
self._data['matches'] = matches or []
def getMatches(self):
return self.__matches
return self._data.get('matches', [])
def setData(self, *args, **argv):
# if overwrite - set data and filter None values:
if len(args) == 1:
# todo: if support >= 2.7 only:
# self._data = {k:v for k,v in args[0].iteritems() if v is not None}
self._data = dict([(k,v) for k,v in args[0].iteritems() if v is not None])
# add k,v list or dict (merge):
elif len(args) == 2:
self._data.update((args,))
elif len(args) > 2:
self._data.update((k,v) for k,v in zip(*[iter(args)]*2))
if len(argv):
self._data.update(argv)
# filter (delete) None values:
# todo: if support >= 2.7 only:
# self._data = {k:v for k,v in self._data.iteritems() if v is not None}
self._data = dict([(k,v) for k,v in self._data.iteritems() if v is not None])
def getData(self, key=None, default=None):
# return whole data dict:
if key is None:
return self._data
# return default if not exists:
if not self._data:
return default
# return filtered by lambda/function:
if callable(key):
# todo: if support >= 2.7 only:
# return {k:v for k,v in self._data.iteritems() if key(k)}
return dict([(k,v) for k,v in self._data.iteritems() if key(k)])
# return filtered by keys:
if hasattr(key, '__iter__'):
# todo: if support >= 2.7 only:
# return {k:v for k,v in self._data.iteritems() if k in key}
return dict([(k,v) for k,v in self._data.iteritems() if k in key])
# return single value of data:
return self._data.get(key, default)
class FailTicket(Ticket):
pass
def __init__(self, ip=None, time=None, matches=None, ticket=None):
# this class variables:
self.__retry = 0
self.__lastReset = None
# create/copy using default ticket constructor:
Ticket.__init__(self, ip, time, matches, ticket)
# init:
if ticket is None:
self.__lastReset = time if time is not None else self.getTime()
if not self.__retry:
self.__retry = self._data['failures'];
def setRetry(self, value):
self.__retry = value
if not self._data['failures']:
self._data['failures'] = 1
if not value:
self._data['failures'] = 0
self._data['matches'] = []
def getRetry(self):
return max(self.__retry, self._data['failures'])
def inc(self, matches=None, attempt=1, count=1):
self.__retry += count
self._data['failures'] += attempt
if matches:
self._data['matches'] += matches
def setLastTime(self, value):
if value > self._time:
self._time = value
def getLastTime(self):
return self._time
def getLastReset(self):
return self.__lastReset
def setLastReset(self, value):
self.__lastReset = value
##
# Ban Ticket.

View File

@ -95,7 +95,7 @@ class Transmitter:
return None
elif command[0] == "sleep":
value = command[1]
time.sleep(int(value))
time.sleep(float(value))
return None
elif command[0] == "flushlogs":
return self.__server.flushLogs()

242
fail2ban/server/utils.py Normal file
View File

@ -0,0 +1,242 @@
# emacs: -*- mode: python; py-indent-offset: 4; indent-tabs-mode: t -*-
# vi: set ft=python sts=4 ts=4 sw=4 noet :
# This file is part of Fail2Ban.
#
# Fail2Ban is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 2 of the License, or
# (at your option) any later version.
#
# Fail2Ban is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with Fail2Ban; if not, write to the Free Software
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
__author__ = "Serg G. Brester (sebres) and Fail2Ban Contributors"
__copyright__ = "Copyright (c) 2004 Cyril Jaquier, 2011-2012 Yaroslav Halchenko, 2012-2015 Serg G. Brester"
__license__ = "GPL"
import logging, os, fcntl, subprocess, time, signal
from ..helpers import getLogger
# Gets the instance of the logger.
logSys = getLogger(__name__)
# Some hints on common abnormal exit codes
_RETCODE_HINTS = {
127: '"Command not found". Make sure that all commands in %(realCmd)r '
'are in the PATH of fail2ban-server process '
'(grep -a PATH= /proc/`pidof -x fail2ban-server`/environ). '
'You may want to start '
'"fail2ban-server -f" separately, initiate it with '
'"fail2ban-client reload" in another shell session and observe if '
'additional informative error messages appear in the terminals.'
}
# Dictionary to lookup signal name from number
signame = dict((num, name)
for name, num in signal.__dict__.iteritems() if name.startswith("SIG"))
class Utils():
"""Utilities provide diverse static methods like executes OS shell commands, etc.
"""
DEFAULT_SLEEP_TIME = 0.1
DEFAULT_SLEEP_INTERVAL = 0.01
class Cache(dict):
def __init__(self, maxCount=1000, maxTime=60*60):
self.maxCount = maxCount
self.maxTime = maxTime
def get(self, k, defv=None):
v = dict.get(self, k)
if v:
if v[1] > time.time():
return v[0]
del self[k]
return defv
def set(self, k, v):
t = time.time()
# clean cache if max count reached:
if len(self) >= self.maxCount:
for (ck,cv) in self.items():
if cv[1] < t:
del self[ck]
# if still max count - remove any one:
if len(self) >= self.maxCount:
self.popitem()
self[k] = (v, t + self.maxTime)
@staticmethod
def setFBlockMode(fhandle, value):
flags = fcntl.fcntl(fhandle, fcntl.F_GETFL)
if not value:
flags |= os.O_NONBLOCK
else:
flags &= ~os.O_NONBLOCK
fcntl.fcntl(fhandle, fcntl.F_SETFL, flags)
return flags
@staticmethod
def executeCmd(realCmd, timeout=60, shell=True, output=False, tout_kill_tree=True):
"""Executes a command.
Parameters
----------
realCmd : str
The command to execute.
timeout : int
The time out in seconds for the command.
shell : bool
If shell is True (default), the specified command (may be a string) will be
executed through the shell.
output : bool
If output is True, the function returns tuple (success, stdoutdata, stderrdata, returncode)
Returns
-------
bool
True if the command succeeded.
Raises
------
OSError
If command fails to be executed.
RuntimeError
If command execution times out.
"""
stdout = stderr = None
retcode = None
if not callable(timeout):
stime = time.time()
timeout_expr = lambda: time.time() - stime <= timeout
else:
timeout_expr = timeout
try:
popen = subprocess.Popen(
realCmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=shell,
preexec_fn=os.setsid # so that killpg does not kill our process
)
retcode = popen.poll()
while retcode is None and timeout_expr():
time.sleep(Utils.DEFAULT_SLEEP_INTERVAL)
retcode = popen.poll()
if retcode is None:
logSys.error("%s -- timed out after %s seconds." %
(realCmd, timeout))
pgid = os.getpgid(popen.pid)
# if not tree - first try to terminate and then kill, otherwise - kill (-9) only:
os.killpg(pgid, signal.SIGTERM) # Terminate the process
time.sleep(Utils.DEFAULT_SLEEP_INTERVAL)
retcode = popen.poll()
#logSys.debug("%s -- terminated %s ", realCmd, retcode)
if retcode is None or tout_kill_tree: # Still going...
os.killpg(pgid, signal.SIGKILL) # Kill the process
time.sleep(Utils.DEFAULT_SLEEP_INTERVAL)
retcode = popen.poll()
#logSys.debug("%s -- killed %s ", realCmd, retcode)
if retcode is None and not Utils.pid_exists(pgid):
retcode = signal.SIGKILL
except OSError as e:
logSys.error("%s -- failed with %s" % (realCmd, e))
std_level = retcode == 0 and logging.DEBUG or logging.ERROR
# if we need output (to return or to log it):
if output or std_level >= logSys.getEffectiveLevel():
# if was timeouted (killed/terminated) - to prevent waiting, set std handles to non-blocking mode.
if popen.stdout:
try:
if retcode < 0:
Utils.setFBlockMode(popen.stdout, False)
stdout = popen.stdout.read()
except IOError as e:
logSys.error(" ... -- failed to read stdout %s", e)
if stdout is not None and stdout != '':
logSys.log(std_level, "%s -- stdout: %r", realCmd, stdout)
popen.stdout.close()
if popen.stderr:
try:
if retcode < 0:
Utils.setFBlockMode(popen.stderr, False)
stderr = popen.stderr.read()
except IOError as e:
logSys.error(" ... -- failed to read stderr %s", e)
if stderr is not None and stderr != '':
logSys.log(std_level, "%s -- stderr: %r", realCmd, stderr)
popen.stderr.close()
if retcode == 0:
logSys.debug("%s -- returned successfully", realCmd)
return True if not output else (True, stdout, stderr, retcode)
elif retcode is None:
logSys.error("%s -- unable to kill PID %i" % (realCmd, popen.pid))
elif retcode < 0 or retcode > 128:
# dash would return negative while bash 128 + n
sigcode = -retcode if retcode < 0 else retcode - 128
logSys.error("%s -- killed with %s (return code: %s)" %
(realCmd, signame.get(sigcode, "signal %i" % sigcode), retcode))
else:
msg = _RETCODE_HINTS.get(retcode, None)
logSys.error("%s -- returned %i" % (realCmd, retcode))
if msg:
logSys.info("HINT on %i: %s", retcode, msg % locals())
return False if not output else (False, stdout, stderr, retcode)
@staticmethod
def wait_for(cond, timeout, interval=None):
"""Wait until condition expression `cond` is True, up to `timeout` sec
"""
ini = 1
while True:
ret = cond()
if ret:
return ret
if ini:
ini = stm = 0
time0 = time.time() + timeout
if not interval:
interval = Utils.DEFAULT_SLEEP_INTERVAL
if time.time() > time0:
break
stm = min(stm + interval, Utils.DEFAULT_SLEEP_TIME)
time.sleep(stm)
return ret
# Solution from http://stackoverflow.com/questions/568271/how-to-check-if-there-exists-a-process-with-a-given-pid
# under cc by-sa 3.0
if os.name == 'posix':
@staticmethod
def pid_exists(pid):
"""Check whether pid exists in the current process table."""
import errno
if pid < 0:
return False
try:
os.kill(pid, 0)
except OSError as e:
return e.errno == errno.EPERM
else:
return True
else:
@staticmethod
def pid_exists(pid):
import ctypes
kernel32 = ctypes.windll.kernel32
SYNCHRONIZE = 0x100000
process = kernel32.OpenProcess(SYNCHRONIZE, 0, pid)
if process != 0:
kernel32.CloseHandle(process)
return True
else:
return False

View File

@ -29,6 +29,8 @@ if sys.version_info >= (2,7):
def setUp(self):
"""Call before every test case."""
unittest.F2B.SkipIfNoNetwork()
self.jail = DummyJail()
self.jail.actions.add("test")

View File

@ -46,6 +46,8 @@ class SMTPActionTest(unittest.TestCase):
def setUp(self):
"""Call before every test case."""
unittest.F2B.SkipIfNoNetwork()
self.jail = DummyJail()
pythonModule = os.path.join(CONFIG_DIR, "action.d", "smtp.py")
pythonModuleName = os.path.basename(pythonModule.rstrip(".py"))

View File

@ -30,6 +30,7 @@ import tempfile
from ..server.actions import Actions
from ..server.ticket import FailTicket
from ..server.utils import Utils
from .dummyjail import DummyJail
from .utils import LogCaptureTestCase
@ -81,8 +82,7 @@ class ExecuteActions(LogCaptureTestCase):
self.defaultActions()
self.__actions.start()
with open(self.__tmpfilename) as f:
time.sleep(3)
self.assertEqual(f.read(),"ip start 64\n")
self.assertTrue( Utils.wait_for(lambda: (f.read() == "ip start 64\n"), 3) )
self.__actions.stop()
self.__actions.join()
@ -97,8 +97,7 @@ class ExecuteActions(LogCaptureTestCase):
self.assertLogged("TestAction initialised")
self.__actions.start()
time.sleep(3)
self.assertLogged("TestAction action start")
self.assertTrue( Utils.wait_for(lambda: self._is_logged("TestAction action start"), 3) )
self.__actions.stop()
self.__actions.join()
@ -135,8 +134,7 @@ class ExecuteActions(LogCaptureTestCase):
"action.d/action_errors.py"),
{})
self.__actions.start()
time.sleep(3)
self.assertLogged("Failed to start")
self.assertTrue( Utils.wait_for(lambda: self._is_logged("Failed to start"), 3) )
self.__actions.stop()
self.__actions.join()
self.assertLogged("Failed to stop")

View File

@ -25,10 +25,12 @@ __copyright__ = "Copyright (c) 2004 Cyril Jaquier"
__license__ = "GPL"
import os
import time
import tempfile
import time
import unittest
from ..server.action import CommandAction, CallingMap
from ..server.utils import Utils
from .utils import LogCaptureTestCase
from .utils import pid_exists
@ -194,16 +196,17 @@ class CommandActionTest(LogCaptureTestCase):
self.assertLogged('HINT on 127: "Command not found"')
def testExecuteTimeout(self):
unittest.F2B.SkipIfFast()
stime = time.time()
# Should take a minute
self.assertFalse(CommandAction.executeCmd('sleep 60', timeout=2))
self.assertFalse(CommandAction.executeCmd('sleep 30', timeout=1))
# give a test still 1 second, because system could be too busy
self.assertTrue(time.time() >= stime + 2 and time.time() <= stime + 3)
self.assertTrue(time.time() >= stime + 1 and time.time() <= stime + 2)
self.assertLogged(
'sleep 60 -- timed out after 2 seconds',
'sleep 60 -- timed out after 3 seconds'
'sleep 30 -- timed out after 1 seconds',
'sleep 30 -- timed out after 2 seconds'
)
self.assertLogged('sleep 60 -- killed with SIGTERM')
self.assertLogged('sleep 30 -- killed with SIGTERM')
def testExecuteTimeoutWithNastyChildren(self):
# temporary file for a nasty kid shell script
@ -215,29 +218,53 @@ class CommandActionTest(LogCaptureTestCase):
echo "$$" > %s.pid
echo "my pid $$ . sleeping lo-o-o-ong"
sleep 10000
sleep 30
""" % tmpFilename)
stime = 0
# timeout as long as pid-file was not created, but max 5 seconds
def getnasty_tout():
return (
getnastypid() is None
and time.time() - stime <= 5
)
def getnastypid():
with open(tmpFilename + '.pid') as f:
return int(f.read())
cpid = None
if os.path.isfile(tmpFilename + '.pid'):
with open(tmpFilename + '.pid') as f:
try:
cpid = int(f.read())
except ValueError:
pass
return cpid
# First test if can kill the bastard
stime = time.time()
self.assertFalse(CommandAction.executeCmd(
'bash %s' % tmpFilename, timeout=.1))
'bash %s' % tmpFilename, timeout=getnasty_tout))
# Wait up to 3 seconds, the child got killed
cpid = getnastypid()
# Verify that the process itself got killed
self.assertFalse(pid_exists(getnastypid())) # process should have been killed
self.assertTrue(Utils.wait_for(lambda: not pid_exists(cpid), 3)) # process should have been killed
self.assertLogged('my pid ')
self.assertLogged('timed out')
self.assertLogged('killed with SIGTERM')
self.assertLogged('killed with SIGTERM',
'killed with SIGKILL')
os.unlink(tmpFilename + '.pid')
# A bit evolved case even though, previous test already tests killing children processes
stime = time.time()
self.assertFalse(CommandAction.executeCmd(
'out=`bash %s`; echo ALRIGHT' % tmpFilename, timeout=.2))
'out=`bash %s`; echo ALRIGHT' % tmpFilename, timeout=getnasty_tout))
# Wait up to 3 seconds, the child got killed
cpid = getnastypid()
# Verify that the process itself got killed
self.assertFalse(pid_exists(getnastypid()))
self.assertTrue(Utils.wait_for(lambda: not pid_exists(cpid), 3))
self.assertLogged('my pid ')
self.assertLogged('timed out')
self.assertLogged('killed with SIGTERM')
self.assertLogged('killed with SIGTERM',
'killed with SIGKILL')
os.unlink(tmpFilename)
os.unlink(tmpFilename + '.pid')

View File

@ -60,6 +60,7 @@ class AddFailure(unittest.TestCase):
class StatusExtendedCymruInfo(unittest.TestCase):
def setUp(self):
"""Call before every test case."""
unittest.F2B.SkipIfNoNetwork()
self.__ban_ip = "93.184.216.34"
self.__asn = "15133"
self.__country = "EU"

View File

@ -35,7 +35,12 @@ from ..server.ticket import FailTicket
from ..server.actions import Actions
from .dummyjail import DummyJail
try:
from ..server.database import Fail2BanDb
from ..server.database import Fail2BanDb as Fail2BanDb
# because of tests performance use memory instead of file:
def TestFail2BanDb(filename):
if unittest.F2B.fast:
return Fail2BanDb(':memory:')
return Fail2BanDb(filename)
except ImportError:
Fail2BanDb = None
from .utils import LogCaptureTestCase
@ -55,7 +60,7 @@ class DatabaseTest(LogCaptureTestCase):
elif Fail2BanDb is None:
return
_, self.dbFilename = tempfile.mkstemp(".db", "fail2ban_")
self.db = Fail2BanDb(self.dbFilename)
self.db = TestFail2BanDb(self.dbFilename)
def tearDown(self):
"""Call after every test case."""
@ -66,7 +71,7 @@ class DatabaseTest(LogCaptureTestCase):
os.remove(self.dbFilename)
def testGetFilename(self):
if Fail2BanDb is None: # pragma: no cover
if Fail2BanDb is None or self.db.filename == ':memory:': # pragma: no cover
return
self.assertEqual(self.dbFilename, self.db.filename)
@ -88,7 +93,7 @@ class DatabaseTest(LogCaptureTestCase):
"/this/path/should/not/exist")
def testCreateAndReconnect(self):
if Fail2BanDb is None: # pragma: no cover
if Fail2BanDb is None or self.db.filename == ':memory:': # pragma: no cover
return
self.testAddJail()
# Reconnect...

View File

@ -39,28 +39,27 @@ class DummyJail(Jail, object):
self.__actions = Actions(self)
def __len__(self):
try:
self.lock.acquire()
with self.lock:
return len(self.queue)
finally:
self.lock.release()
def isEmpty(self):
with self.lock:
return not self.queue
def isFilled(self):
with self.lock:
return bool(self.queue)
def putFailTicket(self, ticket):
try:
self.lock.acquire()
with self.lock:
self.queue.append(ticket)
finally:
self.lock.release()
def getFailTicket(self):
try:
self.lock.acquire()
with self.lock:
try:
return self.queue.pop()
except IndexError:
return False
finally:
self.lock.release()
@property
def name(self):

View File

@ -41,6 +41,7 @@ from ..server.filterpoll import FilterPoll
from ..server.filter import Filter, FileFilter, FileContainer, DNSUtils
from ..server.failmanager import FailManagerEmpty
from ..server.mytime import MyTime
from ..server.utils import Utils
from .utils import setUpMyTime, tearDownMyTime, mtimesleep, LogCaptureTestCase
from .dummyjail import DummyJail
@ -158,7 +159,7 @@ def _copy_lines_between_files(in_, fout, n=None, skip=0, mode='a', terminal_line
# Opened earlier, therefore must close it
fin.close()
# to give other threads possibly some time to crunch
time.sleep(0.1)
time.sleep(Utils.DEFAULT_SLEEP_INTERVAL)
return fout
@ -287,6 +288,11 @@ class IgnoreIP(LogCaptureTestCase):
class IgnoreIPDNS(IgnoreIP):
def setUp(self):
"""Call before every test case."""
unittest.F2B.SkipIfNoNetwork()
IgnoreIP.setUp(self)
def testIgnoreIPDNSOK(self):
self.filter.addIgnoreIP("www.epfl.ch")
self.assertTrue(self.filter.inIgnoreIPList("128.178.50.12"))
@ -413,16 +419,11 @@ class LogFileMonitor(LogCaptureTestCase):
def isModified(self, delay=2.):
"""Wait up to `delay` sec to assure that it was modified or not
"""
time0 = time.time()
while time.time() < time0 + delay:
if self.filter.isModified(self.name):
return True
time.sleep(0.1)
return False
return Utils.wait_for(lambda: self.filter.isModified(self.name), delay)
def notModified(self):
# shorter wait time for not modified status
return not self.isModified(0.4)
return not self.isModified(4*Utils.DEFAULT_SLEEP_TIME)
def testUnaccessibleLogFile(self):
os.chmod(self.name, 0)
@ -571,26 +572,21 @@ def get_monitor_failures_testcase(Filter_):
#time.sleep(0.2) # Give FS time to ack the removal
pass
def isFilled(self, delay=2.):
def isFilled(self, delay=1.):
"""Wait up to `delay` sec to assure that it was modified or not
"""
time0 = time.time()
while time.time() < time0 + delay:
if len(self.jail):
return True
time.sleep(0.1)
return False
return Utils.wait_for(lambda: self.jail.isFilled(), delay)
def _sleep_4_poll(self):
# Since FilterPoll relies on time stamps and some
# actions might be happening too fast in the tests,
# sleep a bit to guarantee reliable time stamps
if isinstance(self.filter, FilterPoll):
mtimesleep()
Utils.wait_for(lambda: self.filter.is_alive(), 4*Utils.DEFAULT_SLEEP_TIME)
def isEmpty(self, delay=0.4):
def isEmpty(self, delay=4*Utils.DEFAULT_SLEEP_TIME):
# shorter wait time for not modified status
return not self.isFilled(delay)
return Utils.wait_for(lambda: self.jail.isEmpty(), delay)
def assert_correct_last_attempt(self, failures, count=None):
self.assertTrue(self.isFilled(20)) # give Filter a chance to react
@ -645,10 +641,11 @@ def get_monitor_failures_testcase(Filter_):
self.file = _copy_lines_between_files(GetFailures.FILENAME_01, self.name,
n=14, mode='w')
# Poll might need more time
self.assertTrue(self.isEmpty(4 + int(isinstance(self.filter, FilterPoll))*2),
self.assertTrue(self.isEmpty(min(4, 100 * Utils.DEFAULT_SLEEP_TIME)),
"Queue must be empty but it is not: %s."
% (', '.join([str(x) for x in self.jail.queue])))
self.assertRaises(FailManagerEmpty, self.filter.failManager.toBan)
Utils.wait_for(lambda: self.filter.failManager.getFailTotal() == 2, 50 * Utils.DEFAULT_SLEEP_TIME)
self.assertEqual(self.filter.failManager.getFailTotal(), 2)
# move aside, but leaving the handle still open...
@ -673,7 +670,7 @@ def get_monitor_failures_testcase(Filter_):
if interim_kill:
_killfile(None, self.name)
time.sleep(0.2) # let them know
time.sleep(Utils.DEFAULT_SLEEP_TIME) # let them know
# now create a new one to override old one
_copy_lines_between_files(GetFailures.FILENAME_01, self.name + '.new',
@ -720,7 +717,7 @@ def get_monitor_failures_testcase(Filter_):
_copy_lines_between_files(GetFailures.FILENAME_01, self.file, n=100)
# so we should get no more failures detected
self.assertTrue(self.isEmpty(2))
self.assertTrue(self.isEmpty(200 * Utils.DEFAULT_SLEEP_TIME))
# but then if we add it back again
self.filter.addLogPath(self.name)
@ -777,19 +774,14 @@ def get_monitor_failures_journal_testcase(Filter_): # pragma: systemd no cover
return "MonitorJournalFailures%s(%s)" \
% (Filter_, hasattr(self, 'name') and self.name or 'tempfile')
def isFilled(self, delay=2.):
def isFilled(self, delay=1.):
"""Wait up to `delay` sec to assure that it was modified or not
"""
time0 = time.time()
while time.time() < time0 + delay:
if len(self.jail):
return True
time.sleep(0.1)
return False
return Utils.wait_for(lambda: self.jail.isFilled(), delay)
def isEmpty(self, delay=0.4):
def isEmpty(self, delay=4*Utils.DEFAULT_SLEEP_TIME):
# shorter wait time for not modified status
return not self.isFilled(delay)
return Utils.wait_for(lambda: self.jail.isEmpty(), delay)
def assert_correct_ban(self, test_ip, test_attempts):
self.assertTrue(self.isFilled(10)) # give Filter a chance to react
@ -848,7 +840,7 @@ def get_monitor_failures_journal_testcase(Filter_): # pragma: systemd no cover
_copy_lines_to_journal(
self.test_file, self.journal_fields, n=5, skip=5)
# so we should get no more failures detected
self.assertTrue(self.isEmpty(2))
self.assertTrue(self.isEmpty(200 * Utils.DEFAULT_SLEEP_TIME))
# but then if we add it back again
self.filter.addJournalMatch([
@ -879,6 +871,7 @@ class GetFailures(unittest.TestCase):
def setUp(self):
"""Call before every test case."""
unittest.F2B.SkipIfNoNetwork()
setUpMyTime()
self.jail = DummyJail()
self.filter = FileFilter(self.jail)
@ -1065,6 +1058,10 @@ class GetFailures(unittest.TestCase):
class DNSUtilsTests(unittest.TestCase):
def setUp(self):
"""Call before every test case."""
unittest.F2B.SkipIfNoNetwork()
def testUseDns(self):
res = DNSUtils.textToIp('www.example.com', 'no')
self.assertEqual(res, [])
@ -1090,6 +1087,7 @@ class DNSUtilsTests(unittest.TestCase):
def testIpToName(self):
res = DNSUtils.ipToName('8.8.4.4')
self.assertEqual(res, 'google-public-dns-b.google.com')
unittest.F2B.SkipIfNoNetwork()
# invalid ip (TEST-NET-1 according to RFC 5737)
res = DNSUtils.ipToName('192.0.2.0')
self.assertEqual(res, None)

View File

@ -36,6 +36,7 @@ from ..server.failregex import Regex, FailRegex, RegexException
from ..server.server import Server
from ..server.jail import Jail
from ..server.jailthread import JailThread
from ..server.utils import Utils
from .utils import LogCaptureTestCase
from ..helpers import getLogger
from .. import version
@ -74,14 +75,14 @@ class TransmitterBase(unittest.TestCase):
"""Call after every test case."""
self.server.quit()
def setGetTest(self, cmd, inValue, outValue=None, outCode=0, jail=None, repr_=False):
def setGetTest(self, cmd, inValue, outValue=(None,), outCode=0, jail=None, repr_=False):
setCmd = ["set", cmd, inValue]
getCmd = ["get", cmd]
if jail is not None:
setCmd.insert(1, jail)
getCmd.insert(1, jail)
if outValue is None:
if outValue == (None,):
outValue = inValue
def v(x):
@ -165,15 +166,21 @@ class Transmitter(TransmitterBase):
self.assertEqual(self.transm.proceed(["version"]), (0, version.version))
def testSleep(self):
t0 = time.time()
self.assertEqual(self.transm.proceed(["sleep", "1"]), (0, None))
t1 = time.time()
# Approx 1 second delay but not faster
dt = t1 - t0
self.assertTrue(0.99 < dt < 1.1, msg="Sleep was %g sec" % dt)
if not unittest.F2B.fast:
t0 = time.time()
self.assertEqual(self.transm.proceed(["sleep", "0.1"]), (0, None))
t1 = time.time()
# Approx 0.1 second delay but not faster
dt = t1 - t0
self.assertTrue(0.09 < dt < 0.2, msg="Sleep was %g sec" % dt)
else: # pragma: no cover
self.assertEqual(self.transm.proceed(["sleep", "0.0001"]), (0, None))
def testDatabase(self):
tmp, tmpFilename = tempfile.mkstemp(".db", "fail2ban_")
if not unittest.F2B.fast:
tmp, tmpFilename = tempfile.mkstemp(".db", "fail2ban_")
else: # pragma: no cover
tmpFilename = ':memory:'
# Jails present, can't change database
self.setGetTestNOK("dbfile", tmpFilename)
self.server.delJail(self.jailName)
@ -205,8 +212,9 @@ class Transmitter(TransmitterBase):
self.assertEqual(self.transm.proceed(
["set", "dbfile", "None"]),
(0, None))
os.close(tmp)
os.unlink(tmpFilename)
if not unittest.F2B.fast:
os.close(tmp)
os.unlink(tmpFilename)
def testAddJail(self):
jail2 = "TestJail2"
@ -229,7 +237,11 @@ class Transmitter(TransmitterBase):
def testStartStopJail(self):
self.assertEqual(
self.transm.proceed(["start", self.jailName]), (0, None))
time.sleep(1)
time.sleep(Utils.DEFAULT_SLEEP_TIME)
# wait until not started (3 seconds as long as any RuntimeError, ex.: RuntimeError('cannot join thread before it is started',)):
self.assertTrue( Utils.wait_for(
lambda: self.server.is_alive(1) and not isinstance(self.transm.proceed(["status", self.jailName]), RuntimeError),
3) )
self.assertEqual(
self.transm.proceed(["stop", self.jailName]), (0, None))
self.assertTrue(self.jailName not in self.server._Server__jails)
@ -243,9 +255,12 @@ class Transmitter(TransmitterBase):
# yoh: workaround for gh-146. I still think that there is some
# race condition and missing locking somewhere, but for now
# giving it a small delay reliably helps to proceed with tests
time.sleep(0.1)
time.sleep(Utils.DEFAULT_SLEEP_TIME)
self.assertTrue( Utils.wait_for(
lambda: self.server.is_alive(2) and not isinstance(self.transm.proceed(["status", self.jailName]), RuntimeError),
3) )
self.assertEqual(self.transm.proceed(["stop", "all"]), (0, None))
time.sleep(1)
self.assertTrue( Utils.wait_for( lambda: not len(self.server._Server__jails), 3) )
self.assertTrue(self.jailName not in self.server._Server__jails)
self.assertTrue("TestJail2" not in self.server._Server__jails)
@ -301,11 +316,11 @@ class Transmitter(TransmitterBase):
self.assertEqual(
self.transm.proceed(["set", self.jailName, "banip", "127.0.0.1"]),
(0, "127.0.0.1"))
time.sleep(1) # Give chance to ban
time.sleep(Utils.DEFAULT_SLEEP_TIME) # Give chance to ban
self.assertEqual(
self.transm.proceed(["set", self.jailName, "banip", "Badger"]),
(0, "Badger")) #NOTE: Is IP address validated? Is DNS Lookup done?
time.sleep(1) # Give chance to ban
time.sleep(Utils.DEFAULT_SLEEP_TIME) # Give chance to ban
# Unban IP
self.assertEqual(
self.transm.proceed(

View File

@ -33,6 +33,7 @@ import unittest
from .. import protocol
from ..server.asyncserver import AsyncServer, AsyncServerException
from ..server.utils import Utils
from ..client.csocket import CSocket
@ -54,14 +55,20 @@ class Socket(unittest.TestCase):
"""Test transmitter proceed method which just returns first arg"""
return message
def _serverSocket(self):
try:
return CSocket(self.sock_name)
except Exception as e:
return None
def testSocket(self):
serverThread = threading.Thread(
target=self.server.start, args=(self.sock_name, False))
serverThread.daemon = True
serverThread.start()
time.sleep(1)
time.sleep(Utils.DEFAULT_SLEEP_TIME)
client = CSocket(self.sock_name)
client = Utils.wait_for(self._serverSocket, 2)
testMessage = ["A", "test", "message"]
self.assertEqual(client.send(testMessage), testMessage)
@ -71,7 +78,7 @@ class Socket(unittest.TestCase):
client.close()
self.server.stop()
serverThread.join(1)
serverThread.join(Utils.DEFAULT_SLEEP_TIME)
self.assertFalse(os.path.exists(self.sock_name))
def testSocketForce(self):
@ -85,10 +92,10 @@ class Socket(unittest.TestCase):
target=self.server.start, args=(self.sock_name, True))
serverThread.daemon = True
serverThread.start()
time.sleep(1)
time.sleep(Utils.DEFAULT_SLEEP_TIME)
self.server.stop()
serverThread.join(1)
serverThread.join(Utils.DEFAULT_SLEEP_TIME)
self.assertFalse(os.path.exists(self.sock_name))

View File

@ -30,8 +30,10 @@ import time
import unittest
from StringIO import StringIO
from ..server.mytime import MyTime
from ..helpers import getLogger
from ..server.filter import DNSUtils
from ..server.mytime import MyTime
from ..server.utils import Utils
logSys = getLogger(__name__)
@ -45,6 +47,43 @@ if not CONFIG_DIR:
CONFIG_DIR = '/etc/fail2ban'
class F2B():
def __init__(self, fast=False, no_network=False):
self.fast=fast
self.no_network=no_network
def SkipIfFast(self):
pass
def SkipIfNoNetwork(self):
pass
def initTests(opts):
if opts: # pragma: no cover
unittest.F2B = F2B(opts.fast, opts.no_network)
else:
unittest.F2B = F2B()
# --fast :
if unittest.F2B.fast: # pragma: no cover
# prevent long sleeping during test cases...
Utils.DEFAULT_SLEEP_TIME = 0.0025
Utils.DEFAULT_SLEEP_INTERVAL = 0.0005
def F2B_SkipIfFast():
raise unittest.SkipTest('Skip test because of "--fast"')
unittest.F2B.SkipIfFast = F2B_SkipIfFast
else:
# sleep intervals are large - use replacement for sleep to check time to sleep:
_org_sleep = time.sleep
def _new_sleep(v):
if (v > Utils.DEFAULT_SLEEP_TIME):
raise ValueError('[BAD-CODE] To long sleep interval: %s, try to use conditional Utils.wait_for instead' % v)
_org_sleep(min(v, Utils.DEFAULT_SLEEP_TIME))
time.sleep = _new_sleep
# --no-network :
if unittest.F2B.no_network: # pragma: no cover
def F2B_SkipIfNoNetwork():
raise unittest.SkipTest('Skip test because of "--no-network"')
unittest.F2B.SkipIfNoNetwork = F2B_SkipIfNoNetwork
def mtimesleep():
# no sleep now should be necessary since polling tracks now not only
# mtime but also ino and size
@ -70,7 +109,8 @@ def tearDownMyTime():
MyTime.myTime = None
def gatherTests(regexps=None, no_network=False):
def gatherTests(regexps=None, opts=None):
initTests(opts)
# Import all the test cases here instead of a module level to
# avoid circular imports
from . import banmanagertestcase
@ -142,10 +182,10 @@ def gatherTests(regexps=None, no_network=False):
tests.addTest(unittest.makeSuite(filtertestcase.LogFile))
tests.addTest(unittest.makeSuite(filtertestcase.LogFileMonitor))
tests.addTest(unittest.makeSuite(filtertestcase.LogFileFilterPoll))
if not no_network:
tests.addTest(unittest.makeSuite(filtertestcase.IgnoreIPDNS))
tests.addTest(unittest.makeSuite(filtertestcase.GetFailures))
tests.addTest(unittest.makeSuite(filtertestcase.DNSUtilsTests))
# each test case class self will check no network, and skip it (we see it in log)
tests.addTest(unittest.makeSuite(filtertestcase.IgnoreIPDNS))
tests.addTest(unittest.makeSuite(filtertestcase.GetFailures))
tests.addTest(unittest.makeSuite(filtertestcase.DNSUtilsTests))
tests.addTest(unittest.makeSuite(filtertestcase.JailTests))
# DateDetector
@ -161,9 +201,6 @@ def gatherTests(regexps=None, no_network=False):
for file_ in os.listdir(
os.path.abspath(os.path.dirname(action_d.__file__))):
if file_.startswith("test_") and file_.endswith(".py"):
if no_network and file_ in ['test_badips.py','test_smtp.py']: #pragma: no cover
# Test required network
continue
tests.addTest(testloader.loadTestsFromName(
"%s.%s" % (action_d.__name__, os.path.splitext(file_)[0])))
@ -178,6 +215,9 @@ def gatherTests(regexps=None, no_network=False):
# yoh: Since I do not know better way for parametric tests
# with good old unittest
try:
# because gamin can be very slow on some platforms (and can produce many failures
# with fast sleep interval) - skip it by fast run:
unittest.F2B.SkipIfFast()
from ..server.filtergamin import FilterGamin
filters.append(FilterGamin)
except Exception, e: # pragma: no cover
@ -272,29 +312,4 @@ class LogCaptureTestCase(unittest.TestCase):
def printLog(self):
print(self._log.getvalue())
# Solution from http://stackoverflow.com/questions/568271/how-to-check-if-there-exists-a-process-with-a-given-pid
# under cc by-sa 3.0
if os.name == 'posix':
def pid_exists(pid):
"""Check whether pid exists in the current process table."""
import errno
if pid < 0:
return False
try:
os.kill(pid, 0)
except OSError as e:
return e.errno == errno.EPERM
else:
return True
else:
def pid_exists(pid):
import ctypes
kernel32 = ctypes.windll.kernel32
SYNCHRONIZE = 0x100000
process = kernel32.OpenProcess(SYNCHRONIZE, 0, pid)
if process != 0:
kernel32.CloseHandle(process)
return True
else:
return False
pid_exists = Utils.pid_exists