fix corner cases by maxEntries = 0 (no matches should be saved), test cases extended to cover it + code review

pull/2402/head
sebres 2019-04-18 19:37:42 +02:00
parent 5ebac4fe61
commit 5df78ad11f
5 changed files with 40 additions and 9 deletions

View File

@ -541,8 +541,13 @@ class Fail2BanDb(object):
#TODO: Implement data parts once arbitrary match keys completed #TODO: Implement data parts once arbitrary match keys completed
data = ticket.getData() data = ticket.getData()
matches = data.get('matches') matches = data.get('matches')
if matches and len(matches) > self.maxEntries: if self.maxEntries:
data['matches'] = matches[-self.maxEntries:] if matches and len(matches) > self.maxEntries:
data = data.copy()
data['matches'] = matches[-self.maxEntries:]
elif matches:
data = data.copy()
del data['matches']
cur.execute( cur.execute(
"INSERT INTO bans(jail, ip, timeofban, data) VALUES(?, ?, ?, ?)", "INSERT INTO bans(jail, ip, timeofban, data) VALUES(?, ?, ?, ?)",
(jail.name, ip, int(round(ticket.getTime())), data)) (jail.name, ip, int(round(ticket.getTime())), data))
@ -702,6 +707,8 @@ class Fail2BanDb(object):
queryArgs.append(fromtime - forbantime) queryArgs.append(fromtime - forbantime)
if ip is None: if ip is None:
query += " GROUP BY ip ORDER BY ip, timeofban DESC" query += " GROUP BY ip ORDER BY ip, timeofban DESC"
else:
query += " ORDER BY timeofban DESC LIMIT 1"
cur = self._db.cursor() cur = self._db.cursor()
return cur.execute(query, queryArgs) return cur.execute(query, queryArgs)
@ -718,9 +725,10 @@ class Fail2BanDb(object):
# logSys.debug('restore ticket %r, %r, %r', banip, timeofban, data) # logSys.debug('restore ticket %r, %r, %r', banip, timeofban, data)
ticket = FailTicket(banip, timeofban, data=data) ticket = FailTicket(banip, timeofban, data=data)
# logSys.debug('restored ticket: %r', ticket) # logSys.debug('restored ticket: %r', ticket)
if ip is not None: return ticket
tickets.append(ticket) tickets.append(ticket)
return tickets if ip is None else ticket return tickets
@commitandrollback @commitandrollback
def purge(self, cur): def purge(self, cur):

View File

@ -87,7 +87,7 @@ class FailManager:
attempt = 1 attempt = 1
else: else:
# will be incremented / extended (be sure we have at least +1 attempt): # will be incremented / extended (be sure we have at least +1 attempt):
matches = ticket.getMatches() matches = ticket.getMatches() if self.maxEntries else None
attempt = ticket.getAttempt() attempt = ticket.getAttempt()
if attempt <= 0: if attempt <= 0:
attempt += 1 attempt += 1
@ -98,9 +98,12 @@ class FailManager:
fData.setRetry(0) fData.setRetry(0)
fData.inc(matches, attempt, count) fData.inc(matches, attempt, count)
# truncate to maxEntries: # truncate to maxEntries:
matches = fData.getMatches() if self.maxEntries:
if len(matches) > self.maxEntries: matches = fData.getMatches()
fData.setMatches(matches[-self.maxEntries:]) if len(matches) > self.maxEntries:
fData.setMatches(matches[-self.maxEntries:])
else:
fData.setMatches(None)
except KeyError: except KeyError:
# if already FailTicket - add it direct, otherwise create (using copy all ticket data): # if already FailTicket - add it direct, otherwise create (using copy all ticket data):
if isinstance(ticket, FailTicket): if isinstance(ticket, FailTicket):

View File

@ -135,7 +135,13 @@ class Ticket(object):
return self._data['failures'] return self._data['failures']
def setMatches(self, matches): def setMatches(self, matches):
self._data['matches'] = matches or [] if matches:
self._data['matches'] = matches
else:
try:
del self._data['matches']
except KeyError:
pass
def getMatches(self): def getMatches(self):
return [(line if not isinstance(line, (list, tuple)) else "".join(line)) \ return [(line if not isinstance(line, (list, tuple)) else "".join(line)) \

View File

@ -361,7 +361,6 @@ class DatabaseTest(LogCaptureTestCase):
ticket.setAttempt(len(failures)) ticket.setAttempt(len(failures))
self.db.addBan(self.jail, ticket) self.db.addBan(self.jail, ticket)
# should retrieve 2 matches only, but count of all attempts: # should retrieve 2 matches only, but count of all attempts:
self.db.maxEntries = maxEntries;
ticket = self.db.getBansMerged("127.0.0.1") ticket = self.db.getBansMerged("127.0.0.1")
self.assertEqual(ticket.getAttempt(), 2 * len(failures)) self.assertEqual(ticket.getAttempt(), 2 * len(failures))
self.assertEqual(len(ticket.getMatches()), maxEntries) self.assertEqual(len(ticket.getMatches()), maxEntries)
@ -372,6 +371,13 @@ class DatabaseTest(LogCaptureTestCase):
self.assertEqual(ticket.getAttempt(), len(failures)) self.assertEqual(ticket.getAttempt(), len(failures))
self.assertEqual(len(ticket.getMatches()), maxEntries) self.assertEqual(len(ticket.getMatches()), maxEntries)
self.assertEqual(ticket.getMatches(), matches2find[-maxEntries:]) self.assertEqual(ticket.getMatches(), matches2find[-maxEntries:])
# should retrieve 0 matches by last ban:
self.db.maxEntries = 0;
self.db.addBan(self.jail, ticket)
ticket = self.db.getCurrentBans(self.jail, "127.0.0.1", fromtime=MyTime.time()-100)
self.assertTrue(ticket is not None)
self.assertEqual(ticket.getAttempt(), len(failures))
self.assertEqual(len(ticket.getMatches()), 0)
def testGetBansMerged(self): def testGetBansMerged(self):
self.testAddJail() self.testAddJail()

View File

@ -110,6 +110,14 @@ class AddFailure(unittest.TestCase):
self.assertEqual(ticket.getAttempt(), 2 * len(failures) + 1) self.assertEqual(ticket.getAttempt(), 2 * len(failures) + 1)
self.assertEqual(len(ticket.getMatches()), maxEntries) self.assertEqual(len(ticket.getMatches()), maxEntries)
self.assertEqual(ticket.getMatches(), failures[len(failures) - maxEntries:]) self.assertEqual(ticket.getMatches(), failures[len(failures) - maxEntries:])
# no matches by maxEntries == 0 :
self.__failManager.maxEntries = 0
self.__failManager.addFailure(ticket)
manFailList = self.__failManager._FailManager__failList
ticket = manFailList["127.0.0.1"]
self.assertEqual(len(ticket.getMatches()), 0)
# test set matches None to None:
ticket.setMatches(None)
def testFailManagerMaxTime(self): def testFailManagerMaxTime(self):
self._addDefItems() self._addDefItems()