socket, asyncserver: several fixes, python version dependency removed + test coverage extended

pull/1460/head
sebres 2017-02-28 19:52:44 +01:00
parent 9129a414e3
commit 4b53c6b975
3 changed files with 87 additions and 33 deletions

View File

@ -45,13 +45,13 @@ class CSocket:
def __del__(self): def __del__(self):
self.close(False) self.close(False)
def send(self, msg): def send(self, msg, nonblocking=False, timeout=None):
# Convert every list member to string # Convert every list member to string
obj = dumps(map( obj = dumps(map(
lambda m: str(m) if not isinstance(m, (list, dict, set)) else m, msg), lambda m: str(m) if not isinstance(m, (list, dict, set)) else m, msg),
HIGHEST_PROTOCOL) HIGHEST_PROTOCOL)
self.__csock.send(obj + CSPROTO.END) self.__csock.send(obj + CSPROTO.END)
return self.receive(self.__csock) return self.receive(self.__csock, nonblocking, timeout)
def settimeout(self, timeout): def settimeout(self, timeout):
self.__csock.settimeout(timeout if timeout != -1 else self.__deftout) self.__csock.settimeout(timeout if timeout != -1 else self.__deftout)
@ -65,11 +65,13 @@ class CSocket:
self.__csock = None self.__csock = None
@staticmethod @staticmethod
def receive(sock): def receive(sock, nonblocking=False, timeout=None):
msg = CSPROTO.EMPTY msg = CSPROTO.EMPTY
if nonblocking: sock.setblocking(0)
if timeout: sock.settimeout(timeout)
while msg.rfind(CSPROTO.END) == -1: while msg.rfind(CSPROTO.END) == -1:
chunk = sock.recv(6) chunk = sock.recv(512)
if chunk == '': if chunk in ('', b''): # python 3.x may return b'' instead of ''
raise RuntimeError("socket connection broken") raise RuntimeError("socket connection broken")
msg = msg + chunk msg = msg + chunk
return loads(msg) return loads(msg)

View File

@ -95,13 +95,21 @@ class RequestHandler(asynchat.async_chat):
message = dumps("ERROR: %s" % e, HIGHEST_PROTOCOL) message = dumps("ERROR: %s" % e, HIGHEST_PROTOCOL)
self.push(message + CSPROTO.END) self.push(message + CSPROTO.END)
##
# Handles an communication errors in request.
#
def handle_error(self): def handle_error(self):
e1, e2 = formatExceptionInfo() try:
logSys.error("Unexpected communication error: %s" % str(e2)) e1, e2 = formatExceptionInfo()
logSys.error(traceback.format_exc().splitlines()) logSys.error("Unexpected communication error: %s" % str(e2))
self.close() logSys.error(traceback.format_exc().splitlines())
# Sends the response to the client.
message = dumps("ERROR: %s" % e2, HIGHEST_PROTOCOL)
self.push(message + CSPROTO.END)
except Exception as e: # pragma: no cover - normally unreachable
pass
self.close_when_done()
def loop(active, timeout=None, use_poll=False): def loop(active, timeout=None, use_poll=False):
"""Custom event loop implementation """Custom event loop implementation
@ -125,7 +133,7 @@ def loop(active, timeout=None, use_poll=False):
poll(timeout) poll(timeout)
if errCount: if errCount:
errCount -= 1 errCount -= 1
except Exception as e: # pragma: no cover except Exception as e:
if not active(): if not active():
break break
errCount += 1 errCount += 1
@ -181,7 +189,7 @@ class AsyncServer(asyncore.dispatcher):
# @param sock: socket file. # @param sock: socket file.
# @param force: remove the socket file if exists. # @param force: remove the socket file if exists.
def start(self, sock, force, use_poll=False): def start(self, sock, force, timeout=None, use_poll=False):
self.__worker = threading.current_thread() self.__worker = threading.current_thread()
self.__sock = sock self.__sock = sock
# Remove socket # Remove socket
@ -207,12 +215,11 @@ class AsyncServer(asyncore.dispatcher):
if self.onstart: if self.onstart:
self.onstart() self.onstart()
# Event loop as long as active: # Event loop as long as active:
loop(lambda: self.__loop, use_poll=use_poll) loop(lambda: self.__loop, timeout=timeout, use_poll=use_poll)
self.__active = False self.__active = False
# Cleanup all # Cleanup all
self.stop() self.stop()
def close(self): def close(self):
stopflg = False stopflg = False
if self.__active: if self.__active:

