From 29bedd70d533a6186e00fed9a4072785e1062c1d Mon Sep 17 00:00:00 2001 From: sebres Date: Fri, 2 Mar 2018 20:56:53 +0100 Subject: [PATCH 1/3] socket stability and coverage: cherry picked from 0.11 version (avoid many sporadic unhandled exceptions) --- fail2ban/client/csocket.py | 12 +-- fail2ban/tests/sockettestcase.py | 139 ++++++++++++++++++++++++++----- 2 files changed, 123 insertions(+), 28 deletions(-) diff --git a/fail2ban/client/csocket.py b/fail2ban/client/csocket.py index 86dd17c9..ce01ae08 100644 --- a/fail2ban/client/csocket.py +++ b/fail2ban/client/csocket.py @@ -45,13 +45,13 @@ class CSocket: def __del__(self): self.close(False) - def send(self, msg): + def send(self, msg, nonblocking=False, timeout=None): # Convert every list member to string obj = dumps(map( lambda m: str(m) if not isinstance(m, (list, dict, set)) else m, msg), HIGHEST_PROTOCOL) self.__csock.send(obj + CSPROTO.END) - return self.receive(self.__csock) + return self.receive(self.__csock, nonblocking, timeout) def settimeout(self, timeout): self.__csock.settimeout(timeout if timeout != -1 else self.__deftout) @@ -66,11 +66,13 @@ class CSocket: self.__csock = None @staticmethod - def receive(sock): + def receive(sock, nonblocking=False, timeout=None): msg = CSPROTO.EMPTY + if nonblocking: sock.setblocking(0) + if timeout: sock.settimeout(timeout) while msg.rfind(CSPROTO.END) == -1: - chunk = sock.recv(6) - if chunk == '': + chunk = sock.recv(512) + if chunk in ('', b''): # python 3.x may return b'' instead of '' raise RuntimeError("socket connection broken") msg = msg + chunk return loads(msg) diff --git a/fail2ban/tests/sockettestcase.py b/fail2ban/tests/sockettestcase.py index 4f9b9d7a..a7c3a43c 100644 --- a/fail2ban/tests/sockettestcase.py +++ b/fail2ban/tests/sockettestcase.py @@ -34,42 +34,70 @@ import unittest from .utils import LogCaptureTestCase from .. import protocol -from ..server.asyncserver import AsyncServer, AsyncServerException, loop +from ..server.asyncserver import asyncore, RequestHandler, loop, AsyncServer, AsyncServerException from ..server.utils import Utils from ..client.csocket import CSocket +from .utils import LogCaptureTestCase -class Socket(unittest.TestCase): + +def TestMsgError(*args): + raise Exception('test unpickle error') +class TestMsg(object): + def __init__(self, unpickle=(TestMsgError, ())): + self.unpickle = unpickle + def __reduce__(self): + return self.unpickle + + +class Socket(LogCaptureTestCase): def setUp(self): """Call before every test case.""" + LogCaptureTestCase.setUp(self) super(Socket, self).setUp() self.server = AsyncServer(self) sock_fd, sock_name = tempfile.mkstemp('fail2ban.sock', 'socket') os.close(sock_fd) os.remove(sock_name) self.sock_name = sock_name + self.serverThread = None def tearDown(self): """Call after every test case.""" + if self.serverThread: + self.server.stop(); # stop if not already stopped + self._stopServerThread() + LogCaptureTestCase.tearDown(self) @staticmethod def proceed(message): """Test transmitter proceed method which just returns first arg""" return message - def testStopPerCloseUnexpected(self): + def _createServerThread(self, force=False): # start in separate thread : - serverThread = threading.Thread( - target=self.server.start, args=(self.sock_name, False)) + self.serverThread = serverThread = threading.Thread( + target=self.server.start, args=(self.sock_name, force)) serverThread.daemon = True serverThread.start() self.assertTrue(Utils.wait_for(self.server.isActive, unittest.F2B.maxWaitTime(10))) + return serverThread + + def _stopServerThread(self): + serverThread = self.serverThread + # wait for end of thread : + Utils.wait_for(lambda: not serverThread.isAlive() + or serverThread.join(Utils.DEFAULT_SLEEP_TIME), unittest.F2B.maxWaitTime(10)) + self.serverThread = None + + def testStopPerCloseUnexpected(self): + # start in separate thread : + serverThread = self._createServerThread() # unexpected stop directly after start: self.server.close() # wait for end of thread : - Utils.wait_for(lambda: not serverThread.isAlive() - or serverThread.join(Utils.DEFAULT_SLEEP_INTERVAL), unittest.F2B.maxWaitTime(10)) + self._stopServerThread() self.assertFalse(serverThread.isAlive()) # clean : self.server.stop() @@ -83,30 +111,99 @@ class Socket(unittest.TestCase): return None def testSocket(self): - serverThread = threading.Thread( - target=self.server.start, args=(self.sock_name, False)) - serverThread.daemon = True - serverThread.start() - self.assertTrue(Utils.wait_for(self.server.isActive, unittest.F2B.maxWaitTime(10))) - time.sleep(Utils.DEFAULT_SLEEP_TIME) - + # start in separate thread : + serverThread = self._createServerThread() client = Utils.wait_for(self._serverSocket, 2) + testMessage = ["A", "test", "message"] self.assertEqual(client.send(testMessage), testMessage) + # test wrong message: + self.assertEqual(client.send([[TestMsg()]]), 'ERROR: test unpickle error') + self.assertLogged("Caught unhandled exception", "test unpickle error", all=True) + + # test good message again: + self.assertEqual(client.send(testMessage), testMessage) + # test close message client.close() # 2nd close does nothing client.close() + # force shutdown: + self.server.stop_communication() + # test send again (should get in shutdown message): + client = Utils.wait_for(self._serverSocket, 2) + self.assertEqual(client.send(testMessage), ['SHUTDOWN']) + self.server.stop() # wait for end of thread : - Utils.wait_for(lambda: not serverThread.isAlive() - or serverThread.join(Utils.DEFAULT_SLEEP_INTERVAL), unittest.F2B.maxWaitTime(10)) + self._stopServerThread() self.assertFalse(serverThread.isAlive()) self.assertFalse(self.server.isActive()) self.assertFalse(os.path.exists(self.sock_name)) + def testSocketConnectBroken(self): + # start in separate thread : + serverThread = self._createServerThread() + client = Utils.wait_for(self._serverSocket, 2) + # unexpected stop during message body: + testMessage = ["A", "test", "message", [protocol.CSPROTO.END]] + + org_handler = RequestHandler.found_terminator + try: + RequestHandler.found_terminator = lambda self: self.close() + self.assertRaisesRegexp(RuntimeError, r"socket connection broken", + lambda: client.send(testMessage, timeout=unittest.F2B.maxWaitTime(10))) + finally: + RequestHandler.found_terminator = org_handler + + def testStopByCommunicate(self): + # start in separate thread : + serverThread = self._createServerThread() + client = Utils.wait_for(self._serverSocket, 2) + + testMessage = ["A", "test", "message"] + self.assertEqual(client.send(testMessage), testMessage) + + org_handler = RequestHandler.found_terminator + try: + RequestHandler.found_terminator = lambda self: TestMsgError() + #self.assertRaisesRegexp(RuntimeError, r"socket connection broken", client.send, testMessage) + self.assertEqual(client.send(testMessage), 'ERROR: test unpickle error') + finally: + RequestHandler.found_terminator = org_handler + + # check errors were logged: + self.assertLogged("Unexpected communication error", "test unpickle error", all=True) + + self.server.stop() + # wait for end of thread : + self._stopServerThread() + self.assertFalse(serverThread.isAlive()) + + def testLoopErrors(self): + # replace poll handler to produce error in loop-cycle: + org_poll = asyncore.poll + err = {'cntr': 0} + def _produce_error(*args): + err['cntr'] += 1 + if err['cntr'] < 50: + raise RuntimeError('test errors in poll') + return org_poll(*args) + + try: + asyncore.poll = _produce_error + serverThread = self._createServerThread() + # wait all-cases processed: + self.assertTrue(Utils.wait_for(lambda: err['cntr'] > 50, unittest.F2B.maxWaitTime(10))) + finally: + # restore: + asyncore.poll = org_poll + # check errors were logged: + self.assertLogged("Server connection was closed: test errors in poll", + "Too many errors - stop logging connection errors", all=True) + def testSocketForce(self): open(self.sock_name, 'w').close() # Create sock file # Try to start without force @@ -114,16 +211,12 @@ class Socket(unittest.TestCase): AsyncServerException, self.server.start, self.sock_name, False) # Try again with force set - serverThread = threading.Thread( - target=self.server.start, args=(self.sock_name, True)) - serverThread.daemon = True - serverThread.start() - self.assertTrue(Utils.wait_for(self.server.isActive, unittest.F2B.maxWaitTime(10))) + serverThread = self._createServerThread(True) self.server.stop() # wait for end of thread : - Utils.wait_for(lambda: not serverThread.isAlive() - or serverThread.join(Utils.DEFAULT_SLEEP_INTERVAL), unittest.F2B.maxWaitTime(10)) + self._stopServerThread() + self.assertFalse(serverThread.isAlive()) self.assertFalse(self.server.isActive()) self.assertFalse(os.path.exists(self.sock_name)) From 96836cb199ed076353cbfe2c02c90948e4235d77 Mon Sep 17 00:00:00 2001 From: sebres Date: Fri, 2 Mar 2018 21:30:03 +0100 Subject: [PATCH 2/3] fix several errors (shutdown in test-cases during stop communication, better error handling by unpickle/deserialization, etc) --- fail2ban/server/asyncserver.py | 43 +++++++++++++++++++------------- fail2ban/tests/sockettestcase.py | 5 ++-- 2 files changed, 28 insertions(+), 20 deletions(-) diff --git a/fail2ban/server/asyncserver.py b/fail2ban/server/asyncserver.py index eb99c69a..e3400737 100644 --- a/fail2ban/server/asyncserver.py +++ b/fail2ban/server/asyncserver.py @@ -76,6 +76,10 @@ class RequestHandler(asynchat.async_chat): #logSys.debug("Received raw data: " + str(data)) self.__buffer.append(data) + # exception identifies deserialization errors (exception by load in pickle): + class LoadError(Exception): + pass + ## # Handles a new request. # @@ -93,7 +97,12 @@ class RequestHandler(asynchat.async_chat): self.close_when_done() return # Deserialize - message = loads(message) + try: + message = loads(message) + except Exception as e: + logSys.error('PROTO-error: load message failed: %s', e, + exc_info=logSys.getEffectiveLevel() 100: if ( - e.args[0] == errno.EMFILE # [Errno 24] Too many open files + (isinstance(e, socket.error) and e.args[0] == errno.EMFILE) # [Errno 24] Too many open files or sum(self.__errCount.itervalues()) > 1000 ): - logSys.critical("Too many errors - critical count reached %r", err_count) + logSys.critical("Too many errors - critical count reached %r", self.__errCount) self.stop() return - except TypeError as e: # pragma: no cover - logSys.warning("Type error: %s", e) - return if self.__errCount['accept']: self.__errCount['accept'] -= 1; AsyncServer.__markCloseOnExec(conn) @@ -265,6 +273,13 @@ class AsyncServer(asyncore.dispatcher): stopflg = False if self.__active: self.__loop = False + # shutdown socket here: + if self.socket: + try: + self.socket.shutdown(socket.SHUT_RDWR) + except socket.error: # pragma: no cover - normally unreachable + pass + # close connection: asyncore.dispatcher.close(self) # If not the loop thread (stops self in handler), wait (a little bit) # for the server leaves loop, before remove socket @@ -284,14 +299,8 @@ class AsyncServer(asyncore.dispatcher): def stop_communication(self): if self.__transmitter: - logSys.debug("Stop communication") + logSys.debug("Stop communication, shutdown") self.__transmitter = None - # shutdown socket here: - if self.socket: - try: - self.socket.shutdown(socket.SHUT_RDWR) - except socket.error: # pragma: no cover - normally unreachable - pass ## # Stops the server. diff --git a/fail2ban/tests/sockettestcase.py b/fail2ban/tests/sockettestcase.py index a7c3a43c..4e14ece5 100644 --- a/fail2ban/tests/sockettestcase.py +++ b/fail2ban/tests/sockettestcase.py @@ -54,10 +54,9 @@ class Socket(LogCaptureTestCase): def setUp(self): """Call before every test case.""" - LogCaptureTestCase.setUp(self) super(Socket, self).setUp() self.server = AsyncServer(self) - sock_fd, sock_name = tempfile.mkstemp('fail2ban.sock', 'socket') + sock_fd, sock_name = tempfile.mkstemp('fail2ban.sock', 'f2b-socket') os.close(sock_fd) os.remove(sock_name) self.sock_name = sock_name @@ -120,7 +119,7 @@ class Socket(LogCaptureTestCase): # test wrong message: self.assertEqual(client.send([[TestMsg()]]), 'ERROR: test unpickle error') - self.assertLogged("Caught unhandled exception", "test unpickle error", all=True) + self.assertLogged("PROTO-error: load message failed:", "test unpickle error", all=True) # test good message again: self.assertEqual(client.send(testMessage), testMessage) From 1bdda6c8eb497990b0c000947d6cdd0e3aa4e5cb Mon Sep 17 00:00:00 2001 From: sebres Date: Fri, 2 Mar 2018 20:08:48 +0100 Subject: [PATCH 3/3] cache coverage --- fail2ban/server/utils.py | 2 +- fail2ban/tests/filtertestcase.py | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/fail2ban/server/utils.py b/fail2ban/server/utils.py index 58363ff0..8569a3f2 100644 --- a/fail2ban/server/utils.py +++ b/fail2ban/server/utils.py @@ -102,7 +102,7 @@ class Utils(): def unset(self, k): try: del self._cache[k] - except KeyError: # pragme: no cover + except KeyError: pass diff --git a/fail2ban/tests/filtertestcase.py b/fail2ban/tests/filtertestcase.py index b5877b7f..2bbfcd9d 100644 --- a/fail2ban/tests/filtertestcase.py +++ b/fail2ban/tests/filtertestcase.py @@ -1640,6 +1640,8 @@ class DNSUtilsTests(unittest.TestCase): c.set(i, i) for i in xrange(5): self.assertEqual(c.get(i), i) + # remove unavailable key: + c.unset('a'); c.unset('a') def testCacheMaxSize(self): c = Utils.Cache(maxCount=5, maxTime=60)