mirror of https://github.com/fail2ban/fail2ban
fix corner cases by maxEntries = 0 (no matches should be saved), test cases extended to cover it + code review
parent
5ebac4fe61
commit
5df78ad11f
|
@ -541,8 +541,13 @@ class Fail2BanDb(object):
|
||||||
#TODO: Implement data parts once arbitrary match keys completed
|
#TODO: Implement data parts once arbitrary match keys completed
|
||||||
data = ticket.getData()
|
data = ticket.getData()
|
||||||
matches = data.get('matches')
|
matches = data.get('matches')
|
||||||
if matches and len(matches) > self.maxEntries:
|
if self.maxEntries:
|
||||||
data['matches'] = matches[-self.maxEntries:]
|
if matches and len(matches) > self.maxEntries:
|
||||||
|
data = data.copy()
|
||||||
|
data['matches'] = matches[-self.maxEntries:]
|
||||||
|
elif matches:
|
||||||
|
data = data.copy()
|
||||||
|
del data['matches']
|
||||||
cur.execute(
|
cur.execute(
|
||||||
"INSERT INTO bans(jail, ip, timeofban, data) VALUES(?, ?, ?, ?)",
|
"INSERT INTO bans(jail, ip, timeofban, data) VALUES(?, ?, ?, ?)",
|
||||||
(jail.name, ip, int(round(ticket.getTime())), data))
|
(jail.name, ip, int(round(ticket.getTime())), data))
|
||||||
|
@ -702,6 +707,8 @@ class Fail2BanDb(object):
|
||||||
queryArgs.append(fromtime - forbantime)
|
queryArgs.append(fromtime - forbantime)
|
||||||
if ip is None:
|
if ip is None:
|
||||||
query += " GROUP BY ip ORDER BY ip, timeofban DESC"
|
query += " GROUP BY ip ORDER BY ip, timeofban DESC"
|
||||||
|
else:
|
||||||
|
query += " ORDER BY timeofban DESC LIMIT 1"
|
||||||
cur = self._db.cursor()
|
cur = self._db.cursor()
|
||||||
return cur.execute(query, queryArgs)
|
return cur.execute(query, queryArgs)
|
||||||
|
|
||||||
|
@ -718,9 +725,10 @@ class Fail2BanDb(object):
|
||||||
# logSys.debug('restore ticket %r, %r, %r', banip, timeofban, data)
|
# logSys.debug('restore ticket %r, %r, %r', banip, timeofban, data)
|
||||||
ticket = FailTicket(banip, timeofban, data=data)
|
ticket = FailTicket(banip, timeofban, data=data)
|
||||||
# logSys.debug('restored ticket: %r', ticket)
|
# logSys.debug('restored ticket: %r', ticket)
|
||||||
|
if ip is not None: return ticket
|
||||||
tickets.append(ticket)
|
tickets.append(ticket)
|
||||||
|
|
||||||
return tickets if ip is None else ticket
|
return tickets
|
||||||
|
|
||||||
@commitandrollback
|
@commitandrollback
|
||||||
def purge(self, cur):
|
def purge(self, cur):
|
||||||
|
|
|
@ -87,7 +87,7 @@ class FailManager:
|
||||||
attempt = 1
|
attempt = 1
|
||||||
else:
|
else:
|
||||||
# will be incremented / extended (be sure we have at least +1 attempt):
|
# will be incremented / extended (be sure we have at least +1 attempt):
|
||||||
matches = ticket.getMatches()
|
matches = ticket.getMatches() if self.maxEntries else None
|
||||||
attempt = ticket.getAttempt()
|
attempt = ticket.getAttempt()
|
||||||
if attempt <= 0:
|
if attempt <= 0:
|
||||||
attempt += 1
|
attempt += 1
|
||||||
|
@ -98,9 +98,12 @@ class FailManager:
|
||||||
fData.setRetry(0)
|
fData.setRetry(0)
|
||||||
fData.inc(matches, attempt, count)
|
fData.inc(matches, attempt, count)
|
||||||
# truncate to maxEntries:
|
# truncate to maxEntries:
|
||||||
matches = fData.getMatches()
|
if self.maxEntries:
|
||||||
if len(matches) > self.maxEntries:
|
matches = fData.getMatches()
|
||||||
fData.setMatches(matches[-self.maxEntries:])
|
if len(matches) > self.maxEntries:
|
||||||
|
fData.setMatches(matches[-self.maxEntries:])
|
||||||
|
else:
|
||||||
|
fData.setMatches(None)
|
||||||
except KeyError:
|
except KeyError:
|
||||||
# if already FailTicket - add it direct, otherwise create (using copy all ticket data):
|
# if already FailTicket - add it direct, otherwise create (using copy all ticket data):
|
||||||
if isinstance(ticket, FailTicket):
|
if isinstance(ticket, FailTicket):
|
||||||
|
|
|
@ -135,7 +135,13 @@ class Ticket(object):
|
||||||
return self._data['failures']
|
return self._data['failures']
|
||||||
|
|
||||||
def setMatches(self, matches):
|
def setMatches(self, matches):
|
||||||
self._data['matches'] = matches or []
|
if matches:
|
||||||
|
self._data['matches'] = matches
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
del self._data['matches']
|
||||||
|
except KeyError:
|
||||||
|
pass
|
||||||
|
|
||||||
def getMatches(self):
|
def getMatches(self):
|
||||||
return [(line if not isinstance(line, (list, tuple)) else "".join(line)) \
|
return [(line if not isinstance(line, (list, tuple)) else "".join(line)) \
|
||||||
|
|
|
@ -361,7 +361,6 @@ class DatabaseTest(LogCaptureTestCase):
|
||||||
ticket.setAttempt(len(failures))
|
ticket.setAttempt(len(failures))
|
||||||
self.db.addBan(self.jail, ticket)
|
self.db.addBan(self.jail, ticket)
|
||||||
# should retrieve 2 matches only, but count of all attempts:
|
# should retrieve 2 matches only, but count of all attempts:
|
||||||
self.db.maxEntries = maxEntries;
|
|
||||||
ticket = self.db.getBansMerged("127.0.0.1")
|
ticket = self.db.getBansMerged("127.0.0.1")
|
||||||
self.assertEqual(ticket.getAttempt(), 2 * len(failures))
|
self.assertEqual(ticket.getAttempt(), 2 * len(failures))
|
||||||
self.assertEqual(len(ticket.getMatches()), maxEntries)
|
self.assertEqual(len(ticket.getMatches()), maxEntries)
|
||||||
|
@ -372,6 +371,13 @@ class DatabaseTest(LogCaptureTestCase):
|
||||||
self.assertEqual(ticket.getAttempt(), len(failures))
|
self.assertEqual(ticket.getAttempt(), len(failures))
|
||||||
self.assertEqual(len(ticket.getMatches()), maxEntries)
|
self.assertEqual(len(ticket.getMatches()), maxEntries)
|
||||||
self.assertEqual(ticket.getMatches(), matches2find[-maxEntries:])
|
self.assertEqual(ticket.getMatches(), matches2find[-maxEntries:])
|
||||||
|
# should retrieve 0 matches by last ban:
|
||||||
|
self.db.maxEntries = 0;
|
||||||
|
self.db.addBan(self.jail, ticket)
|
||||||
|
ticket = self.db.getCurrentBans(self.jail, "127.0.0.1", fromtime=MyTime.time()-100)
|
||||||
|
self.assertTrue(ticket is not None)
|
||||||
|
self.assertEqual(ticket.getAttempt(), len(failures))
|
||||||
|
self.assertEqual(len(ticket.getMatches()), 0)
|
||||||
|
|
||||||
def testGetBansMerged(self):
|
def testGetBansMerged(self):
|
||||||
self.testAddJail()
|
self.testAddJail()
|
||||||
|
|
|
@ -110,6 +110,14 @@ class AddFailure(unittest.TestCase):
|
||||||
self.assertEqual(ticket.getAttempt(), 2 * len(failures) + 1)
|
self.assertEqual(ticket.getAttempt(), 2 * len(failures) + 1)
|
||||||
self.assertEqual(len(ticket.getMatches()), maxEntries)
|
self.assertEqual(len(ticket.getMatches()), maxEntries)
|
||||||
self.assertEqual(ticket.getMatches(), failures[len(failures) - maxEntries:])
|
self.assertEqual(ticket.getMatches(), failures[len(failures) - maxEntries:])
|
||||||
|
# no matches by maxEntries == 0 :
|
||||||
|
self.__failManager.maxEntries = 0
|
||||||
|
self.__failManager.addFailure(ticket)
|
||||||
|
manFailList = self.__failManager._FailManager__failList
|
||||||
|
ticket = manFailList["127.0.0.1"]
|
||||||
|
self.assertEqual(len(ticket.getMatches()), 0)
|
||||||
|
# test set matches None to None:
|
||||||
|
ticket.setMatches(None)
|
||||||
|
|
||||||
def testFailManagerMaxTime(self):
|
def testFailManagerMaxTime(self):
|
||||||
self._addDefItems()
|
self._addDefItems()
|
||||||
|
|
Loading…
Reference in New Issue