actions: bug fix in lambdas in checkBan, because getBansMerged could return None (purge resp. asynchronous addBan), make the logic all around more stable;

test cases: extended with test to check action together with database functionality (ex.: to verify lambdas in checkBan);
database: getBansMerged should work within lock, using reentrant lock (cause call of getBans inside of getBansMerged);
pull/716/head
sebres 2014-09-23 19:57:55 +02:00
parent 6c2937affc
commit 2b38d46fb5
4 changed files with 118 additions and 43 deletions

View File

@ -246,6 +246,44 @@ class Actions(JailThread, Mapping):
logSys.debug(self._jail.name + ": action terminated") logSys.debug(self._jail.name + ": action terminated")
return True return True
def __getBansMerged(self, mi, idx):
"""Helper for lamda to get bans merged once
This function never returns None for ainfo lambdas - always a ticket (merged or single one)
and prevents any errors through merging (to guarantee ban actions will be executed).
[TODO] move merging to observer - here we could wait for merge and read already merged info from a database
Parameters
----------
mi : dict
initial for lambda should contains {ip, ticket}
idx : str
key to get a merged bans :
'all' - bans merged for all jails
'jail' - bans merged for current jail only
Returns
-------
BanTicket
merged or self ticket only
"""
if idx in mi:
return mi[idx] if mi[idx] is not None else mi['ticket']
try:
jail=self._jail
ip=mi['ip']
mi[idx] = None
if idx == 'all':
mi[idx] = jail.database.getBansMerged(ip=ip)
elif idx == 'jail':
mi[idx] = jail.database.getBansMerged(ip=ip, jail=jail)
except Exception as e:
logSys.error(
"Failed to get %s bans merged, jail '%s': %s",
idx, jail.name, e,
exc_info=logSys.getEffectiveLevel()<=logging.DEBUG)
return mi[idx] if mi[idx] is not None else mi['ticket']
def __checkBan(self): def __checkBan(self):
"""Check for IP address to ban. """Check for IP address to ban.
@ -272,16 +310,13 @@ class Actions(JailThread, Mapping):
aInfo["time"] = bTicket.getTime() aInfo["time"] = bTicket.getTime()
aInfo["matches"] = "\n".join(bTicket.getMatches()) aInfo["matches"] = "\n".join(bTicket.getMatches())
btime = bTicket.getBanTime(self.__banManager.getBanTime()) btime = bTicket.getBanTime(self.__banManager.getBanTime())
# [todo] move merging to observer - here we could read already merged info from database (faster); # retarded merge info via twice lambdas : once for merge, once for matches/failures:
if self._jail.database is not None: if self._jail.database is not None:
aInfo["ipmatches"] = lambda jail=self._jail: "\n".join( mi4ip = lambda idx, self=self, mi={'ip':ip, 'ticket':bTicket}: self.__getBansMerged(mi, idx)
jail.database.getBansMerged(ip=ip).getMatches()) aInfo["ipmatches"] = lambda: "\n".join(mi4ip('all').getMatches())
aInfo["ipjailmatches"] = lambda jail=self._jail: "\n".join( aInfo["ipjailmatches"] = lambda: "\n".join(mi4ip('jail').getMatches())
jail.database.getBansMerged(ip=ip, jail=jail).getMatches()) aInfo["ipfailures"] = lambda: mi4ip('all').getAttempt()
aInfo["ipfailures"] = lambda jail=self._jail: \ aInfo["ipjailfailures"] = lambda: mi4ip('jail').getAttempt()
jail.database.getBansMerged(ip=ip).getAttempt()
aInfo["ipjailfailures"] = lambda jail=self._jail: \
jail.database.getBansMerged(ip=ip, jail=jail).getAttempt()
if btime != -1: if btime != -1:
bendtime = aInfo["time"] + btime bendtime = aInfo["time"] + btime

View File

