mirror of https://github.com/fail2ban/fail2ban
asyncserver (asyncore) code fixed and test cases repaired (always delete temp files, wait for end of thread/server, etc)
definitely closes gh-161, also other usage of asyncore event loop (in test_smtp.py) repair cache in ipToName (can returns None), precaching of invalid IPs (according to RFC 5737) to stop endless wait for resolving it in test cases.pull/1346/head
parent
770c219ab6
commit
72f29e9061
|
@ -27,12 +27,14 @@ __license__ = "GPL"
|
||||||
from pickle import dumps, loads, HIGHEST_PROTOCOL
|
from pickle import dumps, loads, HIGHEST_PROTOCOL
|
||||||
import asynchat
|
import asynchat
|
||||||
import asyncore
|
import asyncore
|
||||||
|
import errno
|
||||||
import fcntl
|
import fcntl
|
||||||
import os
|
import os
|
||||||
import socket
|
import socket
|
||||||
import sys
|
import sys
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
|
from .utils import Utils
|
||||||
from ..protocol import CSPROTO
|
from ..protocol import CSPROTO
|
||||||
from ..helpers import getLogger,formatExceptionInfo
|
from ..helpers import getLogger,formatExceptionInfo
|
||||||
|
|
||||||
|
@ -89,6 +91,29 @@ class RequestHandler(asynchat.async_chat):
|
||||||
self.close()
|
self.close()
|
||||||
|
|
||||||
|
|
||||||
|
def loop(active, timeout=None, use_poll=False):
|
||||||
|
# Use poll instead of loop, because of recognition of active flag,
|
||||||
|
# because of loop timeout mistake: different in poll and poll2 (sec vs ms),
|
||||||
|
# and to prevent sporadical errors like EBADF 'Bad file descriptor' etc. (see gh-161)
|
||||||
|
if timeout is None:
|
||||||
|
timeout = Utils.DEFAULT_SLEEP_TIME
|
||||||
|
poll = asyncore.poll
|
||||||
|
if use_poll and asyncore.poll2 and hasattr(asyncore.select, 'poll'): # pragma: no cover
|
||||||
|
logSys.debug('Server listener (select) uses poll')
|
||||||
|
# poll2 expected a timeout in milliseconds (but poll and loop in seconds):
|
||||||
|
timeout = float(timeout) / 1000
|
||||||
|
poll = asyncore.poll2
|
||||||
|
# Poll as long as active:
|
||||||
|
while active():
|
||||||
|
try:
|
||||||
|
poll(timeout)
|
||||||
|
except Exception as e: # pragma: no cover
|
||||||
|
if e.args[0] in (errno.ENOTCONN, errno.EBADF): # (errno.EBADF, 'Bad file descriptor')
|
||||||
|
logSys.info('Server connection was closed: %s', str(e))
|
||||||
|
else:
|
||||||
|
logSys.error('Server connection was closed: %s', str(e))
|
||||||
|
|
||||||
|
|
||||||
##
|
##
|
||||||
# Asynchronous server class.
|
# Asynchronous server class.
|
||||||
#
|
#
|
||||||
|
@ -102,6 +127,7 @@ class AsyncServer(asyncore.dispatcher):
|
||||||
self.__transmitter = transmitter
|
self.__transmitter = transmitter
|
||||||
self.__sock = "/var/run/fail2ban/fail2ban.sock"
|
self.__sock = "/var/run/fail2ban/fail2ban.sock"
|
||||||
self.__init = False
|
self.__init = False
|
||||||
|
self.__active = False
|
||||||
|
|
||||||
##
|
##
|
||||||
# Returns False as we only read the socket first.
|
# Returns False as we only read the socket first.
|
||||||
|
@ -129,7 +155,7 @@ class AsyncServer(asyncore.dispatcher):
|
||||||
# @param sock: socket file.
|
# @param sock: socket file.
|
||||||
# @param force: remove the socket file if exists.
|
# @param force: remove the socket file if exists.
|
||||||
|
|
||||||
def start(self, sock, force):
|
def start(self, sock, force, use_poll=False):
|
||||||
self.__sock = sock
|
self.__sock = sock
|
||||||
# Remove socket
|
# Remove socket
|
||||||
if os.path.exists(sock):
|
if os.path.exists(sock):
|
||||||
|
@ -149,28 +175,31 @@ class AsyncServer(asyncore.dispatcher):
|
||||||
AsyncServer.__markCloseOnExec(self.socket)
|
AsyncServer.__markCloseOnExec(self.socket)
|
||||||
self.listen(1)
|
self.listen(1)
|
||||||
# Sets the init flag.
|
# Sets the init flag.
|
||||||
self.__init = True
|
self.__init = self.__active = True
|
||||||
# TODO Add try..catch
|
# Event loop as long as active:
|
||||||
# There's a bug report for Python 2.6/3.0 that use_poll=True yields some 2.5 incompatibilities:
|
loop(lambda: self.__active)
|
||||||
if (sys.version_info >= (2, 7) and sys.version_info < (2, 8)) \
|
# Cleanup all
|
||||||
or (sys.version_info >= (3, 4)): # if python 2.7 ...
|
self.stop()
|
||||||
logSys.debug("Detected Python 2.7. asyncore.loop() using poll")
|
|
||||||
asyncore.loop(use_poll=True) # workaround for the "Bad file descriptor" issue on Python 2.7, gh-161
|
|
||||||
else:
|
def close(self):
|
||||||
asyncore.loop(use_poll=False) # fixes the "Unexpected communication problem" issue on Python 2.6 and 3.0
|
if self.__active:
|
||||||
|
asyncore.dispatcher.close(self)
|
||||||
|
# Remove socket (file) only if it was created:
|
||||||
|
if self.__init and os.path.exists(self.__sock):
|
||||||
|
logSys.debug("Removed socket file " + self.__sock)
|
||||||
|
os.remove(self.__sock)
|
||||||
|
logSys.debug("Socket shutdown")
|
||||||
|
self.__active = False
|
||||||
|
|
||||||
##
|
##
|
||||||
# Stops the communication server.
|
# Stops the communication server.
|
||||||
|
|
||||||
def stop(self):
|
def stop(self):
|
||||||
if self.__init:
|
|
||||||
# Only closes the socket if it was initialized first.
|
|
||||||
self.close()
|
self.close()
|
||||||
# Remove socket
|
|
||||||
if os.path.exists(self.__sock):
|
def isActive(self):
|
||||||
logSys.debug("Removed socket file " + self.__sock)
|
return self.__active
|
||||||
os.remove(self.__sock)
|
|
||||||
logSys.debug("Socket shutdown")
|
|
||||||
|
|
||||||
##
|
##
|
||||||
# Marks socket as close-on-exec to avoid leaking file descriptors when
|
# Marks socket as close-on-exec to avoid leaking file descriptors when
|
||||||
|
|
|
@ -1023,8 +1023,8 @@ class DNSUtils:
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def ipToName(ip):
|
def ipToName(ip):
|
||||||
# cache, also prevent long wait during retrieving of name for wrong addresses, lazy dns:
|
# cache, also prevent long wait during retrieving of name for wrong addresses, lazy dns:
|
||||||
v = DNSUtils.CACHE_ipToName.get(ip)
|
v = DNSUtils.CACHE_ipToName.get(ip, ())
|
||||||
if v is not None:
|
if v != ():
|
||||||
return v
|
return v
|
||||||
# retrieve name
|
# retrieve name
|
||||||
try:
|
try:
|
||||||
|
|
|
@ -19,7 +19,6 @@
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import smtpd
|
import smtpd
|
||||||
import asyncore
|
|
||||||
import threading
|
import threading
|
||||||
import unittest
|
import unittest
|
||||||
import sys
|
import sys
|
||||||
|
@ -30,7 +29,7 @@ else:
|
||||||
|
|
||||||
from ..dummyjail import DummyJail
|
from ..dummyjail import DummyJail
|
||||||
|
|
||||||
from ..utils import CONFIG_DIR
|
from ..utils import CONFIG_DIR, asyncserver
|
||||||
|
|
||||||
|
|
||||||
class TestSMTPServer(smtpd.SMTPServer):
|
class TestSMTPServer(smtpd.SMTPServer):
|
||||||
|
@ -46,8 +45,6 @@ class SMTPActionTest(unittest.TestCase):
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
"""Call before every test case."""
|
"""Call before every test case."""
|
||||||
unittest.F2B.SkipIfNoNetwork()
|
|
||||||
|
|
||||||
self.jail = DummyJail()
|
self.jail = DummyJail()
|
||||||
pythonModule = os.path.join(CONFIG_DIR, "action.d", "smtp.py")
|
pythonModule = os.path.join(CONFIG_DIR, "action.d", "smtp.py")
|
||||||
pythonModuleName = os.path.basename(pythonModule.rstrip(".py"))
|
pythonModuleName = os.path.basename(pythonModule.rstrip(".py"))
|
||||||
|
@ -64,13 +61,16 @@ class SMTPActionTest(unittest.TestCase):
|
||||||
self.action = customActionModule.Action(
|
self.action = customActionModule.Action(
|
||||||
self.jail, "test", host="127.0.0.1:%i" % port)
|
self.jail, "test", host="127.0.0.1:%i" % port)
|
||||||
|
|
||||||
|
## because of bug in loop (see loop in asyncserver.py) use it's loop instead of asyncore.loop:
|
||||||
|
self._active = True
|
||||||
self._loop_thread = threading.Thread(
|
self._loop_thread = threading.Thread(
|
||||||
target=asyncore.loop, kwargs={'timeout': 1})
|
target=asyncserver.loop, kwargs={'active': lambda: self._active})
|
||||||
self._loop_thread.start()
|
self._loop_thread.start()
|
||||||
|
|
||||||
def tearDown(self):
|
def tearDown(self):
|
||||||
"""Call after every test case."""
|
"""Call after every test case."""
|
||||||
self.smtpd.close()
|
self.smtpd.close()
|
||||||
|
self._active = False
|
||||||
self._loop_thread.join()
|
self._loop_thread.join()
|
||||||
|
|
||||||
def testStart(self):
|
def testStart(self):
|
||||||
|
|
|
@ -62,11 +62,14 @@ class TransmitterBase(unittest.TestCase):
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
"""Call before every test case."""
|
"""Call before every test case."""
|
||||||
self.transm = self.server._Server__transm
|
self.transm = self.server._Server__transm
|
||||||
|
self.tmp_files = []
|
||||||
sock_fd, sock_name = tempfile.mkstemp('fail2ban.sock', 'transmitter')
|
sock_fd, sock_name = tempfile.mkstemp('fail2ban.sock', 'transmitter')
|
||||||
os.close(sock_fd)
|
os.close(sock_fd)
|
||||||
|
self.tmp_files.append(sock_name)
|
||||||
pidfile_fd, pidfile_name = tempfile.mkstemp(
|
pidfile_fd, pidfile_name = tempfile.mkstemp(
|
||||||
'fail2ban.pid', 'transmitter')
|
'fail2ban.pid', 'transmitter')
|
||||||
os.close(pidfile_fd)
|
os.close(pidfile_fd)
|
||||||
|
self.tmp_files.append(pidfile_name)
|
||||||
self.server.start(sock_name, pidfile_name, force=False)
|
self.server.start(sock_name, pidfile_name, force=False)
|
||||||
self.jailName = "TestJail1"
|
self.jailName = "TestJail1"
|
||||||
self.server.addJail(self.jailName, "auto")
|
self.server.addJail(self.jailName, "auto")
|
||||||
|
@ -74,6 +77,9 @@ class TransmitterBase(unittest.TestCase):
|
||||||
def tearDown(self):
|
def tearDown(self):
|
||||||
"""Call after every test case."""
|
"""Call after every test case."""
|
||||||
self.server.quit()
|
self.server.quit()
|
||||||
|
for f in self.tmp_files:
|
||||||
|
if os.path.exists(f):
|
||||||
|
os.remove(f)
|
||||||
|
|
||||||
def setGetTest(self, cmd, inValue, outValue=(None,), outCode=0, jail=None, repr_=False):
|
def setGetTest(self, cmd, inValue, outValue=(None,), outCode=0, jail=None, repr_=False):
|
||||||
"""Process set/get commands and compare both return values
|
"""Process set/get commands and compare both return values
|
||||||
|
|
|
@ -55,6 +55,24 @@ class Socket(unittest.TestCase):
|
||||||
"""Test transmitter proceed method which just returns first arg"""
|
"""Test transmitter proceed method which just returns first arg"""
|
||||||
return message
|
return message
|
||||||
|
|
||||||
|
def testStopPerCloseUnexpected(self):
|
||||||
|
# start in separate thread :
|
||||||
|
serverThread = threading.Thread(
|
||||||
|
target=self.server.start, args=(self.sock_name, False))
|
||||||
|
serverThread.daemon = True
|
||||||
|
serverThread.start()
|
||||||
|
self.assertTrue(Utils.wait_for(self.server.isActive, unittest.F2B.maxWaitTime(10)))
|
||||||
|
# unexpected stop directly after start:
|
||||||
|
self.server.close()
|
||||||
|
# wait for end of thread :
|
||||||
|
Utils.wait_for(lambda: not serverThread.isAlive()
|
||||||
|
or serverThread.join(Utils.DEFAULT_SLEEP_INTERVAL), unittest.F2B.maxWaitTime(10))
|
||||||
|
self.assertFalse(serverThread.isAlive())
|
||||||
|
# clean :
|
||||||
|
self.server.stop()
|
||||||
|
self.assertFalse(self.server.isActive())
|
||||||
|
self.assertFalse(os.path.exists(self.sock_name))
|
||||||
|
|
||||||
def _serverSocket(self):
|
def _serverSocket(self):
|
||||||
try:
|
try:
|
||||||
return CSocket(self.sock_name)
|
return CSocket(self.sock_name)
|
||||||
|
@ -66,6 +84,7 @@ class Socket(unittest.TestCase):
|
||||||
target=self.server.start, args=(self.sock_name, False))
|
target=self.server.start, args=(self.sock_name, False))
|
||||||
serverThread.daemon = True
|
serverThread.daemon = True
|
||||||
serverThread.start()
|
serverThread.start()
|
||||||
|
self.assertTrue(Utils.wait_for(self.server.isActive, unittest.F2B.maxWaitTime(10)))
|
||||||
time.sleep(Utils.DEFAULT_SLEEP_TIME)
|
time.sleep(Utils.DEFAULT_SLEEP_TIME)
|
||||||
|
|
||||||
client = Utils.wait_for(self._serverSocket, 2)
|
client = Utils.wait_for(self._serverSocket, 2)
|
||||||
|
@ -78,7 +97,11 @@ class Socket(unittest.TestCase):
|
||||||
client.close()
|
client.close()
|
||||||
|
|
||||||
self.server.stop()
|
self.server.stop()
|
||||||
serverThread.join(Utils.DEFAULT_SLEEP_TIME)
|
# wait for end of thread :
|
||||||
|
Utils.wait_for(lambda: not serverThread.isAlive()
|
||||||
|
or serverThread.join(Utils.DEFAULT_SLEEP_INTERVAL), unittest.F2B.maxWaitTime(10))
|
||||||
|
self.assertFalse(serverThread.isAlive())
|
||||||
|
self.assertFalse(self.server.isActive())
|
||||||
self.assertFalse(os.path.exists(self.sock_name))
|
self.assertFalse(os.path.exists(self.sock_name))
|
||||||
|
|
||||||
def testSocketForce(self):
|
def testSocketForce(self):
|
||||||
|
@ -92,10 +115,13 @@ class Socket(unittest.TestCase):
|
||||||
target=self.server.start, args=(self.sock_name, True))
|
target=self.server.start, args=(self.sock_name, True))
|
||||||
serverThread.daemon = True
|
serverThread.daemon = True
|
||||||
serverThread.start()
|
serverThread.start()
|
||||||
time.sleep(Utils.DEFAULT_SLEEP_TIME)
|
self.assertTrue(Utils.wait_for(self.server.isActive, unittest.F2B.maxWaitTime(10)))
|
||||||
|
|
||||||
self.server.stop()
|
self.server.stop()
|
||||||
serverThread.join(Utils.DEFAULT_SLEEP_TIME)
|
# wait for end of thread :
|
||||||
|
Utils.wait_for(lambda: not serverThread.isAlive()
|
||||||
|
or serverThread.join(Utils.DEFAULT_SLEEP_INTERVAL), unittest.F2B.maxWaitTime(10))
|
||||||
|
self.assertFalse(self.server.isActive())
|
||||||
self.assertFalse(os.path.exists(self.sock_name))
|
self.assertFalse(os.path.exists(self.sock_name))
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -35,6 +35,8 @@ from ..helpers import getLogger
|
||||||
from ..server.filter import DNSUtils
|
from ..server.filter import DNSUtils
|
||||||
from ..server.mytime import MyTime
|
from ..server.mytime import MyTime
|
||||||
from ..server.utils import Utils
|
from ..server.utils import Utils
|
||||||
|
# for action_d.test_smtp :
|
||||||
|
from ..server import asyncserver
|
||||||
|
|
||||||
|
|
||||||
logSys = getLogger(__name__)
|
logSys = getLogger(__name__)
|
||||||
|
@ -61,6 +63,10 @@ class F2B(optparse.Values):
|
||||||
pass
|
pass
|
||||||
def SkipIfNoNetwork(self):
|
def SkipIfNoNetwork(self):
|
||||||
pass
|
pass
|
||||||
|
def maxWaitTime(self,wtime):
|
||||||
|
if self.fast:
|
||||||
|
wtime = float(wtime) / 10
|
||||||
|
return wtime
|
||||||
|
|
||||||
|
|
||||||
def initTests(opts):
|
def initTests(opts):
|
||||||
|
@ -87,6 +93,13 @@ def initTests(opts):
|
||||||
def F2B_SkipIfNoNetwork():
|
def F2B_SkipIfNoNetwork():
|
||||||
raise unittest.SkipTest('Skip test because of "--no-network"')
|
raise unittest.SkipTest('Skip test because of "--no-network"')
|
||||||
unittest.F2B.SkipIfNoNetwork = F2B_SkipIfNoNetwork
|
unittest.F2B.SkipIfNoNetwork = F2B_SkipIfNoNetwork
|
||||||
|
# precache all invalid ip's (TEST-NET-1, ..., TEST-NET-3 according to RFC 5737):
|
||||||
|
c = DNSUtils.CACHE_ipToName
|
||||||
|
for i in xrange(255):
|
||||||
|
c.set('192.0.2.%s' % i, None)
|
||||||
|
c.set('198.51.100.%s' % i, None)
|
||||||
|
c.set('203.0.113.%s' % i, None)
|
||||||
|
|
||||||
|
|
||||||
def mtimesleep():
|
def mtimesleep():
|
||||||
# no sleep now should be necessary since polling tracks now not only
|
# no sleep now should be necessary since polling tracks now not only
|
||||||
|
|
Loading…
Reference in New Issue