mirror of https://github.com/fail2ban/fail2ban
small fix (missing import of logging) + test coverage
parent
28b5262976
commit
9129a414e3
|
@ -37,7 +37,7 @@ import traceback
|
|||
|
||||
from .utils import Utils
|
||||
from ..protocol import CSPROTO
|
||||
from ..helpers import getLogger,formatExceptionInfo
|
||||
from ..helpers import logging, getLogger, formatExceptionInfo
|
||||
|
||||
# Gets the instance of the logger.
|
||||
logSys = getLogger(__name__)
|
||||
|
@ -88,9 +88,12 @@ class RequestHandler(asynchat.async_chat):
|
|||
message = dumps(message, HIGHEST_PROTOCOL)
|
||||
# Sends the response to the client.
|
||||
self.push(message + CSPROTO.END)
|
||||
except Exception as e: # pragma: no cover
|
||||
except Exception as e:
|
||||
logSys.error("Caught unhandled exception: %r", e,
|
||||
exc_info=logSys.getEffectiveLevel()<=logging.DEBUG)
|
||||
# Sends the response to the client.
|
||||
message = dumps("ERROR: %s" % e, HIGHEST_PROTOCOL)
|
||||
self.push(message + CSPROTO.END)
|
||||
|
||||
|
||||
def handle_error(self):
|
||||
|
@ -161,10 +164,10 @@ class AsyncServer(asyncore.dispatcher):
|
|||
def handle_accept(self):
|
||||
try:
|
||||
conn, addr = self.accept()
|
||||
except socket.error:
|
||||
except socket.error: # pragma: no cover
|
||||
logSys.warning("Socket error")
|
||||
return
|
||||
except TypeError:
|
||||
except TypeError: # pragma: no cover
|
||||
logSys.warning("Type error")
|
||||
return
|
||||
AsyncServer.__markCloseOnExec(conn)
|
||||
|
@ -194,7 +197,7 @@ class AsyncServer(asyncore.dispatcher):
|
|||
self.set_reuse_addr()
|
||||
try:
|
||||
self.bind(sock)
|
||||
except Exception:
|
||||
except Exception: # pragma: no cover
|
||||
raise AsyncServerException("Unable to bind socket %s" % self.__sock)
|
||||
AsyncServer.__markCloseOnExec(self.socket)
|
||||
self.listen(1)
|
||||
|
|
|
@ -32,37 +32,57 @@ import time
|
|||
import unittest
|
||||
|
||||
from .. import protocol
|
||||
from ..server.asyncserver import AsyncServer, AsyncServerException
|
||||
from ..server.asyncserver import RequestHandler, AsyncServer, AsyncServerException
|
||||
from ..server.utils import Utils
|
||||
from ..client.csocket import CSocket
|
||||
|
||||
from .utils import LogCaptureTestCase
|
||||
|
||||
class Socket(unittest.TestCase):
|
||||
|
||||
def TestMsgError(*args):
|
||||
raise Exception('test unpickle error')
|
||||
class TestMsg(object):
|
||||
def __reduce__(self):
|
||||
return (TestMsgError, ())
|
||||
|
||||
|
||||
class Socket(LogCaptureTestCase):
|
||||
|
||||
def setUp(self):
|
||||
"""Call before every test case."""
|
||||
LogCaptureTestCase.setUp(self)
|
||||
super(Socket, self).setUp()
|
||||
self.server = AsyncServer(self)
|
||||
sock_fd, sock_name = tempfile.mkstemp('fail2ban.sock', 'socket')
|
||||
os.close(sock_fd)
|
||||
os.remove(sock_name)
|
||||
self.sock_name = sock_name
|
||||
self.serverThread = None
|
||||
|
||||
def tearDown(self):
|
||||
"""Call after every test case."""
|
||||
if self.serverThread:
|
||||
self.server.stop(); # stop if not already stopped
|
||||
self.serverThread.join()
|
||||
LogCaptureTestCase.tearDown(self)
|
||||
|
||||
@staticmethod
|
||||
def proceed(message):
|
||||
"""Test transmitter proceed method which just returns first arg"""
|
||||
return message
|
||||
|
||||
def testStopPerCloseUnexpected(self):
|
||||
def _createServerThread(self):
|
||||
# start in separate thread :
|
||||
serverThread = threading.Thread(
|
||||
self.serverThread = serverThread = threading.Thread(
|
||||
target=self.server.start, args=(self.sock_name, False))
|
||||
serverThread.daemon = True
|
||||
serverThread.start()
|
||||
self.assertTrue(Utils.wait_for(self.server.isActive, unittest.F2B.maxWaitTime(10)))
|
||||
return serverThread
|
||||
|
||||
def testStopPerCloseUnexpected(self):
|
||||
# start in separate thread :
|
||||
serverThread = self._createServerThread()
|
||||
# unexpected stop directly after start:
|
||||
self.server.close()
|
||||
# wait for end of thread :
|
||||
|
@ -81,22 +101,31 @@ class Socket(unittest.TestCase):
|
|||
return None
|
||||
|
||||
def testSocket(self):
|
||||
serverThread = threading.Thread(
|
||||
target=self.server.start, args=(self.sock_name, False))
|
||||
serverThread.daemon = True
|
||||
serverThread.start()
|
||||
self.assertTrue(Utils.wait_for(self.server.isActive, unittest.F2B.maxWaitTime(10)))
|
||||
time.sleep(Utils.DEFAULT_SLEEP_TIME)
|
||||
|
||||
# start in separate thread :
|
||||
serverThread = self._createServerThread()
|
||||
client = Utils.wait_for(self._serverSocket, 2)
|
||||
|
||||
testMessage = ["A", "test", "message"]
|
||||
self.assertEqual(client.send(testMessage), testMessage)
|
||||
|
||||
# test wrong message:
|
||||
self.assertEqual(client.send([[TestMsg()]]), 'ERROR: test unpickle error')
|
||||
self.assertLogged("Caught unhandled exception", "test unpickle error", all=True)
|
||||
|
||||
# test good message again:
|
||||
self.assertEqual(client.send(testMessage), testMessage)
|
||||
|
||||
# test close message
|
||||
client.close()
|
||||
# 2nd close does nothing
|
||||
client.close()
|
||||
|
||||
# force shutdown:
|
||||
self.server.stop_communication()
|
||||
# test send again (should get in shutdown message):
|
||||
client = Utils.wait_for(self._serverSocket, 2)
|
||||
self.assertEqual(client.send(testMessage), ['SHUTDOWN'])
|
||||
|
||||
self.server.stop()
|
||||
# wait for end of thread :
|
||||
Utils.wait_for(lambda: not serverThread.isAlive()
|
||||
|
@ -105,6 +134,25 @@ class Socket(unittest.TestCase):
|
|||
self.assertFalse(self.server.isActive())
|
||||
self.assertFalse(os.path.exists(self.sock_name))
|
||||
|
||||
|
||||
def testSocketConnectBroken(self):
|
||||
# start in separate thread :
|
||||
serverThread = self._createServerThread()
|
||||
client = Utils.wait_for(self._serverSocket, 2)
|
||||
|
||||
testMessage = ["A", "test", "message"]
|
||||
self.assertEqual(client.send(testMessage), testMessage)
|
||||
|
||||
org_handler = RequestHandler.found_terminator
|
||||
try:
|
||||
RequestHandler.found_terminator = lambda self: TestMsgError()
|
||||
self.assertRaisesRegexp(RuntimeError, r"socket connection broken", client.send, testMessage)
|
||||
finally:
|
||||
RequestHandler.found_terminator = org_handler
|
||||
|
||||
self.assertLogged("Unexpected communication error", "test unpickle error", all=True)
|
||||
self.server.stop()
|
||||
|
||||
def testSocketForce(self):
|
||||
open(self.sock_name, 'w').close() # Create sock file
|
||||
# Try to start without force
|
||||
|
|
Loading…
Reference in New Issue