mirror of https://github.com/fail2ban/fail2ban
Merge branch 'speedup-client-status' into 0.10
commit
4c2539856c
|
@ -48,7 +48,8 @@ class CSocket:
|
||||||
def send(self, msg, nonblocking=False, timeout=None):
|
def send(self, msg, nonblocking=False, timeout=None):
|
||||||
# Convert every list member to string
|
# Convert every list member to string
|
||||||
obj = dumps(map(CSocket.convert, msg), HIGHEST_PROTOCOL)
|
obj = dumps(map(CSocket.convert, msg), HIGHEST_PROTOCOL)
|
||||||
self.__csock.send(obj + CSPROTO.END)
|
self.__csock.send(obj)
|
||||||
|
self.__csock.send(CSPROTO.END)
|
||||||
return self.receive(self.__csock, nonblocking, timeout)
|
return self.receive(self.__csock, nonblocking, timeout)
|
||||||
|
|
||||||
def settimeout(self, timeout):
|
def settimeout(self, timeout):
|
||||||
|
@ -81,9 +82,12 @@ class CSocket:
|
||||||
msg = CSPROTO.EMPTY
|
msg = CSPROTO.EMPTY
|
||||||
if nonblocking: sock.setblocking(0)
|
if nonblocking: sock.setblocking(0)
|
||||||
if timeout: sock.settimeout(timeout)
|
if timeout: sock.settimeout(timeout)
|
||||||
while msg.rfind(CSPROTO.END) == -1:
|
bufsize = 1024
|
||||||
chunk = sock.recv(512)
|
while msg.rfind(CSPROTO.END, -32) == -1:
|
||||||
if chunk in ('', b''): # python 3.x may return b'' instead of ''
|
chunk = sock.recv(bufsize)
|
||||||
raise RuntimeError("socket connection broken")
|
if not len(chunk):
|
||||||
|
raise socket.error(104, 'Connection reset by peer')
|
||||||
|
if chunk == CSPROTO.END: break
|
||||||
msg = msg + chunk
|
msg = msg + chunk
|
||||||
|
if bufsize < 32768: bufsize <<= 1
|
||||||
return loads(msg)
|
return loads(msg)
|
||||||
|
|
|
@ -27,13 +27,17 @@ import sys
|
||||||
|
|
||||||
from ..version import version, normVersion
|
from ..version import version, normVersion
|
||||||
from ..protocol import printFormatted
|
from ..protocol import printFormatted
|
||||||
from ..helpers import getLogger, str2LogLevel, getVerbosityFormat
|
from ..helpers import getLogger, str2LogLevel, getVerbosityFormat, BrokenPipeError
|
||||||
|
|
||||||
# Gets the instance of the logger.
|
# Gets the instance of the logger.
|
||||||
logSys = getLogger("fail2ban")
|
logSys = getLogger("fail2ban")
|
||||||
|
|
||||||
def output(s): # pragma: no cover
|
def output(s): # pragma: no cover
|
||||||
|
try:
|
||||||
print(s)
|
print(s)
|
||||||
|
except (BrokenPipeError, IOError) as e: # pragma: no cover
|
||||||
|
if e.errno != 32: # closed / broken pipe
|
||||||
|
raise
|
||||||
|
|
||||||
# Config parameters required to start fail2ban which can be also set via command line (overwrite fail2ban.conf),
|
# Config parameters required to start fail2ban which can be also set via command line (overwrite fail2ban.conf),
|
||||||
CONFIG_PARAMS = ("socket", "pidfile", "logtarget", "loglevel", "syslogsocket")
|
CONFIG_PARAMS = ("socket", "pidfile", "logtarget", "loglevel", "syslogsocket")
|
||||||
|
@ -310,12 +314,16 @@ class Fail2banCmdLine():
|
||||||
def _exit(code=0):
|
def _exit(code=0):
|
||||||
# implicit flush without to produce broken pipe error (32):
|
# implicit flush without to produce broken pipe error (32):
|
||||||
sys.stderr.close()
|
sys.stderr.close()
|
||||||
sys.stdout.close()
|
try:
|
||||||
|
sys.stdout.flush()
|
||||||
# exit:
|
# exit:
|
||||||
if hasattr(os, '_exit') and os._exit:
|
if hasattr(sys, 'exit') and sys.exit:
|
||||||
os._exit(code)
|
|
||||||
else:
|
|
||||||
sys.exit(code)
|
sys.exit(code)
|
||||||
|
else:
|
||||||
|
os._exit(code)
|
||||||
|
except (BrokenPipeError, IOError) as e: # pragma: no cover
|
||||||
|
if e.errno != 32: # closed / broken pipe
|
||||||
|
raise
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def exit(code=0):
|
def exit(code=0):
|
||||||
|
|
|
@ -224,9 +224,10 @@ def __stopOnIOError(logSys=None, logHndlr=None): # pragma: no cover
|
||||||
sys.exit(0)
|
sys.exit(0)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
BrokenPipeError
|
BrokenPipeError = BrokenPipeError
|
||||||
except NameError: # pragma: 3.x no cover
|
except NameError: # pragma: 3.x no cover
|
||||||
BrokenPipeError = IOError
|
BrokenPipeError = IOError
|
||||||
|
|
||||||
__origLog = logging.Logger._log
|
__origLog = logging.Logger._log
|
||||||
def __safeLog(self, level, msg, args, **kwargs):
|
def __safeLog(self, level, msg, args, **kwargs):
|
||||||
"""Safe log inject to avoid possible errors by unsafe log-handlers,
|
"""Safe log inject to avoid possible errors by unsafe log-handlers,
|
||||||
|
|
|
@ -661,13 +661,19 @@ class Actions(JailThread, Mapping):
|
||||||
"""Status of current and total ban counts and current banned IP list.
|
"""Status of current and total ban counts and current banned IP list.
|
||||||
"""
|
"""
|
||||||
# TODO: Allow this list to be printed as 'status' output
|
# TODO: Allow this list to be printed as 'status' output
|
||||||
supported_flavors = ["basic", "cymru"]
|
supported_flavors = ["short", "basic", "cymru"]
|
||||||
if flavor is None or flavor not in supported_flavors:
|
if flavor is None or flavor not in supported_flavors:
|
||||||
logSys.warning("Unsupported extended jail status flavor %r. Supported: %s" % (flavor, supported_flavors))
|
logSys.warning("Unsupported extended jail status flavor %r. Supported: %s" % (flavor, supported_flavors))
|
||||||
# Always print this information (basic)
|
# Always print this information (basic)
|
||||||
ret = [("Currently banned", self.__banManager.size()),
|
if flavor != "short":
|
||||||
("Total banned", self.__banManager.getBanTotal()),
|
banned = self.__banManager.getBanList()
|
||||||
("Banned IP list", self.__banManager.getBanList())]
|
cnt = len(banned)
|
||||||
|
else:
|
||||||
|
cnt = self.__banManager.size()
|
||||||
|
ret = [("Currently banned", cnt),
|
||||||
|
("Total banned", self.__banManager.getBanTotal())]
|
||||||
|
if flavor != "short":
|
||||||
|
ret += [("Banned IP list", banned)]
|
||||||
if flavor == "cymru":
|
if flavor == "cymru":
|
||||||
cymru_info = self.__banManager.getBanListExtendedCymruInfo()
|
cymru_info = self.__banManager.getBanListExtendedCymruInfo()
|
||||||
ret += \
|
ret += \
|
||||||
|
|
|
@ -66,7 +66,6 @@ class BanManager:
|
||||||
# @param value the time
|
# @param value the time
|
||||||
|
|
||||||
def setBanTime(self, value):
|
def setBanTime(self, value):
|
||||||
with self.__lock:
|
|
||||||
self.__banTime = int(value)
|
self.__banTime = int(value)
|
||||||
|
|
||||||
##
|
##
|
||||||
|
@ -76,7 +75,6 @@ class BanManager:
|
||||||
# @return the time
|
# @return the time
|
||||||
|
|
||||||
def getBanTime(self):
|
def getBanTime(self):
|
||||||
with self.__lock:
|
|
||||||
return self.__banTime
|
return self.__banTime
|
||||||
|
|
||||||
##
|
##
|
||||||
|
@ -85,7 +83,6 @@ class BanManager:
|
||||||
# @param value total number
|
# @param value total number
|
||||||
|
|
||||||
def setBanTotal(self, value):
|
def setBanTotal(self, value):
|
||||||
with self.__lock:
|
|
||||||
self.__banTotal = value
|
self.__banTotal = value
|
||||||
|
|
||||||
##
|
##
|
||||||
|
@ -94,7 +91,6 @@ class BanManager:
|
||||||
# @return the total number
|
# @return the total number
|
||||||
|
|
||||||
def getBanTotal(self):
|
def getBanTotal(self):
|
||||||
with self.__lock:
|
|
||||||
return self.__banTotal
|
return self.__banTotal
|
||||||
|
|
||||||
##
|
##
|
||||||
|
@ -103,7 +99,6 @@ class BanManager:
|
||||||
# @return IP list
|
# @return IP list
|
||||||
|
|
||||||
def getBanList(self):
|
def getBanList(self):
|
||||||
with self.__lock:
|
|
||||||
return list(self.__banList.keys())
|
return list(self.__banList.keys())
|
||||||
|
|
||||||
##
|
##
|
||||||
|
@ -112,8 +107,7 @@ class BanManager:
|
||||||
# @return ban list iterator
|
# @return ban list iterator
|
||||||
|
|
||||||
def __iter__(self):
|
def __iter__(self):
|
||||||
# ensure iterator is safe (traverse over the list in snapshot created within lock):
|
# ensure iterator is safe - traverse over the list in snapshot created within lock (GIL):
|
||||||
with self.__lock:
|
|
||||||
return iter(list(self.__banList.values()))
|
return iter(list(self.__banList.values()))
|
||||||
|
|
||||||
##
|
##
|
||||||
|
|
|
@ -47,11 +47,9 @@ class FailManager:
|
||||||
self.__bgSvc = BgService()
|
self.__bgSvc = BgService()
|
||||||
|
|
||||||
def setFailTotal(self, value):
|
def setFailTotal(self, value):
|
||||||
with self.__lock:
|
|
||||||
self.__failTotal = value
|
self.__failTotal = value
|
||||||
|
|
||||||
def getFailTotal(self):
|
def getFailTotal(self):
|
||||||
with self.__lock:
|
|
||||||
return self.__failTotal
|
return self.__failTotal
|
||||||
|
|
||||||
def getFailCount(self):
|
def getFailCount(self):
|
||||||
|
@ -123,7 +121,6 @@ class FailManager:
|
||||||
return attempts
|
return attempts
|
||||||
|
|
||||||
def size(self):
|
def size(self):
|
||||||
with self.__lock:
|
|
||||||
return len(self.__failList)
|
return len(self.__failList)
|
||||||
|
|
||||||
def cleanup(self, time):
|
def cleanup(self, time):
|
||||||
|
|
|
@ -96,6 +96,8 @@ class ExecuteActions(LogCaptureTestCase):
|
||||||
self.assertLogged("stdout: %r" % 'ip flush', "stdout: %r" % 'ip stop')
|
self.assertLogged("stdout: %r" % 'ip flush', "stdout: %r" % 'ip stop')
|
||||||
self.assertEqual(self.__actions.status(),[("Currently banned", 0 ),
|
self.assertEqual(self.__actions.status(),[("Currently banned", 0 ),
|
||||||
("Total banned", 0 ), ("Banned IP list", [] )])
|
("Total banned", 0 ), ("Banned IP list", [] )])
|
||||||
|
self.assertEqual(self.__actions.status('short'),[("Currently banned", 0 ),
|
||||||
|
("Total banned", 0 )])
|
||||||
|
|
||||||
def testAddActionPython(self):
|
def testAddActionPython(self):
|
||||||
self.__actions.add(
|
self.__actions.add(
|
||||||
|
|
|
@ -152,7 +152,7 @@ class Socket(LogCaptureTestCase):
|
||||||
org_handler = RequestHandler.found_terminator
|
org_handler = RequestHandler.found_terminator
|
||||||
try:
|
try:
|
||||||
RequestHandler.found_terminator = lambda self: self.close()
|
RequestHandler.found_terminator = lambda self: self.close()
|
||||||
self.assertRaisesRegexp(RuntimeError, r"socket connection broken",
|
self.assertRaisesRegexp(Exception, r"reset by peer|Broken pipe",
|
||||||
lambda: client.send(testMessage, timeout=unittest.F2B.maxWaitTime(10)))
|
lambda: client.send(testMessage, timeout=unittest.F2B.maxWaitTime(10)))
|
||||||
finally:
|
finally:
|
||||||
RequestHandler.found_terminator = org_handler
|
RequestHandler.found_terminator = org_handler
|
||||||
|
@ -168,7 +168,7 @@ class Socket(LogCaptureTestCase):
|
||||||
org_handler = RequestHandler.found_terminator
|
org_handler = RequestHandler.found_terminator
|
||||||
try:
|
try:
|
||||||
RequestHandler.found_terminator = lambda self: TestMsgError()
|
RequestHandler.found_terminator = lambda self: TestMsgError()
|
||||||
#self.assertRaisesRegexp(RuntimeError, r"socket connection broken", client.send, testMessage)
|
#self.assertRaisesRegexp(Exception, r"reset by peer|Broken pipe", client.send, testMessage)
|
||||||
self.assertEqual(client.send(testMessage), 'ERROR: test unpickle error')
|
self.assertEqual(client.send(testMessage), 'ERROR: test unpickle error')
|
||||||
finally:
|
finally:
|
||||||
RequestHandler.found_terminator = org_handler
|
RequestHandler.found_terminator = org_handler
|
||||||
|
|
Loading…
Reference in New Issue