mirror of https://github.com/fail2ban/fail2ban
better (sane) stop server handling, AsyncServer.stop_communication back-ported to 0.10 (cherry-picked from 0.11);
parent
aa9cefc3f8
commit
bf6667d4da
|
@ -37,7 +37,7 @@ import traceback
|
||||||
|
|
||||||
from .utils import Utils
|
from .utils import Utils
|
||||||
from ..protocol import CSPROTO
|
from ..protocol import CSPROTO
|
||||||
from ..helpers import getLogger,formatExceptionInfo
|
from ..helpers import logging, getLogger, formatExceptionInfo
|
||||||
|
|
||||||
# Gets the instance of the logger.
|
# Gets the instance of the logger.
|
||||||
logSys = getLogger(__name__)
|
logSys = getLogger(__name__)
|
||||||
|
@ -80,22 +80,36 @@ class RequestHandler(asynchat.async_chat):
|
||||||
# Deserialize
|
# Deserialize
|
||||||
message = loads(message)
|
message = loads(message)
|
||||||
# Gives the message to the transmitter.
|
# Gives the message to the transmitter.
|
||||||
message = self.__transmitter.proceed(message)
|
if self.__transmitter:
|
||||||
|
message = self.__transmitter.proceed(message)
|
||||||
|
else:
|
||||||
|
message = ['SHUTDOWN']
|
||||||
# Serializes the response.
|
# Serializes the response.
|
||||||
message = dumps(message, HIGHEST_PROTOCOL)
|
message = dumps(message, HIGHEST_PROTOCOL)
|
||||||
# Sends the response to the client.
|
# Sends the response to the client.
|
||||||
self.push(message + CSPROTO.END)
|
self.push(message + CSPROTO.END)
|
||||||
except Exception as e: # pragma: no cover
|
except Exception as e:
|
||||||
logSys.error("Caught unhandled exception: %r", e,
|
logSys.error("Caught unhandled exception: %r", e,
|
||||||
exc_info=logSys.getEffectiveLevel()<=logging.DEBUG)
|
exc_info=logSys.getEffectiveLevel()<=logging.DEBUG)
|
||||||
|
# Sends the response to the client.
|
||||||
|
message = dumps("ERROR: %s" % e, HIGHEST_PROTOCOL)
|
||||||
|
self.push(message + CSPROTO.END)
|
||||||
|
|
||||||
|
##
|
||||||
|
# Handles an communication errors in request.
|
||||||
|
#
|
||||||
def handle_error(self):
|
def handle_error(self):
|
||||||
e1, e2 = formatExceptionInfo()
|
try:
|
||||||
logSys.error("Unexpected communication error: %s" % str(e2))
|
e1, e2 = formatExceptionInfo()
|
||||||
logSys.error(traceback.format_exc().splitlines())
|
logSys.error("Unexpected communication error: %s" % str(e2))
|
||||||
self.close()
|
logSys.error(traceback.format_exc().splitlines())
|
||||||
|
# Sends the response to the client.
|
||||||
|
message = dumps("ERROR: %s" % e2, HIGHEST_PROTOCOL)
|
||||||
|
self.push(message + CSPROTO.END)
|
||||||
|
except Exception as e: # pragma: no cover - normally unreachable
|
||||||
|
pass
|
||||||
|
self.close_when_done()
|
||||||
|
|
||||||
|
|
||||||
def loop(active, timeout=None, use_poll=False):
|
def loop(active, timeout=None, use_poll=False):
|
||||||
"""Custom event loop implementation
|
"""Custom event loop implementation
|
||||||
|
@ -119,18 +133,20 @@ def loop(active, timeout=None, use_poll=False):
|
||||||
poll(timeout)
|
poll(timeout)
|
||||||
if errCount:
|
if errCount:
|
||||||
errCount -= 1
|
errCount -= 1
|
||||||
except Exception as e: # pragma: no cover
|
except Exception as e:
|
||||||
if not active():
|
if not active():
|
||||||
break
|
break
|
||||||
errCount += 1
|
errCount += 1
|
||||||
if errCount < 20:
|
if errCount < 20:
|
||||||
if e.args[0] in (errno.ENOTCONN, errno.EBADF): # (errno.EBADF, 'Bad file descriptor')
|
# errno.ENOTCONN - 'Socket is not connected'
|
||||||
|
# errno.EBADF - 'Bad file descriptor'
|
||||||
|
if e.args[0] in (errno.ENOTCONN, errno.EBADF): # pragma: no cover (too sporadic)
|
||||||
logSys.info('Server connection was closed: %s', str(e))
|
logSys.info('Server connection was closed: %s', str(e))
|
||||||
else:
|
else:
|
||||||
logSys.error('Server connection was closed: %s', str(e))
|
logSys.error('Server connection was closed: %s', str(e))
|
||||||
elif errCount == 20:
|
elif errCount == 20:
|
||||||
logSys.info('Too many errors - stop logging connection errors')
|
|
||||||
logSys.exception(e)
|
logSys.exception(e)
|
||||||
|
logSys.error('Too many errors - stop logging connection errors')
|
||||||
|
|
||||||
|
|
||||||
##
|
##
|
||||||
|
@ -158,10 +174,10 @@ class AsyncServer(asyncore.dispatcher):
|
||||||
def handle_accept(self):
|
def handle_accept(self):
|
||||||
try:
|
try:
|
||||||
conn, addr = self.accept()
|
conn, addr = self.accept()
|
||||||
except socket.error:
|
except socket.error: # pragma: no cover
|
||||||
logSys.warning("Socket error")
|
logSys.warning("Socket error")
|
||||||
return
|
return
|
||||||
except TypeError:
|
except TypeError: # pragma: no cover
|
||||||
logSys.warning("Type error")
|
logSys.warning("Type error")
|
||||||
return
|
return
|
||||||
AsyncServer.__markCloseOnExec(conn)
|
AsyncServer.__markCloseOnExec(conn)
|
||||||
|
@ -175,7 +191,7 @@ class AsyncServer(asyncore.dispatcher):
|
||||||
# @param sock: socket file.
|
# @param sock: socket file.
|
||||||
# @param force: remove the socket file if exists.
|
# @param force: remove the socket file if exists.
|
||||||
|
|
||||||
def start(self, sock, force, use_poll=False):
|
def start(self, sock, force, timeout=None, use_poll=False):
|
||||||
self.__worker = threading.current_thread()
|
self.__worker = threading.current_thread()
|
||||||
self.__sock = sock
|
self.__sock = sock
|
||||||
# Remove socket
|
# Remove socket
|
||||||
|
@ -191,7 +207,7 @@ class AsyncServer(asyncore.dispatcher):
|
||||||
self.set_reuse_addr()
|
self.set_reuse_addr()
|
||||||
try:
|
try:
|
||||||
self.bind(sock)
|
self.bind(sock)
|
||||||
except Exception:
|
except Exception: # pragma: no cover
|
||||||
raise AsyncServerException("Unable to bind socket %s" % self.__sock)
|
raise AsyncServerException("Unable to bind socket %s" % self.__sock)
|
||||||
AsyncServer.__markCloseOnExec(self.socket)
|
AsyncServer.__markCloseOnExec(self.socket)
|
||||||
self.listen(1)
|
self.listen(1)
|
||||||
|
@ -201,12 +217,11 @@ class AsyncServer(asyncore.dispatcher):
|
||||||
if self.onstart:
|
if self.onstart:
|
||||||
self.onstart()
|
self.onstart()
|
||||||
# Event loop as long as active:
|
# Event loop as long as active:
|
||||||
loop(lambda: self.__loop, use_poll=use_poll)
|
loop(lambda: self.__loop, timeout=timeout, use_poll=use_poll)
|
||||||
self.__active = False
|
self.__active = False
|
||||||
# Cleanup all
|
# Cleanup all
|
||||||
self.stop()
|
self.stop()
|
||||||
|
|
||||||
|
|
||||||
def close(self):
|
def close(self):
|
||||||
stopflg = False
|
stopflg = False
|
||||||
if self.__active:
|
if self.__active:
|
||||||
|
@ -228,6 +243,13 @@ class AsyncServer(asyncore.dispatcher):
|
||||||
##
|
##
|
||||||
# Stops the communication server.
|
# Stops the communication server.
|
||||||
|
|
||||||
|
def stop_communication(self):
|
||||||
|
logSys.debug("Stop communication")
|
||||||
|
self.__transmitter = None
|
||||||
|
|
||||||
|
##
|
||||||
|
# Stops the server.
|
||||||
|
|
||||||
def stop(self):
|
def stop(self):
|
||||||
self.close()
|
self.close()
|
||||||
|
|
||||||
|
|
|
@ -82,7 +82,7 @@ class Server:
|
||||||
|
|
||||||
def __sigTERMhandler(self, signum, frame): # pragma: no cover - indirect tested
|
def __sigTERMhandler(self, signum, frame): # pragma: no cover - indirect tested
|
||||||
logSys.debug("Caught signal %d. Exiting", signum)
|
logSys.debug("Caught signal %d. Exiting", signum)
|
||||||
self.quit_signal()
|
self.quit()
|
||||||
|
|
||||||
def __sigUSR1handler(self, signum, fname): # pragma: no cover - indirect tested
|
def __sigUSR1handler(self, signum, fname): # pragma: no cover - indirect tested
|
||||||
logSys.debug("Caught signal %d. Flushing logs", signum)
|
logSys.debug("Caught signal %d. Flushing logs", signum)
|
||||||
|
@ -152,7 +152,7 @@ class Server:
|
||||||
except AsyncServerException as e:
|
except AsyncServerException as e:
|
||||||
logSys.error("Could not start server: %s", e)
|
logSys.error("Could not start server: %s", e)
|
||||||
|
|
||||||
# Stop server
|
# Stop (if not yet already executed):
|
||||||
self.quit()
|
self.quit()
|
||||||
|
|
||||||
# Removes the PID file.
|
# Removes the PID file.
|
||||||
|
@ -161,13 +161,22 @@ class Server:
|
||||||
os.remove(pidfile)
|
os.remove(pidfile)
|
||||||
except (OSError, IOError) as e: # pragma: no cover
|
except (OSError, IOError) as e: # pragma: no cover
|
||||||
logSys.error("Unable to remove PID file: %s", e)
|
logSys.error("Unable to remove PID file: %s", e)
|
||||||
logSys.info("Exiting Fail2ban")
|
|
||||||
|
|
||||||
def quit(self):
|
def quit(self):
|
||||||
self.quit_signal()
|
# Prevent to call quit twice:
|
||||||
|
self.quit = lambda: False
|
||||||
|
|
||||||
logSys.info("Shutdown in progress...")
|
logSys.info("Shutdown in progress...")
|
||||||
|
|
||||||
|
# Stop communication first because if jail's unban action
|
||||||
|
# tries to communicate via fail2ban-client we get a lockup
|
||||||
|
# among threads. So the simplest resolution is to stop all
|
||||||
|
# communications first (which should be ok anyways since we
|
||||||
|
# are exiting)
|
||||||
|
# See https://github.com/fail2ban/fail2ban/issues/7
|
||||||
|
if self.__asyncServer is not None:
|
||||||
|
self.__asyncServer.stop_communication()
|
||||||
|
|
||||||
# Restore default signal handlers:
|
# Restore default signal handlers:
|
||||||
if _thread_name() == '_MainThread':
|
if _thread_name() == '_MainThread':
|
||||||
for s, sh in self.__prev_signals.iteritems():
|
for s, sh in self.__prev_signals.iteritems():
|
||||||
|
@ -182,18 +191,11 @@ class Server:
|
||||||
self.__db.close()
|
self.__db.close()
|
||||||
self.__db = None
|
self.__db = None
|
||||||
|
|
||||||
def quit_signal(self):
|
# Stop async
|
||||||
# Prevent to call quit_signal twice:
|
|
||||||
self.quit_signal = lambda: False
|
|
||||||
# Stop communication first because if jail's unban action
|
|
||||||
# tries to communicate via fail2ban-client we get a lockup
|
|
||||||
# among threads. So the simplest resolution is to stop all
|
|
||||||
# communications first (which should be ok anyways since we
|
|
||||||
# are exiting)
|
|
||||||
# See https://github.com/fail2ban/fail2ban/issues/7
|
|
||||||
if self.__asyncServer is not None:
|
if self.__asyncServer is not None:
|
||||||
self.__asyncServer.stop()
|
self.__asyncServer.stop()
|
||||||
self.__asyncServer = None
|
self.__asyncServer = None
|
||||||
|
logSys.info("Exiting Fail2ban")
|
||||||
|
|
||||||
def addJail(self, name, backend):
|
def addJail(self, name, backend):
|
||||||
addflg = True
|
addflg = True
|
||||||
|
@ -610,7 +612,7 @@ class Server:
|
||||||
try:
|
try:
|
||||||
handler.flush()
|
handler.flush()
|
||||||
handler.close()
|
handler.close()
|
||||||
except (ValueError, KeyError): # pragma: no cover
|
except (ValueError, KeyError): # pragma: no cover
|
||||||
# Is known to be thrown after logging was shutdown once
|
# Is known to be thrown after logging was shutdown once
|
||||||
# with older Pythons -- seems to be safe to ignore there
|
# with older Pythons -- seems to be safe to ignore there
|
||||||
# At least it was still failing on 2.6.2-0ubuntu1 (jaunty)
|
# At least it was still failing on 2.6.2-0ubuntu1 (jaunty)
|
||||||
|
|
Loading…
Reference in New Issue