mirror of https://github.com/fail2ban/fail2ban
actions: bug fix in lambdas in checkBan, because getBansMerged could return None (purge resp. asynchronous addBan), make the logic all around more stable;
test cases: extended with test to check action together with database functionality (ex.: to verify lambdas in checkBan); database: getBansMerged should work within lock, using reentrant lock (cause call of getBans inside of getBansMerged);pull/716/head
parent
6c2937affc
commit
2b38d46fb5
|
@ -246,6 +246,44 @@ class Actions(JailThread, Mapping):
|
||||||
logSys.debug(self._jail.name + ": action terminated")
|
logSys.debug(self._jail.name + ": action terminated")
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
def __getBansMerged(self, mi, idx):
|
||||||
|
"""Helper for lamda to get bans merged once
|
||||||
|
|
||||||
|
This function never returns None for ainfo lambdas - always a ticket (merged or single one)
|
||||||
|
and prevents any errors through merging (to guarantee ban actions will be executed).
|
||||||
|
[TODO] move merging to observer - here we could wait for merge and read already merged info from a database
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
mi : dict
|
||||||
|
initial for lambda should contains {ip, ticket}
|
||||||
|
idx : str
|
||||||
|
key to get a merged bans :
|
||||||
|
'all' - bans merged for all jails
|
||||||
|
'jail' - bans merged for current jail only
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
BanTicket
|
||||||
|
merged or self ticket only
|
||||||
|
"""
|
||||||
|
if idx in mi:
|
||||||
|
return mi[idx] if mi[idx] is not None else mi['ticket']
|
||||||
|
try:
|
||||||
|
jail=self._jail
|
||||||
|
ip=mi['ip']
|
||||||
|
mi[idx] = None
|
||||||
|
if idx == 'all':
|
||||||
|
mi[idx] = jail.database.getBansMerged(ip=ip)
|
||||||
|
elif idx == 'jail':
|
||||||
|
mi[idx] = jail.database.getBansMerged(ip=ip, jail=jail)
|
||||||
|
except Exception as e:
|
||||||
|
logSys.error(
|
||||||
|
"Failed to get %s bans merged, jail '%s': %s",
|
||||||
|
idx, jail.name, e,
|
||||||
|
exc_info=logSys.getEffectiveLevel()<=logging.DEBUG)
|
||||||
|
return mi[idx] if mi[idx] is not None else mi['ticket']
|
||||||
|
|
||||||
def __checkBan(self):
|
def __checkBan(self):
|
||||||
"""Check for IP address to ban.
|
"""Check for IP address to ban.
|
||||||
|
|
||||||
|
@ -272,16 +310,13 @@ class Actions(JailThread, Mapping):
|
||||||
aInfo["time"] = bTicket.getTime()
|
aInfo["time"] = bTicket.getTime()
|
||||||
aInfo["matches"] = "\n".join(bTicket.getMatches())
|
aInfo["matches"] = "\n".join(bTicket.getMatches())
|
||||||
btime = bTicket.getBanTime(self.__banManager.getBanTime())
|
btime = bTicket.getBanTime(self.__banManager.getBanTime())
|
||||||
# [todo] move merging to observer - here we could read already merged info from database (faster);
|
# retarded merge info via twice lambdas : once for merge, once for matches/failures:
|
||||||
if self._jail.database is not None:
|
if self._jail.database is not None:
|
||||||
aInfo["ipmatches"] = lambda jail=self._jail: "\n".join(
|
mi4ip = lambda idx, self=self, mi={'ip':ip, 'ticket':bTicket}: self.__getBansMerged(mi, idx)
|
||||||
jail.database.getBansMerged(ip=ip).getMatches())
|
aInfo["ipmatches"] = lambda: "\n".join(mi4ip('all').getMatches())
|
||||||
aInfo["ipjailmatches"] = lambda jail=self._jail: "\n".join(
|
aInfo["ipjailmatches"] = lambda: "\n".join(mi4ip('jail').getMatches())
|
||||||
jail.database.getBansMerged(ip=ip, jail=jail).getMatches())
|
aInfo["ipfailures"] = lambda: mi4ip('all').getAttempt()
|
||||||
aInfo["ipfailures"] = lambda jail=self._jail: \
|
aInfo["ipjailfailures"] = lambda: mi4ip('jail').getAttempt()
|
||||||
jail.database.getBansMerged(ip=ip).getAttempt()
|
|
||||||
aInfo["ipjailfailures"] = lambda jail=self._jail: \
|
|
||||||
jail.database.getBansMerged(ip=ip, jail=jail).getAttempt()
|
|
||||||
|
|
||||||
if btime != -1:
|
if btime != -1:
|
||||||
bendtime = aInfo["time"] + btime
|
bendtime = aInfo["time"] + btime
|
||||||
|
|
|
@ -27,7 +27,7 @@ import sqlite3
|
||||||
import json
|
import json
|
||||||
import locale
|
import locale
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from threading import Lock
|
from threading import RLock
|
||||||
|
|
||||||
from .mytime import MyTime
|
from .mytime import MyTime
|
||||||
from .ticket import FailTicket
|
from .ticket import FailTicket
|
||||||
|
@ -138,7 +138,7 @@ class Fail2BanDb(object):
|
||||||
|
|
||||||
def __init__(self, filename, purgeAge=24*60*60, outDatedFactor=3):
|
def __init__(self, filename, purgeAge=24*60*60, outDatedFactor=3):
|
||||||
try:
|
try:
|
||||||
self._lock = Lock()
|
self._lock = RLock()
|
||||||
self._db = sqlite3.connect(
|
self._db = sqlite3.connect(
|
||||||
filename, check_same_thread=False,
|
filename, check_same_thread=False,
|
||||||
detect_types=sqlite3.PARSE_DECLTYPES)
|
detect_types=sqlite3.PARSE_DECLTYPES)
|
||||||
|
@ -397,6 +397,10 @@ class Fail2BanDb(object):
|
||||||
del self._bansMergedCache[(ticket.getIP(), jail)]
|
del self._bansMergedCache[(ticket.getIP(), jail)]
|
||||||
except KeyError:
|
except KeyError:
|
||||||
pass
|
pass
|
||||||
|
try:
|
||||||
|
del self._bansMergedCache[(ticket.getIP(), None)]
|
||||||
|
except KeyError:
|
||||||
|
pass
|
||||||
#TODO: Implement data parts once arbitrary match keys completed
|
#TODO: Implement data parts once arbitrary match keys completed
|
||||||
cur.execute(
|
cur.execute(
|
||||||
"INSERT INTO bans(jail, ip, timeofban, bantime, bancount, data) VALUES(?, ?, ?, ?, ?, ?)",
|
"INSERT INTO bans(jail, ip, timeofban, bantime, bancount, data) VALUES(?, ?, ?, ?, ?, ?)",
|
||||||
|
@ -496,40 +500,41 @@ class Fail2BanDb(object):
|
||||||
in a list. When `ip` argument passed, a single `Ticket` is
|
in a list. When `ip` argument passed, a single `Ticket` is
|
||||||
returned.
|
returned.
|
||||||
"""
|
"""
|
||||||
cacheKey = None
|
with self._lock:
|
||||||
if bantime is None or bantime < 0:
|
cacheKey = None
|
||||||
cacheKey = (ip, jail)
|
if bantime is None or bantime < 0:
|
||||||
if cacheKey in self._bansMergedCache:
|
cacheKey = (ip, jail)
|
||||||
return self._bansMergedCache[cacheKey]
|
if cacheKey in self._bansMergedCache:
|
||||||
|
return self._bansMergedCache[cacheKey]
|
||||||
|
|
||||||
tickets = []
|
tickets = []
|
||||||
ticket = None
|
ticket = None
|
||||||
|
|
||||||
results = list(self._getBans(ip=ip, jail=jail, bantime=bantime))
|
results = list(self._getBans(ip=ip, jail=jail, bantime=bantime))
|
||||||
if results:
|
if results:
|
||||||
prev_banip = results[0][0]
|
prev_banip = results[0][0]
|
||||||
matches = []
|
matches = []
|
||||||
failures = 0
|
failures = 0
|
||||||
for banip, timeofban, data in results:
|
for banip, timeofban, data in results:
|
||||||
#TODO: Implement data parts once arbitrary match keys completed
|
#TODO: Implement data parts once arbitrary match keys completed
|
||||||
if banip != prev_banip:
|
if banip != prev_banip:
|
||||||
ticket = FailTicket(prev_banip, prev_timeofban, matches)
|
ticket = FailTicket(prev_banip, prev_timeofban, matches)
|
||||||
ticket.setAttempt(failures)
|
ticket.setAttempt(failures)
|
||||||
tickets.append(ticket)
|
tickets.append(ticket)
|
||||||
# Reset variables
|
# Reset variables
|
||||||
prev_banip = banip
|
prev_banip = banip
|
||||||
matches = []
|
matches = []
|
||||||
failures = 0
|
failures = 0
|
||||||
matches.extend(data['matches'])
|
matches.extend(data['matches'])
|
||||||
failures += data['failures']
|
failures += data['failures']
|
||||||
prev_timeofban = timeofban
|
prev_timeofban = timeofban
|
||||||
ticket = FailTicket(banip, prev_timeofban, matches)
|
ticket = FailTicket(banip, prev_timeofban, matches)
|
||||||
ticket.setAttempt(failures)
|
ticket.setAttempt(failures)
|
||||||
tickets.append(ticket)
|
tickets.append(ticket)
|
||||||
|
|
||||||
if cacheKey:
|
if cacheKey:
|
||||||
self._bansMergedCache[cacheKey] = tickets if ip is None else ticket
|
self._bansMergedCache[cacheKey] = tickets if ip is None else ticket
|
||||||
return tickets if ip is None else ticket
|
return tickets if ip is None else ticket
|
||||||
|
|
||||||
@commitandrollback
|
@commitandrollback
|
||||||
def getBan(self, cur, ip, jail=None, forbantime=None, overalljails=None, fromtime=None):
|
def getBan(self, cur, ip, jail=None, forbantime=None, overalljails=None, fromtime=None):
|
||||||
|
|
|
@ -32,18 +32,21 @@ import shutil
|
||||||
from ..server.filter import FileContainer
|
from ..server.filter import FileContainer
|
||||||
from ..server.mytime import MyTime
|
from ..server.mytime import MyTime
|
||||||
from ..server.ticket import FailTicket
|
from ..server.ticket import FailTicket
|
||||||
|
from ..server.actions import Actions
|
||||||
from .dummyjail import DummyJail
|
from .dummyjail import DummyJail
|
||||||
try:
|
try:
|
||||||
from ..server.database import Fail2BanDb
|
from ..server.database import Fail2BanDb
|
||||||
except ImportError:
|
except ImportError:
|
||||||
Fail2BanDb = None
|
Fail2BanDb = None
|
||||||
|
from .utils import LogCaptureTestCase
|
||||||
|
|
||||||
TEST_FILES_DIR = os.path.join(os.path.dirname(__file__), "files")
|
TEST_FILES_DIR = os.path.join(os.path.dirname(__file__), "files")
|
||||||
|
|
||||||
class DatabaseTest(unittest.TestCase):
|
class DatabaseTest(LogCaptureTestCase):
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
"""Call before every test case."""
|
"""Call before every test case."""
|
||||||
|
super(DatabaseTest, self).setUp()
|
||||||
if Fail2BanDb is None and sys.version_info >= (2,7): # pragma: no cover
|
if Fail2BanDb is None and sys.version_info >= (2,7): # pragma: no cover
|
||||||
raise unittest.SkipTest(
|
raise unittest.SkipTest(
|
||||||
"Unable to import fail2ban database module as sqlite is not "
|
"Unable to import fail2ban database module as sqlite is not "
|
||||||
|
@ -55,6 +58,7 @@ class DatabaseTest(unittest.TestCase):
|
||||||
|
|
||||||
def tearDown(self):
|
def tearDown(self):
|
||||||
"""Call after every test case."""
|
"""Call after every test case."""
|
||||||
|
super(DatabaseTest, self).tearDown()
|
||||||
if Fail2BanDb is None: # pragma: no cover
|
if Fail2BanDb is None: # pragma: no cover
|
||||||
return
|
return
|
||||||
# Cleanup
|
# Cleanup
|
||||||
|
@ -267,6 +271,23 @@ class DatabaseTest(unittest.TestCase):
|
||||||
tickets = self.db.getBansMerged(bantime=-1)
|
tickets = self.db.getBansMerged(bantime=-1)
|
||||||
self.assertEqual(len(tickets), 2)
|
self.assertEqual(len(tickets), 2)
|
||||||
|
|
||||||
|
def testActionWithDB(self):
|
||||||
|
# test action together with database functionality
|
||||||
|
self.testAddJail() # Jail required
|
||||||
|
self.jail.database = self.db;
|
||||||
|
actions = Actions(self.jail)
|
||||||
|
actions.add(
|
||||||
|
"action_checkainfo",
|
||||||
|
os.path.join(TEST_FILES_DIR, "action.d/action_checkainfo.py"),
|
||||||
|
{})
|
||||||
|
ticket = FailTicket("1.2.3.4")
|
||||||
|
ticket.setAttempt(5)
|
||||||
|
ticket.setMatches(['test', 'test'])
|
||||||
|
self.jail.putFailTicket(ticket)
|
||||||
|
actions._Actions__checkBan()
|
||||||
|
self.assertTrue(self._is_logged("ban ainfo %s, %s, %s, %s" % (True, True, True, True)))
|
||||||
|
|
||||||
|
|
||||||
def testPurge(self):
|
def testPurge(self):
|
||||||
if Fail2BanDb is None: # pragma: no cover
|
if Fail2BanDb is None: # pragma: no cover
|
||||||
return
|
return
|
||||||
|
|
|
@ -0,0 +1,14 @@
|
||||||
|
|
||||||
|
from fail2ban.server.action import ActionBase
|
||||||
|
|
||||||
|
class TestAction(ActionBase):
|
||||||
|
|
||||||
|
def ban(self, aInfo):
|
||||||
|
self._logSys.info("ban ainfo %s, %s, %s, %s",
|
||||||
|
aInfo["ipmatches"] != '', aInfo["ipjailmatches"] != '', aInfo["ipfailures"] > 0, aInfo["ipjailfailures"] > 0
|
||||||
|
)
|
||||||
|
|
||||||
|
def unban(self, aInfo):
|
||||||
|
pass
|
||||||
|
|
||||||
|
Action = TestAction
|
Loading…
Reference in New Issue