asyncserver (asyncore) code fixed and test cases repaired (always delete temp files, wait for end of thread/server, etc)

definitely closes gh-161, also other usage of asyncore event loop (in test_smtp.py)
repair cache in ipToName (can returns None), precaching of invalid IPs (according to RFC 5737) to stop endless wait for resolving it in test cases.
pull/1346/head
sebres 9 years ago
parent 770c219ab6
commit 72f29e9061

@ -27,12 +27,14 @@ __license__ = "GPL"
from pickle import dumps, loads, HIGHEST_PROTOCOL from pickle import dumps, loads, HIGHEST_PROTOCOL
import asynchat import asynchat
import asyncore import asyncore
import errno
import fcntl import fcntl
import os import os
import socket import socket
import sys import sys
import traceback import traceback
from .utils import Utils
from ..protocol import CSPROTO from ..protocol import CSPROTO
from ..helpers import getLogger,formatExceptionInfo from ..helpers import getLogger,formatExceptionInfo
@ -89,6 +91,29 @@ class RequestHandler(asynchat.async_chat):
self.close() self.close()
def loop(active, timeout=None, use_poll=False):
# Use poll instead of loop, because of recognition of active flag,
# because of loop timeout mistake: different in poll and poll2 (sec vs ms),
# and to prevent sporadical errors like EBADF 'Bad file descriptor' etc. (see gh-161)
if timeout is None:
timeout = Utils.DEFAULT_SLEEP_TIME
poll = asyncore.poll
if use_poll and asyncore.poll2 and hasattr(asyncore.select, 'poll'): # pragma: no cover
logSys.debug('Server listener (select) uses poll')
# poll2 expected a timeout in milliseconds (but poll and loop in seconds):
timeout = float(timeout) / 1000
poll = asyncore.poll2
# Poll as long as active:
while active():
try:
poll(timeout)
except Exception as e: # pragma: no cover
if e.args[0] in (errno.ENOTCONN, errno.EBADF): # (errno.EBADF, 'Bad file descriptor')
logSys.info('Server connection was closed: %s', str(e))
else:
logSys.error('Server connection was closed: %s', str(e))
## ##
# Asynchronous server class. # Asynchronous server class.
# #
@ -102,6 +127,7 @@ class AsyncServer(asyncore.dispatcher):
self.__transmitter = transmitter self.__transmitter = transmitter
self.__sock = "/var/run/fail2ban/fail2ban.sock" self.__sock = "/var/run/fail2ban/fail2ban.sock"
self.__init = False self.__init = False
self.__active = False
## ##
# Returns False as we only read the socket first. # Returns False as we only read the socket first.
@ -129,7 +155,7 @@ class AsyncServer(asyncore.dispatcher):
# @param sock: socket file. # @param sock: socket file.
# @param force: remove the socket file if exists. # @param force: remove the socket file if exists.
def start(self, sock, force): def start(self, sock, force, use_poll=False):
self.__sock = sock self.__sock = sock
# Remove socket # Remove socket
if os.path.exists(sock): if os.path.exists(sock):
@ -149,28 +175,31 @@ class AsyncServer(asyncore.dispatcher):
AsyncServer.__markCloseOnExec(self.socket) AsyncServer.__markCloseOnExec(self.socket)
self.listen(1) self.listen(1)
# Sets the init flag. # Sets the init flag.
self.__init = True self.__init = self.__active = True
# TODO Add try..catch # Event loop as long as active:
# There's a bug report for Python 2.6/3.0 that use_poll=True yields some 2.5 incompatibilities: loop(lambda: self.__active)
if (sys.version_info >= (2, 7) and sys.version_info < (2, 8)) \ # Cleanup all
or (sys.version_info >= (3, 4)): # if python 2.7 ... self.stop()
logSys.debug("Detected Python 2.7. asyncore.loop() using poll")
asyncore.loop(use_poll=True) # workaround for the "Bad file descriptor" issue on Python 2.7, gh-161
else: def close(self):
asyncore.loop(use_poll=False) # fixes the "Unexpected communication problem" issue on Python 2.6 and 3.0 if self.__active:
asyncore.dispatcher.close(self)
# Remove socket (file) only if it was created:
if self.__init and os.path.exists(self.__sock):
logSys.debug("Removed socket file " + self.__sock)
os.remove(self.__sock)
logSys.debug("Socket shutdown")
self.__active = False
## ##
# Stops the communication server. # Stops the communication server.
def stop(self): def stop(self):
if self.__init: self.close()
# Only closes the socket if it was initialized first.
self.close() def isActive(self):
# Remove socket return self.__active
if os.path.exists(self.__sock):
logSys.debug("Removed socket file " + self.__sock)
os.remove(self.__sock)
logSys.debug("Socket shutdown")
## ##
# Marks socket as close-on-exec to avoid leaking file descriptors when # Marks socket as close-on-exec to avoid leaking file descriptors when

