diff --git a/fail2ban/client/csocket.py b/fail2ban/client/csocket.py index ab3e294b..88795674 100644 --- a/fail2ban/client/csocket.py +++ b/fail2ban/client/csocket.py @@ -48,7 +48,8 @@ class CSocket: def send(self, msg, nonblocking=False, timeout=None): # Convert every list member to string obj = dumps(map(CSocket.convert, msg), HIGHEST_PROTOCOL) - self.__csock.send(obj + CSPROTO.END) + self.__csock.send(obj) + self.__csock.send(CSPROTO.END) return self.receive(self.__csock, nonblocking, timeout) def settimeout(self, timeout): @@ -81,9 +82,12 @@ class CSocket: msg = CSPROTO.EMPTY if nonblocking: sock.setblocking(0) if timeout: sock.settimeout(timeout) - while msg.rfind(CSPROTO.END) == -1: - chunk = sock.recv(512) - if chunk in ('', b''): # python 3.x may return b'' instead of '' - raise RuntimeError("socket connection broken") + bufsize = 1024 + while msg.rfind(CSPROTO.END, -32) == -1: + chunk = sock.recv(bufsize) + if not len(chunk): + raise socket.error(104, 'Connection reset by peer') + if chunk == CSPROTO.END: break msg = msg + chunk + if bufsize < 32768: bufsize <<= 1 return loads(msg) diff --git a/fail2ban/client/fail2bancmdline.py b/fail2ban/client/fail2bancmdline.py index 8936e03f..03683cad 100644 --- a/fail2ban/client/fail2bancmdline.py +++ b/fail2ban/client/fail2bancmdline.py @@ -27,13 +27,17 @@ import sys from ..version import version, normVersion from ..protocol import printFormatted -from ..helpers import getLogger, str2LogLevel, getVerbosityFormat +from ..helpers import getLogger, str2LogLevel, getVerbosityFormat, BrokenPipeError # Gets the instance of the logger. logSys = getLogger("fail2ban") def output(s): # pragma: no cover - print(s) + try: + print(s) + except (BrokenPipeError, IOError) as e: # pragma: no cover + if e.errno != 32: # closed / broken pipe + raise # Config parameters required to start fail2ban which can be also set via command line (overwrite fail2ban.conf), CONFIG_PARAMS = ("socket", "pidfile", "logtarget", "loglevel", "syslogsocket") @@ -310,12 +314,16 @@ class Fail2banCmdLine(): def _exit(code=0): # implicit flush without to produce broken pipe error (32): sys.stderr.close() - sys.stdout.close() - # exit: - if hasattr(os, '_exit') and os._exit: - os._exit(code) - else: - sys.exit(code) + try: + sys.stdout.flush() + # exit: + if hasattr(sys, 'exit') and sys.exit: + sys.exit(code) + else: + os._exit(code) + except (BrokenPipeError, IOError) as e: # pragma: no cover + if e.errno != 32: # closed / broken pipe + raise @staticmethod def exit(code=0): diff --git a/fail2ban/helpers.py b/fail2ban/helpers.py index f381576e..c45be849 100644 --- a/fail2ban/helpers.py +++ b/fail2ban/helpers.py @@ -224,9 +224,10 @@ def __stopOnIOError(logSys=None, logHndlr=None): # pragma: no cover sys.exit(0) try: - BrokenPipeError + BrokenPipeError = BrokenPipeError except NameError: # pragma: 3.x no cover - BrokenPipeError = IOError + BrokenPipeError = IOError + __origLog = logging.Logger._log def __safeLog(self, level, msg, args, **kwargs): """Safe log inject to avoid possible errors by unsafe log-handlers, diff --git a/fail2ban/server/actions.py b/fail2ban/server/actions.py index 3308d4b2..f14d8d7b 100644 --- a/fail2ban/server/actions.py +++ b/fail2ban/server/actions.py @@ -661,13 +661,19 @@ class Actions(JailThread, Mapping): """Status of current and total ban counts and current banned IP list. """ # TODO: Allow this list to be printed as 'status' output - supported_flavors = ["basic", "cymru"] + supported_flavors = ["short", "basic", "cymru"] if flavor is None or flavor not in supported_flavors: logSys.warning("Unsupported extended jail status flavor %r. Supported: %s" % (flavor, supported_flavors)) # Always print this information (basic) - ret = [("Currently banned", self.__banManager.size()), - ("Total banned", self.__banManager.getBanTotal()), - ("Banned IP list", self.__banManager.getBanList())] + if flavor != "short": + banned = self.__banManager.getBanList() + cnt = len(banned) + else: + cnt = self.__banManager.size() + ret = [("Currently banned", cnt), + ("Total banned", self.__banManager.getBanTotal())] + if flavor != "short": + ret += [("Banned IP list", banned)] if flavor == "cymru": cymru_info = self.__banManager.getBanListExtendedCymruInfo() ret += \ diff --git a/fail2ban/server/banmanager.py b/fail2ban/server/banmanager.py index 479ba26f..fe38dec6 100644 --- a/fail2ban/server/banmanager.py +++ b/fail2ban/server/banmanager.py @@ -66,8 +66,7 @@ class BanManager: # @param value the time def setBanTime(self, value): - with self.__lock: - self.__banTime = int(value) + self.__banTime = int(value) ## # Get the ban time. @@ -76,8 +75,7 @@ class BanManager: # @return the time def getBanTime(self): - with self.__lock: - return self.__banTime + return self.__banTime ## # Set the total number of banned address. @@ -85,8 +83,7 @@ class BanManager: # @param value total number def setBanTotal(self, value): - with self.__lock: - self.__banTotal = value + self.__banTotal = value ## # Get the total number of banned address. @@ -94,8 +91,7 @@ class BanManager: # @return the total number def getBanTotal(self): - with self.__lock: - return self.__banTotal + return self.__banTotal ## # Returns a copy of the IP list. @@ -103,8 +99,7 @@ class BanManager: # @return IP list def getBanList(self): - with self.__lock: - return list(self.__banList.keys()) + return list(self.__banList.keys()) ## # Returns a iterator to ban list (used in reload, so idle). @@ -112,9 +107,8 @@ class BanManager: # @return ban list iterator def __iter__(self): - # ensure iterator is safe (traverse over the list in snapshot created within lock): - with self.__lock: - return iter(list(self.__banList.values())) + # ensure iterator is safe - traverse over the list in snapshot created within lock (GIL): + return iter(list(self.__banList.values())) ## # Returns normalized value diff --git a/fail2ban/server/failmanager.py b/fail2ban/server/failmanager.py index 3458aed5..3e81e8b5 100644 --- a/fail2ban/server/failmanager.py +++ b/fail2ban/server/failmanager.py @@ -47,12 +47,10 @@ class FailManager: self.__bgSvc = BgService() def setFailTotal(self, value): - with self.__lock: - self.__failTotal = value + self.__failTotal = value def getFailTotal(self): - with self.__lock: - return self.__failTotal + return self.__failTotal def getFailCount(self): # may be slow on large list of failures, should be used for test purposes only... @@ -123,8 +121,7 @@ class FailManager: return attempts def size(self): - with self.__lock: - return len(self.__failList) + return len(self.__failList) def cleanup(self, time): with self.__lock: diff --git a/fail2ban/tests/actionstestcase.py b/fail2ban/tests/actionstestcase.py index d97d9921..532fe6ed 100644 --- a/fail2ban/tests/actionstestcase.py +++ b/fail2ban/tests/actionstestcase.py @@ -96,6 +96,8 @@ class ExecuteActions(LogCaptureTestCase): self.assertLogged("stdout: %r" % 'ip flush', "stdout: %r" % 'ip stop') self.assertEqual(self.__actions.status(),[("Currently banned", 0 ), ("Total banned", 0 ), ("Banned IP list", [] )]) + self.assertEqual(self.__actions.status('short'),[("Currently banned", 0 ), + ("Total banned", 0 )]) def testAddActionPython(self): self.__actions.add( diff --git a/fail2ban/tests/sockettestcase.py b/fail2ban/tests/sockettestcase.py index 8cd22a41..2d414e5c 100644 --- a/fail2ban/tests/sockettestcase.py +++ b/fail2ban/tests/sockettestcase.py @@ -152,7 +152,7 @@ class Socket(LogCaptureTestCase): org_handler = RequestHandler.found_terminator try: RequestHandler.found_terminator = lambda self: self.close() - self.assertRaisesRegexp(RuntimeError, r"socket connection broken", + self.assertRaisesRegexp(Exception, r"reset by peer|Broken pipe", lambda: client.send(testMessage, timeout=unittest.F2B.maxWaitTime(10))) finally: RequestHandler.found_terminator = org_handler @@ -168,7 +168,7 @@ class Socket(LogCaptureTestCase): org_handler = RequestHandler.found_terminator try: RequestHandler.found_terminator = lambda self: TestMsgError() - #self.assertRaisesRegexp(RuntimeError, r"socket connection broken", client.send, testMessage) + #self.assertRaisesRegexp(Exception, r"reset by peer|Broken pipe", client.send, testMessage) self.assertEqual(client.send(testMessage), 'ERROR: test unpickle error') finally: RequestHandler.found_terminator = org_handler