mirror of https://github.com/fail2ban/fail2ban
fix for JSON serialization bug for set object (gh-2103): currently there are only users, so simply serialized as a list.
parent
0a50f2e19e
commit
34b586b51e
|
@ -39,9 +39,14 @@ from ..helpers import getLogger, PREFER_ENC
|
|||
logSys = getLogger(__name__)
|
||||
|
||||
if sys.version_info >= (3,):
|
||||
def _json_default(x):
|
||||
if isinstance(x, set):
|
||||
x = list(x)
|
||||
return x
|
||||
|
||||
def _json_dumps_safe(x):
|
||||
try:
|
||||
x = json.dumps(x, ensure_ascii=False).encode(
|
||||
x = json.dumps(x, ensure_ascii=False, default=_json_default).encode(
|
||||
PREFER_ENC, 'replace')
|
||||
except Exception as e: # pragma: no cover
|
||||
logSys.error('json dumps failed: %s', e)
|
||||
|
@ -60,7 +65,7 @@ else:
|
|||
def _normalize(x):
|
||||
if isinstance(x, dict):
|
||||
return dict((_normalize(k), _normalize(v)) for k, v in x.iteritems())
|
||||
elif isinstance(x, list):
|
||||
elif isinstance(x, (list, set)):
|
||||
return [_normalize(element) for element in x]
|
||||
elif isinstance(x, unicode):
|
||||
return x.encode(PREFER_ENC)
|
||||
|
@ -527,10 +532,13 @@ class Fail2BanDb(object):
|
|||
except KeyError:
|
||||
pass
|
||||
#TODO: Implement data parts once arbitrary match keys completed
|
||||
data = ticket.getData()
|
||||
matches = data.get('matches')
|
||||
if matches and len(matches) > self.maxEntries:
|
||||
data['matches'] = matches[-self.maxEntries:]
|
||||
cur.execute(
|
||||
"INSERT INTO bans(jail, ip, timeofban, data) VALUES(?, ?, ?, ?)",
|
||||
(jail.name, ip, int(round(ticket.getTime())),
|
||||
ticket.getData()))
|
||||
(jail.name, ip, int(round(ticket.getTime())), data))
|
||||
|
||||
@commitandrollback
|
||||
def delBan(self, cur, jail, *args):
|
||||
|
@ -659,11 +667,11 @@ class Fail2BanDb(object):
|
|||
else:
|
||||
matches = m[-maxadd:] + matches
|
||||
failures += data.get('failures', 1)
|
||||
tickdata.update(data.get('data', {}))
|
||||
data['failures'] = failures
|
||||
data['matches'] = matches
|
||||
tickdata.update(data)
|
||||
prev_timeofban = timeofban
|
||||
ticket = FailTicket(banip, prev_timeofban, matches)
|
||||
ticket.setAttempt(failures)
|
||||
ticket.setData(**tickdata)
|
||||
ticket = FailTicket(banip, prev_timeofban, data=tickdata)
|
||||
tickets.append(ticket)
|
||||
|
||||
if cacheKey:
|
||||
|
|
|
@ -302,12 +302,18 @@ class DatabaseTest(LogCaptureTestCase):
|
|||
def testGetBansMerged_MaxEntries(self):
|
||||
self.testAddJail()
|
||||
maxEntries = 2
|
||||
failures = ["abc\n", "123\n", "ABC\n", "1234\n"]
|
||||
failures = [
|
||||
{"matches": ["abc\n"], "user": set(['test'])},
|
||||
{"matches": ["123\n"], "user": set(['test'])},
|
||||
{"matches": ["ABC\n"], "user": set(['test', 'root'])},
|
||||
{"matches": ["1234\n"], "user": set(['test', 'root'])},
|
||||
]
|
||||
matches2find = [f["matches"][0] for f in failures]
|
||||
# add failures sequential:
|
||||
i = 80
|
||||
for f in failures:
|
||||
i -= 10
|
||||
ticket = FailTicket("127.0.0.1", MyTime.time() - i, [f])
|
||||
ticket = FailTicket("127.0.0.1", MyTime.time() - i, data=f)
|
||||
ticket.setAttempt(1)
|
||||
self.db.addBan(self.jail, ticket)
|
||||
# should retrieve 2 matches only, but count of all attempts:
|
||||
|
@ -316,9 +322,10 @@ class DatabaseTest(LogCaptureTestCase):
|
|||
self.assertEqual(ticket.getIP(), "127.0.0.1")
|
||||
self.assertEqual(ticket.getAttempt(), len(failures))
|
||||
self.assertEqual(len(ticket.getMatches()), maxEntries)
|
||||
self.assertEqual(ticket.getMatches(), failures[len(failures) - maxEntries:])
|
||||
self.assertEqual(ticket.getMatches(), matches2find[-maxEntries:])
|
||||
# add more failures at once:
|
||||
ticket = FailTicket("127.0.0.1", MyTime.time() - 10, failures)
|
||||
ticket = FailTicket("127.0.0.1", MyTime.time() - 10, matches2find,
|
||||
data={"user": set(['test', 'root'])})
|
||||
ticket.setAttempt(len(failures))
|
||||
self.db.addBan(self.jail, ticket)
|
||||
# should retrieve 2 matches only, but count of all attempts:
|
||||
|
@ -326,7 +333,13 @@ class DatabaseTest(LogCaptureTestCase):
|
|||
ticket = self.db.getBansMerged("127.0.0.1")
|
||||
self.assertEqual(ticket.getAttempt(), 2 * len(failures))
|
||||
self.assertEqual(len(ticket.getMatches()), maxEntries)
|
||||
self.assertEqual(ticket.getMatches(), failures[len(failures) - maxEntries:])
|
||||
self.assertEqual(ticket.getMatches(), matches2find[-maxEntries:])
|
||||
# also using getCurrentBans:
|
||||
ticket = self.db.getCurrentBans(self.jail, "127.0.0.1", fromtime=MyTime.time()-100)
|
||||
self.assertTrue(ticket is not None)
|
||||
self.assertEqual(ticket.getAttempt(), len(failures))
|
||||
self.assertEqual(len(ticket.getMatches()), maxEntries)
|
||||
self.assertEqual(ticket.getMatches(), matches2find[-maxEntries:])
|
||||
|
||||
def testGetBansMerged(self):
|
||||
self.testAddJail()
|
||||
|
|
Loading…
Reference in New Issue