diff --git a/fail2ban/server/actions.py b/fail2ban/server/actions.py index 3a7e21da..cee5bf21 100644 --- a/fail2ban/server/actions.py +++ b/fail2ban/server/actions.py @@ -490,7 +490,7 @@ class Actions(JailThread, Mapping): """ log = True if actions is None: - logSys.debug("Flush ban list") + logSys.debug(" Flush ban list") lst = self.__banManager.flushBanList() else: log = False # don't log "[jail] Unban ..." if removing actions only. @@ -505,16 +505,16 @@ class Actions(JailThread, Mapping): else: unbactions[name] = action actions = unbactions + # flush the database also: + if db and self._jail.database is not None: + logSys.debug(" Flush jail in database") + self._jail.database.delBan(self._jail) # unban each ticket with non-flasheable actions: for ticket in lst: - # delete ip from database also: - if db and self._jail.database is not None: - ip = str(ticket.getIP()) - self._jail.database.delBan(self._jail, ip) # unban ip: self.__unBan(ticket, actions=actions, log=log) cnt += 1 - logSys.debug("Unbanned %s, %s ticket(s) in %r", + logSys.debug(" Unbanned %s, %s ticket(s) in %r", cnt, self.__banManager.size(), self._jail.name) return cnt diff --git a/fail2ban/server/database.py b/fail2ban/server/database.py index 5112bf54..c132fb10 100644 --- a/fail2ban/server/database.py +++ b/fail2ban/server/database.py @@ -566,23 +566,30 @@ class Fail2BanDb(object): ticket.getData())) @commitandrollback - def delBan(self, cur, jail, ip): - """Delete a ban from the database. + def delBan(self, cur, jail, *args): + """Delete a single or multiple tickets from the database. Parameters ---------- jail : Jail - Jail in which the ban has occurred. - ip : str - IP to be removed. + Jail in which the ticket(s) should be removed. + args : list of IP + IPs to be removed, if not given all tickets of jail will be removed. """ - queryArgs = (jail.name, str(ip)); - cur.execute( - "DELETE FROM bips WHERE jail = ? AND ip = ?", - queryArgs) - cur.execute( - "DELETE FROM bans WHERE jail = ? AND ip = ?", - queryArgs); + query1 = "DELETE FROM bips WHERE jail = ?" + query2 = "DELETE FROM bans WHERE jail = ?" + queryArgs = [jail.name]; + if not len(args): + cur.execute(query1, queryArgs); + cur.execute(query2, queryArgs); + return + query1 += " AND ip = ?" + query2 += " AND ip = ?" + queryArgs.append(''); + for ip in args: + queryArgs[1] = str(ip); + cur.execute(query1, queryArgs); + cur.execute(query2, queryArgs); @commitandrollback def _getBans(self, cur, jail=None, bantime=None, ip=None): diff --git a/fail2ban/tests/databasetestcase.py b/fail2ban/tests/databasetestcase.py index 9c60033f..78a15637 100644 --- a/fail2ban/tests/databasetestcase.py +++ b/fail2ban/tests/databasetestcase.py @@ -270,9 +270,10 @@ class DatabaseTest(LogCaptureTestCase): ticket = FailTicket("127.0.0.1", 0, ["abc\n"]) self.db.addBan(self.jail, ticket) - self.assertEqual(len(self.db.getBans(jail=self.jail)), 1) + tickets = self.db.getBans(jail=self.jail) + self.assertEqual(len(tickets), 1) self.assertTrue( - isinstance(self.db.getBans(jail=self.jail)[0], FailTicket)) + isinstance(tickets[0], FailTicket)) def testAddBanInvalidEncoded(self): if Fail2BanDb is None: # pragma: no cover @@ -305,10 +306,28 @@ class DatabaseTest(LogCaptureTestCase): or readtickets[2] == tickets[2] ) + def _testAdd3Bans(self): + self.testAddJail() + for i in (1, 2, 3): + ticket = FailTicket(("192.0.2.%d" % i), 0, ["test\n"]) + self.db.addBan(self.jail, ticket) + tickets = self.db.getBans(jail=self.jail) + self.assertEqual(len(tickets), 3) + return tickets + def testDelBan(self): - self.testAddBan() - ticket = self.db.getBans(jail=self.jail)[0] - self.db.delBan(self.jail, ticket.getIP()) + tickets = self._testAdd3Bans() + # delete single IP: + self.db.delBan(self.jail, tickets[0].getIP()) + self.assertEqual(len(self.db.getBans(jail=self.jail)), 2) + # delete two IPs: + self.db.delBan(self.jail, tickets[1].getIP(), tickets[2].getIP()) + self.assertEqual(len(self.db.getBans(jail=self.jail)), 0) + + def testFlushBans(self): + self._testAdd3Bans() + # flush all bans: + self.db.delBan(self.jail) self.assertEqual(len(self.db.getBans(jail=self.jail)), 0) def testGetBansWithTime(self):