From 4b53c6b9755717757bd97680a07ad9151c33a6bd Mon Sep 17 00:00:00 2001 From: sebres Date: Tue, 28 Feb 2017 19:52:44 +0100 Subject: [PATCH] socket, asyncserver: several fixes, python version dependency removed + test coverage extended --- fail2ban/client/csocket.py | 12 +++-- fail2ban/server/asyncserver.py | 27 +++++++---- fail2ban/tests/sockettestcase.py | 81 +++++++++++++++++++++++++------- 3 files changed, 87 insertions(+), 33 deletions(-) diff --git a/fail2ban/client/csocket.py b/fail2ban/client/csocket.py index 6b478460..e53ca1fd 100644 --- a/fail2ban/client/csocket.py +++ b/fail2ban/client/csocket.py @@ -45,13 +45,13 @@ class CSocket: def __del__(self): self.close(False) - def send(self, msg): + def send(self, msg, nonblocking=False, timeout=None): # Convert every list member to string obj = dumps(map( lambda m: str(m) if not isinstance(m, (list, dict, set)) else m, msg), HIGHEST_PROTOCOL) self.__csock.send(obj + CSPROTO.END) - return self.receive(self.__csock) + return self.receive(self.__csock, nonblocking, timeout) def settimeout(self, timeout): self.__csock.settimeout(timeout if timeout != -1 else self.__deftout) @@ -65,11 +65,13 @@ class CSocket: self.__csock = None @staticmethod - def receive(sock): + def receive(sock, nonblocking=False, timeout=None): msg = CSPROTO.EMPTY + if nonblocking: sock.setblocking(0) + if timeout: sock.settimeout(timeout) while msg.rfind(CSPROTO.END) == -1: - chunk = sock.recv(6) - if chunk == '': + chunk = sock.recv(512) + if chunk in ('', b''): # python 3.x may return b'' instead of '' raise RuntimeError("socket connection broken") msg = msg + chunk return loads(msg) diff --git a/fail2ban/server/asyncserver.py b/fail2ban/server/asyncserver.py index e18451a4..47f5c27a 100644 --- a/fail2ban/server/asyncserver.py +++ b/fail2ban/server/asyncserver.py @@ -95,13 +95,21 @@ class RequestHandler(asynchat.async_chat): message = dumps("ERROR: %s" % e, HIGHEST_PROTOCOL) self.push(message + CSPROTO.END) - + ## + # Handles an communication errors in request. + # def handle_error(self): - e1, e2 = formatExceptionInfo() - logSys.error("Unexpected communication error: %s" % str(e2)) - logSys.error(traceback.format_exc().splitlines()) - self.close() - + try: + e1, e2 = formatExceptionInfo() + logSys.error("Unexpected communication error: %s" % str(e2)) + logSys.error(traceback.format_exc().splitlines()) + # Sends the response to the client. + message = dumps("ERROR: %s" % e2, HIGHEST_PROTOCOL) + self.push(message + CSPROTO.END) + except Exception as e: # pragma: no cover - normally unreachable + pass + self.close_when_done() + def loop(active, timeout=None, use_poll=False): """Custom event loop implementation @@ -125,7 +133,7 @@ def loop(active, timeout=None, use_poll=False): poll(timeout) if errCount: errCount -= 1 - except Exception as e: # pragma: no cover + except Exception as e: if not active(): break errCount += 1 @@ -181,7 +189,7 @@ class AsyncServer(asyncore.dispatcher): # @param sock: socket file. # @param force: remove the socket file if exists. - def start(self, sock, force, use_poll=False): + def start(self, sock, force, timeout=None, use_poll=False): self.__worker = threading.current_thread() self.__sock = sock # Remove socket @@ -207,12 +215,11 @@ class AsyncServer(asyncore.dispatcher): if self.onstart: self.onstart() # Event loop as long as active: - loop(lambda: self.__loop, use_poll=use_poll) + loop(lambda: self.__loop, timeout=timeout, use_poll=use_poll) self.__active = False # Cleanup all self.stop() - def close(self): stopflg = False if self.__active: diff --git a/fail2ban/tests/sockettestcase.py b/fail2ban/tests/sockettestcase.py index f6d86582..9d9cba15 100644 --- a/fail2ban/tests/sockettestcase.py +++ b/fail2ban/tests/sockettestcase.py @@ -32,7 +32,7 @@ import time import unittest from .. import protocol -from ..server.asyncserver import RequestHandler, AsyncServer, AsyncServerException +from ..server.asyncserver import asyncore, RequestHandler, AsyncServer, AsyncServerException from ..server.utils import Utils from ..client.csocket import CSocket @@ -42,8 +42,10 @@ from .utils import LogCaptureTestCase def TestMsgError(*args): raise Exception('test unpickle error') class TestMsg(object): + def __init__(self, unpickle=(TestMsgError, ())): + self.unpickle = unpickle def __reduce__(self): - return (TestMsgError, ()) + return self.unpickle class Socket(LogCaptureTestCase): @@ -63,7 +65,7 @@ class Socket(LogCaptureTestCase): """Call after every test case.""" if self.serverThread: self.server.stop(); # stop if not already stopped - self.serverThread.join() + self._stopServerThread() LogCaptureTestCase.tearDown(self) @staticmethod @@ -71,14 +73,21 @@ class Socket(LogCaptureTestCase): """Test transmitter proceed method which just returns first arg""" return message - def _createServerThread(self): + def _createServerThread(self, force=False): # start in separate thread : self.serverThread = serverThread = threading.Thread( - target=self.server.start, args=(self.sock_name, False)) + target=self.server.start, args=(self.sock_name, force)) serverThread.daemon = True serverThread.start() self.assertTrue(Utils.wait_for(self.server.isActive, unittest.F2B.maxWaitTime(10))) return serverThread + + def _stopServerThread(self): + serverThread = self.serverThread + # wait for end of thread : + Utils.wait_for(lambda: not serverThread.isAlive() + or serverThread.join(Utils.DEFAULT_SLEEP_TIME), unittest.F2B.maxWaitTime(10)) + self.serverThread = None def testStopPerCloseUnexpected(self): # start in separate thread : @@ -86,8 +95,7 @@ class Socket(LogCaptureTestCase): # unexpected stop directly after start: self.server.close() # wait for end of thread : - Utils.wait_for(lambda: not serverThread.isAlive() - or serverThread.join(Utils.DEFAULT_SLEEP_INTERVAL), unittest.F2B.maxWaitTime(10)) + self._stopServerThread() self.assertFalse(serverThread.isAlive()) # clean : self.server.stop() @@ -128,17 +136,30 @@ class Socket(LogCaptureTestCase): self.server.stop() # wait for end of thread : - Utils.wait_for(lambda: not serverThread.isAlive() - or serverThread.join(Utils.DEFAULT_SLEEP_INTERVAL), unittest.F2B.maxWaitTime(10)) + self._stopServerThread() self.assertFalse(serverThread.isAlive()) self.assertFalse(self.server.isActive()) self.assertFalse(os.path.exists(self.sock_name)) - def testSocketConnectBroken(self): # start in separate thread : serverThread = self._createServerThread() client = Utils.wait_for(self._serverSocket, 2) + # unexpected stop during message body: + testMessage = ["A", "test", "message", [protocol.CSPROTO.END]] + + org_handler = RequestHandler.found_terminator + try: + RequestHandler.found_terminator = lambda self: self.close() + self.assertRaisesRegexp(RuntimeError, r"socket connection broken", + lambda: client.send(testMessage, timeout=unittest.F2B.maxWaitTime(10))) + finally: + RequestHandler.found_terminator = org_handler + + def testStopByCommunicate(self): + # start in separate thread : + serverThread = self._createServerThread() + client = Utils.wait_for(self._serverSocket, 2) testMessage = ["A", "test", "message"] self.assertEqual(client.send(testMessage), testMessage) @@ -146,12 +167,40 @@ class Socket(LogCaptureTestCase): org_handler = RequestHandler.found_terminator try: RequestHandler.found_terminator = lambda self: TestMsgError() - self.assertRaisesRegexp(RuntimeError, r"socket connection broken", client.send, testMessage) + #self.assertRaisesRegexp(RuntimeError, r"socket connection broken", client.send, testMessage) + self.assertEqual(client.send(testMessage), 'ERROR: test unpickle error') finally: RequestHandler.found_terminator = org_handler + # check errors were logged: self.assertLogged("Unexpected communication error", "test unpickle error", all=True) + self.server.stop() + # wait for end of thread : + self._stopServerThread() + self.assertFalse(serverThread.isAlive()) + + def testLoopErrors(self): + # replace poll handler to produce error in loop-cycle: + org_poll = asyncore.poll + err = {'cntr': 0} + def _produce_error(*args): + err['cntr'] += 1 + if err['cntr'] < 50: + raise RuntimeError('test errors in poll') + return org_poll(*args) + + try: + asyncore.poll = _produce_error + serverThread = self._createServerThread() + # wait all-cases processed: + self.assertTrue(Utils.wait_for(lambda: err['cntr'] > 50, unittest.F2B.maxWaitTime(10))) + finally: + # restore: + asyncore.poll = org_poll + # check errors were logged: + self.assertLogged("Server connection was closed: test errors in poll", + "Too many errors - stop logging connection errors", all=True) def testSocketForce(self): open(self.sock_name, 'w').close() # Create sock file @@ -160,16 +209,12 @@ class Socket(LogCaptureTestCase): AsyncServerException, self.server.start, self.sock_name, False) # Try again with force set - serverThread = threading.Thread( - target=self.server.start, args=(self.sock_name, True)) - serverThread.daemon = True - serverThread.start() - self.assertTrue(Utils.wait_for(self.server.isActive, unittest.F2B.maxWaitTime(10))) + serverThread = self._createServerThread(True) self.server.stop() # wait for end of thread : - Utils.wait_for(lambda: not serverThread.isAlive() - or serverThread.join(Utils.DEFAULT_SLEEP_INTERVAL), unittest.F2B.maxWaitTime(10)) + self._stopServerThread() + self.assertFalse(serverThread.isAlive()) self.assertFalse(self.server.isActive()) self.assertFalse(os.path.exists(self.sock_name))