mirror of https://github.com/fail2ban/fail2ban
ENH: Add getBansMerged method to Fail2BanDb
Creates a single ticket for an IP, made up of all previous banspull/480/head
parent
e18af48e34
commit
00ecd22851
|
@ -227,24 +227,42 @@ class Fail2BanDb(object):
|
|||
"failures": ticket.getAttempt()}))
|
||||
|
||||
@commitandrollback()
|
||||
def getBans(self, cur, jail=None, bantime=None):
|
||||
query = "SELECT ip, timeofban, data FROM bans"
|
||||
def _getBans(self, cur, jail=None, bantime=None, ip=None):
|
||||
query = "SELECT ip, timeofban, data FROM bans WHERE 1"
|
||||
queryArgs = []
|
||||
|
||||
if jail is not None:
|
||||
query += " WHERE jail=?"
|
||||
query += " AND jail=?"
|
||||
queryArgs.append(jail.getName())
|
||||
if bantime is not None:
|
||||
query += " AND timeofban > ?"
|
||||
queryArgs.append(MyTime.time() - bantime)
|
||||
if ip is not None:
|
||||
query += " AND ip=?"
|
||||
queryArgs.append(ip)
|
||||
query += " ORDER BY timeofban"
|
||||
|
||||
return cur.execute(query, queryArgs)
|
||||
|
||||
def getBans(self, *args, **kwargs):
|
||||
tickets = []
|
||||
for ip, timeofban, data in cur.execute(query, queryArgs):
|
||||
for ip, timeofban, data in self._getBans(*args, **kwargs):
|
||||
#TODO: Implement data parts once arbitrary match keys completed
|
||||
tickets.append(FailTicket(ip, timeofban, data['matches']))
|
||||
tickets[-1].setAttempt(data['failures'])
|
||||
return tickets
|
||||
|
||||
def getBansMerged(self, ip, *args, **kwargs):
|
||||
matches = []
|
||||
failures = 0
|
||||
for ip, timeofban, data in self._getBans(*args, ip=ip, **kwargs):
|
||||
#TODO: Implement data parts once arbitrary match keys completed
|
||||
matches.extend(data['matches'])
|
||||
failures += data['failures']
|
||||
ticket = FailTicket(ip, timeofban, matches)
|
||||
ticket.setAttempt(failures)
|
||||
return ticket
|
||||
|
||||
@commitandrollback()
|
||||
def purge(self, cur):
|
||||
cur.execute(
|
||||
|
|
|
@ -144,7 +144,8 @@ class Jail:
|
|||
self.__action.start()
|
||||
# Restore any previous valid bans from the database
|
||||
if self.__db is not None:
|
||||
for ticket in self.__db.getBans(self, self.__action.getBanTime()):
|
||||
for ticket in self.__db.getBans(
|
||||
jail=self, bantime=self.__action.getBanTime()):
|
||||
self.__queue.put(ticket)
|
||||
logSys.info("Jail '%s' started" % self.__name)
|
||||
|
||||
|
|
|
@ -124,12 +124,43 @@ class DatabaseTest(unittest.TestCase):
|
|||
|
||||
def testAddBan(self):
|
||||
self.testAddJail()
|
||||
ticket = FailTicket("127.0.0.1", 0, [])
|
||||
ticket = FailTicket("127.0.0.1", 0, ["abc\n"])
|
||||
self.db.addBan(self.jail, ticket)
|
||||
|
||||
self.assertEquals(len(self.db.getBans(self.jail)), 1)
|
||||
self.assertTrue(
|
||||
isinstance(self.db.getBans(self.jail)[0], FailTicket))
|
||||
isinstance(self.db.getBans(jail=self.jail)[0], FailTicket))
|
||||
|
||||
def testGetBansMerged(self):
|
||||
self.testAddJail()
|
||||
|
||||
jail2 = DummyJail()
|
||||
self.db.addJail(jail2)
|
||||
|
||||
ticket = FailTicket("127.0.0.1", 10, ["abc\n"])
|
||||
ticket.setAttempt(10)
|
||||
self.db.addBan(self.jail, ticket)
|
||||
ticket = FailTicket("127.0.0.1", 20, ["123\n"])
|
||||
ticket.setAttempt(20)
|
||||
self.db.addBan(self.jail, ticket)
|
||||
ticket = FailTicket("127.0.0.2", 30, ["ABC\n"])
|
||||
ticket.setAttempt(30)
|
||||
self.db.addBan(self.jail, ticket)
|
||||
ticket = FailTicket("127.0.0.1", 40, ["ABC\n"])
|
||||
ticket.setAttempt(40)
|
||||
self.db.addBan(jail2, ticket)
|
||||
|
||||
# All for IP 127.0.0.1
|
||||
ticket = self.db.getBansMerged("127.0.0.1")
|
||||
self.assertEqual(ticket.getIP(), "127.0.0.1")
|
||||
self.assertEqual(ticket.getAttempt(), 70)
|
||||
self.assertEqual(ticket.getMatches(), ["abc\n", "123\n", "ABC\n"])
|
||||
|
||||
# All for IP 127.0.0.1 for single jail
|
||||
ticket = self.db.getBansMerged("127.0.0.1", jail=self.jail)
|
||||
self.assertEqual(ticket.getIP(), "127.0.0.1")
|
||||
self.assertEqual(ticket.getAttempt(), 30)
|
||||
self.assertEqual(ticket.getMatches(), ["abc\n", "123\n"])
|
||||
|
||||
def testPurge(self):
|
||||
self.testAddJail() # Add jail
|
||||
|
@ -145,12 +176,13 @@ class DatabaseTest(unittest.TestCase):
|
|||
self.db.delJail(self.jail)
|
||||
self.db.purge() # Purge should remove all bans
|
||||
self.assertEqual(len(self.db.getJailNames()), 0)
|
||||
self.assertEqual(len(self.db.getBans(self.jail)), 0)
|
||||
self.assertEqual(len(self.db.getBans(jail=self.jail)), 0)
|
||||
|
||||
# Should leave jail
|
||||
self.testAddJail()
|
||||
self.db.addBan(self.jail, FailTicket("127.0.0.1", MyTime.time(), []))
|
||||
self.db.addBan(
|
||||
self.jail, FailTicket("127.0.0.1", MyTime.time(), ["abc\n"]))
|
||||
self.db.delJail(self.jail)
|
||||
self.db.purge() # Should leave jail as ban present
|
||||
self.assertEqual(len(self.db.getJailNames()), 1)
|
||||
self.assertEqual(len(self.db.getBans(self.jail)), 1)
|
||||
self.assertEqual(len(self.db.getBans(jail=self.jail)), 1)
|
||||
|
|
Loading…
Reference in New Issue