Merge branch 'speedup-client-status' into 0.10

pull/2842/head^2^2
sebres 2020-09-23 13:03:45 +02:00
commit 4c2539856c
8 changed files with 52 additions and 40 deletions

View File

@ -48,7 +48,8 @@ class CSocket:
def send(self, msg, nonblocking=False, timeout=None): def send(self, msg, nonblocking=False, timeout=None):
# Convert every list member to string # Convert every list member to string
obj = dumps(map(CSocket.convert, msg), HIGHEST_PROTOCOL) obj = dumps(map(CSocket.convert, msg), HIGHEST_PROTOCOL)
self.__csock.send(obj + CSPROTO.END) self.__csock.send(obj)
self.__csock.send(CSPROTO.END)
return self.receive(self.__csock, nonblocking, timeout) return self.receive(self.__csock, nonblocking, timeout)
def settimeout(self, timeout): def settimeout(self, timeout):
@ -81,9 +82,12 @@ class CSocket:
msg = CSPROTO.EMPTY msg = CSPROTO.EMPTY
if nonblocking: sock.setblocking(0) if nonblocking: sock.setblocking(0)
if timeout: sock.settimeout(timeout) if timeout: sock.settimeout(timeout)
while msg.rfind(CSPROTO.END) == -1: bufsize = 1024
chunk = sock.recv(512) while msg.rfind(CSPROTO.END, -32) == -1:
if chunk in ('', b''): # python 3.x may return b'' instead of '' chunk = sock.recv(bufsize)
raise RuntimeError("socket connection broken") if not len(chunk):
raise socket.error(104, 'Connection reset by peer')
if chunk == CSPROTO.END: break
msg = msg + chunk msg = msg + chunk
if bufsize < 32768: bufsize <<= 1
return loads(msg) return loads(msg)

View File

@ -27,13 +27,17 @@ import sys
from ..version import version, normVersion from ..version import version, normVersion
from ..protocol import printFormatted from ..protocol import printFormatted
from ..helpers import getLogger, str2LogLevel, getVerbosityFormat from ..helpers import getLogger, str2LogLevel, getVerbosityFormat, BrokenPipeError
# Gets the instance of the logger. # Gets the instance of the logger.
logSys = getLogger("fail2ban") logSys = getLogger("fail2ban")
def output(s): # pragma: no cover def output(s): # pragma: no cover
try:
print(s) print(s)
except (BrokenPipeError, IOError) as e: # pragma: no cover
if e.errno != 32: # closed / broken pipe
raise
# Config parameters required to start fail2ban which can be also set via command line (overwrite fail2ban.conf), # Config parameters required to start fail2ban which can be also set via command line (overwrite fail2ban.conf),
CONFIG_PARAMS = ("socket", "pidfile", "logtarget", "loglevel", "syslogsocket") CONFIG_PARAMS = ("socket", "pidfile", "logtarget", "loglevel", "syslogsocket")
@ -310,12 +314,16 @@ class Fail2banCmdLine():
def _exit(code=0): def _exit(code=0):
# implicit flush without to produce broken pipe error (32): # implicit flush without to produce broken pipe error (32):
sys.stderr.close() sys.stderr.close()
sys.stdout.close() try:
sys.stdout.flush()
# exit: # exit:
if hasattr(os, '_exit') and os._exit: if hasattr(sys, 'exit') and sys.exit:
os._exit(code)
else:
sys.exit(code) sys.exit(code)
else:
os._exit(code)
except (BrokenPipeError, IOError) as e: # pragma: no cover
if e.errno != 32: # closed / broken pipe
raise
@staticmethod @staticmethod
def exit(code=0): def exit(code=0):

View File

@ -224,9 +224,10 @@ def __stopOnIOError(logSys=None, logHndlr=None): # pragma: no cover
sys.exit(0) sys.exit(0)
try: try:
BrokenPipeError BrokenPipeError = BrokenPipeError
except NameError: # pragma: 3.x no cover except NameError: # pragma: 3.x no cover
BrokenPipeError = IOError BrokenPipeError = IOError
__origLog = logging.Logger._log __origLog = logging.Logger._log
def __safeLog(self, level, msg, args, **kwargs): def __safeLog(self, level, msg, args, **kwargs):
"""Safe log inject to avoid possible errors by unsafe log-handlers, """Safe log inject to avoid possible errors by unsafe log-handlers,

View File

@ -661,13 +661,19 @@ class Actions(JailThread, Mapping):
"""Status of current and total ban counts and current banned IP list. """Status of current and total ban counts and current banned IP list.
""" """
# TODO: Allow this list to be printed as 'status' output # TODO: Allow this list to be printed as 'status' output
supported_flavors = ["basic", "cymru"] supported_flavors = ["short", "basic", "cymru"]
if flavor is None or flavor not in supported_flavors: if flavor is None or flavor not in supported_flavors:
logSys.warning("Unsupported extended jail status flavor %r. Supported: %s" % (flavor, supported_flavors)) logSys.warning("Unsupported extended jail status flavor %r. Supported: %s" % (flavor, supported_flavors))
# Always print this information (basic) # Always print this information (basic)
ret = [("Currently banned", self.__banManager.size()), if flavor != "short":
("Total banned", self.__banManager.getBanTotal()), banned = self.__banManager.getBanList()
("Banned IP list", self.__banManager.getBanList())] cnt = len(banned)
else:
cnt = self.__banManager.size()
ret = [("Currently banned", cnt),
("Total banned", self.__banManager.getBanTotal())]
if flavor != "short":
ret += [("Banned IP list", banned)]
if flavor == "cymru": if flavor == "cymru":
cymru_info = self.__banManager.getBanListExtendedCymruInfo() cymru_info = self.__banManager.getBanListExtendedCymruInfo()
ret += \ ret += \

