optimized BanManager: increase performance, fewer system load, try to prevent memory leakage:

- better ban/unban handling within actions (e.g. used dict instead of list)
- don't copy bans resp. its list on some operations;
- added new unbantime handling to relieve unBanList (prevent permanent searching for tickets to unban)
- prefer failure-ID as identifier of the ticket to its IP (most of the time the same, but it can be something else e.g. user name in some complex jails, as introduced in 0.10)
pull/1557/head
sebres 2016-09-08 18:27:55 +02:00
parent d2ddc59c40
commit 27f6fc083a
4 changed files with 108 additions and 94 deletions

View File

@ -204,7 +204,7 @@ class Actions(JailThread, Mapping):
if db and self._jail.database is not None:
self._jail.database.delBan(self._jail, ip)
# Find the ticket with the IP.
ticket = self.__banManager.getTicketByIP(ip)
ticket = self.__banManager.getTicketByID(ip)
if ticket is not None:
# Unban the IP.
self.__unBan(ticket)
@ -303,8 +303,11 @@ class Actions(JailThread, Mapping):
bool
True if an IP address get banned.
"""
ticket = self._jail.getFailTicket()
if ticket:
cnt = 0
while cnt < 100:
ticket = self._jail.getFailTicket()
if not ticket:
break
aInfo = CallingMap()
bTicket = BanManager.createBanTicket(ticket)
ip = bTicket.getIP()
@ -320,6 +323,7 @@ class Actions(JailThread, Mapping):
aInfo["ipfailures"] = lambda: mi4ip(True).getAttempt()
aInfo["ipjailfailures"] = lambda: mi4ip().getAttempt()
if self.__banManager.addBanTicket(bTicket):
cnt += 1
logSys.notice("[%s] %sBan %s", self._jail.name, ('' if not bTicket.getRestored() else 'Restore '), ip)
for name, action in self._actions.iteritems():
try:
@ -330,19 +334,26 @@ class Actions(JailThread, Mapping):
"info '%r': %s",
self._jail.name, name, aInfo, e,
exc_info=logSys.getEffectiveLevel()<=logging.DEBUG)
return True
else:
logSys.notice("[%s] %s already banned" % (self._jail.name,
aInfo["ip"]))
return False
logSys.notice("[%s] %s already banned", self._jail.name, ip)
if cnt:
logSys.debug("Banned %s / %s, %s ticket(s) in %r", cnt,
self.__banManager.getBanTotal(), self.__banManager.size(), self._jail.name)
return cnt
def __checkUnBan(self):
"""Check for IP address to unban.
Unban IP addresses which are outdated.
"""
for ticket in self.__banManager.unBanList(MyTime.time()):
lst = self.__banManager.unBanList(MyTime.time())
for ticket in lst:
self.__unBan(ticket)
cnt = len(lst)
if cnt:
logSys.debug("Unbanned %s, %s ticket(s) in %r",
cnt, self.__banManager.size(), self._jail.name)
return cnt
def __flushBan(self, db=False):
"""Flush the ban list.
@ -358,7 +369,10 @@ class Actions(JailThread, Mapping):
self._jail.database.delBan(self._jail, ip)
# unban ip:
self.__unBan(ticket)
return len(lst)
cnt = len(lst)
logSys.debug("Unbanned %s, %s ticket(s) in %r",
cnt, self.__banManager.size(), self._jail.name)
return cnt
def __unBan(self, ticket):
"""Unbans host corresponding to the ticket.

View File

@ -51,11 +51,13 @@ class BanManager:
## Mutex used to protect the ban list.
self.__lock = Lock()
## The ban list.
self.__banList = list()
self.__banList = dict()
## The amount of time an IP address gets banned.
self.__banTime = 600
## Total number of banned IP address
self.__banTotal = 0
## The time for next unban process (for performance and load reasons):
self.__nextUnbanTime = 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFL
##
# Set the ban time.
@ -64,11 +66,8 @@ class BanManager:
# @param value the time
def setBanTime(self, value):
try:
self.__lock.acquire()
with self.__lock:
self.__banTime = int(value)
finally:
self.__lock.release()
##
# Get the ban time.
@ -77,11 +76,8 @@ class BanManager:
# @return the time
def getBanTime(self):
try:
self.__lock.acquire()
with self.__lock:
return self.__banTime
finally:
self.__lock.release()
##
# Set the total number of banned address.
@ -89,11 +85,8 @@ class BanManager:
# @param value total number
def setBanTotal(self, value):
try:
self.__lock.acquire()
with self.__lock:
self.__banTotal = value
finally:
self.__lock.release()
##
# Get the total number of banned address.
@ -101,11 +94,8 @@ class BanManager:
# @return the total number
def getBanTotal(self):
try:
self.__lock.acquire()
with self.__lock:
return self.__banTotal
finally:
self.__lock.release()
##
# Returns a copy of the IP list.
@ -113,11 +103,8 @@ class BanManager:
# @return IP list
def getBanList(self):
try:
self.__lock.acquire()
return [m.getIP() for m in self.__banList]
finally:
self.__lock.release()
with self.__lock:
return self.__banList.keys()
##
# Returns normalized value
@ -149,7 +136,7 @@ class BanManager:
return return_dict
self.__lock.acquire()
try:
for banData in self.__banList:
for banData in self.__banList.values():
ip = banData.getIP()
# Reference: http://www.team-cymru.org/Services/ip-to-asn.html#dns
question = ip.getPTR(
@ -261,29 +248,31 @@ class BanManager:
# @return True if the IP address is not in the ban list
def addBanTicket(self, ticket):
try:
self.__lock.acquire()
with self.__lock:
# check already banned
for oldticket in self.__banList:
if ticket.getIP() == oldticket.getIP():
# if already permanent
btold, told = oldticket.getBanTime(self.__banTime), oldticket.getTime()
if btold == -1:
return False
# if given time is less than already banned time
btnew, tnew = ticket.getBanTime(self.__banTime), ticket.getTime()
if btnew != -1 and tnew + btnew <= told + btold:
return False
# we have longest ban - set new (increment) ban time
oldticket.setTime(tnew)
oldticket.setBanTime(btnew)
fid = ticket.getID()
oldticket = self.__banList.get(fid)
if oldticket:
# if already permanent
btold, told = oldticket.getBanTime(self.__banTime), oldticket.getTime()
if btold == -1:
return False
# if given time is less than already banned time
btnew, tnew = ticket.getBanTime(self.__banTime), ticket.getTime()
if btnew != -1 and tnew + btnew <= told + btold:
return False
# we have longest ban - set new (increment) ban time
oldticket.setTime(tnew)
oldticket.setBanTime(btnew)
return False
# not yet banned - add new
self.__banList.append(ticket)
self.__banList[fid] = ticket
self.__banTotal += 1
# correct next unban time:
eob = ticket.getEndOfBanTime(self.__banTime)
if self.__nextUnbanTime > eob:
self.__nextUnbanTime = eob
return True
finally:
self.__lock.release()
##
# Get the size of the ban list.
@ -291,11 +280,7 @@ class BanManager:
# @return the size
def size(self):
try:
self.__lock.acquire()
return len(self.__banList)
finally:
self.__lock.release()
return len(self.__banList)
##
# Check if a ticket is in the list.
@ -306,10 +291,7 @@ class BanManager:
# @return True if a ticket already exists
def _inBanList(self, ticket):
for i in self.__banList:
if ticket.getIP() == i.getIP():
return True
return False
return ticket.getID() in self.__banList
##
# Get the list of IP address to unban.
@ -319,22 +301,39 @@ class BanManager:
# @return the list of ticket to unban
def unBanList(self, time):
try:
self.__lock.acquire()
with self.__lock:
# Permanent banning
if self.__banTime < 0:
return list()
# Gets the list of ticket to remove.
unBanList = [ticket for ticket in self.__banList if ticket.isTimedOut(time, self.__banTime)]
# Check next unban time:
if self.__nextUnbanTime > time:
return list()
# Gets the list of ticket to remove (thereby correct next unban time).
unBanList = {}
self.__nextUnbanTime = 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFL
for fid,ticket in self.__banList.iteritems():
# current time greater as end of ban - timed out:
eob = ticket.getEndOfBanTime(self.__banTime)
if time > eob:
unBanList[fid] = ticket
elif self.__nextUnbanTime > eob:
self.__nextUnbanTime = eob
# Removes tickets.
self.__banList = [ticket for ticket in self.__banList
if ticket not in unBanList]
if len(unBanList):
if len(unBanList) / 2.0 <= len(self.__banList) / 3.0:
# few as 2/3 should be removed - remove particular items:
for fid in unBanList.iterkeys():
del self.__banList[fid]
else:
# create new dictionary without items to be deleted:
self.__banList = dict((fid,ticket) for fid,ticket in self.__banList.iteritems() \
if fid not in unBanList)
return unBanList
finally:
self.__lock.release()
# return list of tickets:
return unBanList.values()
##
# Flush the ban list.
@ -343,28 +342,21 @@ class BanManager:
# @return the complete ban list
def flushBanList(self):
try:
self.__lock.acquire()
uBList = self.__banList
self.__banList = list()
with self.__lock:
uBList = self.__banList.values()
self.__banList = dict()
return uBList
finally:
self.__lock.release()
##
# Gets the ticket for the specified IP.
# Gets the ticket for the specified ID (most of the time it is IP-address).
#
# @return the ticket for the IP or False.
def getTicketByIP(self, ip):
try:
self.__lock.acquire()
# Find the ticket the IP goes with and return it
for i, ticket in enumerate(self.__banList):
if ticket.getIP() == ip:
# Return the ticket after removing (popping)
# if from the ban list.
return self.__banList.pop(i)
finally:
self.__lock.release()
# @return the ticket or False.
def getTicketByID(self, fid):
with self.__lock:
try:
# Return the ticket after removing (popping)
# if from the ban list.
return self.__banList.pop(fid)
except KeyError:
pass
return None # if none found

View File

@ -96,8 +96,8 @@ class Ticket:
def setBanTime(self, value):
self._banTime = value;
def getBanTime(self, defaultBT = None):
return (self._banTime if not self._banTime is None else defaultBT);
def getBanTime(self, defaultBT=None):
return (self._banTime if self._banTime is not None else defaultBT)
def setBanCount(self, value):
self._banCount = value;
@ -108,8 +108,16 @@ class Ticket:
def getBanCount(self):
return self._banCount;
def isTimedOut(self, time, defaultBT = None):
bantime = (self._banTime if not self._banTime is None else defaultBT);
def getEndOfBanTime(self, defaultBT=None):
bantime = (self._banTime if self._banTime is not None else defaultBT)
# permanent
if bantime == -1:
return 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFL
# unban time (end of ban):
return self._time + bantime
def isTimedOut(self, time, defaultBT=None):
bantime = (self._banTime if self._banTime is not None else defaultBT)
# permanent
if bantime == -1:
return False

View File

@ -70,7 +70,7 @@ class AddFailure(unittest.TestCase):
self.assertFalse(self.__banManager.addBanTicket(ticket2))
self.assertEqual(self.__banManager.size(), 1)
# pop ticket and check it was prolonged :
banticket = self.__banManager.getTicketByIP(ticket2.getIP())
banticket = self.__banManager.getTicketByID(ticket2.getID())
self.assertEqual(banticket.getTime(), ticket2.getTime())
self.assertEqual(banticket.getTime(), ticket2.getTime())
self.assertEqual(banticket.getBanTime(), ticket2.getBanTime(self.__banManager.getBanTime()))