ENH: Add getBansMerged method to Fail2BanDb

Creates a single ticket for an IP, made up of all previous bans
pull/480/head
Steven Hiscocks 2013-12-12 22:22:30 +00:00
parent e18af48e34
commit 00ecd22851
3 changed files with 61 additions and 10 deletions

View File

@ -227,24 +227,42 @@ class Fail2BanDb(object):
"failures": ticket.getAttempt()}))
@commitandrollback()
def getBans(self, cur, jail=None, bantime=None):
query = "SELECT ip, timeofban, data FROM bans"
def _getBans(self, cur, jail=None, bantime=None, ip=None):
query = "SELECT ip, timeofban, data FROM bans WHERE 1"
queryArgs = []
if jail is not None:
query += " WHERE jail=?"
query += " AND jail=?"
queryArgs.append(jail.getName())
if bantime is not None:
query += " AND timeofban > ?"
queryArgs.append(MyTime.time() - bantime)
if ip is not None:
query += " AND ip=?"
queryArgs.append(ip)
query += " ORDER BY timeofban"
return cur.execute(query, queryArgs)
def getBans(self, *args, **kwargs):
tickets = []
for ip, timeofban, data in cur.execute(query, queryArgs):
for ip, timeofban, data in self._getBans(*args, **kwargs):
#TODO: Implement data parts once arbitrary match keys completed
tickets.append(FailTicket(ip, timeofban, data['matches']))
tickets[-1].setAttempt(data['failures'])
return tickets
def getBansMerged(self, ip, *args, **kwargs):
matches = []
failures = 0
for ip, timeofban, data in self._getBans(*args, ip=ip, **kwargs):
#TODO: Implement data parts once arbitrary match keys completed
matches.extend(data['matches'])
failures += data['failures']
ticket = FailTicket(ip, timeofban, matches)
ticket.setAttempt(failures)
return ticket
@commitandrollback()
def purge(self, cur):
cur.execute(

View File

@ -144,7 +144,8 @@ class Jail:
self.__action.start()
# Restore any previous valid bans from the database
if self.__db is not None:
for ticket in self.__db.getBans(self, self.__action.getBanTime()):
for ticket in self.__db.getBans(
jail=self, bantime=self.__action.getBanTime()):
self.__queue.put(ticket)
logSys.info("Jail '%s' started" % self.__name)

View File

@ -124,12 +124,43 @@ class DatabaseTest(unittest.TestCase):
def testAddBan(self):
self.testAddJail()
ticket = FailTicket("127.0.0.1", 0, [])
ticket = FailTicket("127.0.0.1", 0, ["abc\n"])
self.db.addBan(self.jail, ticket)
self.assertEquals(len(self.db.getBans(self.jail)), 1)
self.assertTrue(
isinstance(self.db.getBans(self.jail)[0], FailTicket))
isinstance(self.db.getBans(jail=self.jail)[0], FailTicket))
def testGetBansMerged(self):
self.testAddJail()
jail2 = DummyJail()
self.db.addJail(jail2)
ticket = FailTicket("127.0.0.1", 10, ["abc\n"])
ticket.setAttempt(10)
self.db.addBan(self.jail, ticket)
ticket = FailTicket("127.0.0.1", 20, ["123\n"])
ticket.setAttempt(20)
self.db.addBan(self.jail, ticket)
ticket = FailTicket("127.0.0.2", 30, ["ABC\n"])
ticket.setAttempt(30)
self.db.addBan(self.jail, ticket)
ticket = FailTicket("127.0.0.1", 40, ["ABC\n"])
ticket.setAttempt(40)
self.db.addBan(jail2, ticket)
# All for IP 127.0.0.1
ticket = self.db.getBansMerged("127.0.0.1")
self.assertEqual(ticket.getIP(), "127.0.0.1")
self.assertEqual(ticket.getAttempt(), 70)
self.assertEqual(ticket.getMatches(), ["abc\n", "123\n", "ABC\n"])
# All for IP 127.0.0.1 for single jail
ticket = self.db.getBansMerged("127.0.0.1", jail=self.jail)
self.assertEqual(ticket.getIP(), "127.0.0.1")
self.assertEqual(ticket.getAttempt(), 30)
self.assertEqual(ticket.getMatches(), ["abc\n", "123\n"])
def testPurge(self):
self.testAddJail() # Add jail
@ -145,12 +176,13 @@ class DatabaseTest(unittest.TestCase):
self.db.delJail(self.jail)
self.db.purge() # Purge should remove all bans
self.assertEqual(len(self.db.getJailNames()), 0)
self.assertEqual(len(self.db.getBans(self.jail)), 0)
self.assertEqual(len(self.db.getBans(jail=self.jail)), 0)
# Should leave jail
self.testAddJail()
self.db.addBan(self.jail, FailTicket("127.0.0.1", MyTime.time(), []))
self.db.addBan(
self.jail, FailTicket("127.0.0.1", MyTime.time(), ["abc\n"]))
self.db.delJail(self.jail)
self.db.purge() # Should leave jail as ban present
self.assertEqual(len(self.db.getJailNames()), 1)
self.assertEqual(len(self.db.getBans(self.jail)), 1)
self.assertEqual(len(self.db.getBans(jail=self.jail)), 1)