diff --git a/fail2ban/server/database.py b/fail2ban/server/database.py index 073b426c..6b22334e 100644 --- a/fail2ban/server/database.py +++ b/fail2ban/server/database.py @@ -227,24 +227,42 @@ class Fail2BanDb(object): "failures": ticket.getAttempt()})) @commitandrollback() - def getBans(self, cur, jail=None, bantime=None): - query = "SELECT ip, timeofban, data FROM bans" + def _getBans(self, cur, jail=None, bantime=None, ip=None): + query = "SELECT ip, timeofban, data FROM bans WHERE 1" queryArgs = [] if jail is not None: - query += " WHERE jail=?" + query += " AND jail=?" queryArgs.append(jail.getName()) if bantime is not None: query += " AND timeofban > ?" queryArgs.append(MyTime.time() - bantime) + if ip is not None: + query += " AND ip=?" + queryArgs.append(ip) + query += " ORDER BY timeofban" + return cur.execute(query, queryArgs) + + def getBans(self, *args, **kwargs): tickets = [] - for ip, timeofban, data in cur.execute(query, queryArgs): + for ip, timeofban, data in self._getBans(*args, **kwargs): #TODO: Implement data parts once arbitrary match keys completed tickets.append(FailTicket(ip, timeofban, data['matches'])) tickets[-1].setAttempt(data['failures']) return tickets + def getBansMerged(self, ip, *args, **kwargs): + matches = [] + failures = 0 + for ip, timeofban, data in self._getBans(*args, ip=ip, **kwargs): + #TODO: Implement data parts once arbitrary match keys completed + matches.extend(data['matches']) + failures += data['failures'] + ticket = FailTicket(ip, timeofban, matches) + ticket.setAttempt(failures) + return ticket + @commitandrollback() def purge(self, cur): cur.execute( diff --git a/fail2ban/server/jail.py b/fail2ban/server/jail.py index a24810ed..ff780fe9 100644 --- a/fail2ban/server/jail.py +++ b/fail2ban/server/jail.py @@ -144,7 +144,8 @@ class Jail: self.__action.start() # Restore any previous valid bans from the database if self.__db is not None: - for ticket in self.__db.getBans(self, self.__action.getBanTime()): + for ticket in self.__db.getBans( + jail=self, bantime=self.__action.getBanTime()): self.__queue.put(ticket) logSys.info("Jail '%s' started" % self.__name) diff --git a/fail2ban/tests/databasetestcase.py b/fail2ban/tests/databasetestcase.py index e001b865..2c7422b2 100644 --- a/fail2ban/tests/databasetestcase.py +++ b/fail2ban/tests/databasetestcase.py @@ -124,12 +124,43 @@ class DatabaseTest(unittest.TestCase): def testAddBan(self): self.testAddJail() - ticket = FailTicket("127.0.0.1", 0, []) + ticket = FailTicket("127.0.0.1", 0, ["abc\n"]) self.db.addBan(self.jail, ticket) self.assertEquals(len(self.db.getBans(self.jail)), 1) self.assertTrue( - isinstance(self.db.getBans(self.jail)[0], FailTicket)) + isinstance(self.db.getBans(jail=self.jail)[0], FailTicket)) + + def testGetBansMerged(self): + self.testAddJail() + + jail2 = DummyJail() + self.db.addJail(jail2) + + ticket = FailTicket("127.0.0.1", 10, ["abc\n"]) + ticket.setAttempt(10) + self.db.addBan(self.jail, ticket) + ticket = FailTicket("127.0.0.1", 20, ["123\n"]) + ticket.setAttempt(20) + self.db.addBan(self.jail, ticket) + ticket = FailTicket("127.0.0.2", 30, ["ABC\n"]) + ticket.setAttempt(30) + self.db.addBan(self.jail, ticket) + ticket = FailTicket("127.0.0.1", 40, ["ABC\n"]) + ticket.setAttempt(40) + self.db.addBan(jail2, ticket) + + # All for IP 127.0.0.1 + ticket = self.db.getBansMerged("127.0.0.1") + self.assertEqual(ticket.getIP(), "127.0.0.1") + self.assertEqual(ticket.getAttempt(), 70) + self.assertEqual(ticket.getMatches(), ["abc\n", "123\n", "ABC\n"]) + + # All for IP 127.0.0.1 for single jail + ticket = self.db.getBansMerged("127.0.0.1", jail=self.jail) + self.assertEqual(ticket.getIP(), "127.0.0.1") + self.assertEqual(ticket.getAttempt(), 30) + self.assertEqual(ticket.getMatches(), ["abc\n", "123\n"]) def testPurge(self): self.testAddJail() # Add jail @@ -145,12 +176,13 @@ class DatabaseTest(unittest.TestCase): self.db.delJail(self.jail) self.db.purge() # Purge should remove all bans self.assertEqual(len(self.db.getJailNames()), 0) - self.assertEqual(len(self.db.getBans(self.jail)), 0) + self.assertEqual(len(self.db.getBans(jail=self.jail)), 0) # Should leave jail self.testAddJail() - self.db.addBan(self.jail, FailTicket("127.0.0.1", MyTime.time(), [])) + self.db.addBan( + self.jail, FailTicket("127.0.0.1", MyTime.time(), ["abc\n"])) self.db.delJail(self.jail) self.db.purge() # Should leave jail as ban present self.assertEqual(len(self.db.getJailNames()), 1) - self.assertEqual(len(self.db.getBans(self.jail)), 1) + self.assertEqual(len(self.db.getBans(jail=self.jail)), 1)