asyncserver (asyncore) code fixed and test cases repaired (always delete temp files, wait for end of thread/server, etc)

definitely closes gh-161, also other usage of asyncore event loop (in test_smtp.py)
repair cache in ipToName (can returns None), precaching of invalid IPs (according to RFC 5737) to stop endless wait for resolving it in test cases.
pull/1346/head
sebres 2015-07-23 20:23:07 +02:00
parent 770c219ab6
commit 72f29e9061
6 changed files with 103 additions and 29 deletions

View File

@ -27,12 +27,14 @@ __license__ = "GPL"
from pickle import dumps, loads, HIGHEST_PROTOCOL
import asynchat
import asyncore
import errno
import fcntl
import os
import socket
import sys
import traceback
from .utils import Utils
from ..protocol import CSPROTO
from ..helpers import getLogger,formatExceptionInfo
@ -89,6 +91,29 @@ class RequestHandler(asynchat.async_chat):
self.close()
def loop(active, timeout=None, use_poll=False):
# Use poll instead of loop, because of recognition of active flag,
# because of loop timeout mistake: different in poll and poll2 (sec vs ms),
# and to prevent sporadical errors like EBADF 'Bad file descriptor' etc. (see gh-161)
if timeout is None:
timeout = Utils.DEFAULT_SLEEP_TIME
poll = asyncore.poll
if use_poll and asyncore.poll2 and hasattr(asyncore.select, 'poll'): # pragma: no cover
logSys.debug('Server listener (select) uses poll')
# poll2 expected a timeout in milliseconds (but poll and loop in seconds):
timeout = float(timeout) / 1000
poll = asyncore.poll2
# Poll as long as active:
while active():
try:
poll(timeout)
except Exception as e: # pragma: no cover
if e.args[0] in (errno.ENOTCONN, errno.EBADF): # (errno.EBADF, 'Bad file descriptor')
logSys.info('Server connection was closed: %s', str(e))
else:
logSys.error('Server connection was closed: %s', str(e))
##
# Asynchronous server class.
#
@ -102,6 +127,7 @@ class AsyncServer(asyncore.dispatcher):
self.__transmitter = transmitter
self.__sock = "/var/run/fail2ban/fail2ban.sock"
self.__init = False
self.__active = False
##
# Returns False as we only read the socket first.
@ -129,7 +155,7 @@ class AsyncServer(asyncore.dispatcher):
# @param sock: socket file.
# @param force: remove the socket file if exists.
def start(self, sock, force):
def start(self, sock, force, use_poll=False):
self.__sock = sock
# Remove socket
if os.path.exists(sock):
@ -149,28 +175,31 @@ class AsyncServer(asyncore.dispatcher):
AsyncServer.__markCloseOnExec(self.socket)
self.listen(1)
# Sets the init flag.
self.__init = True
# TODO Add try..catch
# There's a bug report for Python 2.6/3.0 that use_poll=True yields some 2.5 incompatibilities:
if (sys.version_info >= (2, 7) and sys.version_info < (2, 8)) \
or (sys.version_info >= (3, 4)): # if python 2.7 ...
logSys.debug("Detected Python 2.7. asyncore.loop() using poll")
asyncore.loop(use_poll=True) # workaround for the "Bad file descriptor" issue on Python 2.7, gh-161
else:
asyncore.loop(use_poll=False) # fixes the "Unexpected communication problem" issue on Python 2.6 and 3.0
self.__init = self.__active = True
# Event loop as long as active:
loop(lambda: self.__active)
# Cleanup all
self.stop()
def close(self):
if self.__active:
asyncore.dispatcher.close(self)
# Remove socket (file) only if it was created:
if self.__init and os.path.exists(self.__sock):
logSys.debug("Removed socket file " + self.__sock)
os.remove(self.__sock)
logSys.debug("Socket shutdown")
self.__active = False
##
# Stops the communication server.
def stop(self):
if self.__init:
# Only closes the socket if it was initialized first.
self.close()
# Remove socket
if os.path.exists(self.__sock):
logSys.debug("Removed socket file " + self.__sock)
os.remove(self.__sock)
logSys.debug("Socket shutdown")
self.close()
def isActive(self):
return self.__active
##
# Marks socket as close-on-exec to avoid leaking file descriptors when

View File

@ -1023,8 +1023,8 @@ class DNSUtils:
@staticmethod
def ipToName(ip):
# cache, also prevent long wait during retrieving of name for wrong addresses, lazy dns:
v = DNSUtils.CACHE_ipToName.get(ip)
if v is not None:
v = DNSUtils.CACHE_ipToName.get(ip, ())
if v != ():
return v
# retrieve name
try:

View File

@ -19,7 +19,6 @@
import os
import smtpd
import asyncore
import threading
import unittest
import sys
@ -30,7 +29,7 @@ else:
from ..dummyjail import DummyJail
from ..utils import CONFIG_DIR
from ..utils import CONFIG_DIR, asyncserver
class TestSMTPServer(smtpd.SMTPServer):
@ -46,8 +45,6 @@ class SMTPActionTest(unittest.TestCase):
def setUp(self):
"""Call before every test case."""
unittest.F2B.SkipIfNoNetwork()
self.jail = DummyJail()
pythonModule = os.path.join(CONFIG_DIR, "action.d", "smtp.py")
pythonModuleName = os.path.basename(pythonModule.rstrip(".py"))
@ -64,13 +61,16 @@ class SMTPActionTest(unittest.TestCase):
self.action = customActionModule.Action(
self.jail, "test", host="127.0.0.1:%i" % port)
## because of bug in loop (see loop in asyncserver.py) use it's loop instead of asyncore.loop:
self._active = True
self._loop_thread = threading.Thread(
target=asyncore.loop, kwargs={'timeout': 1})
target=asyncserver.loop, kwargs={'active': lambda: self._active})
self._loop_thread.start()
def tearDown(self):
"""Call after every test case."""
self.smtpd.close()
self._active = False
self._loop_thread.join()
def testStart(self):

