diff --git a/fail2ban/server/actions.py b/fail2ban/server/actions.py index 8c60e55f..244d1e68 100644 --- a/fail2ban/server/actions.py +++ b/fail2ban/server/actions.py @@ -204,7 +204,7 @@ class Actions(JailThread, Mapping): if db and self._jail.database is not None: self._jail.database.delBan(self._jail, ip) # Find the ticket with the IP. - ticket = self.__banManager.getTicketByIP(ip) + ticket = self.__banManager.getTicketByID(ip) if ticket is not None: # Unban the IP. self.__unBan(ticket) @@ -303,8 +303,11 @@ class Actions(JailThread, Mapping): bool True if an IP address get banned. """ - ticket = self._jail.getFailTicket() - if ticket: + cnt = 0 + while cnt < 100: + ticket = self._jail.getFailTicket() + if not ticket: + break aInfo = CallingMap() bTicket = BanManager.createBanTicket(ticket) ip = bTicket.getIP() @@ -320,6 +323,7 @@ class Actions(JailThread, Mapping): aInfo["ipfailures"] = lambda: mi4ip(True).getAttempt() aInfo["ipjailfailures"] = lambda: mi4ip().getAttempt() if self.__banManager.addBanTicket(bTicket): + cnt += 1 logSys.notice("[%s] %sBan %s", self._jail.name, ('' if not bTicket.getRestored() else 'Restore '), ip) for name, action in self._actions.iteritems(): try: @@ -330,19 +334,26 @@ class Actions(JailThread, Mapping): "info '%r': %s", self._jail.name, name, aInfo, e, exc_info=logSys.getEffectiveLevel()<=logging.DEBUG) - return True else: - logSys.notice("[%s] %s already banned" % (self._jail.name, - aInfo["ip"])) - return False + logSys.notice("[%s] %s already banned", self._jail.name, ip) + if cnt: + logSys.debug("Banned %s / %s, %s ticket(s) in %r", cnt, + self.__banManager.getBanTotal(), self.__banManager.size(), self._jail.name) + return cnt def __checkUnBan(self): """Check for IP address to unban. Unban IP addresses which are outdated. """ - for ticket in self.__banManager.unBanList(MyTime.time()): + lst = self.__banManager.unBanList(MyTime.time()) + for ticket in lst: self.__unBan(ticket) + cnt = len(lst) + if cnt: + logSys.debug("Unbanned %s, %s ticket(s) in %r", + cnt, self.__banManager.size(), self._jail.name) + return cnt def __flushBan(self, db=False): """Flush the ban list. @@ -358,7 +369,10 @@ class Actions(JailThread, Mapping): self._jail.database.delBan(self._jail, ip) # unban ip: self.__unBan(ticket) - return len(lst) + cnt = len(lst) + logSys.debug("Unbanned %s, %s ticket(s) in %r", + cnt, self.__banManager.size(), self._jail.name) + return cnt def __unBan(self, ticket): """Unbans host corresponding to the ticket. diff --git a/fail2ban/server/banmanager.py b/fail2ban/server/banmanager.py index afa70685..e0a9e5ca 100644 --- a/fail2ban/server/banmanager.py +++ b/fail2ban/server/banmanager.py @@ -51,11 +51,13 @@ class BanManager: ## Mutex used to protect the ban list. self.__lock = Lock() ## The ban list. - self.__banList = list() + self.__banList = dict() ## The amount of time an IP address gets banned. self.__banTime = 600 ## Total number of banned IP address self.__banTotal = 0 + ## The time for next unban process (for performance and load reasons): + self.__nextUnbanTime = 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFL ## # Set the ban time. @@ -64,11 +66,8 @@ class BanManager: # @param value the time def setBanTime(self, value): - try: - self.__lock.acquire() + with self.__lock: self.__banTime = int(value) - finally: - self.__lock.release() ## # Get the ban time. @@ -77,11 +76,8 @@ class BanManager: # @return the time def getBanTime(self): - try: - self.__lock.acquire() + with self.__lock: return self.__banTime - finally: - self.__lock.release() ## # Set the total number of banned address. @@ -89,11 +85,8 @@ class BanManager: # @param value total number def setBanTotal(self, value): - try: - self.__lock.acquire() + with self.__lock: self.__banTotal = value - finally: - self.__lock.release() ## # Get the total number of banned address. @@ -101,11 +94,8 @@ class BanManager: # @return the total number def getBanTotal(self): - try: - self.__lock.acquire() + with self.__lock: return self.__banTotal - finally: - self.__lock.release() ## # Returns a copy of the IP list. @@ -113,11 +103,8 @@ class BanManager: # @return IP list def getBanList(self): - try: - self.__lock.acquire() - return [m.getIP() for m in self.__banList] - finally: - self.__lock.release() + with self.__lock: + return self.__banList.keys() ## # Returns normalized value @@ -149,7 +136,7 @@ class BanManager: return return_dict self.__lock.acquire() try: - for banData in self.__banList: + for banData in self.__banList.values(): ip = banData.getIP() # Reference: http://www.team-cymru.org/Services/ip-to-asn.html#dns question = ip.getPTR( @@ -261,29 +248,31 @@ class BanManager: # @return True if the IP address is not in the ban list def addBanTicket(self, ticket): - try: - self.__lock.acquire() + with self.__lock: # check already banned - for oldticket in self.__banList: - if ticket.getIP() == oldticket.getIP(): - # if already permanent - btold, told = oldticket.getBanTime(self.__banTime), oldticket.getTime() - if btold == -1: - return False - # if given time is less than already banned time - btnew, tnew = ticket.getBanTime(self.__banTime), ticket.getTime() - if btnew != -1 and tnew + btnew <= told + btold: - return False - # we have longest ban - set new (increment) ban time - oldticket.setTime(tnew) - oldticket.setBanTime(btnew) + fid = ticket.getID() + oldticket = self.__banList.get(fid) + if oldticket: + # if already permanent + btold, told = oldticket.getBanTime(self.__banTime), oldticket.getTime() + if btold == -1: return False + # if given time is less than already banned time + btnew, tnew = ticket.getBanTime(self.__banTime), ticket.getTime() + if btnew != -1 and tnew + btnew <= told + btold: + return False + # we have longest ban - set new (increment) ban time + oldticket.setTime(tnew) + oldticket.setBanTime(btnew) + return False # not yet banned - add new - self.__banList.append(ticket) + self.__banList[fid] = ticket self.__banTotal += 1 + # correct next unban time: + eob = ticket.getEndOfBanTime(self.__banTime) + if self.__nextUnbanTime > eob: + self.__nextUnbanTime = eob return True - finally: - self.__lock.release() ## # Get the size of the ban list. @@ -291,11 +280,7 @@ class BanManager: # @return the size def size(self): - try: - self.__lock.acquire() - return len(self.__banList) - finally: - self.__lock.release() + return len(self.__banList) ## # Check if a ticket is in the list. @@ -306,10 +291,7 @@ class BanManager: # @return True if a ticket already exists def _inBanList(self, ticket): - for i in self.__banList: - if ticket.getIP() == i.getIP(): - return True - return False + return ticket.getID() in self.__banList ## # Get the list of IP address to unban. @@ -319,22 +301,39 @@ class BanManager: # @return the list of ticket to unban def unBanList(self, time): - try: - self.__lock.acquire() + with self.__lock: # Permanent banning if self.__banTime < 0: return list() - # Gets the list of ticket to remove. - unBanList = [ticket for ticket in self.__banList if ticket.isTimedOut(time, self.__banTime)] - + # Check next unban time: + if self.__nextUnbanTime > time: + return list() + + # Gets the list of ticket to remove (thereby correct next unban time). + unBanList = {} + self.__nextUnbanTime = 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFL + for fid,ticket in self.__banList.iteritems(): + # current time greater as end of ban - timed out: + eob = ticket.getEndOfBanTime(self.__banTime) + if time > eob: + unBanList[fid] = ticket + elif self.__nextUnbanTime > eob: + self.__nextUnbanTime = eob + # Removes tickets. - self.__banList = [ticket for ticket in self.__banList - if ticket not in unBanList] + if len(unBanList): + if len(unBanList) / 2.0 <= len(self.__banList) / 3.0: + # few as 2/3 should be removed - remove particular items: + for fid in unBanList.iterkeys(): + del self.__banList[fid] + else: + # create new dictionary without items to be deleted: + self.__banList = dict((fid,ticket) for fid,ticket in self.__banList.iteritems() \ + if fid not in unBanList) - return unBanList - finally: - self.__lock.release() + # return list of tickets: + return unBanList.values() ## # Flush the ban list. @@ -343,28 +342,21 @@ class BanManager: # @return the complete ban list def flushBanList(self): - try: - self.__lock.acquire() - uBList = self.__banList - self.__banList = list() + with self.__lock: + uBList = self.__banList.values() + self.__banList = dict() return uBList - finally: - self.__lock.release() ## - # Gets the ticket for the specified IP. + # Gets the ticket for the specified ID (most of the time it is IP-address). # - # @return the ticket for the IP or False. - def getTicketByIP(self, ip): - try: - self.__lock.acquire() - - # Find the ticket the IP goes with and return it - for i, ticket in enumerate(self.__banList): - if ticket.getIP() == ip: - # Return the ticket after removing (popping) - # if from the ban list. - return self.__banList.pop(i) - finally: - self.__lock.release() + # @return the ticket or False. + def getTicketByID(self, fid): + with self.__lock: + try: + # Return the ticket after removing (popping) + # if from the ban list. + return self.__banList.pop(fid) + except KeyError: + pass return None # if none found diff --git a/fail2ban/server/ticket.py b/fail2ban/server/ticket.py index 8d07fcf3..4dc6ab88 100644 --- a/fail2ban/server/ticket.py +++ b/fail2ban/server/ticket.py @@ -96,8 +96,8 @@ class Ticket: def setBanTime(self, value): self._banTime = value; - def getBanTime(self, defaultBT = None): - return (self._banTime if not self._banTime is None else defaultBT); + def getBanTime(self, defaultBT=None): + return (self._banTime if self._banTime is not None else defaultBT) def setBanCount(self, value): self._banCount = value; @@ -108,8 +108,16 @@ class Ticket: def getBanCount(self): return self._banCount; - def isTimedOut(self, time, defaultBT = None): - bantime = (self._banTime if not self._banTime is None else defaultBT); + def getEndOfBanTime(self, defaultBT=None): + bantime = (self._banTime if self._banTime is not None else defaultBT) + # permanent + if bantime == -1: + return 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFL + # unban time (end of ban): + return self._time + bantime + + def isTimedOut(self, time, defaultBT=None): + bantime = (self._banTime if self._banTime is not None else defaultBT) # permanent if bantime == -1: return False diff --git a/fail2ban/tests/banmanagertestcase.py b/fail2ban/tests/banmanagertestcase.py index f47dc848..c9f41c6c 100644 --- a/fail2ban/tests/banmanagertestcase.py +++ b/fail2ban/tests/banmanagertestcase.py @@ -70,7 +70,7 @@ class AddFailure(unittest.TestCase): self.assertFalse(self.__banManager.addBanTicket(ticket2)) self.assertEqual(self.__banManager.size(), 1) # pop ticket and check it was prolonged : - banticket = self.__banManager.getTicketByIP(ticket2.getIP()) + banticket = self.__banManager.getTicketByID(ticket2.getID()) self.assertEqual(banticket.getTime(), ticket2.getTime()) self.assertEqual(banticket.getTime(), ticket2.getTime()) self.assertEqual(banticket.getBanTime(), ticket2.getBanTime(self.__banManager.getBanTime()))