Merge branch 'speedup-client-status' into 0.10

pull/2842/head^2^2
sebres 2020-09-23 13:03:45 +02:00
commit 4c2539856c
8 changed files with 52 additions and 40 deletions

View File

@ -48,7 +48,8 @@ class CSocket:
def send(self, msg, nonblocking=False, timeout=None): def send(self, msg, nonblocking=False, timeout=None):
# Convert every list member to string # Convert every list member to string
obj = dumps(map(CSocket.convert, msg), HIGHEST_PROTOCOL) obj = dumps(map(CSocket.convert, msg), HIGHEST_PROTOCOL)
self.__csock.send(obj + CSPROTO.END) self.__csock.send(obj)
self.__csock.send(CSPROTO.END)
return self.receive(self.__csock, nonblocking, timeout) return self.receive(self.__csock, nonblocking, timeout)
def settimeout(self, timeout): def settimeout(self, timeout):
@ -81,9 +82,12 @@ class CSocket:
msg = CSPROTO.EMPTY msg = CSPROTO.EMPTY
if nonblocking: sock.setblocking(0) if nonblocking: sock.setblocking(0)
if timeout: sock.settimeout(timeout) if timeout: sock.settimeout(timeout)
while msg.rfind(CSPROTO.END) == -1: bufsize = 1024
chunk = sock.recv(512) while msg.rfind(CSPROTO.END, -32) == -1:
if chunk in ('', b''): # python 3.x may return b'' instead of '' chunk = sock.recv(bufsize)
raise RuntimeError("socket connection broken") if not len(chunk):
raise socket.error(104, 'Connection reset by peer')
if chunk == CSPROTO.END: break
msg = msg + chunk msg = msg + chunk
if bufsize < 32768: bufsize <<= 1
return loads(msg) return loads(msg)

View File

@ -27,13 +27,17 @@ import sys
from ..version import version, normVersion from ..version import version, normVersion
from ..protocol import printFormatted from ..protocol import printFormatted
from ..helpers import getLogger, str2LogLevel, getVerbosityFormat from ..helpers import getLogger, str2LogLevel, getVerbosityFormat, BrokenPipeError
# Gets the instance of the logger. # Gets the instance of the logger.
logSys = getLogger("fail2ban") logSys = getLogger("fail2ban")
def output(s): # pragma: no cover def output(s): # pragma: no cover
print(s) try:
print(s)
except (BrokenPipeError, IOError) as e: # pragma: no cover
if e.errno != 32: # closed / broken pipe
raise
# Config parameters required to start fail2ban which can be also set via command line (overwrite fail2ban.conf), # Config parameters required to start fail2ban which can be also set via command line (overwrite fail2ban.conf),
CONFIG_PARAMS = ("socket", "pidfile", "logtarget", "loglevel", "syslogsocket") CONFIG_PARAMS = ("socket", "pidfile", "logtarget", "loglevel", "syslogsocket")
@ -310,12 +314,16 @@ class Fail2banCmdLine():
def _exit(code=0): def _exit(code=0):
# implicit flush without to produce broken pipe error (32): # implicit flush without to produce broken pipe error (32):
sys.stderr.close() sys.stderr.close()
sys.stdout.close() try:
# exit: sys.stdout.flush()
if hasattr(os, '_exit') and os._exit: # exit:
os._exit(code) if hasattr(sys, 'exit') and sys.exit:
else: sys.exit(code)
sys.exit(code) else:
os._exit(code)
except (BrokenPipeError, IOError) as e: # pragma: no cover
if e.errno != 32: # closed / broken pipe
raise
@staticmethod @staticmethod
def exit(code=0): def exit(code=0):

View File

@ -224,9 +224,10 @@ def __stopOnIOError(logSys=None, logHndlr=None): # pragma: no cover
sys.exit(0) sys.exit(0)
try: try:
BrokenPipeError BrokenPipeError = BrokenPipeError
except NameError: # pragma: 3.x no cover except NameError: # pragma: 3.x no cover
BrokenPipeError = IOError BrokenPipeError = IOError
__origLog = logging.Logger._log __origLog = logging.Logger._log
def __safeLog(self, level, msg, args, **kwargs): def __safeLog(self, level, msg, args, **kwargs):
"""Safe log inject to avoid possible errors by unsafe log-handlers, """Safe log inject to avoid possible errors by unsafe log-handlers,

View File

@ -661,13 +661,19 @@ class Actions(JailThread, Mapping):
"""Status of current and total ban counts and current banned IP list. """Status of current and total ban counts and current banned IP list.
""" """
# TODO: Allow this list to be printed as 'status' output # TODO: Allow this list to be printed as 'status' output
supported_flavors = ["basic", "cymru"] supported_flavors = ["short", "basic", "cymru"]
if flavor is None or flavor not in supported_flavors: if flavor is None or flavor not in supported_flavors:
logSys.warning("Unsupported extended jail status flavor %r. Supported: %s" % (flavor, supported_flavors)) logSys.warning("Unsupported extended jail status flavor %r. Supported: %s" % (flavor, supported_flavors))
# Always print this information (basic) # Always print this information (basic)
ret = [("Currently banned", self.__banManager.size()), if flavor != "short":
("Total banned", self.__banManager.getBanTotal()), banned = self.__banManager.getBanList()
("Banned IP list", self.__banManager.getBanList())] cnt = len(banned)
else:
cnt = self.__banManager.size()
ret = [("Currently banned", cnt),
("Total banned", self.__banManager.getBanTotal())]
if flavor != "short":
ret += [("Banned IP list", banned)]
if flavor == "cymru": if flavor == "cymru":
cymru_info = self.__banManager.getBanListExtendedCymruInfo() cymru_info = self.__banManager.getBanListExtendedCymruInfo()
ret += \ ret += \