@ -1023,8 +1023,8 @@ class DNSUtils:
@staticmethod @staticmethod
def ipToName(ip): def ipToName(ip):
# cache, also prevent long wait during retrieving of name for wrong addresses, lazy dns: # cache, also prevent long wait during retrieving of name for wrong addresses, lazy dns:
v = DNSUtils.CACHE_ipToName.get(ip) v = DNSUtils.CACHE_ipToName.get(ip, ())
if v is not None: if v != ():
return v return v
# retrieve name # retrieve name
try: try:

@ -19,7 +19,6 @@
import os import os
import smtpd import smtpd
import asyncore
import threading import threading
import unittest import unittest
import sys import sys
@ -30,7 +29,7 @@ else:
from ..dummyjail import DummyJail from ..dummyjail import DummyJail
from ..utils import CONFIG_DIR from ..utils import CONFIG_DIR, asyncserver
class TestSMTPServer(smtpd.SMTPServer): class TestSMTPServer(smtpd.SMTPServer):
@ -46,8 +45,6 @@ class SMTPActionTest(unittest.TestCase):
def setUp(self): def setUp(self):
"""Call before every test case.""" """Call before every test case."""
unittest.F2B.SkipIfNoNetwork()
self.jail = DummyJail() self.jail = DummyJail()
pythonModule = os.path.join(CONFIG_DIR, "action.d", "smtp.py") pythonModule = os.path.join(CONFIG_DIR, "action.d", "smtp.py")
pythonModuleName = os.path.basename(pythonModule.rstrip(".py")) pythonModuleName = os.path.basename(pythonModule.rstrip(".py"))
@ -64,13 +61,16 @@ class SMTPActionTest(unittest.TestCase):
self.action = customActionModule.Action( self.action = customActionModule.Action(
self.jail, "test", host="127.0.0.1:%i" % port) self.jail, "test", host="127.0.0.1:%i" % port)
## because of bug in loop (see loop in asyncserver.py) use it's loop instead of asyncore.loop:
self._active = True
self._loop_thread = threading.Thread( self._loop_thread = threading.Thread(
target=asyncore.loop, kwargs={'timeout': 1}) target=asyncserver.loop, kwargs={'active': lambda: self._active})
self._loop_thread.start() self._loop_thread.start()
def tearDown(self): def tearDown(self):
"""Call after every test case.""" """Call after every test case."""
self.smtpd.close() self.smtpd.close()
self._active = False
self._loop_thread.join() self._loop_thread.join()
def testStart(self): def testStart(self):

@ -62,11 +62,14 @@ class TransmitterBase(unittest.TestCase):
def setUp(self): def setUp(self):
"""Call before every test case.""" """Call before every test case."""
self.transm = self.server._Server__transm self.transm = self.server._Server__transm
self.tmp_files = []
sock_fd, sock_name = tempfile.mkstemp('fail2ban.sock', 'transmitter') sock_fd, sock_name = tempfile.mkstemp('fail2ban.sock', 'transmitter')
os.close(sock_fd) os.close(sock_fd)
self.tmp_files.append(sock_name)
pidfile_fd, pidfile_name = tempfile.mkstemp( pidfile_fd, pidfile_name = tempfile.mkstemp(
'fail2ban.pid', 'transmitter') 'fail2ban.pid', 'transmitter')
os.close(pidfile_fd) os.close(pidfile_fd)
self.tmp_files.append(pidfile_name)
self.server.start(sock_name, pidfile_name, force=False) self.server.start(sock_name, pidfile_name, force=False)
self.jailName = "TestJail1" self.jailName = "TestJail1"
self.server.addJail(self.jailName, "auto") self.server.addJail(self.jailName, "auto")
@ -74,6 +77,9 @@ class TransmitterBase(unittest.TestCase):
def tearDown(self): def tearDown(self):
"""Call after every test case.""" """Call after every test case."""
self.server.quit() self.server.quit()
for f in self.tmp_files:
if os.path.exists(f):
os.remove(f)
def setGetTest(self, cmd, inValue, outValue=(None,), outCode=0, jail=None, repr_=False): def setGetTest(self, cmd, inValue, outValue=(None,), outCode=0, jail=None, repr_=False):
"""Process set/get commands and compare both return values """Process set/get commands and compare both return values

