mirror of https://github.com/fail2ban/fail2ban
socket, asyncserver: several fixes, python version dependency removed + test coverage extended
parent
9129a414e3
commit
4b53c6b975
|
@ -45,13 +45,13 @@ class CSocket:
|
||||||
def __del__(self):
|
def __del__(self):
|
||||||
self.close(False)
|
self.close(False)
|
||||||
|
|
||||||
def send(self, msg):
|
def send(self, msg, nonblocking=False, timeout=None):
|
||||||
# Convert every list member to string
|
# Convert every list member to string
|
||||||
obj = dumps(map(
|
obj = dumps(map(
|
||||||
lambda m: str(m) if not isinstance(m, (list, dict, set)) else m, msg),
|
lambda m: str(m) if not isinstance(m, (list, dict, set)) else m, msg),
|
||||||
HIGHEST_PROTOCOL)
|
HIGHEST_PROTOCOL)
|
||||||
self.__csock.send(obj + CSPROTO.END)
|
self.__csock.send(obj + CSPROTO.END)
|
||||||
return self.receive(self.__csock)
|
return self.receive(self.__csock, nonblocking, timeout)
|
||||||
|
|
||||||
def settimeout(self, timeout):
|
def settimeout(self, timeout):
|
||||||
self.__csock.settimeout(timeout if timeout != -1 else self.__deftout)
|
self.__csock.settimeout(timeout if timeout != -1 else self.__deftout)
|
||||||
|
@ -65,11 +65,13 @@ class CSocket:
|
||||||
self.__csock = None
|
self.__csock = None
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def receive(sock):
|
def receive(sock, nonblocking=False, timeout=None):
|
||||||
msg = CSPROTO.EMPTY
|
msg = CSPROTO.EMPTY
|
||||||
|
if nonblocking: sock.setblocking(0)
|
||||||
|
if timeout: sock.settimeout(timeout)
|
||||||
while msg.rfind(CSPROTO.END) == -1:
|
while msg.rfind(CSPROTO.END) == -1:
|
||||||
chunk = sock.recv(6)
|
chunk = sock.recv(512)
|
||||||
if chunk == '':
|
if chunk in ('', b''): # python 3.x may return b'' instead of ''
|
||||||
raise RuntimeError("socket connection broken")
|
raise RuntimeError("socket connection broken")
|
||||||
msg = msg + chunk
|
msg = msg + chunk
|
||||||
return loads(msg)
|
return loads(msg)
|
||||||
|
|
|
@ -95,12 +95,20 @@ class RequestHandler(asynchat.async_chat):
|
||||||
message = dumps("ERROR: %s" % e, HIGHEST_PROTOCOL)
|
message = dumps("ERROR: %s" % e, HIGHEST_PROTOCOL)
|
||||||
self.push(message + CSPROTO.END)
|
self.push(message + CSPROTO.END)
|
||||||
|
|
||||||
|
##
|
||||||
|
# Handles an communication errors in request.
|
||||||
|
#
|
||||||
def handle_error(self):
|
def handle_error(self):
|
||||||
e1, e2 = formatExceptionInfo()
|
try:
|
||||||
logSys.error("Unexpected communication error: %s" % str(e2))
|
e1, e2 = formatExceptionInfo()
|
||||||
logSys.error(traceback.format_exc().splitlines())
|
logSys.error("Unexpected communication error: %s" % str(e2))
|
||||||
self.close()
|
logSys.error(traceback.format_exc().splitlines())
|
||||||
|
# Sends the response to the client.
|
||||||
|
message = dumps("ERROR: %s" % e2, HIGHEST_PROTOCOL)
|
||||||
|
self.push(message + CSPROTO.END)
|
||||||
|
except Exception as e: # pragma: no cover - normally unreachable
|
||||||
|
pass
|
||||||
|
self.close_when_done()
|
||||||
|
|
||||||
|
|
||||||
def loop(active, timeout=None, use_poll=False):
|
def loop(active, timeout=None, use_poll=False):
|
||||||
|
@ -125,7 +133,7 @@ def loop(active, timeout=None, use_poll=False):
|
||||||
poll(timeout)
|
poll(timeout)
|
||||||
if errCount:
|
if errCount:
|
||||||
errCount -= 1
|
errCount -= 1
|
||||||
except Exception as e: # pragma: no cover
|
except Exception as e:
|
||||||
if not active():
|
if not active():
|
||||||
break
|
break
|
||||||
errCount += 1
|
errCount += 1
|
||||||
|
@ -181,7 +189,7 @@ class AsyncServer(asyncore.dispatcher):
|
||||||
# @param sock: socket file.
|
# @param sock: socket file.
|
||||||
# @param force: remove the socket file if exists.
|
# @param force: remove the socket file if exists.
|
||||||
|
|
||||||
def start(self, sock, force, use_poll=False):
|
def start(self, sock, force, timeout=None, use_poll=False):
|
||||||
self.__worker = threading.current_thread()
|
self.__worker = threading.current_thread()
|
||||||
self.__sock = sock
|
self.__sock = sock
|
||||||
# Remove socket
|
# Remove socket
|
||||||
|
@ -207,12 +215,11 @@ class AsyncServer(asyncore.dispatcher):
|
||||||
if self.onstart:
|
if self.onstart:
|
||||||
self.onstart()
|
self.onstart()
|
||||||
# Event loop as long as active:
|
# Event loop as long as active:
|
||||||
loop(lambda: self.__loop, use_poll=use_poll)
|
loop(lambda: self.__loop, timeout=timeout, use_poll=use_poll)
|
||||||
self.__active = False
|
self.__active = False
|
||||||
# Cleanup all
|
# Cleanup all
|
||||||
self.stop()
|
self.stop()
|
||||||
|
|
||||||
|
|
||||||
def close(self):
|
def close(self):
|
||||||
stopflg = False
|
stopflg = False
|
||||||
if self.__active:
|
if self.__active:
|
||||||
|
|
|
@ -32,7 +32,7 @@ import time
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from .. import protocol
|
from .. import protocol
|
||||||
from ..server.asyncserver import RequestHandler, AsyncServer, AsyncServerException
|
from ..server.asyncserver import asyncore, RequestHandler, AsyncServer, AsyncServerException
|
||||||
from ..server.utils import Utils
|
from ..server.utils import Utils
|
||||||
from ..client.csocket import CSocket
|
from ..client.csocket import CSocket
|
||||||
|
|
||||||
|
@ -42,8 +42,10 @@ from .utils import LogCaptureTestCase
|
||||||
def TestMsgError(*args):
|
def TestMsgError(*args):
|
||||||
raise Exception('test unpickle error')
|
raise Exception('test unpickle error')
|
||||||
class TestMsg(object):
|
class TestMsg(object):
|
||||||
|
def __init__(self, unpickle=(TestMsgError, ())):
|
||||||
|
self.unpickle = unpickle
|
||||||
def __reduce__(self):
|
def __reduce__(self):
|
||||||
return (TestMsgError, ())
|
return self.unpickle
|
||||||
|
|
||||||
|
|
||||||
class Socket(LogCaptureTestCase):
|
class Socket(LogCaptureTestCase):
|
||||||
|
@ -63,7 +65,7 @@ class Socket(LogCaptureTestCase):
|
||||||
"""Call after every test case."""
|
"""Call after every test case."""
|
||||||
if self.serverThread:
|
if self.serverThread:
|
||||||
self.server.stop(); # stop if not already stopped
|
self.server.stop(); # stop if not already stopped
|
||||||
self.serverThread.join()
|
self._stopServerThread()
|
||||||
LogCaptureTestCase.tearDown(self)
|
LogCaptureTestCase.tearDown(self)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@ -71,23 +73,29 @@ class Socket(LogCaptureTestCase):
|
||||||
"""Test transmitter proceed method which just returns first arg"""
|
"""Test transmitter proceed method which just returns first arg"""
|
||||||
return message
|
return message
|
||||||
|
|
||||||
def _createServerThread(self):
|
def _createServerThread(self, force=False):
|
||||||
# start in separate thread :
|
# start in separate thread :
|
||||||
self.serverThread = serverThread = threading.Thread(
|
self.serverThread = serverThread = threading.Thread(
|
||||||
target=self.server.start, args=(self.sock_name, False))
|
target=self.server.start, args=(self.sock_name, force))
|
||||||
serverThread.daemon = True
|
serverThread.daemon = True
|
||||||
serverThread.start()
|
serverThread.start()
|
||||||
self.assertTrue(Utils.wait_for(self.server.isActive, unittest.F2B.maxWaitTime(10)))
|
self.assertTrue(Utils.wait_for(self.server.isActive, unittest.F2B.maxWaitTime(10)))
|
||||||
return serverThread
|
return serverThread
|
||||||
|
|
||||||
|
def _stopServerThread(self):
|
||||||
|
serverThread = self.serverThread
|
||||||
|
# wait for end of thread :
|
||||||
|
Utils.wait_for(lambda: not serverThread.isAlive()
|
||||||
|
or serverThread.join(Utils.DEFAULT_SLEEP_TIME), unittest.F2B.maxWaitTime(10))
|
||||||
|
self.serverThread = None
|
||||||
|
|
||||||
def testStopPerCloseUnexpected(self):
|
def testStopPerCloseUnexpected(self):
|
||||||
# start in separate thread :
|
# start in separate thread :
|
||||||
serverThread = self._createServerThread()
|
serverThread = self._createServerThread()
|
||||||
# unexpected stop directly after start:
|
# unexpected stop directly after start:
|
||||||
self.server.close()
|
self.server.close()
|
||||||
# wait for end of thread :
|
# wait for end of thread :
|
||||||
Utils.wait_for(lambda: not serverThread.isAlive()
|
self._stopServerThread()
|
||||||
or serverThread.join(Utils.DEFAULT_SLEEP_INTERVAL), unittest.F2B.maxWaitTime(10))
|
|
||||||
self.assertFalse(serverThread.isAlive())
|
self.assertFalse(serverThread.isAlive())
|
||||||
# clean :
|
# clean :
|
||||||
self.server.stop()
|
self.server.stop()
|
||||||
|
@ -128,17 +136,30 @@ class Socket(LogCaptureTestCase):
|
||||||
|
|
||||||
self.server.stop()
|
self.server.stop()
|
||||||
# wait for end of thread :
|
# wait for end of thread :
|
||||||
Utils.wait_for(lambda: not serverThread.isAlive()
|
self._stopServerThread()
|
||||||
or serverThread.join(Utils.DEFAULT_SLEEP_INTERVAL), unittest.F2B.maxWaitTime(10))
|
|
||||||
self.assertFalse(serverThread.isAlive())
|
self.assertFalse(serverThread.isAlive())
|
||||||
self.assertFalse(self.server.isActive())
|
self.assertFalse(self.server.isActive())
|
||||||
self.assertFalse(os.path.exists(self.sock_name))
|
self.assertFalse(os.path.exists(self.sock_name))
|
||||||
|
|
||||||
|
|
||||||
def testSocketConnectBroken(self):
|
def testSocketConnectBroken(self):
|
||||||
# start in separate thread :
|
# start in separate thread :
|
||||||
serverThread = self._createServerThread()
|
serverThread = self._createServerThread()
|
||||||
client = Utils.wait_for(self._serverSocket, 2)
|
client = Utils.wait_for(self._serverSocket, 2)
|
||||||
|
# unexpected stop during message body:
|
||||||
|
testMessage = ["A", "test", "message", [protocol.CSPROTO.END]]
|
||||||
|
|
||||||
|
org_handler = RequestHandler.found_terminator
|
||||||
|
try:
|
||||||
|
RequestHandler.found_terminator = lambda self: self.close()
|
||||||
|
self.assertRaisesRegexp(RuntimeError, r"socket connection broken",
|
||||||
|
lambda: client.send(testMessage, timeout=unittest.F2B.maxWaitTime(10)))
|
||||||
|
finally:
|
||||||
|
RequestHandler.found_terminator = org_handler
|
||||||
|
|
||||||
|
def testStopByCommunicate(self):
|
||||||
|
# start in separate thread :
|
||||||
|
serverThread = self._createServerThread()
|
||||||
|
client = Utils.wait_for(self._serverSocket, 2)
|
||||||
|
|
||||||
testMessage = ["A", "test", "message"]
|
testMessage = ["A", "test", "message"]
|
||||||
self.assertEqual(client.send(testMessage), testMessage)
|
self.assertEqual(client.send(testMessage), testMessage)
|
||||||
|
@ -146,12 +167,40 @@ class Socket(LogCaptureTestCase):
|
||||||
org_handler = RequestHandler.found_terminator
|
org_handler = RequestHandler.found_terminator
|
||||||
try:
|
try:
|
||||||
RequestHandler.found_terminator = lambda self: TestMsgError()
|
RequestHandler.found_terminator = lambda self: TestMsgError()
|
||||||
self.assertRaisesRegexp(RuntimeError, r"socket connection broken", client.send, testMessage)
|
#self.assertRaisesRegexp(RuntimeError, r"socket connection broken", client.send, testMessage)
|
||||||
|
self.assertEqual(client.send(testMessage), 'ERROR: test unpickle error')
|
||||||
finally:
|
finally:
|
||||||
RequestHandler.found_terminator = org_handler
|
RequestHandler.found_terminator = org_handler
|
||||||
|
|
||||||
|
# check errors were logged:
|
||||||
self.assertLogged("Unexpected communication error", "test unpickle error", all=True)
|
self.assertLogged("Unexpected communication error", "test unpickle error", all=True)
|
||||||
|
|
||||||
self.server.stop()
|
self.server.stop()
|
||||||
|
# wait for end of thread :
|
||||||
|
self._stopServerThread()
|
||||||
|
self.assertFalse(serverThread.isAlive())
|
||||||
|
|
||||||
|
def testLoopErrors(self):
|
||||||
|
# replace poll handler to produce error in loop-cycle:
|
||||||
|
org_poll = asyncore.poll
|
||||||
|
err = {'cntr': 0}
|
||||||
|
def _produce_error(*args):
|
||||||
|
err['cntr'] += 1
|
||||||
|
if err['cntr'] < 50:
|
||||||
|
raise RuntimeError('test errors in poll')
|
||||||
|
return org_poll(*args)
|
||||||
|
|
||||||
|
try:
|
||||||
|
asyncore.poll = _produce_error
|
||||||
|
serverThread = self._createServerThread()
|
||||||
|
# wait all-cases processed:
|
||||||
|
self.assertTrue(Utils.wait_for(lambda: err['cntr'] > 50, unittest.F2B.maxWaitTime(10)))
|
||||||
|
finally:
|
||||||
|
# restore:
|
||||||
|
asyncore.poll = org_poll
|
||||||
|
# check errors were logged:
|
||||||
|
self.assertLogged("Server connection was closed: test errors in poll",
|
||||||
|
"Too many errors - stop logging connection errors", all=True)
|
||||||
|
|
||||||
def testSocketForce(self):
|
def testSocketForce(self):
|
||||||
open(self.sock_name, 'w').close() # Create sock file
|
open(self.sock_name, 'w').close() # Create sock file
|
||||||
|
@ -160,16 +209,12 @@ class Socket(LogCaptureTestCase):
|
||||||
AsyncServerException, self.server.start, self.sock_name, False)
|
AsyncServerException, self.server.start, self.sock_name, False)
|
||||||
|
|
||||||
# Try again with force set
|
# Try again with force set
|
||||||
serverThread = threading.Thread(
|
serverThread = self._createServerThread(True)
|
||||||
target=self.server.start, args=(self.sock_name, True))
|
|
||||||
serverThread.daemon = True
|
|
||||||
serverThread.start()
|
|
||||||
self.assertTrue(Utils.wait_for(self.server.isActive, unittest.F2B.maxWaitTime(10)))
|
|
||||||
|
|
||||||
self.server.stop()
|
self.server.stop()
|
||||||
# wait for end of thread :
|
# wait for end of thread :
|
||||||
Utils.wait_for(lambda: not serverThread.isAlive()
|
self._stopServerThread()
|
||||||
or serverThread.join(Utils.DEFAULT_SLEEP_INTERVAL), unittest.F2B.maxWaitTime(10))
|
self.assertFalse(serverThread.isAlive())
|
||||||
self.assertFalse(self.server.isActive())
|
self.assertFalse(self.server.isActive())
|
||||||
self.assertFalse(os.path.exists(self.sock_name))
|
self.assertFalse(os.path.exists(self.sock_name))
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue