mirror of https://github.com/fail2ban/fail2ban
asyncserver (asyncore) code fixed and test cases repaired (always delete temp files, wait for end of thread/server, etc)
definitely closes gh-161, also other usage of asyncore event loop (in test_smtp.py) repair cache in ipToName (can returns None), precaching of invalid IPs (according to RFC 5737) to stop endless wait for resolving it in test cases.pull/1346/head
parent
770c219ab6
commit
72f29e9061
|
@ -27,12 +27,14 @@ __license__ = "GPL"
|
|||
from pickle import dumps, loads, HIGHEST_PROTOCOL
|
||||
import asynchat
|
||||
import asyncore
|
||||
import errno
|
||||
import fcntl
|
||||
import os
|
||||
import socket
|
||||
import sys
|
||||
import traceback
|
||||
|
||||
from .utils import Utils
|
||||
from ..protocol import CSPROTO
|
||||
from ..helpers import getLogger,formatExceptionInfo
|
||||
|
||||
|
@ -89,6 +91,29 @@ class RequestHandler(asynchat.async_chat):
|
|||
self.close()
|
||||
|
||||
|
||||
def loop(active, timeout=None, use_poll=False):
|
||||
# Use poll instead of loop, because of recognition of active flag,
|
||||
# because of loop timeout mistake: different in poll and poll2 (sec vs ms),
|
||||
# and to prevent sporadical errors like EBADF 'Bad file descriptor' etc. (see gh-161)
|
||||
if timeout is None:
|
||||
timeout = Utils.DEFAULT_SLEEP_TIME
|
||||
poll = asyncore.poll
|
||||
if use_poll and asyncore.poll2 and hasattr(asyncore.select, 'poll'): # pragma: no cover
|
||||
logSys.debug('Server listener (select) uses poll')
|
||||
# poll2 expected a timeout in milliseconds (but poll and loop in seconds):
|
||||
timeout = float(timeout) / 1000
|
||||
poll = asyncore.poll2
|
||||
# Poll as long as active:
|
||||
while active():
|
||||
try:
|
||||
poll(timeout)
|
||||
except Exception as e: # pragma: no cover
|
||||
if e.args[0] in (errno.ENOTCONN, errno.EBADF): # (errno.EBADF, 'Bad file descriptor')
|
||||
logSys.info('Server connection was closed: %s', str(e))
|
||||
else:
|
||||
logSys.error('Server connection was closed: %s', str(e))
|
||||
|
||||
|
||||
##
|
||||
# Asynchronous server class.
|
||||
#
|
||||
|
@ -102,6 +127,7 @@ class AsyncServer(asyncore.dispatcher):
|
|||
self.__transmitter = transmitter
|
||||
self.__sock = "/var/run/fail2ban/fail2ban.sock"
|
||||
self.__init = False
|
||||
self.__active = False
|
||||
|
||||
##
|
||||
# Returns False as we only read the socket first.
|
||||
|
@ -129,7 +155,7 @@ class AsyncServer(asyncore.dispatcher):
|
|||
# @param sock: socket file.
|
||||
# @param force: remove the socket file if exists.
|
||||
|
||||
def start(self, sock, force):
|
||||
def start(self, sock, force, use_poll=False):
|
||||
self.__sock = sock
|
||||
# Remove socket
|
||||
if os.path.exists(sock):
|
||||
|
@ -149,28 +175,31 @@ class AsyncServer(asyncore.dispatcher):
|
|||
AsyncServer.__markCloseOnExec(self.socket)
|
||||
self.listen(1)
|
||||
# Sets the init flag.
|
||||
self.__init = True
|
||||
# TODO Add try..catch
|
||||
# There's a bug report for Python 2.6/3.0 that use_poll=True yields some 2.5 incompatibilities:
|
||||
if (sys.version_info >= (2, 7) and sys.version_info < (2, 8)) \
|
||||
or (sys.version_info >= (3, 4)): # if python 2.7 ...
|
||||
logSys.debug("Detected Python 2.7. asyncore.loop() using poll")
|
||||
asyncore.loop(use_poll=True) # workaround for the "Bad file descriptor" issue on Python 2.7, gh-161
|
||||
else:
|
||||
asyncore.loop(use_poll=False) # fixes the "Unexpected communication problem" issue on Python 2.6 and 3.0
|
||||
|
||||
self.__init = self.__active = True
|
||||
# Event loop as long as active:
|
||||
loop(lambda: self.__active)
|
||||
# Cleanup all
|
||||
self.stop()
|
||||
|
||||
|
||||
def close(self):
|
||||
if self.__active:
|
||||
asyncore.dispatcher.close(self)
|
||||
# Remove socket (file) only if it was created:
|
||||
if self.__init and os.path.exists(self.__sock):
|
||||
logSys.debug("Removed socket file " + self.__sock)
|
||||
os.remove(self.__sock)
|
||||
logSys.debug("Socket shutdown")
|
||||
self.__active = False
|
||||
|
||||
##
|
||||
# Stops the communication server.
|
||||
|
||||
def stop(self):
|
||||
if self.__init:
|
||||
# Only closes the socket if it was initialized first.
|
||||
self.close()
|
||||
# Remove socket
|
||||
if os.path.exists(self.__sock):
|
||||
logSys.debug("Removed socket file " + self.__sock)
|
||||
os.remove(self.__sock)
|
||||
logSys.debug("Socket shutdown")
|
||||
self.close()
|
||||
|
||||
def isActive(self):
|
||||
return self.__active
|
||||
|
||||
##
|
||||
# Marks socket as close-on-exec to avoid leaking file descriptors when
|
||||
|
|
|
@ -1023,8 +1023,8 @@ class DNSUtils:
|
|||
@staticmethod
|
||||
def ipToName(ip):
|
||||
# cache, also prevent long wait during retrieving of name for wrong addresses, lazy dns:
|
||||
v = DNSUtils.CACHE_ipToName.get(ip)
|
||||
if v is not None:
|
||||
v = DNSUtils.CACHE_ipToName.get(ip, ())
|
||||
if v != ():
|
||||
return v
|
||||
# retrieve name
|
||||
try:
|
||||
|
|
|
@ -19,7 +19,6 @@
|
|||
|
||||
import os
|
||||
import smtpd
|
||||
import asyncore
|
||||
import threading
|
||||
import unittest
|
||||
import sys
|
||||
|
@ -30,7 +29,7 @@ else:
|
|||
|
||||
from ..dummyjail import DummyJail
|
||||
|
||||
from ..utils import CONFIG_DIR
|
||||
from ..utils import CONFIG_DIR, asyncserver
|
||||
|
||||
|
||||
class TestSMTPServer(smtpd.SMTPServer):
|
||||
|
@ -46,8 +45,6 @@ class SMTPActionTest(unittest.TestCase):
|
|||
|
||||
def setUp(self):
|
||||
"""Call before every test case."""
|
||||
unittest.F2B.SkipIfNoNetwork()
|
||||
|
||||
self.jail = DummyJail()
|
||||
pythonModule = os.path.join(CONFIG_DIR, "action.d", "smtp.py")
|
||||
pythonModuleName = os.path.basename(pythonModule.rstrip(".py"))
|
||||
|
@ -64,13 +61,16 @@ class SMTPActionTest(unittest.TestCase):
|
|||
self.action = customActionModule.Action(
|
||||
self.jail, "test", host="127.0.0.1:%i" % port)
|
||||
|
||||
## because of bug in loop (see loop in asyncserver.py) use it's loop instead of asyncore.loop:
|
||||
self._active = True
|
||||
self._loop_thread = threading.Thread(
|
||||
target=asyncore.loop, kwargs={'timeout': 1})
|
||||
target=asyncserver.loop, kwargs={'active': lambda: self._active})
|
||||
self._loop_thread.start()
|
||||
|
||||
def tearDown(self):
|
||||
"""Call after every test case."""
|
||||
self.smtpd.close()
|
||||
self._active = False
|
||||
self._loop_thread.join()
|
||||
|
||||
def testStart(self):
|
||||
|
|
|
@ -62,11 +62,14 @@ class TransmitterBase(unittest.TestCase):
|
|||
def setUp(self):
|
||||
"""Call before every test case."""
|
||||
self.transm = self.server._Server__transm
|
||||
self.tmp_files = []
|
||||
sock_fd, sock_name = tempfile.mkstemp('fail2ban.sock', 'transmitter')
|
||||
os.close(sock_fd)
|
||||
self.tmp_files.append(sock_name)
|
||||
pidfile_fd, pidfile_name = tempfile.mkstemp(
|
||||
'fail2ban.pid', 'transmitter')
|
||||
os.close(pidfile_fd)
|
||||
self.tmp_files.append(pidfile_name)
|
||||
self.server.start(sock_name, pidfile_name, force=False)
|
||||
self.jailName = "TestJail1"
|
||||
self.server.addJail(self.jailName, "auto")
|
||||
|
@ -74,6 +77,9 @@ class TransmitterBase(unittest.TestCase):
|
|||
def tearDown(self):
|
||||
"""Call after every test case."""
|
||||
self.server.quit()
|
||||
for f in self.tmp_files:
|
||||
if os.path.exists(f):
|
||||
os.remove(f)
|
||||
|
||||
def setGetTest(self, cmd, inValue, outValue=(None,), outCode=0, jail=None, repr_=False):
|
||||
"""Process set/get commands and compare both return values
|
||||
|
|
|
@ -55,6 +55,24 @@ class Socket(unittest.TestCase):
|
|||
"""Test transmitter proceed method which just returns first arg"""
|
||||
return message
|
||||
|
||||
def testStopPerCloseUnexpected(self):
|
||||
# start in separate thread :
|
||||
serverThread = threading.Thread(
|
||||
target=self.server.start, args=(self.sock_name, False))
|
||||
serverThread.daemon = True
|
||||
serverThread.start()
|
||||
self.assertTrue(Utils.wait_for(self.server.isActive, unittest.F2B.maxWaitTime(10)))
|
||||
# unexpected stop directly after start:
|
||||
self.server.close()
|
||||
# wait for end of thread :
|
||||
Utils.wait_for(lambda: not serverThread.isAlive()
|
||||
or serverThread.join(Utils.DEFAULT_SLEEP_INTERVAL), unittest.F2B.maxWaitTime(10))
|
||||
self.assertFalse(serverThread.isAlive())
|
||||
# clean :
|
||||
self.server.stop()
|
||||
self.assertFalse(self.server.isActive())
|
||||
self.assertFalse(os.path.exists(self.sock_name))
|
||||
|
||||
def _serverSocket(self):
|
||||
try:
|
||||
return CSocket(self.sock_name)
|
||||
|
@ -66,6 +84,7 @@ class Socket(unittest.TestCase):
|
|||
target=self.server.start, args=(self.sock_name, False))
|
||||
serverThread.daemon = True
|
||||
serverThread.start()
|
||||
self.assertTrue(Utils.wait_for(self.server.isActive, unittest.F2B.maxWaitTime(10)))
|
||||
time.sleep(Utils.DEFAULT_SLEEP_TIME)
|
||||
|
||||
client = Utils.wait_for(self._serverSocket, 2)
|
||||
|
@ -78,7 +97,11 @@ class Socket(unittest.TestCase):
|
|||
client.close()
|
||||
|
||||
self.server.stop()
|
||||
serverThread.join(Utils.DEFAULT_SLEEP_TIME)
|
||||
# wait for end of thread :
|
||||
Utils.wait_for(lambda: not serverThread.isAlive()
|
||||
or serverThread.join(Utils.DEFAULT_SLEEP_INTERVAL), unittest.F2B.maxWaitTime(10))
|
||||
self.assertFalse(serverThread.isAlive())
|
||||
self.assertFalse(self.server.isActive())
|
||||
self.assertFalse(os.path.exists(self.sock_name))
|
||||
|
||||
def testSocketForce(self):
|
||||
|
@ -92,10 +115,13 @@ class Socket(unittest.TestCase):
|
|||
target=self.server.start, args=(self.sock_name, True))
|
||||
serverThread.daemon = True
|
||||
serverThread.start()
|
||||
time.sleep(Utils.DEFAULT_SLEEP_TIME)
|
||||
self.assertTrue(Utils.wait_for(self.server.isActive, unittest.F2B.maxWaitTime(10)))
|
||||
|
||||
self.server.stop()
|
||||
serverThread.join(Utils.DEFAULT_SLEEP_TIME)
|
||||
# wait for end of thread :
|
||||
Utils.wait_for(lambda: not serverThread.isAlive()
|
||||
or serverThread.join(Utils.DEFAULT_SLEEP_INTERVAL), unittest.F2B.maxWaitTime(10))
|
||||
self.assertFalse(self.server.isActive())
|
||||
self.assertFalse(os.path.exists(self.sock_name))
|
||||
|
||||
|
||||
|
|
|
@ -35,6 +35,8 @@ from ..helpers import getLogger
|
|||
from ..server.filter import DNSUtils
|
||||
from ..server.mytime import MyTime
|
||||
from ..server.utils import Utils
|
||||
# for action_d.test_smtp :
|
||||
from ..server import asyncserver
|
||||
|
||||
|
||||
logSys = getLogger(__name__)
|
||||
|
@ -61,6 +63,10 @@ class F2B(optparse.Values):
|
|||
pass
|
||||
def SkipIfNoNetwork(self):
|
||||
pass
|
||||
def maxWaitTime(self,wtime):
|
||||
if self.fast:
|
||||
wtime = float(wtime) / 10
|
||||
return wtime
|
||||
|
||||
|
||||
def initTests(opts):
|
||||
|
@ -87,6 +93,13 @@ def initTests(opts):
|
|||
def F2B_SkipIfNoNetwork():
|
||||
raise unittest.SkipTest('Skip test because of "--no-network"')
|
||||
unittest.F2B.SkipIfNoNetwork = F2B_SkipIfNoNetwork
|
||||
# precache all invalid ip's (TEST-NET-1, ..., TEST-NET-3 according to RFC 5737):
|
||||
c = DNSUtils.CACHE_ipToName
|
||||
for i in xrange(255):
|
||||
c.set('192.0.2.%s' % i, None)
|
||||
c.set('198.51.100.%s' % i, None)
|
||||
c.set('203.0.113.%s' % i, None)
|
||||
|
||||
|
||||
def mtimesleep():
|
||||
# no sleep now should be necessary since polling tracks now not only
|
||||
|
|
Loading…
Reference in New Issue