View File

@ -66,8 +66,7 @@ class BanManager:
# @param value the time # @param value the time
def setBanTime(self, value): def setBanTime(self, value):
with self.__lock: self.__banTime = int(value)
self.__banTime = int(value)
## ##
# Get the ban time. # Get the ban time.
@ -76,8 +75,7 @@ class BanManager:
# @return the time # @return the time
def getBanTime(self): def getBanTime(self):
with self.__lock: return self.__banTime
return self.__banTime
## ##
# Set the total number of banned address. # Set the total number of banned address.
@ -85,8 +83,7 @@ class BanManager:
# @param value total number # @param value total number
def setBanTotal(self, value): def setBanTotal(self, value):
with self.__lock: self.__banTotal = value
self.__banTotal = value
## ##
# Get the total number of banned address. # Get the total number of banned address.
@ -94,8 +91,7 @@ class BanManager:
# @return the total number # @return the total number
def getBanTotal(self): def getBanTotal(self):
with self.__lock: return self.__banTotal
return self.__banTotal
## ##
# Returns a copy of the IP list. # Returns a copy of the IP list.
@ -103,8 +99,7 @@ class BanManager:
# @return IP list # @return IP list
def getBanList(self): def getBanList(self):
with self.__lock: return list(self.__banList.keys())
return list(self.__banList.keys())
## ##
# Returns a iterator to ban list (used in reload, so idle). # Returns a iterator to ban list (used in reload, so idle).
@ -112,9 +107,8 @@ class BanManager:
# @return ban list iterator # @return ban list iterator
def __iter__(self): def __iter__(self):
# ensure iterator is safe (traverse over the list in snapshot created within lock): # ensure iterator is safe - traverse over the list in snapshot created within lock (GIL):
with self.__lock: return iter(list(self.__banList.values()))
return iter(list(self.__banList.values()))
## ##
# Returns normalized value # Returns normalized value

View File

@ -47,12 +47,10 @@ class FailManager:
self.__bgSvc = BgService() self.__bgSvc = BgService()
def setFailTotal(self, value): def setFailTotal(self, value):
with self.__lock: self.__failTotal = value
self.__failTotal = value
def getFailTotal(self): def getFailTotal(self):
with self.__lock: return self.__failTotal
return self.__failTotal
def getFailCount(self): def getFailCount(self):
# may be slow on large list of failures, should be used for test purposes only... # may be slow on large list of failures, should be used for test purposes only...
@ -123,8 +121,7 @@ class FailManager:
return attempts return attempts
def size(self): def size(self):
with self.__lock: return len(self.__failList)
return len(self.__failList)
def cleanup(self, time): def cleanup(self, time):
with self.__lock: with self.__lock:

View File

@ -96,6 +96,8 @@ class ExecuteActions(LogCaptureTestCase):
self.assertLogged("stdout: %r" % 'ip flush', "stdout: %r" % 'ip stop') self.assertLogged("stdout: %r" % 'ip flush', "stdout: %r" % 'ip stop')
self.assertEqual(self.__actions.status(),[("Currently banned", 0 ), self.assertEqual(self.__actions.status(),[("Currently banned", 0 ),
("Total banned", 0 ), ("Banned IP list", [] )]) ("Total banned", 0 ), ("Banned IP list", [] )])
self.assertEqual(self.__actions.status('short'),[("Currently banned", 0 ),
("Total banned", 0 )])
def testAddActionPython(self): def testAddActionPython(self):
self.__actions.add( self.__actions.add(

View File

@ -152,7 +152,7 @@ class Socket(LogCaptureTestCase):
org_handler = RequestHandler.found_terminator org_handler = RequestHandler.found_terminator
try: try:
RequestHandler.found_terminator = lambda self: self.close() RequestHandler.found_terminator = lambda self: self.close()
self.assertRaisesRegexp(RuntimeError, r"socket connection broken", self.assertRaisesRegexp(Exception, r"reset by peer|Broken pipe",
lambda: client.send(testMessage, timeout=unittest.F2B.maxWaitTime(10))) lambda: client.send(testMessage, timeout=unittest.F2B.maxWaitTime(10)))
finally: finally:
RequestHandler.found_terminator = org_handler RequestHandler.found_terminator = org_handler
@ -168,7 +168,7 @@ class Socket(LogCaptureTestCase):
org_handler = RequestHandler.found_terminator org_handler = RequestHandler.found_terminator
try: try:
RequestHandler.found_terminator = lambda self: TestMsgError() RequestHandler.found_terminator = lambda self: TestMsgError()
#self.assertRaisesRegexp(RuntimeError, r"socket connection broken", client.send, testMessage) #self.assertRaisesRegexp(Exception, r"reset by peer|Broken pipe", client.send, testMessage)
self.assertEqual(client.send(testMessage), 'ERROR: test unpickle error') self.assertEqual(client.send(testMessage), 'ERROR: test unpickle error')
finally: finally:
RequestHandler.found_terminator = org_handler RequestHandler.found_terminator = org_handler