actions: bug fix in lambdas in checkBan, because getBansMerged could return None (purge resp. asynchronous addBan), make the logic all around more stable;

test cases: extended with test to check action together with database functionality (ex.: to verify lambdas in checkBan);
database: getBansMerged should work within lock, using reentrant lock (cause call of getBans inside of getBansMerged);
pull/839/head
sebres 2014-09-23 19:57:55 +02:00
parent 7acddcbe4a
commit 518cc92ccc
4 changed files with 116 additions and 42 deletions

View File

@ -243,6 +243,44 @@ class Actions(JailThread, Mapping):
logSys.debug(self._jail.name + ": action terminated") logSys.debug(self._jail.name + ": action terminated")
return True return True
def __getBansMerged(self, mi, idx):
"""Helper for lamda to get bans merged once
This function never returns None for ainfo lambdas - always a ticket (merged or single one)
and prevents any errors through merging (to guarantee ban actions will be executed).
[TODO] move merging to observer - here we could wait for merge and read already merged info from a database
Parameters
----------
mi : dict
initial for lambda should contains {ip, ticket}
idx : str
key to get a merged bans :
'all' - bans merged for all jails
'jail' - bans merged for current jail only
Returns
-------
BanTicket
merged or self ticket only
"""
if idx in mi:
return mi[idx] if mi[idx] is not None else mi['ticket']
try:
jail=self._jail
ip=mi['ip']
mi[idx] = None
if idx == 'all':
mi[idx] = jail.database.getBansMerged(ip=ip)
elif idx == 'jail':
mi[idx] = jail.database.getBansMerged(ip=ip, jail=jail)
except Exception as e:
logSys.error(
"Failed to get %s bans merged, jail '%s': %s",
idx, jail.name, e,
exc_info=logSys.getEffectiveLevel()<=logging.DEBUG)
return mi[idx] if mi[idx] is not None else mi['ticket']
def __checkBan(self): def __checkBan(self):
"""Check for IP address to ban. """Check for IP address to ban.
@ -264,14 +302,11 @@ class Actions(JailThread, Mapping):
aInfo["time"] = bTicket.getTime() aInfo["time"] = bTicket.getTime()
aInfo["matches"] = "\n".join(bTicket.getMatches()) aInfo["matches"] = "\n".join(bTicket.getMatches())
if self._jail.database is not None: if self._jail.database is not None:
aInfo["ipmatches"] = lambda jail=self._jail: "\n".join( mi4ip = lambda idx, self=self, mi={'ip':ip, 'ticket':bTicket}: self.__getBansMerged(mi, idx)
jail.database.getBansMerged(ip=ip).getMatches()) aInfo["ipmatches"] = lambda: "\n".join(mi4ip('all').getMatches())
aInfo["ipjailmatches"] = lambda jail=self._jail: "\n".join( aInfo["ipjailmatches"] = lambda: "\n".join(mi4ip('jail').getMatches())
jail.database.getBansMerged(ip=ip, jail=jail).getMatches()) aInfo["ipfailures"] = lambda: mi4ip('all').getAttempt()
aInfo["ipfailures"] = lambda jail=self._jail: \ aInfo["ipjailfailures"] = lambda: mi4ip('jail').getAttempt()
jail.database.getBansMerged(ip=ip).getAttempt()
aInfo["ipjailfailures"] = lambda jail=self._jail: \
jail.database.getBansMerged(ip=ip, jail=jail).getAttempt()
if self.__banManager.addBanTicket(bTicket): if self.__banManager.addBanTicket(bTicket):
logSys.notice("[%s] Ban %s" % (self._jail.name, aInfo["ip"])) logSys.notice("[%s] Ban %s" % (self._jail.name, aInfo["ip"]))
for name, action in self._actions.iteritems(): for name, action in self._actions.iteritems():

View File

