diff --git a/fail2ban/server/asyncserver.py b/fail2ban/server/asyncserver.py index ea6cc326..e18451a4 100644 --- a/fail2ban/server/asyncserver.py +++ b/fail2ban/server/asyncserver.py @@ -37,7 +37,7 @@ import traceback from .utils import Utils from ..protocol import CSPROTO -from ..helpers import getLogger,formatExceptionInfo +from ..helpers import logging, getLogger, formatExceptionInfo # Gets the instance of the logger. logSys = getLogger(__name__) @@ -88,9 +88,12 @@ class RequestHandler(asynchat.async_chat): message = dumps(message, HIGHEST_PROTOCOL) # Sends the response to the client. self.push(message + CSPROTO.END) - except Exception as e: # pragma: no cover + except Exception as e: logSys.error("Caught unhandled exception: %r", e, exc_info=logSys.getEffectiveLevel()<=logging.DEBUG) + # Sends the response to the client. + message = dumps("ERROR: %s" % e, HIGHEST_PROTOCOL) + self.push(message + CSPROTO.END) def handle_error(self): @@ -161,10 +164,10 @@ class AsyncServer(asyncore.dispatcher): def handle_accept(self): try: conn, addr = self.accept() - except socket.error: + except socket.error: # pragma: no cover logSys.warning("Socket error") return - except TypeError: + except TypeError: # pragma: no cover logSys.warning("Type error") return AsyncServer.__markCloseOnExec(conn) @@ -194,7 +197,7 @@ class AsyncServer(asyncore.dispatcher): self.set_reuse_addr() try: self.bind(sock) - except Exception: + except Exception: # pragma: no cover raise AsyncServerException("Unable to bind socket %s" % self.__sock) AsyncServer.__markCloseOnExec(self.socket) self.listen(1) diff --git a/fail2ban/tests/sockettestcase.py b/fail2ban/tests/sockettestcase.py index 1a94a952..f6d86582 100644 --- a/fail2ban/tests/sockettestcase.py +++ b/fail2ban/tests/sockettestcase.py @@ -32,37 +32,57 @@ import time import unittest from .. import protocol -from ..server.asyncserver import AsyncServer, AsyncServerException +from ..server.asyncserver import RequestHandler, AsyncServer, AsyncServerException from ..server.utils import Utils from ..client.csocket import CSocket +from .utils import LogCaptureTestCase -class Socket(unittest.TestCase): + +def TestMsgError(*args): + raise Exception('test unpickle error') +class TestMsg(object): + def __reduce__(self): + return (TestMsgError, ()) + + +class Socket(LogCaptureTestCase): def setUp(self): """Call before every test case.""" + LogCaptureTestCase.setUp(self) super(Socket, self).setUp() self.server = AsyncServer(self) sock_fd, sock_name = tempfile.mkstemp('fail2ban.sock', 'socket') os.close(sock_fd) os.remove(sock_name) self.sock_name = sock_name + self.serverThread = None def tearDown(self): """Call after every test case.""" + if self.serverThread: + self.server.stop(); # stop if not already stopped + self.serverThread.join() + LogCaptureTestCase.tearDown(self) @staticmethod def proceed(message): """Test transmitter proceed method which just returns first arg""" return message - def testStopPerCloseUnexpected(self): + def _createServerThread(self): # start in separate thread : - serverThread = threading.Thread( + self.serverThread = serverThread = threading.Thread( target=self.server.start, args=(self.sock_name, False)) serverThread.daemon = True serverThread.start() self.assertTrue(Utils.wait_for(self.server.isActive, unittest.F2B.maxWaitTime(10))) + return serverThread + + def testStopPerCloseUnexpected(self): + # start in separate thread : + serverThread = self._createServerThread() # unexpected stop directly after start: self.server.close() # wait for end of thread : @@ -81,22 +101,31 @@ class Socket(unittest.TestCase): return None def testSocket(self): - serverThread = threading.Thread( - target=self.server.start, args=(self.sock_name, False)) - serverThread.daemon = True - serverThread.start() - self.assertTrue(Utils.wait_for(self.server.isActive, unittest.F2B.maxWaitTime(10))) - time.sleep(Utils.DEFAULT_SLEEP_TIME) - + # start in separate thread : + serverThread = self._createServerThread() client = Utils.wait_for(self._serverSocket, 2) + testMessage = ["A", "test", "message"] self.assertEqual(client.send(testMessage), testMessage) + # test wrong message: + self.assertEqual(client.send([[TestMsg()]]), 'ERROR: test unpickle error') + self.assertLogged("Caught unhandled exception", "test unpickle error", all=True) + + # test good message again: + self.assertEqual(client.send(testMessage), testMessage) + # test close message client.close() # 2nd close does nothing client.close() + # force shutdown: + self.server.stop_communication() + # test send again (should get in shutdown message): + client = Utils.wait_for(self._serverSocket, 2) + self.assertEqual(client.send(testMessage), ['SHUTDOWN']) + self.server.stop() # wait for end of thread : Utils.wait_for(lambda: not serverThread.isAlive() @@ -105,6 +134,25 @@ class Socket(unittest.TestCase): self.assertFalse(self.server.isActive()) self.assertFalse(os.path.exists(self.sock_name)) + + def testSocketConnectBroken(self): + # start in separate thread : + serverThread = self._createServerThread() + client = Utils.wait_for(self._serverSocket, 2) + + testMessage = ["A", "test", "message"] + self.assertEqual(client.send(testMessage), testMessage) + + org_handler = RequestHandler.found_terminator + try: + RequestHandler.found_terminator = lambda self: TestMsgError() + self.assertRaisesRegexp(RuntimeError, r"socket connection broken", client.send, testMessage) + finally: + RequestHandler.found_terminator = org_handler + + self.assertLogged("Unexpected communication error", "test unpickle error", all=True) + self.server.stop() + def testSocketForce(self): open(self.sock_name, 'w').close() # Create sock file # Try to start without force