@ -27,7 +27,7 @@ import sqlite3
import json import json
import locale import locale
from functools import wraps from functools import wraps
from threading import Lock from threading import RLock
from .mytime import MyTime from .mytime import MyTime
from .ticket import FailTicket from .ticket import FailTicket
@ -138,7 +138,7 @@ class Fail2BanDb(object):
def __init__(self, filename, purgeAge=24*60*60, outDatedFactor=3): def __init__(self, filename, purgeAge=24*60*60, outDatedFactor=3):
try: try:
self._lock = Lock() self._lock = RLock()
self._db = sqlite3.connect( self._db = sqlite3.connect(
filename, check_same_thread=False, filename, check_same_thread=False,
detect_types=sqlite3.PARSE_DECLTYPES) detect_types=sqlite3.PARSE_DECLTYPES)
@ -397,6 +397,10 @@ class Fail2BanDb(object):
del self._bansMergedCache[(ticket.getIP(), jail)] del self._bansMergedCache[(ticket.getIP(), jail)]
except KeyError: except KeyError:
pass pass
try:
del self._bansMergedCache[(ticket.getIP(), None)]
except KeyError:
pass
#TODO: Implement data parts once arbitrary match keys completed #TODO: Implement data parts once arbitrary match keys completed
cur.execute( cur.execute(
"INSERT INTO bans(jail, ip, timeofban, bantime, bancount, data) VALUES(?, ?, ?, ?, ?, ?)", "INSERT INTO bans(jail, ip, timeofban, bantime, bancount, data) VALUES(?, ?, ?, ?, ?, ?)",
@ -496,40 +500,41 @@ class Fail2BanDb(object):
in a list. When `ip` argument passed, a single `Ticket` is in a list. When `ip` argument passed, a single `Ticket` is
returned. returned.
""" """
cacheKey = None with self._lock:
if bantime is None or bantime < 0: cacheKey = None
cacheKey = (ip, jail) if bantime is None or bantime < 0:
if cacheKey in self._bansMergedCache: cacheKey = (ip, jail)
return self._bansMergedCache[cacheKey] if cacheKey in self._bansMergedCache:
return self._bansMergedCache[cacheKey]
tickets = [] tickets = []
ticket = None ticket = None
results = list(self._getBans(ip=ip, jail=jail, bantime=bantime)) results = list(self._getBans(ip=ip, jail=jail, bantime=bantime))
if results: if results:
prev_banip = results[0][0] prev_banip = results[0][0]
matches = [] matches = []
failures = 0 failures = 0
for banip, timeofban, data in results: for banip, timeofban, data in results:
#TODO: Implement data parts once arbitrary match keys completed #TODO: Implement data parts once arbitrary match keys completed
if banip != prev_banip: if banip != prev_banip:
ticket = FailTicket(prev_banip, prev_timeofban, matches) ticket = FailTicket(prev_banip, prev_timeofban, matches)
ticket.setAttempt(failures) ticket.setAttempt(failures)
tickets.append(ticket) tickets.append(ticket)
# Reset variables # Reset variables
prev_banip = banip prev_banip = banip
matches = [] matches = []
failures = 0 failures = 0
matches.extend(data['matches']) matches.extend(data['matches'])
failures += data['failures'] failures += data['failures']
prev_timeofban = timeofban prev_timeofban = timeofban
ticket = FailTicket(banip, prev_timeofban, matches) ticket = FailTicket(banip, prev_timeofban, matches)
ticket.setAttempt(failures) ticket.setAttempt(failures)
tickets.append(ticket) tickets.append(ticket)
if cacheKey: if cacheKey:
self._bansMergedCache[cacheKey] = tickets if ip is None else ticket self._bansMergedCache[cacheKey] = tickets if ip is None else ticket
return tickets if ip is None else ticket return tickets if ip is None else ticket
@commitandrollback @commitandrollback
def getBan(self, cur, ip, jail=None, forbantime=None, overalljails=None, fromtime=None): def getBan(self, cur, ip, jail=None, forbantime=None, overalljails=None, fromtime=None):

View File

@ -32,18 +32,21 @@ import shutil
from ..server.filter import FileContainer from ..server.filter import FileContainer
from ..server.mytime import MyTime from ..server.mytime import MyTime
from ..server.ticket import FailTicket from ..server.ticket import FailTicket
from ..server.actions import Actions
from .dummyjail import DummyJail from .dummyjail import DummyJail
try: try:
from ..server.database import Fail2BanDb from ..server.database import Fail2BanDb
except ImportError: except ImportError:
Fail2BanDb = None Fail2BanDb = None
from .utils import LogCaptureTestCase
TEST_FILES_DIR = os.path.join(os.path.dirname(__file__), "files") TEST_FILES_DIR = os.path.join(os.path.dirname(__file__), "files")
class DatabaseTest(unittest.TestCase): class DatabaseTest(LogCaptureTestCase):
def setUp(self): def setUp(self):
"""Call before every test case.""" """Call before every test case."""
super(DatabaseTest, self).setUp()
if Fail2BanDb is None and sys.version_info >= (2,7): # pragma: no cover if Fail2BanDb is None and sys.version_info >= (2,7): # pragma: no cover
raise unittest.SkipTest( raise unittest.SkipTest(
"Unable to import fail2ban database module as sqlite is not " "Unable to import fail2ban database module as sqlite is not "
@ -55,6 +58,7 @@ class DatabaseTest(unittest.TestCase):
def tearDown(self): def tearDown(self):
"""Call after every test case.""" """Call after every test case."""
super(DatabaseTest, self).tearDown()
if Fail2BanDb is None: # pragma: no cover if Fail2BanDb is None: # pragma: no cover
return return
# Cleanup # Cleanup
@ -267,6 +271,23 @@ class DatabaseTest(unittest.TestCase):
tickets = self.db.getBansMerged(bantime=-1) tickets = self.db.getBansMerged(bantime=-1)
self.assertEqual(len(tickets), 2) self.assertEqual(len(tickets), 2)
def testActionWithDB(self):
# test action together with database functionality
self.testAddJail() # Jail required
self.jail.database = self.db;
actions = Actions(self.jail)
actions.add(
"action_checkainfo",
os.path.join(TEST_FILES_DIR, "action.d/action_checkainfo.py"),
{})
ticket = FailTicket("1.2.3.4")
ticket.setAttempt(5)
ticket.setMatches(['test', 'test'])
self.jail.putFailTicket(ticket)
actions._Actions__checkBan()
self.assertTrue(self._is_logged("ban ainfo %s, %s, %s, %s" % (True, True, True, True)))
def testPurge(self): def testPurge(self):
if Fail2BanDb is None: # pragma: no cover if Fail2BanDb is None: # pragma: no cover
return return

View File

@ -0,0 +1,14 @@
from fail2ban.server.action import ActionBase
class TestAction(ActionBase):
def ban(self, aInfo):
self._logSys.info("ban ainfo %s, %s, %s, %s",
aInfo["ipmatches"] != '', aInfo["ipjailmatches"] != '', aInfo["ipfailures"] > 0, aInfo["ipjailfailures"] > 0
)
def unban(self, aInfo):
pass
Action = TestAction