diff --git a/fail2ban/server/asyncserver.py b/fail2ban/server/asyncserver.py index a9be0ae2..7322e1e1 100644 --- a/fail2ban/server/asyncserver.py +++ b/fail2ban/server/asyncserver.py @@ -27,12 +27,14 @@ __license__ = "GPL" from pickle import dumps, loads, HIGHEST_PROTOCOL import asynchat import asyncore +import errno import fcntl import os import socket import sys import traceback +from .utils import Utils from ..protocol import CSPROTO from ..helpers import getLogger,formatExceptionInfo @@ -89,6 +91,29 @@ class RequestHandler(asynchat.async_chat): self.close() +def loop(active, timeout=None, use_poll=False): + # Use poll instead of loop, because of recognition of active flag, + # because of loop timeout mistake: different in poll and poll2 (sec vs ms), + # and to prevent sporadical errors like EBADF 'Bad file descriptor' etc. (see gh-161) + if timeout is None: + timeout = Utils.DEFAULT_SLEEP_INTERVAL + poll = asyncore.poll + if use_poll and asyncore.poll2 and hasattr(asyncore.select, 'poll'): # pragma: no cover + logSys.debug('Server listener (select) uses poll') + # poll2 expected a timeout in milliseconds (but poll and loop in seconds): + timeout = float(timeout) / 1000 + poll = asyncore.poll2 + # Poll as long as active: + while active(): + try: + poll() + except Exception as e: # pragma: no cover + if e.args[0] in (errno.ENOTCONN, errno.EBADF): # (errno.EBADF, 'Bad file descriptor') + logSys.info('Server connection was closed: %s', str(e)) + else: + logSys.error('Server connection was closed: %s', str(e)) + + ## # Asynchronous server class. # @@ -102,6 +127,7 @@ class AsyncServer(asyncore.dispatcher): self.__transmitter = transmitter self.__sock = "/var/run/fail2ban/fail2ban.sock" self.__init = False + self.__active = False ## # Returns False as we only read the socket first. @@ -129,7 +155,7 @@ class AsyncServer(asyncore.dispatcher): # @param sock: socket file. # @param force: remove the socket file if exists. - def start(self, sock, force): + def start(self, sock, force, use_poll=False): self.__sock = sock # Remove socket if os.path.exists(sock): @@ -149,28 +175,31 @@ class AsyncServer(asyncore.dispatcher): AsyncServer.__markCloseOnExec(self.socket) self.listen(1) # Sets the init flag. - self.__init = True - # TODO Add try..catch - # There's a bug report for Python 2.6/3.0 that use_poll=True yields some 2.5 incompatibilities: - if (sys.version_info >= (2, 7) and sys.version_info < (2, 8)) \ - or (sys.version_info >= (3, 4)): # if python 2.7 ... - logSys.debug("Detected Python 2.7. asyncore.loop() using poll") - asyncore.loop(use_poll=True) # workaround for the "Bad file descriptor" issue on Python 2.7, gh-161 - else: - asyncore.loop(use_poll=False) # fixes the "Unexpected communication problem" issue on Python 2.6 and 3.0 - + self.__init = self.__active = True + # Event loop as long as active: + loop(lambda: self.__active) + # Cleanup all + self.stop() + + + def close(self): + if self.__active: + asyncore.dispatcher.close(self) + # Remove socket (file) only if it was created: + if self.__init and os.path.exists(self.__sock): + logSys.debug("Removed socket file " + self.__sock) + os.remove(self.__sock) + logSys.debug("Socket shutdown") + self.__active = False + ## # Stops the communication server. def stop(self): - if self.__init: - # Only closes the socket if it was initialized first. - self.close() - # Remove socket - if os.path.exists(self.__sock): - logSys.debug("Removed socket file " + self.__sock) - os.remove(self.__sock) - logSys.debug("Socket shutdown") + self.close() + + def isActive(self): + return self.__active ## # Marks socket as close-on-exec to avoid leaking file descriptors when diff --git a/fail2ban/server/filter.py b/fail2ban/server/filter.py index a080e0c1..32c60199 100644 --- a/fail2ban/server/filter.py +++ b/fail2ban/server/filter.py @@ -1007,8 +1007,8 @@ class DNSUtils: @staticmethod def ipToName(ip): # cache, also prevent long wait during retrieving of name for wrong addresses, lazy dns: - v = DNSUtils.CACHE_ipToName.get(ip) - if v is not None: + v = DNSUtils.CACHE_ipToName.get(ip, ()) + if v != (): return v # retrieve name try: diff --git a/fail2ban/tests/action_d/test_smtp.py b/fail2ban/tests/action_d/test_smtp.py index 27442832..37fe0138 100644 --- a/fail2ban/tests/action_d/test_smtp.py +++ b/fail2ban/tests/action_d/test_smtp.py @@ -19,7 +19,6 @@ import os import smtpd -import asyncore import threading import unittest import sys @@ -30,7 +29,7 @@ else: from ..dummyjail import DummyJail -from ..utils import CONFIG_DIR +from ..utils import CONFIG_DIR, asyncserver class TestSMTPServer(smtpd.SMTPServer): @@ -46,8 +45,6 @@ class SMTPActionTest(unittest.TestCase): def setUp(self): """Call before every test case.""" - unittest.F2B.SkipIfNoNetwork() - self.jail = DummyJail() pythonModule = os.path.join(CONFIG_DIR, "action.d", "smtp.py") pythonModuleName = os.path.basename(pythonModule.rstrip(".py")) @@ -64,13 +61,16 @@ class SMTPActionTest(unittest.TestCase): self.action = customActionModule.Action( self.jail, "test", host="127.0.0.1:%i" % port) + ## because of bug in loop (see loop in asyncserver.py) use it's loop instead of asyncore.loop: + self._active = True self._loop_thread = threading.Thread( - target=asyncore.loop, kwargs={'timeout': 1}) + target=asyncserver.loop, kwargs={'active': lambda: self._active}) self._loop_thread.start() def tearDown(self): """Call after every test case.""" self.smtpd.close() + self._active = False self._loop_thread.join() def testStart(self): diff --git a/fail2ban/tests/servertestcase.py b/fail2ban/tests/servertestcase.py index 1029bda5..133de80a 100644 --- a/fail2ban/tests/servertestcase.py +++ b/fail2ban/tests/servertestcase.py @@ -62,11 +62,14 @@ class TransmitterBase(unittest.TestCase): def setUp(self): """Call before every test case.""" self.transm = self.server._Server__transm + self.tmp_files = [] sock_fd, sock_name = tempfile.mkstemp('fail2ban.sock', 'transmitter') os.close(sock_fd) + self.tmp_files.append(sock_name) pidfile_fd, pidfile_name = tempfile.mkstemp( 'fail2ban.pid', 'transmitter') os.close(pidfile_fd) + self.tmp_files.append(pidfile_name) self.server.start(sock_name, pidfile_name, force=False) self.jailName = "TestJail1" self.server.addJail(self.jailName, "auto") @@ -74,6 +77,9 @@ class TransmitterBase(unittest.TestCase): def tearDown(self): """Call after every test case.""" self.server.quit() + for f in self.tmp_files: + if os.path.exists(f): + os.remove(f) def setGetTest(self, cmd, inValue, outValue=(None,), outCode=0, jail=None, repr_=False): """Process set/get commands and compare both return values diff --git a/fail2ban/tests/sockettestcase.py b/fail2ban/tests/sockettestcase.py index a9408fde..5bf0be57 100644 --- a/fail2ban/tests/sockettestcase.py +++ b/fail2ban/tests/sockettestcase.py @@ -55,6 +55,24 @@ class Socket(unittest.TestCase): """Test transmitter proceed method which just returns first arg""" return message + def testStopPerCloseUnexpected(self): + # start in separate thread : + serverThread = threading.Thread( + target=self.server.start, args=(self.sock_name, False)) + serverThread.daemon = True + serverThread.start() + self.assertTrue(Utils.wait_for(self.server.isActive, unittest.F2B.maxWaitTime(10))) + # unexpected stop directly after start: + self.server.close() + # wait for end of thread : + Utils.wait_for(lambda: not serverThread.isAlive() + or serverThread.join(Utils.DEFAULT_SLEEP_INTERVAL), unittest.F2B.maxWaitTime(10)) + self.assertFalse(serverThread.isAlive()) + # clean : + self.server.stop() + self.assertFalse(self.server.isActive()) + self.assertFalse(os.path.exists(self.sock_name)) + def _serverSocket(self): try: return CSocket(self.sock_name) @@ -66,6 +84,7 @@ class Socket(unittest.TestCase): target=self.server.start, args=(self.sock_name, False)) serverThread.daemon = True serverThread.start() + self.assertTrue(Utils.wait_for(self.server.isActive, unittest.F2B.maxWaitTime(10))) time.sleep(Utils.DEFAULT_SLEEP_TIME) client = Utils.wait_for(self._serverSocket, 2) @@ -78,7 +97,11 @@ class Socket(unittest.TestCase): client.close() self.server.stop() - serverThread.join(Utils.DEFAULT_SLEEP_TIME) + # wait for end of thread : + Utils.wait_for(lambda: not serverThread.isAlive() + or serverThread.join(Utils.DEFAULT_SLEEP_INTERVAL), unittest.F2B.maxWaitTime(10)) + self.assertFalse(serverThread.isAlive()) + self.assertFalse(self.server.isActive()) self.assertFalse(os.path.exists(self.sock_name)) def testSocketForce(self): @@ -92,10 +115,13 @@ class Socket(unittest.TestCase): target=self.server.start, args=(self.sock_name, True)) serverThread.daemon = True serverThread.start() - time.sleep(Utils.DEFAULT_SLEEP_TIME) + self.assertTrue(Utils.wait_for(self.server.isActive, unittest.F2B.maxWaitTime(10))) self.server.stop() - serverThread.join(Utils.DEFAULT_SLEEP_TIME) + # wait for end of thread : + Utils.wait_for(lambda: not serverThread.isAlive() + or serverThread.join(Utils.DEFAULT_SLEEP_INTERVAL), unittest.F2B.maxWaitTime(10)) + self.assertFalse(self.server.isActive()) self.assertFalse(os.path.exists(self.sock_name)) diff --git a/fail2ban/tests/utils.py b/fail2ban/tests/utils.py index 685f51c5..bb9e5fa1 100644 --- a/fail2ban/tests/utils.py +++ b/fail2ban/tests/utils.py @@ -35,6 +35,8 @@ from ..helpers import getLogger from ..server.filter import DNSUtils from ..server.mytime import MyTime from ..server.utils import Utils +# for action_d.test_smtp : +from ..server import asyncserver logSys = getLogger(__name__) @@ -61,6 +63,10 @@ class F2B(optparse.Values): pass def SkipIfNoNetwork(self): pass + def maxWaitTime(self,wtime): + if self.fast: + wtime = float(wtime) / 10 + return wtime def initTests(opts): @@ -87,6 +93,13 @@ def initTests(opts): def F2B_SkipIfNoNetwork(): raise unittest.SkipTest('Skip test because of "--no-network"') unittest.F2B.SkipIfNoNetwork = F2B_SkipIfNoNetwork + # precache all invalid ip's (TEST-NET-1, ..., TEST-NET-3 according to RFC 5737): + c = DNSUtils.CACHE_ipToName + for i in xrange(255): + c.set('192.0.2.%s' % i, None) + c.set('198.51.100.%s' % i, None) + c.set('203.0.113.%s' % i, None) + def mtimesleep(): # no sleep now should be necessary since polling tracks now not only