diff --git a/fail2ban/server/database.py b/fail2ban/server/database.py index d7ba12c1..a06db21c 100644 --- a/fail2ban/server/database.py +++ b/fail2ban/server/database.py @@ -39,9 +39,14 @@ from ..helpers import getLogger, PREFER_ENC logSys = getLogger(__name__) if sys.version_info >= (3,): + def _json_default(x): + if isinstance(x, set): + x = list(x) + return x + def _json_dumps_safe(x): try: - x = json.dumps(x, ensure_ascii=False).encode( + x = json.dumps(x, ensure_ascii=False, default=_json_default).encode( PREFER_ENC, 'replace') except Exception as e: # pragma: no cover logSys.error('json dumps failed: %s', e) @@ -60,7 +65,7 @@ else: def _normalize(x): if isinstance(x, dict): return dict((_normalize(k), _normalize(v)) for k, v in x.iteritems()) - elif isinstance(x, list): + elif isinstance(x, (list, set)): return [_normalize(element) for element in x] elif isinstance(x, unicode): return x.encode(PREFER_ENC) @@ -527,10 +532,13 @@ class Fail2BanDb(object): except KeyError: pass #TODO: Implement data parts once arbitrary match keys completed + data = ticket.getData() + matches = data.get('matches') + if matches and len(matches) > self.maxEntries: + data['matches'] = matches[-self.maxEntries:] cur.execute( "INSERT INTO bans(jail, ip, timeofban, data) VALUES(?, ?, ?, ?)", - (jail.name, ip, int(round(ticket.getTime())), - ticket.getData())) + (jail.name, ip, int(round(ticket.getTime())), data)) @commitandrollback def delBan(self, cur, jail, *args): @@ -659,11 +667,11 @@ class Fail2BanDb(object): else: matches = m[-maxadd:] + matches failures += data.get('failures', 1) - tickdata.update(data.get('data', {})) + data['failures'] = failures + data['matches'] = matches + tickdata.update(data) prev_timeofban = timeofban - ticket = FailTicket(banip, prev_timeofban, matches) - ticket.setAttempt(failures) - ticket.setData(**tickdata) + ticket = FailTicket(banip, prev_timeofban, data=tickdata) tickets.append(ticket) if cacheKey: diff --git a/fail2ban/tests/databasetestcase.py b/fail2ban/tests/databasetestcase.py index 1ee523d9..7690525e 100644 --- a/fail2ban/tests/databasetestcase.py +++ b/fail2ban/tests/databasetestcase.py @@ -302,12 +302,18 @@ class DatabaseTest(LogCaptureTestCase): def testGetBansMerged_MaxEntries(self): self.testAddJail() maxEntries = 2 - failures = ["abc\n", "123\n", "ABC\n", "1234\n"] + failures = [ + {"matches": ["abc\n"], "user": set(['test'])}, + {"matches": ["123\n"], "user": set(['test'])}, + {"matches": ["ABC\n"], "user": set(['test', 'root'])}, + {"matches": ["1234\n"], "user": set(['test', 'root'])}, + ] + matches2find = [f["matches"][0] for f in failures] # add failures sequential: i = 80 for f in failures: i -= 10 - ticket = FailTicket("127.0.0.1", MyTime.time() - i, [f]) + ticket = FailTicket("127.0.0.1", MyTime.time() - i, data=f) ticket.setAttempt(1) self.db.addBan(self.jail, ticket) # should retrieve 2 matches only, but count of all attempts: @@ -316,9 +322,10 @@ class DatabaseTest(LogCaptureTestCase): self.assertEqual(ticket.getIP(), "127.0.0.1") self.assertEqual(ticket.getAttempt(), len(failures)) self.assertEqual(len(ticket.getMatches()), maxEntries) - self.assertEqual(ticket.getMatches(), failures[len(failures) - maxEntries:]) + self.assertEqual(ticket.getMatches(), matches2find[-maxEntries:]) # add more failures at once: - ticket = FailTicket("127.0.0.1", MyTime.time() - 10, failures) + ticket = FailTicket("127.0.0.1", MyTime.time() - 10, matches2find, + data={"user": set(['test', 'root'])}) ticket.setAttempt(len(failures)) self.db.addBan(self.jail, ticket) # should retrieve 2 matches only, but count of all attempts: @@ -326,7 +333,13 @@ class DatabaseTest(LogCaptureTestCase): ticket = self.db.getBansMerged("127.0.0.1") self.assertEqual(ticket.getAttempt(), 2 * len(failures)) self.assertEqual(len(ticket.getMatches()), maxEntries) - self.assertEqual(ticket.getMatches(), failures[len(failures) - maxEntries:]) + self.assertEqual(ticket.getMatches(), matches2find[-maxEntries:]) + # also using getCurrentBans: + ticket = self.db.getCurrentBans(self.jail, "127.0.0.1", fromtime=MyTime.time()-100) + self.assertTrue(ticket is not None) + self.assertEqual(ticket.getAttempt(), len(failures)) + self.assertEqual(len(ticket.getMatches()), maxEntries) + self.assertEqual(ticket.getMatches(), matches2find[-maxEntries:]) def testGetBansMerged(self): self.testAddJail()