From fb7511fdea47862f52ffb1d33a1ea49852e7e685 Mon Sep 17 00:00:00 2001 From: Steven Hiscocks Date: Sun, 15 Dec 2013 21:52:50 +0000 Subject: [PATCH] ENH: Add cache for database getBansMerged This is avoids duplicate queries when using the ip(jail)matches and ip(jail)failures in actions --- fail2ban/server/database.py | 8 ++++++++ fail2ban/tests/databasetestcase.py | 14 ++++++++++++++ 2 files changed, 22 insertions(+) diff --git a/fail2ban/server/database.py b/fail2ban/server/database.py index 997f1c21..b9c2e12d 100644 --- a/fail2ban/server/database.py +++ b/fail2ban/server/database.py @@ -64,6 +64,8 @@ class Fail2BanDb(object): self._dbFilename = filename self._purgeAge = purgeAge + self._bansMergedCache = {} + logSys.info( "Connected to fail2ban persistent database '%s'", filename) except sqlite3.OperationalError, e: @@ -219,6 +221,7 @@ class Fail2BanDb(object): @commitandrollback def addBan(self, cur, jail, ticket): + self._bansMergedCache = {} #TODO: Implement data parts once arbitrary match keys completed cur.execute( "INSERT INTO bans(jail, ip, timeofban, data) VALUES(?, ?, ?, ?)", @@ -253,6 +256,9 @@ class Fail2BanDb(object): return tickets def getBansMerged(self, ip, jail=None, **kwargs): + cacheKey = ip if jail is None else "%s|%s" % (ip, jail.getName()) + if cacheKey in self._bansMergedCache: + return self._bansMergedCache[cacheKey] matches = [] failures = 0 for ip, timeofban, data in self._getBans(ip=ip, jail=jail, **kwargs): @@ -261,10 +267,12 @@ class Fail2BanDb(object): failures += data['failures'] ticket = FailTicket(ip, timeofban, matches) ticket.setAttempt(failures) + self._bansMergedCache[cacheKey] = ticket return ticket @commitandrollback def purge(self, cur): + self._bansMergedCache = {} cur.execute( "DELETE FROM bans WHERE timeofban < ?", (MyTime.time() - self._purgeAge, )) diff --git a/fail2ban/tests/databasetestcase.py b/fail2ban/tests/databasetestcase.py index 44f7a59e..969ea6ff 100644 --- a/fail2ban/tests/databasetestcase.py +++ b/fail2ban/tests/databasetestcase.py @@ -162,6 +162,20 @@ class DatabaseTest(unittest.TestCase): self.assertEqual(ticket.getAttempt(), 30) self.assertEqual(ticket.getMatches(), ["abc\n", "123\n"]) + # Should cache result if no extra bans added + ticketID = id(ticket) + self.assertEqual( + ticketID, + id(self.db.getBansMerged("127.0.0.1", jail=self.jail))) + + ticket = FailTicket("127.0.0.1", 40, ["ABC\n"]) + ticket.setAttempt(40) + self.db.addBan(jail2, ticket) + # Added ticket, so cache should have been cleared + self.assertNotEqual( + ticketID, + id(self.db.getBansMerged("127.0.0.1", jail=self.jail))) + def testPurge(self): self.testAddJail() # Add jail