Merge pull request #839 from sebres/fix-none-getattempt-lambda

Fix none getattempt lambda (close #838,  close #848)
pull/868/head^2
Yaroslav Halchenko 10 years ago
commit a170afcb76

@ -243,6 +243,45 @@ class Actions(JailThread, Mapping):
logSys.debug(self._jail.name + ": action terminated") logSys.debug(self._jail.name + ": action terminated")
return True return True
def __getBansMerged(self, mi, overalljails=False):
"""Gets bans merged once, a helper for lambda(s), prevents stop of executing action by any exception inside.
This function never returns None for ainfo lambdas - always a ticket (merged or single one)
and prevents any errors through merging (to guarantee ban actions will be executed).
[TODO] move merging to observer - here we could wait for merge and read already merged info from a database
Parameters
----------
mi : dict
merge info, initial for lambda should contains {ip, ticket}
overalljails : bool
switch to get a merged bans :
False - (default) bans merged for current jail only
True - bans merged for all jails of current ip address
Returns
-------
BanTicket
merged or self ticket only
"""
idx = 'all' if overalljails else 'jail'
if idx in mi:
return mi[idx] if mi[idx] is not None else mi['ticket']
try:
jail=self._jail
ip=mi['ip']
mi[idx] = None
if overalljails:
mi[idx] = jail.database.getBansMerged(ip=ip)
else:
mi[idx] = jail.database.getBansMerged(ip=ip, jail=jail)
except Exception as e:
logSys.error(
"Failed to get %s bans merged, jail '%s': %s",
idx, jail.name, e,
exc_info=logSys.getEffectiveLevel()<=logging.DEBUG)
return mi[idx] if mi[idx] is not None else mi['ticket']
def __checkBan(self): def __checkBan(self):
"""Check for IP address to ban. """Check for IP address to ban.
@ -264,14 +303,12 @@ class Actions(JailThread, Mapping):
aInfo["time"] = bTicket.getTime() aInfo["time"] = bTicket.getTime()
aInfo["matches"] = "\n".join(bTicket.getMatches()) aInfo["matches"] = "\n".join(bTicket.getMatches())
if self._jail.database is not None: if self._jail.database is not None:
aInfo["ipmatches"] = lambda jail=self._jail: "\n".join( mi4ip = lambda overalljails=False, self=self, \
jail.database.getBansMerged(ip=ip).getMatches()) mi={'ip':ip, 'ticket':bTicket}: self.__getBansMerged(mi, overalljails)
aInfo["ipjailmatches"] = lambda jail=self._jail: "\n".join( aInfo["ipmatches"] = lambda: "\n".join(mi4ip(True).getMatches())
jail.database.getBansMerged(ip=ip, jail=jail).getMatches()) aInfo["ipjailmatches"] = lambda: "\n".join(mi4ip().getMatches())
aInfo["ipfailures"] = lambda jail=self._jail: \ aInfo["ipfailures"] = lambda: mi4ip(True).getAttempt()
jail.database.getBansMerged(ip=ip).getAttempt() aInfo["ipjailfailures"] = lambda: mi4ip().getAttempt()
aInfo["ipjailfailures"] = lambda jail=self._jail: \
jail.database.getBansMerged(ip=ip, jail=jail).getAttempt()
if self.__banManager.addBanTicket(bTicket): if self.__banManager.addBanTicket(bTicket):
logSys.notice("[%s] Ban %s" % (self._jail.name, aInfo["ip"])) logSys.notice("[%s] Ban %s" % (self._jail.name, aInfo["ip"]))
for name, action in self._actions.iteritems(): for name, action in self._actions.iteritems():

@ -27,7 +27,7 @@ import sqlite3
import json import json
import locale import locale
from functools import wraps from functools import wraps
from threading import Lock from threading import RLock
from .mytime import MyTime from .mytime import MyTime
from .ticket import FailTicket from .ticket import FailTicket
@ -123,7 +123,7 @@ class Fail2BanDb(object):
def __init__(self, filename, purgeAge=24*60*60): def __init__(self, filename, purgeAge=24*60*60):
try: try:
self._lock = Lock() self._lock = RLock()
self._db = sqlite3.connect( self._db = sqlite3.connect(
filename, check_same_thread=False, filename, check_same_thread=False,
detect_types=sqlite3.PARSE_DECLTYPES) detect_types=sqlite3.PARSE_DECLTYPES)
@ -365,6 +365,10 @@ class Fail2BanDb(object):
del self._bansMergedCache[(ticket.getIP(), jail)] del self._bansMergedCache[(ticket.getIP(), jail)]
except KeyError: except KeyError:
pass pass
try:
del self._bansMergedCache[(ticket.getIP(), None)]
except KeyError:
pass
#TODO: Implement data parts once arbitrary match keys completed #TODO: Implement data parts once arbitrary match keys completed
cur.execute( cur.execute(
"INSERT INTO bans(jail, ip, timeofban, data) VALUES(?, ?, ?, ?)", "INSERT INTO bans(jail, ip, timeofban, data) VALUES(?, ?, ?, ?)",
@ -455,6 +459,7 @@ class Fail2BanDb(object):
in a list. When `ip` argument passed, a single `Ticket` is in a list. When `ip` argument passed, a single `Ticket` is
returned. returned.
""" """
with self._lock:
cacheKey = None cacheKey = None
if bantime is None or bantime < 0: if bantime is None or bantime < 0:
cacheKey = (ip, jail) cacheKey = (ip, jail)

@ -32,18 +32,21 @@ import shutil
from ..server.filter import FileContainer from ..server.filter import FileContainer
from ..server.mytime import MyTime from ..server.mytime import MyTime
from ..server.ticket import FailTicket from ..server.ticket import FailTicket
from ..server.actions import Actions
from .dummyjail import DummyJail from .dummyjail import DummyJail
try: try:
from ..server.database import Fail2BanDb from ..server.database import Fail2BanDb
except ImportError: except ImportError:
Fail2BanDb = None Fail2BanDb = None
from .utils import LogCaptureTestCase
TEST_FILES_DIR = os.path.join(os.path.dirname(__file__), "files") TEST_FILES_DIR = os.path.join(os.path.dirname(__file__), "files")
class DatabaseTest(unittest.TestCase): class DatabaseTest(LogCaptureTestCase):
def setUp(self): def setUp(self):
"""Call before every test case.""" """Call before every test case."""
super(DatabaseTest, self).setUp()
if Fail2BanDb is None and sys.version_info >= (2,7): # pragma: no cover if Fail2BanDb is None and sys.version_info >= (2,7): # pragma: no cover
raise unittest.SkipTest( raise unittest.SkipTest(
"Unable to import fail2ban database module as sqlite is not " "Unable to import fail2ban database module as sqlite is not "
@ -55,6 +58,7 @@ class DatabaseTest(unittest.TestCase):
def tearDown(self): def tearDown(self):
"""Call after every test case.""" """Call after every test case."""
super(DatabaseTest, self).tearDown()
if Fail2BanDb is None: # pragma: no cover if Fail2BanDb is None: # pragma: no cover
return return
# Cleanup # Cleanup
@ -267,6 +271,22 @@ class DatabaseTest(unittest.TestCase):
tickets = self.db.getBansMerged(bantime=-1) tickets = self.db.getBansMerged(bantime=-1)
self.assertEqual(len(tickets), 2) self.assertEqual(len(tickets), 2)
def testActionWithDB(self):
# test action together with database functionality
self.testAddJail() # Jail required
self.jail.database = self.db;
actions = Actions(self.jail)
actions.add(
"action_checkainfo",
os.path.join(TEST_FILES_DIR, "action.d/action_checkainfo.py"),
{})
ticket = FailTicket("1.2.3.4", MyTime.time(), ['test', 'test'])
ticket.setAttempt(5)
self.jail.putFailTicket(ticket)
actions._Actions__checkBan()
self.assertTrue(self._is_logged("ban ainfo %s, %s, %s, %s" % (True, True, True, True)))
def testPurge(self): def testPurge(self):
if Fail2BanDb is None: # pragma: no cover if Fail2BanDb is None: # pragma: no cover
return return

@ -0,0 +1,14 @@
from fail2ban.server.action import ActionBase
class TestAction(ActionBase):
def ban(self, aInfo):
self._logSys.info("ban ainfo %s, %s, %s, %s",
aInfo["ipmatches"] != '', aInfo["ipjailmatches"] != '', aInfo["ipfailures"] > 0, aInfo["ipjailfailures"] > 0
)
def unban(self, aInfo):
pass
Action = TestAction

@ -804,7 +804,7 @@ class RegexTests(unittest.TestCase):
class _BadThread(JailThread): class _BadThread(JailThread):
def run(self): def run(self):
int("ignore this exception -- raised for testing") raise RuntimeError('run bad thread exception')
class LoggingTests(LogCaptureTestCase): class LoggingTests(LogCaptureTestCase):
@ -814,7 +814,15 @@ class LoggingTests(LogCaptureTestCase):
self.assertEqual(testLogSys.name, "fail2ban.name") self.assertEqual(testLogSys.name, "fail2ban.name")
def testFail2BanExceptHook(self): def testFail2BanExceptHook(self):
prev_exchook = sys.__excepthook__
x = []
sys.__excepthook__ = lambda *args: x.append(args)
try:
badThread = _BadThread() badThread = _BadThread()
badThread.start() badThread.start()
badThread.join() badThread.join()
self.assertTrue(self._is_logged("Unhandled exception")) self.assertTrue(self._is_logged("Unhandled exception"))
finally:
sys.__excepthook__ = prev_exchook
self.assertEqual(len(x), 1)
self.assertEqual(x[0][0], RuntimeError)

Loading…
Cancel
Save