@ -27,7 +27,7 @@ import sqlite3
import json import json
import locale import locale
from functools import wraps from functools import wraps
from threading import Lock from threading import RLock
from .mytime import MyTime from .mytime import MyTime
from .ticket import FailTicket from .ticket import FailTicket
@ -123,7 +123,7 @@ class Fail2BanDb(object):
def __init__(self, filename, purgeAge=24*60*60): def __init__(self, filename, purgeAge=24*60*60):
try: try:
self._lock = Lock() self._lock = RLock()
self._db = sqlite3.connect( self._db = sqlite3.connect(
filename, check_same_thread=False, filename, check_same_thread=False,
detect_types=sqlite3.PARSE_DECLTYPES) detect_types=sqlite3.PARSE_DECLTYPES)
@ -365,6 +365,10 @@ class Fail2BanDb(object):
del self._bansMergedCache[(ticket.getIP(), jail)] del self._bansMergedCache[(ticket.getIP(), jail)]
except KeyError: except KeyError:
pass pass
try:
del self._bansMergedCache[(ticket.getIP(), None)]
except KeyError:
pass
#TODO: Implement data parts once arbitrary match keys completed #TODO: Implement data parts once arbitrary match keys completed
cur.execute( cur.execute(
"INSERT INTO bans(jail, ip, timeofban, data) VALUES(?, ?, ?, ?)", "INSERT INTO bans(jail, ip, timeofban, data) VALUES(?, ?, ?, ?)",
@ -455,6 +459,7 @@ class Fail2BanDb(object):
in a list. When `ip` argument passed, a single `Ticket` is in a list. When `ip` argument passed, a single `Ticket` is
returned. returned.
""" """
with self._lock:
cacheKey = None cacheKey = None
if bantime is None or bantime < 0: if bantime is None or bantime < 0:
cacheKey = (ip, jail) cacheKey = (ip, jail)

View File

@ -32,18 +32,21 @@ import shutil
from ..server.filter import FileContainer from ..server.filter import FileContainer
from ..server.mytime import MyTime from ..server.mytime import MyTime
from ..server.ticket import FailTicket from ..server.ticket import FailTicket
from ..server.actions import Actions
from .dummyjail import DummyJail from .dummyjail import DummyJail
try: try:
from ..server.database import Fail2BanDb from ..server.database import Fail2BanDb
except ImportError: except ImportError:
Fail2BanDb = None Fail2BanDb = None
from .utils import LogCaptureTestCase
TEST_FILES_DIR = os.path.join(os.path.dirname(__file__), "files") TEST_FILES_DIR = os.path.join(os.path.dirname(__file__), "files")
class DatabaseTest(unittest.TestCase): class DatabaseTest(LogCaptureTestCase):
def setUp(self): def setUp(self):
"""Call before every test case.""" """Call before every test case."""
super(DatabaseTest, self).setUp()
if Fail2BanDb is None and sys.version_info >= (2,7): # pragma: no cover if Fail2BanDb is None and sys.version_info >= (2,7): # pragma: no cover
raise unittest.SkipTest( raise unittest.SkipTest(
"Unable to import fail2ban database module as sqlite is not " "Unable to import fail2ban database module as sqlite is not "
@ -55,6 +58,7 @@ class DatabaseTest(unittest.TestCase):
def tearDown(self): def tearDown(self):
"""Call after every test case.""" """Call after every test case."""
super(DatabaseTest, self).tearDown()
if Fail2BanDb is None: # pragma: no cover if Fail2BanDb is None: # pragma: no cover
return return
# Cleanup # Cleanup
@ -267,6 +271,22 @@ class DatabaseTest(unittest.TestCase):
tickets = self.db.getBansMerged(bantime=-1) tickets = self.db.getBansMerged(bantime=-1)
self.assertEqual(len(tickets), 2) self.assertEqual(len(tickets), 2)
def testActionWithDB(self):
# test action together with database functionality
self.testAddJail() # Jail required
self.jail.database = self.db;
actions = Actions(self.jail)
actions.add(
"action_checkainfo",
os.path.join(TEST_FILES_DIR, "action.d/action_checkainfo.py"),
{})
ticket = FailTicket("1.2.3.4", MyTime.time(), ['test', 'test'])
ticket.setAttempt(5)
self.jail.putFailTicket(ticket)
actions._Actions__checkBan()
self.assertTrue(self._is_logged("ban ainfo %s, %s, %s, %s" % (True, True, True, True)))
def testPurge(self): def testPurge(self):
if Fail2BanDb is None: # pragma: no cover if Fail2BanDb is None: # pragma: no cover
return return

View File

@ -0,0 +1,14 @@
from fail2ban.server.action import ActionBase
class TestAction(ActionBase):
def ban(self, aInfo):
self._logSys.info("ban ainfo %s, %s, %s, %s",
aInfo["ipmatches"] != '', aInfo["ipjailmatches"] != '', aInfo["ipfailures"] > 0, aInfo["ipjailfailures"] > 0
)
def unban(self, aInfo):
pass
Action = TestAction