View File

@ -32,7 +32,7 @@ import time
import unittest import unittest
from .. import protocol from .. import protocol
from ..server.asyncserver import RequestHandler, AsyncServer, AsyncServerException from ..server.asyncserver import asyncore, RequestHandler, AsyncServer, AsyncServerException
from ..server.utils import Utils from ..server.utils import Utils
from ..client.csocket import CSocket from ..client.csocket import CSocket
@ -42,8 +42,10 @@ from .utils import LogCaptureTestCase
def TestMsgError(*args): def TestMsgError(*args):
raise Exception('test unpickle error') raise Exception('test unpickle error')
class TestMsg(object): class TestMsg(object):
def __init__(self, unpickle=(TestMsgError, ())):
self.unpickle = unpickle
def __reduce__(self): def __reduce__(self):
return (TestMsgError, ()) return self.unpickle
class Socket(LogCaptureTestCase): class Socket(LogCaptureTestCase):
@ -63,7 +65,7 @@ class Socket(LogCaptureTestCase):
"""Call after every test case.""" """Call after every test case."""
if self.serverThread: if self.serverThread:
self.server.stop(); # stop if not already stopped self.server.stop(); # stop if not already stopped
self.serverThread.join() self._stopServerThread()
LogCaptureTestCase.tearDown(self) LogCaptureTestCase.tearDown(self)
@staticmethod @staticmethod
@ -71,14 +73,21 @@ class Socket(LogCaptureTestCase):
"""Test transmitter proceed method which just returns first arg""" """Test transmitter proceed method which just returns first arg"""
return message return message
def _createServerThread(self): def _createServerThread(self, force=False):
# start in separate thread : # start in separate thread :
self.serverThread = serverThread = threading.Thread( self.serverThread = serverThread = threading.Thread(
target=self.server.start, args=(self.sock_name, False)) target=self.server.start, args=(self.sock_name, force))
serverThread.daemon = True serverThread.daemon = True
serverThread.start() serverThread.start()
self.assertTrue(Utils.wait_for(self.server.isActive, unittest.F2B.maxWaitTime(10))) self.assertTrue(Utils.wait_for(self.server.isActive, unittest.F2B.maxWaitTime(10)))
return serverThread return serverThread
def _stopServerThread(self):
serverThread = self.serverThread
# wait for end of thread :
Utils.wait_for(lambda: not serverThread.isAlive()
or serverThread.join(Utils.DEFAULT_SLEEP_TIME), unittest.F2B.maxWaitTime(10))
self.serverThread = None
def testStopPerCloseUnexpected(self): def testStopPerCloseUnexpected(self):
# start in separate thread : # start in separate thread :
@ -86,8 +95,7 @@ class Socket(LogCaptureTestCase):
# unexpected stop directly after start: # unexpected stop directly after start:
self.server.close() self.server.close()
# wait for end of thread : # wait for end of thread :
Utils.wait_for(lambda: not serverThread.isAlive() self._stopServerThread()
or serverThread.join(Utils.DEFAULT_SLEEP_INTERVAL), unittest.F2B.maxWaitTime(10))
self.assertFalse(serverThread.isAlive()) self.assertFalse(serverThread.isAlive())
# clean : # clean :
self.server.stop() self.server.stop()
@ -128,17 +136,30 @@ class Socket(LogCaptureTestCase):
self.server.stop() self.server.stop()
# wait for end of thread : # wait for end of thread :
Utils.wait_for(lambda: not serverThread.isAlive() self._stopServerThread()
or serverThread.join(Utils.DEFAULT_SLEEP_INTERVAL), unittest.F2B.maxWaitTime(10))
self.assertFalse(serverThread.isAlive()) self.assertFalse(serverThread.isAlive())
self.assertFalse(self.server.isActive()) self.assertFalse(self.server.isActive())
self.assertFalse(os.path.exists(self.sock_name)) self.assertFalse(os.path.exists(self.sock_name))
def testSocketConnectBroken(self): def testSocketConnectBroken(self):
# start in separate thread : # start in separate thread :
serverThread = self._createServerThread() serverThread = self._createServerThread()
client = Utils.wait_for(self._serverSocket, 2) client = Utils.wait_for(self._serverSocket, 2)
# unexpected stop during message body:
testMessage = ["A", "test", "message", [protocol.CSPROTO.END]]
org_handler = RequestHandler.found_terminator
try:
RequestHandler.found_terminator = lambda self: self.close()
self.assertRaisesRegexp(RuntimeError, r"socket connection broken",
lambda: client.send(testMessage, timeout=unittest.F2B.maxWaitTime(10)))
finally:
RequestHandler.found_terminator = org_handler
def testStopByCommunicate(self):
# start in separate thread :
serverThread = self._createServerThread()
client = Utils.wait_for(self._serverSocket, 2)
testMessage = ["A", "test", "message"] testMessage = ["A", "test", "message"]
self.assertEqual(client.send(testMessage), testMessage) self.assertEqual(client.send(testMessage), testMessage)
@ -146,12 +167,40 @@ class Socket(LogCaptureTestCase):
org_handler = RequestHandler.found_terminator org_handler = RequestHandler.found_terminator
try: try:
RequestHandler.found_terminator = lambda self: TestMsgError() RequestHandler.found_terminator = lambda self: TestMsgError()
self.assertRaisesRegexp(RuntimeError, r"socket connection broken", client.send, testMessage) #self.assertRaisesRegexp(RuntimeError, r"socket connection broken", client.send, testMessage)
self.assertEqual(client.send(testMessage), 'ERROR: test unpickle error')
finally: finally:
RequestHandler.found_terminator = org_handler RequestHandler.found_terminator = org_handler
# check errors were logged:
self.assertLogged("Unexpected communication error", "test unpickle error", all=True) self.assertLogged("Unexpected communication error", "test unpickle error", all=True)
self.server.stop() self.server.stop()
# wait for end of thread :
self._stopServerThread()
self.assertFalse(serverThread.isAlive())
def testLoopErrors(self):
# replace poll handler to produce error in loop-cycle:
org_poll = asyncore.poll
err = {'cntr': 0}
def _produce_error(*args):
err['cntr'] += 1
if err['cntr'] < 50:
raise RuntimeError('test errors in poll')
return org_poll(*args)
try:
asyncore.poll = _produce_error
serverThread = self._createServerThread()
# wait all-cases processed:
self.assertTrue(Utils.wait_for(lambda: err['cntr'] > 50, unittest.F2B.maxWaitTime(10)))
finally:
# restore:
asyncore.poll = org_poll
# check errors were logged:
self.assertLogged("Server connection was closed: test errors in poll",
"Too many errors - stop logging connection errors", all=True)
def testSocketForce(self): def testSocketForce(self):
open(self.sock_name, 'w').close() # Create sock file open(self.sock_name, 'w').close() # Create sock file
@ -160,16 +209,12 @@ class Socket(LogCaptureTestCase):
AsyncServerException, self.server.start, self.sock_name, False) AsyncServerException, self.server.start, self.sock_name, False)
# Try again with force set # Try again with force set
serverThread = threading.Thread( serverThread = self._createServerThread(True)
target=self.server.start, args=(self.sock_name, True))
serverThread.daemon = True
serverThread.start()
self.assertTrue(Utils.wait_for(self.server.isActive, unittest.F2B.maxWaitTime(10)))
self.server.stop() self.server.stop()
# wait for end of thread : # wait for end of thread :
Utils.wait_for(lambda: not serverThread.isAlive() self._stopServerThread()
or serverThread.join(Utils.DEFAULT_SLEEP_INTERVAL), unittest.F2B.maxWaitTime(10)) self.assertFalse(serverThread.isAlive())
self.assertFalse(self.server.isActive()) self.assertFalse(self.server.isActive())
self.assertFalse(os.path.exists(self.sock_name)) self.assertFalse(os.path.exists(self.sock_name))