mirror of https://github.com/fail2ban/fail2ban
fix corner cases by maxEntries = 0 (no matches should be saved), test cases extended to cover it + code review
parent
5ebac4fe61
commit
5df78ad11f
|
@ -541,8 +541,13 @@ class Fail2BanDb(object):
|
|||
#TODO: Implement data parts once arbitrary match keys completed
|
||||
data = ticket.getData()
|
||||
matches = data.get('matches')
|
||||
if matches and len(matches) > self.maxEntries:
|
||||
data['matches'] = matches[-self.maxEntries:]
|
||||
if self.maxEntries:
|
||||
if matches and len(matches) > self.maxEntries:
|
||||
data = data.copy()
|
||||
data['matches'] = matches[-self.maxEntries:]
|
||||
elif matches:
|
||||
data = data.copy()
|
||||
del data['matches']
|
||||
cur.execute(
|
||||
"INSERT INTO bans(jail, ip, timeofban, data) VALUES(?, ?, ?, ?)",
|
||||
(jail.name, ip, int(round(ticket.getTime())), data))
|
||||
|
@ -702,6 +707,8 @@ class Fail2BanDb(object):
|
|||
queryArgs.append(fromtime - forbantime)
|
||||
if ip is None:
|
||||
query += " GROUP BY ip ORDER BY ip, timeofban DESC"
|
||||
else:
|
||||
query += " ORDER BY timeofban DESC LIMIT 1"
|
||||
cur = self._db.cursor()
|
||||
return cur.execute(query, queryArgs)
|
||||
|
||||
|
@ -718,9 +725,10 @@ class Fail2BanDb(object):
|
|||
# logSys.debug('restore ticket %r, %r, %r', banip, timeofban, data)
|
||||
ticket = FailTicket(banip, timeofban, data=data)
|
||||
# logSys.debug('restored ticket: %r', ticket)
|
||||
if ip is not None: return ticket
|
||||
tickets.append(ticket)
|
||||
|
||||
return tickets if ip is None else ticket
|
||||
return tickets
|
||||
|
||||
@commitandrollback
|
||||
def purge(self, cur):
|
||||
|
|
|
@ -87,7 +87,7 @@ class FailManager:
|
|||
attempt = 1
|
||||
else:
|
||||
# will be incremented / extended (be sure we have at least +1 attempt):
|
||||
matches = ticket.getMatches()
|
||||
matches = ticket.getMatches() if self.maxEntries else None
|
||||
attempt = ticket.getAttempt()
|
||||
if attempt <= 0:
|
||||
attempt += 1
|
||||
|
@ -98,9 +98,12 @@ class FailManager:
|
|||
fData.setRetry(0)
|
||||
fData.inc(matches, attempt, count)
|
||||
# truncate to maxEntries:
|
||||
matches = fData.getMatches()
|
||||
if len(matches) > self.maxEntries:
|
||||
fData.setMatches(matches[-self.maxEntries:])
|
||||
if self.maxEntries:
|
||||
matches = fData.getMatches()
|
||||
if len(matches) > self.maxEntries:
|
||||
fData.setMatches(matches[-self.maxEntries:])
|
||||
else:
|
||||
fData.setMatches(None)
|
||||
except KeyError:
|
||||
# if already FailTicket - add it direct, otherwise create (using copy all ticket data):
|
||||
if isinstance(ticket, FailTicket):
|
||||
|
|
|
@ -135,7 +135,13 @@ class Ticket(object):
|
|||
return self._data['failures']
|
||||
|
||||
def setMatches(self, matches):
|
||||
self._data['matches'] = matches or []
|
||||
if matches:
|
||||
self._data['matches'] = matches
|
||||
else:
|
||||
try:
|
||||
del self._data['matches']
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
def getMatches(self):
|
||||
return [(line if not isinstance(line, (list, tuple)) else "".join(line)) \
|
||||
|
|
|
@ -361,7 +361,6 @@ class DatabaseTest(LogCaptureTestCase):
|
|||
ticket.setAttempt(len(failures))
|
||||
self.db.addBan(self.jail, ticket)
|
||||
# should retrieve 2 matches only, but count of all attempts:
|
||||
self.db.maxEntries = maxEntries;
|
||||
ticket = self.db.getBansMerged("127.0.0.1")
|
||||
self.assertEqual(ticket.getAttempt(), 2 * len(failures))
|
||||
self.assertEqual(len(ticket.getMatches()), maxEntries)
|
||||
|
@ -372,6 +371,13 @@ class DatabaseTest(LogCaptureTestCase):
|
|||
self.assertEqual(ticket.getAttempt(), len(failures))
|
||||
self.assertEqual(len(ticket.getMatches()), maxEntries)
|
||||
self.assertEqual(ticket.getMatches(), matches2find[-maxEntries:])
|
||||
# should retrieve 0 matches by last ban:
|
||||
self.db.maxEntries = 0;
|
||||
self.db.addBan(self.jail, ticket)
|
||||
ticket = self.db.getCurrentBans(self.jail, "127.0.0.1", fromtime=MyTime.time()-100)
|
||||
self.assertTrue(ticket is not None)
|
||||
self.assertEqual(ticket.getAttempt(), len(failures))
|
||||
self.assertEqual(len(ticket.getMatches()), 0)
|
||||
|
||||
def testGetBansMerged(self):
|
||||
self.testAddJail()
|
||||
|
|
|
@ -110,6 +110,14 @@ class AddFailure(unittest.TestCase):
|
|||
self.assertEqual(ticket.getAttempt(), 2 * len(failures) + 1)
|
||||
self.assertEqual(len(ticket.getMatches()), maxEntries)
|
||||
self.assertEqual(ticket.getMatches(), failures[len(failures) - maxEntries:])
|
||||
# no matches by maxEntries == 0 :
|
||||
self.__failManager.maxEntries = 0
|
||||
self.__failManager.addFailure(ticket)
|
||||
manFailList = self.__failManager._FailManager__failList
|
||||
ticket = manFailList["127.0.0.1"]
|
||||
self.assertEqual(len(ticket.getMatches()), 0)
|
||||
# test set matches None to None:
|
||||
ticket.setMatches(None)
|
||||
|
||||
def testFailManagerMaxTime(self):
|
||||
self._addDefItems()
|
||||
|
|
Loading…
Reference in New Issue