View File

@ -66,7 +66,6 @@ class BanManager:
# @param value the time # @param value the time
def setBanTime(self, value): def setBanTime(self, value):
with self.__lock:
self.__banTime = int(value) self.__banTime = int(value)
## ##
@ -76,7 +75,6 @@ class BanManager:
# @return the time # @return the time
def getBanTime(self): def getBanTime(self):
with self.__lock:
return self.__banTime return self.__banTime
## ##
@ -85,7 +83,6 @@ class BanManager:
# @param value total number # @param value total number
def setBanTotal(self, value): def setBanTotal(self, value):
with self.__lock:
self.__banTotal = value self.__banTotal = value
## ##
@ -94,7 +91,6 @@ class BanManager:
# @return the total number # @return the total number
def getBanTotal(self): def getBanTotal(self):
with self.__lock:
return self.__banTotal return self.__banTotal
## ##
@ -103,7 +99,6 @@ class BanManager:
# @return IP list # @return IP list
def getBanList(self): def getBanList(self):
with self.__lock:
return list(self.__banList.keys()) return list(self.__banList.keys())
## ##
@ -112,8 +107,7 @@ class BanManager:
# @return ban list iterator # @return ban list iterator
def __iter__(self): def __iter__(self):
# ensure iterator is safe (traverse over the list in snapshot created within lock): # ensure iterator is safe - traverse over the list in snapshot created within lock (GIL):
with self.__lock:
return iter(list(self.__banList.values())) return iter(list(self.__banList.values()))
## ##

View File

@ -47,11 +47,9 @@ class FailManager:
self.__bgSvc = BgService() self.__bgSvc = BgService()
def setFailTotal(self, value): def setFailTotal(self, value):
with self.__lock:
self.__failTotal = value self.__failTotal = value
def getFailTotal(self): def getFailTotal(self):
with self.__lock:
return self.__failTotal return self.__failTotal
def getFailCount(self): def getFailCount(self):
@ -123,7 +121,6 @@ class FailManager:
return attempts return attempts
def size(self): def size(self):
with self.__lock:
return len(self.__failList) return len(self.__failList)
def cleanup(self, time): def cleanup(self, time):

View File

@ -96,6 +96,8 @@ class ExecuteActions(LogCaptureTestCase):
self.assertLogged("stdout: %r" % 'ip flush', "stdout: %r" % 'ip stop') self.assertLogged("stdout: %r" % 'ip flush', "stdout: %r" % 'ip stop')
self.assertEqual(self.__actions.status(),[("Currently banned", 0 ), self.assertEqual(self.__actions.status(),[("Currently banned", 0 ),
("Total banned", 0 ), ("Banned IP list", [] )]) ("Total banned", 0 ), ("Banned IP list", [] )])
self.assertEqual(self.__actions.status('short'),[("Currently banned", 0 ),
("Total banned", 0 )])
def testAddActionPython(self): def testAddActionPython(self):
self.__actions.add( self.__actions.add(

View File

@ -152,7 +152,7 @@ class Socket(LogCaptureTestCase):
org_handler = RequestHandler.found_terminator org_handler = RequestHandler.found_terminator
try: try:
RequestHandler.found_terminator = lambda self: self.close() RequestHandler.found_terminator = lambda self: self.close()
self.assertRaisesRegexp(RuntimeError, r"socket connection broken", self.assertRaisesRegexp(Exception, r"reset by peer|Broken pipe",
lambda: client.send(testMessage, timeout=unittest.F2B.maxWaitTime(10))) lambda: client.send(testMessage, timeout=unittest.F2B.maxWaitTime(10)))
finally: finally:
RequestHandler.found_terminator = org_handler RequestHandler.found_terminator = org_handler
@ -168,7 +168,7 @@ class Socket(LogCaptureTestCase):
org_handler = RequestHandler.found_terminator org_handler = RequestHandler.found_terminator
try: try:
RequestHandler.found_terminator = lambda self: TestMsgError() RequestHandler.found_terminator = lambda self: TestMsgError()
#self.assertRaisesRegexp(RuntimeError, r"socket connection broken", client.send, testMessage) #self.assertRaisesRegexp(Exception, r"reset by peer|Broken pipe", client.send, testMessage)
self.assertEqual(client.send(testMessage), 'ERROR: test unpickle error') self.assertEqual(client.send(testMessage), 'ERROR: test unpickle error')
finally: finally:
RequestHandler.found_terminator = org_handler RequestHandler.found_terminator = org_handler