View File

@ -62,11 +62,14 @@ class TransmitterBase(unittest.TestCase):
def setUp(self):
"""Call before every test case."""
self.transm = self.server._Server__transm
self.tmp_files = []
sock_fd, sock_name = tempfile.mkstemp('fail2ban.sock', 'transmitter')
os.close(sock_fd)
self.tmp_files.append(sock_name)
pidfile_fd, pidfile_name = tempfile.mkstemp(
'fail2ban.pid', 'transmitter')
os.close(pidfile_fd)
self.tmp_files.append(pidfile_name)
self.server.start(sock_name, pidfile_name, force=False)
self.jailName = "TestJail1"
self.server.addJail(self.jailName, "auto")
@ -74,6 +77,9 @@ class TransmitterBase(unittest.TestCase):
def tearDown(self):
"""Call after every test case."""
self.server.quit()
for f in self.tmp_files:
if os.path.exists(f):
os.remove(f)
def setGetTest(self, cmd, inValue, outValue=(None,), outCode=0, jail=None, repr_=False):
"""Process set/get commands and compare both return values

View File

@ -55,6 +55,24 @@ class Socket(unittest.TestCase):
"""Test transmitter proceed method which just returns first arg"""
return message
def testStopPerCloseUnexpected(self):
# start in separate thread :
serverThread = threading.Thread(
target=self.server.start, args=(self.sock_name, False))
serverThread.daemon = True
serverThread.start()
self.assertTrue(Utils.wait_for(self.server.isActive, unittest.F2B.maxWaitTime(10)))
# unexpected stop directly after start:
self.server.close()
# wait for end of thread :
Utils.wait_for(lambda: not serverThread.isAlive()
or serverThread.join(Utils.DEFAULT_SLEEP_INTERVAL), unittest.F2B.maxWaitTime(10))
self.assertFalse(serverThread.isAlive())
# clean :
self.server.stop()
self.assertFalse(self.server.isActive())
self.assertFalse(os.path.exists(self.sock_name))
def _serverSocket(self):
try:
return CSocket(self.sock_name)
@ -66,6 +84,7 @@ class Socket(unittest.TestCase):
target=self.server.start, args=(self.sock_name, False))
serverThread.daemon = True
serverThread.start()
self.assertTrue(Utils.wait_for(self.server.isActive, unittest.F2B.maxWaitTime(10)))
time.sleep(Utils.DEFAULT_SLEEP_TIME)
client = Utils.wait_for(self._serverSocket, 2)
@ -78,7 +97,11 @@ class Socket(unittest.TestCase):
client.close()
self.server.stop()
serverThread.join(Utils.DEFAULT_SLEEP_TIME)
# wait for end of thread :
Utils.wait_for(lambda: not serverThread.isAlive()
or serverThread.join(Utils.DEFAULT_SLEEP_INTERVAL), unittest.F2B.maxWaitTime(10))
self.assertFalse(serverThread.isAlive())
self.assertFalse(self.server.isActive())
self.assertFalse(os.path.exists(self.sock_name))
def testSocketForce(self):
@ -92,10 +115,13 @@ class Socket(unittest.TestCase):
target=self.server.start, args=(self.sock_name, True))
serverThread.daemon = True
serverThread.start()
time.sleep(Utils.DEFAULT_SLEEP_TIME)
self.assertTrue(Utils.wait_for(self.server.isActive, unittest.F2B.maxWaitTime(10)))
self.server.stop()
serverThread.join(Utils.DEFAULT_SLEEP_TIME)
# wait for end of thread :
Utils.wait_for(lambda: not serverThread.isAlive()
or serverThread.join(Utils.DEFAULT_SLEEP_INTERVAL), unittest.F2B.maxWaitTime(10))
self.assertFalse(self.server.isActive())
self.assertFalse(os.path.exists(self.sock_name))

View File

@ -35,6 +35,8 @@ from ..helpers import getLogger
from ..server.filter import DNSUtils
from ..server.mytime import MyTime
from ..server.utils import Utils
# for action_d.test_smtp :
from ..server import asyncserver
logSys = getLogger(__name__)
@ -61,6 +63,10 @@ class F2B(optparse.Values):
pass
def SkipIfNoNetwork(self):
pass
def maxWaitTime(self,wtime):
if self.fast:
wtime = float(wtime) / 10
return wtime
def initTests(opts):
@ -87,6 +93,13 @@ def initTests(opts):
def F2B_SkipIfNoNetwork():
raise unittest.SkipTest('Skip test because of "--no-network"')
unittest.F2B.SkipIfNoNetwork = F2B_SkipIfNoNetwork
# precache all invalid ip's (TEST-NET-1, ..., TEST-NET-3 according to RFC 5737):
c = DNSUtils.CACHE_ipToName
for i in xrange(255):
c.set('192.0.2.%s' % i, None)
c.set('198.51.100.%s' % i, None)
c.set('203.0.113.%s' % i, None)
def mtimesleep():
# no sleep now should be necessary since polling tracks now not only