@ -55,6 +55,24 @@ class Socket(unittest.TestCase):
"""Test transmitter proceed method which just returns first arg""" """Test transmitter proceed method which just returns first arg"""
return message return message
def testStopPerCloseUnexpected(self):
# start in separate thread :
serverThread = threading.Thread(
target=self.server.start, args=(self.sock_name, False))
serverThread.daemon = True
serverThread.start()
self.assertTrue(Utils.wait_for(self.server.isActive, unittest.F2B.maxWaitTime(10)))
# unexpected stop directly after start:
self.server.close()
# wait for end of thread :
Utils.wait_for(lambda: not serverThread.isAlive()
or serverThread.join(Utils.DEFAULT_SLEEP_INTERVAL), unittest.F2B.maxWaitTime(10))
self.assertFalse(serverThread.isAlive())
# clean :
self.server.stop()
self.assertFalse(self.server.isActive())
self.assertFalse(os.path.exists(self.sock_name))
def _serverSocket(self): def _serverSocket(self):
try: try:
return CSocket(self.sock_name) return CSocket(self.sock_name)
@ -66,6 +84,7 @@ class Socket(unittest.TestCase):
target=self.server.start, args=(self.sock_name, False)) target=self.server.start, args=(self.sock_name, False))
serverThread.daemon = True serverThread.daemon = True
serverThread.start() serverThread.start()
self.assertTrue(Utils.wait_for(self.server.isActive, unittest.F2B.maxWaitTime(10)))
time.sleep(Utils.DEFAULT_SLEEP_TIME) time.sleep(Utils.DEFAULT_SLEEP_TIME)
client = Utils.wait_for(self._serverSocket, 2) client = Utils.wait_for(self._serverSocket, 2)
@ -78,7 +97,11 @@ class Socket(unittest.TestCase):
client.close() client.close()
self.server.stop() self.server.stop()
serverThread.join(Utils.DEFAULT_SLEEP_TIME) # wait for end of thread :
Utils.wait_for(lambda: not serverThread.isAlive()
or serverThread.join(Utils.DEFAULT_SLEEP_INTERVAL), unittest.F2B.maxWaitTime(10))
self.assertFalse(serverThread.isAlive())
self.assertFalse(self.server.isActive())
self.assertFalse(os.path.exists(self.sock_name)) self.assertFalse(os.path.exists(self.sock_name))
def testSocketForce(self): def testSocketForce(self):
@ -92,10 +115,13 @@ class Socket(unittest.TestCase):
target=self.server.start, args=(self.sock_name, True)) target=self.server.start, args=(self.sock_name, True))
serverThread.daemon = True serverThread.daemon = True
serverThread.start() serverThread.start()
time.sleep(Utils.DEFAULT_SLEEP_TIME) self.assertTrue(Utils.wait_for(self.server.isActive, unittest.F2B.maxWaitTime(10)))
self.server.stop() self.server.stop()
serverThread.join(Utils.DEFAULT_SLEEP_TIME) # wait for end of thread :
Utils.wait_for(lambda: not serverThread.isAlive()
or serverThread.join(Utils.DEFAULT_SLEEP_INTERVAL), unittest.F2B.maxWaitTime(10))
self.assertFalse(self.server.isActive())
self.assertFalse(os.path.exists(self.sock_name)) self.assertFalse(os.path.exists(self.sock_name))

@ -35,6 +35,8 @@ from ..helpers import getLogger
from ..server.filter import DNSUtils from ..server.filter import DNSUtils
from ..server.mytime import MyTime from ..server.mytime import MyTime
from ..server.utils import Utils from ..server.utils import Utils
# for action_d.test_smtp :
from ..server import asyncserver
logSys = getLogger(__name__) logSys = getLogger(__name__)
@ -61,6 +63,10 @@ class F2B(optparse.Values):
pass pass
def SkipIfNoNetwork(self): def SkipIfNoNetwork(self):
pass pass
def maxWaitTime(self,wtime):
if self.fast:
wtime = float(wtime) / 10
return wtime
def initTests(opts): def initTests(opts):
@ -87,6 +93,13 @@ def initTests(opts):
def F2B_SkipIfNoNetwork(): def F2B_SkipIfNoNetwork():
raise unittest.SkipTest('Skip test because of "--no-network"') raise unittest.SkipTest('Skip test because of "--no-network"')
unittest.F2B.SkipIfNoNetwork = F2B_SkipIfNoNetwork unittest.F2B.SkipIfNoNetwork = F2B_SkipIfNoNetwork
# precache all invalid ip's (TEST-NET-1, ..., TEST-NET-3 according to RFC 5737):
c = DNSUtils.CACHE_ipToName
for i in xrange(255):
c.set('192.0.2.%s' % i, None)
c.set('198.51.100.%s' % i, None)
c.set('203.0.113.%s' % i, None)
def mtimesleep(): def mtimesleep():
# no sleep now should be necessary since polling tracks now not only # no sleep now should be necessary since polling tracks now not only

Loading…
Cancel
Save