mirror of https://github.com/fail2ban/fail2ban
socket stability and coverage: cherry picked from 0.11 version (avoid many sporadic unhandled exceptions)
parent
5f021aa648
commit
29bedd70d5
|
@ -45,13 +45,13 @@ class CSocket:
|
|||
def __del__(self):
|
||||
self.close(False)
|
||||
|
||||
def send(self, msg):
|
||||
def send(self, msg, nonblocking=False, timeout=None):
|
||||
# Convert every list member to string
|
||||
obj = dumps(map(
|
||||
lambda m: str(m) if not isinstance(m, (list, dict, set)) else m, msg),
|
||||
HIGHEST_PROTOCOL)
|
||||
self.__csock.send(obj + CSPROTO.END)
|
||||
return self.receive(self.__csock)
|
||||
return self.receive(self.__csock, nonblocking, timeout)
|
||||
|
||||
def settimeout(self, timeout):
|
||||
self.__csock.settimeout(timeout if timeout != -1 else self.__deftout)
|
||||
|
@ -66,11 +66,13 @@ class CSocket:
|
|||
self.__csock = None
|
||||
|
||||
@staticmethod
|
||||
def receive(sock):
|
||||
def receive(sock, nonblocking=False, timeout=None):
|
||||
msg = CSPROTO.EMPTY
|
||||
if nonblocking: sock.setblocking(0)
|
||||
if timeout: sock.settimeout(timeout)
|
||||
while msg.rfind(CSPROTO.END) == -1:
|
||||
chunk = sock.recv(6)
|
||||
if chunk == '':
|
||||
chunk = sock.recv(512)
|
||||
if chunk in ('', b''): # python 3.x may return b'' instead of ''
|
||||
raise RuntimeError("socket connection broken")
|
||||
msg = msg + chunk
|
||||
return loads(msg)
|
||||
|
|
|
@ -34,42 +34,70 @@ import unittest
|
|||
from .utils import LogCaptureTestCase
|
||||
|
||||
from .. import protocol
|
||||
from ..server.asyncserver import AsyncServer, AsyncServerException, loop
|
||||
from ..server.asyncserver import asyncore, RequestHandler, loop, AsyncServer, AsyncServerException
|
||||
from ..server.utils import Utils
|
||||
from ..client.csocket import CSocket
|
||||
|
||||
from .utils import LogCaptureTestCase
|
||||
|
||||
class Socket(unittest.TestCase):
|
||||
|
||||
def TestMsgError(*args):
|
||||
raise Exception('test unpickle error')
|
||||
class TestMsg(object):
|
||||
def __init__(self, unpickle=(TestMsgError, ())):
|
||||
self.unpickle = unpickle
|
||||
def __reduce__(self):
|
||||
return self.unpickle
|
||||
|
||||
|
||||
class Socket(LogCaptureTestCase):
|
||||
|
||||
def setUp(self):
|
||||
"""Call before every test case."""
|
||||
LogCaptureTestCase.setUp(self)
|
||||
super(Socket, self).setUp()
|
||||
self.server = AsyncServer(self)
|
||||
sock_fd, sock_name = tempfile.mkstemp('fail2ban.sock', 'socket')
|
||||
os.close(sock_fd)
|
||||
os.remove(sock_name)
|
||||
self.sock_name = sock_name
|
||||
self.serverThread = None
|
||||
|
||||
def tearDown(self):
|
||||
"""Call after every test case."""
|
||||
if self.serverThread:
|
||||
self.server.stop(); # stop if not already stopped
|
||||
self._stopServerThread()
|
||||
LogCaptureTestCase.tearDown(self)
|
||||
|
||||
@staticmethod
|
||||
def proceed(message):
|
||||
"""Test transmitter proceed method which just returns first arg"""
|
||||
return message
|
||||
|
||||
def testStopPerCloseUnexpected(self):
|
||||
def _createServerThread(self, force=False):
|
||||
# start in separate thread :
|
||||
serverThread = threading.Thread(
|
||||
target=self.server.start, args=(self.sock_name, False))
|
||||
self.serverThread = serverThread = threading.Thread(
|
||||
target=self.server.start, args=(self.sock_name, force))
|
||||
serverThread.daemon = True
|
||||
serverThread.start()
|
||||
self.assertTrue(Utils.wait_for(self.server.isActive, unittest.F2B.maxWaitTime(10)))
|
||||
return serverThread
|
||||
|
||||
def _stopServerThread(self):
|
||||
serverThread = self.serverThread
|
||||
# wait for end of thread :
|
||||
Utils.wait_for(lambda: not serverThread.isAlive()
|
||||
or serverThread.join(Utils.DEFAULT_SLEEP_TIME), unittest.F2B.maxWaitTime(10))
|
||||
self.serverThread = None
|
||||
|
||||
def testStopPerCloseUnexpected(self):
|
||||
# start in separate thread :
|
||||
serverThread = self._createServerThread()
|
||||
# unexpected stop directly after start:
|
||||
self.server.close()
|
||||
# wait for end of thread :
|
||||
Utils.wait_for(lambda: not serverThread.isAlive()
|
||||
or serverThread.join(Utils.DEFAULT_SLEEP_INTERVAL), unittest.F2B.maxWaitTime(10))
|
||||
self._stopServerThread()
|
||||
self.assertFalse(serverThread.isAlive())
|
||||
# clean :
|
||||
self.server.stop()
|
||||
|
@ -83,30 +111,99 @@ class Socket(unittest.TestCase):
|
|||
return None
|
||||
|
||||
def testSocket(self):
|
||||
serverThread = threading.Thread(
|
||||
target=self.server.start, args=(self.sock_name, False))
|
||||
serverThread.daemon = True
|
||||
serverThread.start()
|
||||
self.assertTrue(Utils.wait_for(self.server.isActive, unittest.F2B.maxWaitTime(10)))
|
||||
time.sleep(Utils.DEFAULT_SLEEP_TIME)
|
||||
|
||||
# start in separate thread :
|
||||
serverThread = self._createServerThread()
|
||||
client = Utils.wait_for(self._serverSocket, 2)
|
||||
|
||||
testMessage = ["A", "test", "message"]
|
||||
self.assertEqual(client.send(testMessage), testMessage)
|
||||
|
||||
# test wrong message:
|
||||
self.assertEqual(client.send([[TestMsg()]]), 'ERROR: test unpickle error')
|
||||
self.assertLogged("Caught unhandled exception", "test unpickle error", all=True)
|
||||
|
||||
# test good message again:
|
||||
self.assertEqual(client.send(testMessage), testMessage)
|
||||
|
||||
# test close message
|
||||
client.close()
|
||||
# 2nd close does nothing
|
||||
client.close()
|
||||
|
||||
# force shutdown:
|
||||
self.server.stop_communication()
|
||||
# test send again (should get in shutdown message):
|
||||
client = Utils.wait_for(self._serverSocket, 2)
|
||||
self.assertEqual(client.send(testMessage), ['SHUTDOWN'])
|
||||
|
||||
self.server.stop()
|
||||
# wait for end of thread :
|
||||
Utils.wait_for(lambda: not serverThread.isAlive()
|
||||
or serverThread.join(Utils.DEFAULT_SLEEP_INTERVAL), unittest.F2B.maxWaitTime(10))
|
||||
self._stopServerThread()
|
||||
self.assertFalse(serverThread.isAlive())
|
||||
self.assertFalse(self.server.isActive())
|
||||
self.assertFalse(os.path.exists(self.sock_name))
|
||||
|
||||
def testSocketConnectBroken(self):
|
||||
# start in separate thread :
|
||||
serverThread = self._createServerThread()
|
||||
client = Utils.wait_for(self._serverSocket, 2)
|
||||
# unexpected stop during message body:
|
||||
testMessage = ["A", "test", "message", [protocol.CSPROTO.END]]
|
||||
|
||||
org_handler = RequestHandler.found_terminator
|
||||
try:
|
||||
RequestHandler.found_terminator = lambda self: self.close()
|
||||
self.assertRaisesRegexp(RuntimeError, r"socket connection broken",
|
||||
lambda: client.send(testMessage, timeout=unittest.F2B.maxWaitTime(10)))
|
||||
finally:
|
||||
RequestHandler.found_terminator = org_handler
|
||||
|
||||
def testStopByCommunicate(self):
|
||||
# start in separate thread :
|
||||
serverThread = self._createServerThread()
|
||||
client = Utils.wait_for(self._serverSocket, 2)
|
||||
|
||||
testMessage = ["A", "test", "message"]
|
||||
self.assertEqual(client.send(testMessage), testMessage)
|
||||
|
||||
org_handler = RequestHandler.found_terminator
|
||||
try:
|
||||
RequestHandler.found_terminator = lambda self: TestMsgError()
|
||||
#self.assertRaisesRegexp(RuntimeError, r"socket connection broken", client.send, testMessage)
|
||||
self.assertEqual(client.send(testMessage), 'ERROR: test unpickle error')
|
||||
finally:
|
||||
RequestHandler.found_terminator = org_handler
|
||||
|
||||
# check errors were logged:
|
||||
self.assertLogged("Unexpected communication error", "test unpickle error", all=True)
|
||||
|
||||
self.server.stop()
|
||||
# wait for end of thread :
|
||||
self._stopServerThread()
|
||||
self.assertFalse(serverThread.isAlive())
|
||||
|
||||
def testLoopErrors(self):
|
||||
# replace poll handler to produce error in loop-cycle:
|
||||
org_poll = asyncore.poll
|
||||
err = {'cntr': 0}
|
||||
def _produce_error(*args):
|
||||
err['cntr'] += 1
|
||||
if err['cntr'] < 50:
|
||||
raise RuntimeError('test errors in poll')
|
||||
return org_poll(*args)
|
||||
|
||||
try:
|
||||
asyncore.poll = _produce_error
|
||||
serverThread = self._createServerThread()
|
||||
# wait all-cases processed:
|
||||
self.assertTrue(Utils.wait_for(lambda: err['cntr'] > 50, unittest.F2B.maxWaitTime(10)))
|
||||
finally:
|
||||
# restore:
|
||||
asyncore.poll = org_poll
|
||||
# check errors were logged:
|
||||
self.assertLogged("Server connection was closed: test errors in poll",
|
||||
"Too many errors - stop logging connection errors", all=True)
|
||||
|
||||
def testSocketForce(self):
|
||||
open(self.sock_name, 'w').close() # Create sock file
|
||||
# Try to start without force
|
||||
|
@ -114,16 +211,12 @@ class Socket(unittest.TestCase):
|
|||
AsyncServerException, self.server.start, self.sock_name, False)
|
||||
|
||||
# Try again with force set
|
||||
serverThread = threading.Thread(
|
||||
target=self.server.start, args=(self.sock_name, True))
|
||||
serverThread.daemon = True
|
||||
serverThread.start()
|
||||
self.assertTrue(Utils.wait_for(self.server.isActive, unittest.F2B.maxWaitTime(10)))
|
||||
serverThread = self._createServerThread(True)
|
||||
|
||||
self.server.stop()
|
||||
# wait for end of thread :
|
||||
Utils.wait_for(lambda: not serverThread.isAlive()
|
||||
or serverThread.join(Utils.DEFAULT_SLEEP_INTERVAL), unittest.F2B.maxWaitTime(10))
|
||||
self._stopServerThread()
|
||||
self.assertFalse(serverThread.isAlive())
|
||||
self.assertFalse(self.server.isActive())
|
||||
self.assertFalse(os.path.exists(self.sock_name))
|
||||
|
||||
|
|
Loading…
Reference in New Issue