mirror of https://github.com/fail2ban/fail2ban
rewritten CallingMap: performance optimized, immutable, self-referencing, template possibility (used in new ActionInfo objects);
new ActionInfo handling: saves content between actions, without interim copying (save original on demand, recoverable via reset); test cases extendedpull/1698/head
parent
4efcc29384
commit
d2a3d093c6
|
@ -123,7 +123,7 @@ class SMTPAction(ActionBase):
|
|||
self.message_values = CallingMap(
|
||||
jailname = self._jail.name,
|
||||
hostname = socket.gethostname,
|
||||
bantime = self._jail.actions.getBanTime,
|
||||
bantime = lambda: self._jail.actions.getBanTime(),
|
||||
)
|
||||
|
||||
# bypass ban/unban for restored tickets
|
||||
|
|
|
@ -50,9 +50,14 @@ allowed_ipv6 = True
|
|||
# capture groups from filter for map to ticket data:
|
||||
FCUSTAG_CRE = re.compile(r'<F-([A-Z0-9_\-]+)>'); # currently uppercase only
|
||||
|
||||
# New line, space
|
||||
ADD_REPL_TAGS = {
|
||||
"br": "\n",
|
||||
"sp": " "
|
||||
}
|
||||
|
||||
|
||||
class CallingMap(MutableMapping):
|
||||
class CallingMap(MutableMapping, object):
|
||||
"""A Mapping type which returns the result of callable values.
|
||||
|
||||
`CallingMap` behaves similar to a standard python dictionary,
|
||||
|
@ -69,23 +74,64 @@ class CallingMap(MutableMapping):
|
|||
The dictionary data which can be accessed to obtain items uncalled
|
||||
"""
|
||||
|
||||
# immutable=True saves content between actions, without interim copying (save original on demand, recoverable via reset)
|
||||
__slots__ = ('data', 'storage', 'immutable', '__org_data')
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.storage = dict()
|
||||
self.immutable = True
|
||||
self.data = dict(*args, **kwargs)
|
||||
|
||||
def reset(self, immutable=True):
|
||||
self.storage = dict()
|
||||
try:
|
||||
self.data = self.__org_data
|
||||
except AttributeError:
|
||||
pass
|
||||
self.immutable = immutable
|
||||
|
||||
def __repr__(self):
|
||||
return "%s(%r)" % (self.__class__.__name__, self.data)
|
||||
return "%s(%r)" % (self.__class__.__name__, self._asdict())
|
||||
|
||||
def _asdict(self):
|
||||
try:
|
||||
return dict(self)
|
||||
except:
|
||||
return dict(self.data, **self.storage)
|
||||
|
||||
def __getitem__(self, key):
|
||||
value = self.data[key]
|
||||
try:
|
||||
value = self.storage[key]
|
||||
except KeyError:
|
||||
value = self.data[key]
|
||||
if callable(value):
|
||||
value = value()
|
||||
self.data[key] = value
|
||||
# check arguments can be supplied to callable (for backwards compatibility):
|
||||
value = value(self) if hasattr(value, '__code__') and value.__code__.co_argcount else value()
|
||||
self.storage[key] = value
|
||||
return value
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
self.data[key] = value
|
||||
# mutate to copy:
|
||||
if self.immutable:
|
||||
self.storage = self.storage.copy()
|
||||
self.__org_data = self.data
|
||||
self.data = self.data.copy()
|
||||
self.immutable = False
|
||||
self.storage[key] = value
|
||||
|
||||
def __unavailable(self, key):
|
||||
raise KeyError("Key %r was deleted" % key)
|
||||
|
||||
def __delitem__(self, key):
|
||||
# mutate to copy:
|
||||
if self.immutable:
|
||||
self.storage = self.storage.copy()
|
||||
self.__org_data = self.data
|
||||
self.data = self.data.copy()
|
||||
self.immutable = False
|
||||
try:
|
||||
del self.storage[key]
|
||||
except KeyError:
|
||||
pass
|
||||
del self.data[key]
|
||||
|
||||
def __iter__(self):
|
||||
|
@ -94,7 +140,7 @@ class CallingMap(MutableMapping):
|
|||
def __len__(self):
|
||||
return len(self.data)
|
||||
|
||||
def copy(self):
|
||||
def copy(self): # pargma: no cover
|
||||
return self.__class__(self.data.copy())
|
||||
|
||||
|
||||
|
@ -436,9 +482,6 @@ class CommandAction(ActionBase):
|
|||
# interpolation of dictionary:
|
||||
if subInfo is None:
|
||||
subInfo = substituteRecursiveTags(aInfo, conditional, ignore=cls._escapedTags)
|
||||
# New line, space
|
||||
for (tag, value) in (("br", '\n'), ("sp", " ")):
|
||||
if subInfo.get(tag) is None: subInfo[tag] = value
|
||||
# cache if possible:
|
||||
if csubkey is not None:
|
||||
cache[csubkey] = subInfo
|
||||
|
@ -453,7 +496,8 @@ class CommandAction(ActionBase):
|
|||
if value is None:
|
||||
value = subInfo.get(tag)
|
||||
if value is None:
|
||||
return m.group() # fallback (no replacement)
|
||||
# fallback (no or default replacement)
|
||||
return ADD_REPL_TAGS.get(tag, m.group())
|
||||
value = str(value) # assure string
|
||||
if tag in cls._escapedTags:
|
||||
# That one needs to be escaped since its content is
|
||||
|
|
|
@ -286,44 +286,86 @@ class Actions(JailThread, Mapping):
|
|||
self.stopActions()
|
||||
return True
|
||||
|
||||
def __getBansMerged(self, mi, overalljails=False):
|
||||
"""Gets bans merged once, a helper for lambda(s), prevents stop of executing action by any exception inside.
|
||||
class ActionInfo(CallingMap):
|
||||
|
||||
This function never returns None for ainfo lambdas - always a ticket (merged or single one)
|
||||
and prevents any errors through merging (to guarantee ban actions will be executed).
|
||||
[TODO] move merging to observer - here we could wait for merge and read already merged info from a database
|
||||
AI_DICT = {
|
||||
"ip": lambda self: self.__ticket.getIP(),
|
||||
"ip-rev": lambda self: self['ip'].getPTR(''),
|
||||
"fid": lambda self: self.__ticket.getID(),
|
||||
"failures": lambda self: self.__ticket.getAttempt(),
|
||||
"time": lambda self: self.__ticket.getTime(),
|
||||
"matches": lambda self: "\n".join(self.__ticket.getMatches()),
|
||||
# to bypass actions, that should not be executed for restored tickets
|
||||
"restored": lambda self: (1 if self.__ticket.restored else 0),
|
||||
# extra-interpolation - all match-tags (captured from the filter):
|
||||
"F-*": lambda self, tag=None: self.__ticket.getData(tag),
|
||||
# merged info:
|
||||
"ipmatches": lambda self: "\n".join(self._mi4ip(True).getMatches()),
|
||||
"ipjailmatches": lambda self: "\n".join(self._mi4ip().getMatches()),
|
||||
"ipfailures": lambda self: self._mi4ip(True).getAttempt(),
|
||||
"ipjailfailures": lambda self: self._mi4ip().getAttempt(),
|
||||
}
|
||||
|
||||
Parameters
|
||||
----------
|
||||
mi : dict
|
||||
merge info, initial for lambda should contains {ip, ticket}
|
||||
overalljails : bool
|
||||
switch to get a merged bans :
|
||||
False - (default) bans merged for current jail only
|
||||
True - bans merged for all jails of current ip address
|
||||
__slots__ = CallingMap.__slots__ + ('__ticket', '__jail', '__mi4ip')
|
||||
|
||||
def __init__(self, ticket, jail=None, immutable=True, data=AI_DICT):
|
||||
self.__ticket = ticket
|
||||
self.__jail = jail
|
||||
self.storage = dict()
|
||||
self.immutable = immutable
|
||||
self.data = data
|
||||
|
||||
def copy(self): # pargma: no cover
|
||||
return self.__class__(self.__ticket, self.__jail, self.immutable, self.data.copy())
|
||||
|
||||
def _mi4ip(self, overalljails=False):
|
||||
"""Gets bans merged once, a helper for lambda(s), prevents stop of executing action by any exception inside.
|
||||
|
||||
This function never returns None for ainfo lambdas - always a ticket (merged or single one)
|
||||
and prevents any errors through merging (to guarantee ban actions will be executed).
|
||||
[TODO] move merging to observer - here we could wait for merge and read already merged info from a database
|
||||
|
||||
Parameters
|
||||
----------
|
||||
overalljails : bool
|
||||
switch to get a merged bans :
|
||||
False - (default) bans merged for current jail only
|
||||
True - bans merged for all jails of current ip address
|
||||
|
||||
Returns
|
||||
-------
|
||||
BanTicket
|
||||
merged or self ticket only
|
||||
"""
|
||||
if not hasattr(self, '__mi4ip'):
|
||||
self.__mi4ip = {}
|
||||
mi = self.__mi4ip
|
||||
idx = 'all' if overalljails else 'jail'
|
||||
if idx in mi:
|
||||
return mi[idx] if mi[idx] is not None else self.__ticket
|
||||
try:
|
||||
jail = self.__jail
|
||||
ip = self['ip']
|
||||
mi[idx] = None
|
||||
if not jail.database: # pragma: no cover
|
||||
return self.__ticket
|
||||
if overalljails:
|
||||
mi[idx] = jail.database.getBansMerged(ip=ip)
|
||||
else:
|
||||
mi[idx] = jail.database.getBansMerged(ip=ip, jail=jail)
|
||||
except Exception as e:
|
||||
logSys.error(
|
||||
"Failed to get %s bans merged, jail '%s': %s",
|
||||
idx, jail.name, e,
|
||||
exc_info=logSys.getEffectiveLevel()<=logging.DEBUG)
|
||||
return mi[idx] if mi[idx] is not None else self.__ticket
|
||||
|
||||
|
||||
def __getActionInfo(self, ticket):
|
||||
ip = ticket.getIP()
|
||||
aInfo = Actions.ActionInfo(ticket, self._jail)
|
||||
return aInfo
|
||||
|
||||
Returns
|
||||
-------
|
||||
BanTicket
|
||||
merged or self ticket only
|
||||
"""
|
||||
idx = 'all' if overalljails else 'jail'
|
||||
if idx in mi:
|
||||
return mi[idx] if mi[idx] is not None else mi['ticket']
|
||||
try:
|
||||
jail=self._jail
|
||||
ip=mi['ip']
|
||||
mi[idx] = None
|
||||
if overalljails:
|
||||
mi[idx] = jail.database.getBansMerged(ip=ip)
|
||||
else:
|
||||
mi[idx] = jail.database.getBansMerged(ip=ip, jail=jail)
|
||||
except Exception as e:
|
||||
logSys.error(
|
||||
"Failed to get %s bans merged, jail '%s': %s",
|
||||
idx, jail.name, e,
|
||||
exc_info=logSys.getEffectiveLevel()<=logging.DEBUG)
|
||||
return mi[idx] if mi[idx] is not None else mi['ticket']
|
||||
|
||||
def __checkBan(self):
|
||||
"""Check for IP address to ban.
|
||||
|
@ -343,24 +385,7 @@ class Actions(JailThread, Mapping):
|
|||
break
|
||||
bTicket = BanManager.createBanTicket(ticket)
|
||||
ip = bTicket.getIP()
|
||||
aInfo = CallingMap({
|
||||
"ip" : ip,
|
||||
"ip-rev" : lambda: ip.getPTR(''),
|
||||
"failures": bTicket.getAttempt(),
|
||||
"time" : bTicket.getTime(),
|
||||
"matches" : "\n".join(bTicket.getMatches()),
|
||||
# to bypass actions, that should not be executed for restored tickets
|
||||
"restored": (1 if ticket.restored else 0),
|
||||
# extra-interpolation - all match-tags (captured from the filter):
|
||||
"F-*": lambda tag=None: bTicket.getData(tag)
|
||||
})
|
||||
if self._jail.database is not None:
|
||||
mi4ip = lambda overalljails=False, self=self, \
|
||||
mi={'ip':ip, 'ticket':bTicket}: self.__getBansMerged(mi, overalljails)
|
||||
aInfo["ipmatches"] = lambda: "\n".join(mi4ip(True).getMatches())
|
||||
aInfo["ipjailmatches"] = lambda: "\n".join(mi4ip().getMatches())
|
||||
aInfo["ipfailures"] = lambda: mi4ip(True).getAttempt()
|
||||
aInfo["ipjailfailures"] = lambda: mi4ip().getAttempt()
|
||||
aInfo = self.__getActionInfo(bTicket)
|
||||
reason = {}
|
||||
if self.__banManager.addBanTicket(bTicket, reason=reason):
|
||||
cnt += 1
|
||||
|
@ -369,7 +394,8 @@ class Actions(JailThread, Mapping):
|
|||
try:
|
||||
if ticket.restored and getattr(action, 'norestored', False):
|
||||
continue
|
||||
action.ban(aInfo.copy())
|
||||
if not aInfo.immutable: aInfo.reset()
|
||||
action.ban(aInfo)
|
||||
except Exception as e:
|
||||
logSys.error(
|
||||
"Failed to execute ban jail '%s' action '%s' "
|
||||
|
@ -452,21 +478,17 @@ class Actions(JailThread, Mapping):
|
|||
unbactions = self._actions
|
||||
else:
|
||||
unbactions = actions
|
||||
aInfo = dict()
|
||||
aInfo["ip"] = ticket.getIP()
|
||||
aInfo["failures"] = ticket.getAttempt()
|
||||
aInfo["time"] = ticket.getTime()
|
||||
aInfo["matches"] = "".join(ticket.getMatches())
|
||||
# to bypass actions, that should not be executed for restored tickets
|
||||
aInfo["restored"] = 1 if ticket.restored else 0
|
||||
ip = ticket.getIP()
|
||||
aInfo = self.__getActionInfo(ticket)
|
||||
if actions is None:
|
||||
logSys.notice("[%s] Unban %s", self._jail.name, aInfo["ip"])
|
||||
for name, action in unbactions.iteritems():
|
||||
try:
|
||||
if ticket.restored and getattr(action, 'norestored', False):
|
||||
continue
|
||||
logSys.debug("[%s] action %r: unban %s", self._jail.name, name, aInfo["ip"])
|
||||
action.unban(aInfo.copy())
|
||||
logSys.debug("[%s] action %r: unban %s", self._jail.name, name, ip)
|
||||
if not aInfo.immutable: aInfo.reset()
|
||||
action.unban(aInfo)
|
||||
except Exception as e:
|
||||
logSys.error(
|
||||
"Failed to execute unban jail '%s' action '%s' "
|
||||
|
|
|
@ -194,7 +194,7 @@ class CommandActionTest(LogCaptureTestCase):
|
|||
# Callable
|
||||
self.assertEqual(
|
||||
self.__action.replaceTag("09 <matches> 11",
|
||||
CallingMap(matches=lambda: str(10))),
|
||||
CallingMap(matches=lambda self: str(10))),
|
||||
"09 10 11")
|
||||
|
||||
def testReplaceNoTag(self):
|
||||
|
@ -202,7 +202,7 @@ class CommandActionTest(LogCaptureTestCase):
|
|||
# Will raise ValueError if it is
|
||||
self.assertEqual(
|
||||
self.__action.replaceTag("abc",
|
||||
CallingMap(matches=lambda: int("a"))), "abc")
|
||||
CallingMap(matches=lambda self: int("a"))), "abc")
|
||||
|
||||
def testReplaceTagSelfRecursion(self):
|
||||
setattr(self.__action, 'a', "<a")
|
||||
|
@ -332,7 +332,7 @@ class CommandActionTest(LogCaptureTestCase):
|
|||
aInfo = CallingMap({
|
||||
'ABC': "123",
|
||||
'ip': '192.0.2.1',
|
||||
'F-*': lambda: {
|
||||
'F-*': lambda self: {
|
||||
'fid': 111,
|
||||
'fport': 222,
|
||||
'user': "tester"
|
||||
|
@ -442,7 +442,7 @@ class CommandActionTest(LogCaptureTestCase):
|
|||
"stderr: 'The rain in Spain stays mainly in the plain'\n")
|
||||
|
||||
def testCallingMap(self):
|
||||
mymap = CallingMap(callme=lambda: str(10), error=lambda: int('a'),
|
||||
mymap = CallingMap(callme=lambda self: str(10), error=lambda self: int('a'),
|
||||
dontcallme= "string", number=17)
|
||||
|
||||
# Should work fine
|
||||
|
@ -451,3 +451,43 @@ class CommandActionTest(LogCaptureTestCase):
|
|||
"10 okay string 17")
|
||||
# Error will now trip, demonstrating delayed call
|
||||
self.assertRaises(ValueError, lambda x: "%(error)i" % x, mymap)
|
||||
|
||||
def testCallingMapModify(self):
|
||||
m = CallingMap({
|
||||
'a': lambda self: 2 + 3,
|
||||
'b': lambda self: self['a'] + 6,
|
||||
'c': 'test',
|
||||
})
|
||||
# test reset (without modifications):
|
||||
m.reset()
|
||||
# do modifications:
|
||||
m['a'] = 4
|
||||
del m['c']
|
||||
# test set and delete:
|
||||
self.assertEqual(len(m), 2)
|
||||
self.assertNotIn('c', m)
|
||||
self.assertEqual((m['a'], m['b']), (4, 10))
|
||||
# reset to original and test again:
|
||||
m.reset()
|
||||
s = repr(m)
|
||||
self.assertEqual(len(m), 3)
|
||||
self.assertIn('c', m)
|
||||
self.assertEqual((m['a'], m['b'], m['c']), (5, 11, 'test'))
|
||||
|
||||
def testCallingMapRep(self):
|
||||
m = CallingMap({
|
||||
'a': lambda self: 2 + 3,
|
||||
'b': lambda self: self['a'] + 6,
|
||||
'c': ''
|
||||
})
|
||||
s = repr(m)
|
||||
self.assertIn("'a': 5", s)
|
||||
self.assertIn("'b': 11", s)
|
||||
self.assertIn("'c': ''", s)
|
||||
|
||||
m['c'] = lambda self: self['xxx'] + 7; # unresolvable
|
||||
s = repr(m)
|
||||
self.assertIn("'a': 5", s)
|
||||
self.assertIn("'b': 11", s)
|
||||
self.assertIn("'c': ", s) # presents as callable
|
||||
self.assertNotIn("'c': ''", s) # but not empty
|
||||
|
|
|
@ -1737,6 +1737,6 @@ class ServerConfigReaderTests(LogCaptureTestCase):
|
|||
if not tests.get(test): continue
|
||||
self.pruneLog('# === %s ===' % test)
|
||||
ticket = _actions.CallingMap({
|
||||
'ip': ip, 'ip-rev': lambda: ip.getPTR(''), 'failures': 100,})
|
||||
'ip': ip, 'ip-rev': lambda self: self['ip'].getPTR(''), 'failures': 100,})
|
||||
action.ban(ticket)
|
||||
self.assertLogged(*tests[test], all=True)
|
||||
|
|
Loading…
Reference in New Issue