mirror of https://github.com/fail2ban/fail2ban
Merge pull request #1099 from fail2ban/_sebres/min-connect-overhead
performance fix: minimizes connection overhead (close only at communication end)pull/1102/head
commit
cbb846bb46
|
@ -26,6 +26,8 @@ ver. 0.9.3 (2015/XX/XXX) - wanna-be-released
|
||||||
- Added regex to work with 'userlogins' log
|
- Added regex to work with 'userlogins' log
|
||||||
* action.d/sendmail*.conf - use LC_ALL (superseeding LC_TIME) to override
|
* action.d/sendmail*.conf - use LC_ALL (superseeding LC_TIME) to override
|
||||||
locale on systems with customized LC_ALL
|
locale on systems with customized LC_ALL
|
||||||
|
* performance fix: minimizes connection overhead, close socket only at
|
||||||
|
communication end (gh-1099)
|
||||||
|
|
||||||
- New Features:
|
- New Features:
|
||||||
* New filters:
|
* New filters:
|
||||||
|
|
|
@ -153,11 +153,14 @@ class Fail2banClient:
|
||||||
return self.__processCmd([["ping"]], False)
|
return self.__processCmd([["ping"]], False)
|
||||||
|
|
||||||
def __processCmd(self, cmd, showRet = True):
|
def __processCmd(self, cmd, showRet = True):
|
||||||
|
client = None
|
||||||
|
try:
|
||||||
beautifier = Beautifier()
|
beautifier = Beautifier()
|
||||||
streamRet = True
|
streamRet = True
|
||||||
for c in cmd:
|
for c in cmd:
|
||||||
beautifier.setInputCmd(c)
|
beautifier.setInputCmd(c)
|
||||||
try:
|
try:
|
||||||
|
if not client:
|
||||||
client = CSocket(self.__conf["socket"])
|
client = CSocket(self.__conf["socket"])
|
||||||
ret = client.send(c)
|
ret = client.send(c)
|
||||||
if ret[0] == 0:
|
if ret[0] == 0:
|
||||||
|
@ -177,6 +180,9 @@ class Fail2banClient:
|
||||||
if showRet:
|
if showRet:
|
||||||
logSys.error(e)
|
logSys.error(e)
|
||||||
return False
|
return False
|
||||||
|
finally:
|
||||||
|
if client:
|
||||||
|
client.close()
|
||||||
return streamRet
|
return streamRet
|
||||||
|
|
||||||
def __logSocketError(self):
|
def __logSocketError(self):
|
||||||
|
|
|
@ -26,24 +26,12 @@ __license__ = "GPL"
|
||||||
|
|
||||||
#from cPickle import dumps, loads, HIGHEST_PROTOCOL
|
#from cPickle import dumps, loads, HIGHEST_PROTOCOL
|
||||||
from pickle import dumps, loads, HIGHEST_PROTOCOL
|
from pickle import dumps, loads, HIGHEST_PROTOCOL
|
||||||
|
from ..protocol import CSPROTO
|
||||||
import socket
|
import socket
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
if sys.version_info >= (3,):
|
|
||||||
# b"" causes SyntaxError in python <= 2.5, so below implements equivalent
|
|
||||||
EMPTY_BYTES = bytes("", encoding="ascii")
|
|
||||||
else:
|
|
||||||
# python 2.x, string type is equivalent to bytes.
|
|
||||||
EMPTY_BYTES = ""
|
|
||||||
|
|
||||||
|
|
||||||
class CSocket:
|
class CSocket:
|
||||||
|
|
||||||
if sys.version_info >= (3,):
|
|
||||||
END_STRING = bytes("<F2B_END_COMMAND>", encoding='ascii')
|
|
||||||
else:
|
|
||||||
END_STRING = "<F2B_END_COMMAND>"
|
|
||||||
|
|
||||||
def __init__(self, sock="/var/run/fail2ban/fail2ban.sock"):
|
def __init__(self, sock="/var/run/fail2ban/fail2ban.sock"):
|
||||||
# Create an INET, STREAMing socket
|
# Create an INET, STREAMing socket
|
||||||
#self.csock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
#self.csock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||||
|
@ -51,18 +39,27 @@ class CSocket:
|
||||||
#self.csock.connect(("localhost", 2222))
|
#self.csock.connect(("localhost", 2222))
|
||||||
self.__csock.connect(sock)
|
self.__csock.connect(sock)
|
||||||
|
|
||||||
|
def __del__(self):
|
||||||
|
self.close(False)
|
||||||
|
|
||||||
def send(self, msg):
|
def send(self, msg):
|
||||||
# Convert every list member to string
|
# Convert every list member to string
|
||||||
obj = dumps([str(m) for m in msg], HIGHEST_PROTOCOL)
|
obj = dumps([str(m) for m in msg], HIGHEST_PROTOCOL)
|
||||||
self.__csock.send(obj + CSocket.END_STRING)
|
self.__csock.send(obj + CSPROTO.END)
|
||||||
ret = self.receive(self.__csock)
|
return self.receive(self.__csock)
|
||||||
|
|
||||||
|
def close(self, sendEnd=True):
|
||||||
|
if not self.__csock:
|
||||||
|
return
|
||||||
|
if sendEnd:
|
||||||
|
self.__csock.sendall(CSPROTO.CLOSE + CSPROTO.END)
|
||||||
self.__csock.close()
|
self.__csock.close()
|
||||||
return ret
|
self.__csock = None
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def receive(sock):
|
def receive(sock):
|
||||||
msg = EMPTY_BYTES
|
msg = CSPROTO.EMPTY
|
||||||
while msg.rfind(CSocket.END_STRING) == -1:
|
while msg.rfind(CSPROTO.END) == -1:
|
||||||
chunk = sock.recv(6)
|
chunk = sock.recv(6)
|
||||||
if chunk == '':
|
if chunk == '':
|
||||||
raise RuntimeError("socket connection broken")
|
raise RuntimeError("socket connection broken")
|
||||||
|
|
|
@ -29,6 +29,16 @@ import textwrap
|
||||||
##
|
##
|
||||||
# Describes the protocol used to communicate with the server.
|
# Describes the protocol used to communicate with the server.
|
||||||
|
|
||||||
|
class dotdict(dict):
|
||||||
|
def __getattr__(self, name):
|
||||||
|
return self[name]
|
||||||
|
|
||||||
|
CSPROTO = dotdict({
|
||||||
|
"EMPTY": b"",
|
||||||
|
"END": b"<F2B_END_COMMAND>",
|
||||||
|
"CLOSE": b"<F2B_CLOSE_COMMAND>"
|
||||||
|
})
|
||||||
|
|
||||||
protocol = [
|
protocol = [
|
||||||
['', "BASIC", ""],
|
['', "BASIC", ""],
|
||||||
["start", "starts the server and the jails"],
|
["start", "starts the server and the jails"],
|
||||||
|
|
|
@ -33,19 +33,12 @@ import socket
|
||||||
import sys
|
import sys
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
|
from ..protocol import CSPROTO
|
||||||
from ..helpers import getLogger,formatExceptionInfo
|
from ..helpers import getLogger,formatExceptionInfo
|
||||||
|
|
||||||
# Gets the instance of the logger.
|
# Gets the instance of the logger.
|
||||||
logSys = getLogger(__name__)
|
logSys = getLogger(__name__)
|
||||||
|
|
||||||
if sys.version_info >= (3,):
|
|
||||||
# b"" causes SyntaxError in python <= 2.5, so below implements equivalent
|
|
||||||
EMPTY_BYTES = bytes("", encoding="ascii")
|
|
||||||
else:
|
|
||||||
# python 2.x, string type is equivalent to bytes.
|
|
||||||
EMPTY_BYTES = ""
|
|
||||||
|
|
||||||
|
|
||||||
##
|
##
|
||||||
# Request handler class.
|
# Request handler class.
|
||||||
#
|
#
|
||||||
|
@ -54,17 +47,12 @@ else:
|
||||||
|
|
||||||
class RequestHandler(asynchat.async_chat):
|
class RequestHandler(asynchat.async_chat):
|
||||||
|
|
||||||
if sys.version_info >= (3,):
|
|
||||||
END_STRING = bytes("<F2B_END_COMMAND>", encoding="ascii")
|
|
||||||
else:
|
|
||||||
END_STRING = "<F2B_END_COMMAND>"
|
|
||||||
|
|
||||||
def __init__(self, conn, transmitter):
|
def __init__(self, conn, transmitter):
|
||||||
asynchat.async_chat.__init__(self, conn)
|
asynchat.async_chat.__init__(self, conn)
|
||||||
self.__transmitter = transmitter
|
self.__transmitter = transmitter
|
||||||
self.__buffer = []
|
self.__buffer = []
|
||||||
# Sets the terminator.
|
# Sets the terminator.
|
||||||
self.set_terminator(RequestHandler.END_STRING)
|
self.set_terminator(CSPROTO.END)
|
||||||
|
|
||||||
def collect_incoming_data(self, data):
|
def collect_incoming_data(self, data):
|
||||||
#logSys.debug("Received raw data: " + str(data))
|
#logSys.debug("Received raw data: " + str(data))
|
||||||
|
@ -76,16 +64,21 @@ class RequestHandler(asynchat.async_chat):
|
||||||
# This method is called once we have a complete request.
|
# This method is called once we have a complete request.
|
||||||
|
|
||||||
def found_terminator(self):
|
def found_terminator(self):
|
||||||
|
# Pop whole buffer
|
||||||
|
buf = self.__buffer
|
||||||
|
self.__buffer = []
|
||||||
# Joins the buffer items.
|
# Joins the buffer items.
|
||||||
message = loads(EMPTY_BYTES.join(self.__buffer))
|
message = loads(CSPROTO.EMPTY.join(buf))
|
||||||
|
# Closes the channel if close was received
|
||||||
|
if message == CSPROTO.CLOSE:
|
||||||
|
self.close_when_done()
|
||||||
|
return
|
||||||
# Gives the message to the transmitter.
|
# Gives the message to the transmitter.
|
||||||
message = self.__transmitter.proceed(message)
|
message = self.__transmitter.proceed(message)
|
||||||
# Serializes the response.
|
# Serializes the response.
|
||||||
message = dumps(message, HIGHEST_PROTOCOL)
|
message = dumps(message, HIGHEST_PROTOCOL)
|
||||||
# Sends the response to the client.
|
# Sends the response to the client.
|
||||||
self.push(message + RequestHandler.END_STRING)
|
self.push(message + CSPROTO.END)
|
||||||
# Closes the channel.
|
|
||||||
self.close_when_done()
|
|
||||||
|
|
||||||
def handle_error(self):
|
def handle_error(self):
|
||||||
e1, e2 = formatExceptionInfo()
|
e1, e2 = formatExceptionInfo()
|
||||||
|
|
|
@ -25,11 +25,13 @@ __copyright__ = "Copyright (c) 2013 Steven Hiscocks"
|
||||||
__license__ = "GPL"
|
__license__ = "GPL"
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
import sys
|
||||||
import tempfile
|
import tempfile
|
||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
|
from .. import protocol
|
||||||
from ..server.asyncserver import AsyncServer, AsyncServerException
|
from ..server.asyncserver import AsyncServer, AsyncServerException
|
||||||
from ..client.csocket import CSocket
|
from ..client.csocket import CSocket
|
||||||
|
|
||||||
|
@ -63,6 +65,11 @@ class Socket(unittest.TestCase):
|
||||||
testMessage = ["A", "test", "message"]
|
testMessage = ["A", "test", "message"]
|
||||||
self.assertEqual(client.send(testMessage), testMessage)
|
self.assertEqual(client.send(testMessage), testMessage)
|
||||||
|
|
||||||
|
# test close message
|
||||||
|
client.close()
|
||||||
|
# 2nd close does nothing
|
||||||
|
client.close()
|
||||||
|
|
||||||
self.server.stop()
|
self.server.stop()
|
||||||
serverThread.join(1)
|
serverThread.join(1)
|
||||||
self.assertFalse(os.path.exists(self.sock_name))
|
self.assertFalse(os.path.exists(self.sock_name))
|
||||||
|
@ -83,3 +90,17 @@ class Socket(unittest.TestCase):
|
||||||
self.server.stop()
|
self.server.stop()
|
||||||
serverThread.join(1)
|
serverThread.join(1)
|
||||||
self.assertFalse(os.path.exists(self.sock_name))
|
self.assertFalse(os.path.exists(self.sock_name))
|
||||||
|
|
||||||
|
|
||||||
|
class ClientMisc(unittest.TestCase):
|
||||||
|
|
||||||
|
def testPrintFormattedAndWiki(self):
|
||||||
|
# redirect stdout to devnull
|
||||||
|
saved_stdout = sys.stdout
|
||||||
|
sys.stdout = open(os.devnull, 'w')
|
||||||
|
try:
|
||||||
|
protocol.printFormatted()
|
||||||
|
protocol.printWiki()
|
||||||
|
finally:
|
||||||
|
# restore stdout
|
||||||
|
sys.stdout = saved_stdout
|
||||||
|
|
|
@ -126,6 +126,7 @@ def gatherTests(regexps=None, no_network=False):
|
||||||
tests.addTest(unittest.makeSuite(clientreadertestcase.JailsReaderTestCache))
|
tests.addTest(unittest.makeSuite(clientreadertestcase.JailsReaderTestCache))
|
||||||
# CSocket and AsyncServer
|
# CSocket and AsyncServer
|
||||||
tests.addTest(unittest.makeSuite(sockettestcase.Socket))
|
tests.addTest(unittest.makeSuite(sockettestcase.Socket))
|
||||||
|
tests.addTest(unittest.makeSuite(sockettestcase.ClientMisc))
|
||||||
# Misc helpers
|
# Misc helpers
|
||||||
tests.addTest(unittest.makeSuite(misctestcase.HelpersTest))
|
tests.addTest(unittest.makeSuite(misctestcase.HelpersTest))
|
||||||
tests.addTest(unittest.makeSuite(misctestcase.SetupTest))
|
tests.addTest(unittest.makeSuite(misctestcase.SetupTest))
|
||||||
|
|
Loading…
Reference in New Issue