diff --git a/fail2ban/server/utils.py b/fail2ban/server/utils.py index b59fb4e1..198d5d4c 100644 --- a/fail2ban/server/utils.py +++ b/fail2ban/server/utils.py @@ -27,9 +27,15 @@ import os import signal import subprocess import sys +from threading import Lock import time from ..helpers import getLogger, _merge_dicts, uni_decode +try: + from collections import OrderedDict +except ImportError: # pragma: 3.x no cover + OrderedDict = dict + if sys.version_info >= (3, 3): import importlib.machinery else: @@ -69,7 +75,8 @@ class Utils(): def __init__(self, *args, **kwargs): self.setOptions(*args, **kwargs) - self._cache = {} + self._cache = OrderedDict() + self.__lock = Lock() def setOptions(self, maxCount=1000, maxTime=60): self.maxCount = maxCount @@ -83,7 +90,7 @@ class Utils(): if v: if v[1] > time.time(): return v[0] - del self._cache[k] + self.unset(k) return defv def set(self, k, v): @@ -91,12 +98,21 @@ class Utils(): cache = self._cache # for shorter local access # clean cache if max count reached: if len(cache) >= self.maxCount: - for (ck, cv) in cache.items(): - if cv[1] < t: - del cache[ck] - # if still max count - remove any one: - if len(cache) >= self.maxCount: - cache.popitem() + # avoid multiple modification of list multi-threaded: + with self.__lock: + if len(cache) >= self.maxCount: + for (ck, cv) in cache.items(): + # if expired: + if cv[1] <= t: + self.unset(ck) + elif OrderedDict is not dict: + break + # if still max count - remove any one: + if len(cache) >= self.maxCount: + if OrderedDict is not dict: # first (older): + cache.popitem(False) + else: + cache.popitem() cache[k] = (v, t + self.maxTime) def unset(self, k): diff --git a/fail2ban/tests/servertestcase.py b/fail2ban/tests/servertestcase.py index af08bd86..8b616abc 100644 --- a/fail2ban/tests/servertestcase.py +++ b/fail2ban/tests/servertestcase.py @@ -64,7 +64,7 @@ class TestServer(Server): pass -class TransmitterBase(unittest.TestCase): +class TransmitterBase(LogCaptureTestCase): def setUp(self): """Call before every test case.""" @@ -332,11 +332,11 @@ class Transmitter(TransmitterBase): self.assertEqual( self.transm.proceed(["set", self.jailName, "banip", "127.0.0.1"]), (0, "127.0.0.1")) - time.sleep(Utils.DEFAULT_SLEEP_TIME) # Give chance to ban + self.assertLogged("Ban 127.0.0.1", wait=True) # Give chance to ban self.assertEqual( self.transm.proceed(["set", self.jailName, "banip", "Badger"]), (0, "Badger")) #NOTE: Is IP address validated? Is DNS Lookup done? - time.sleep(Utils.DEFAULT_SLEEP_TIME) # Give chance to ban + self.assertLogged("Ban Badger", wait=True) # Give chance to ban # Unban IP self.assertEqual( self.transm.proceed(