diff --git a/ChangeLog b/ChangeLog index e7dcefdd..b38b6f78 100644 --- a/ChangeLog +++ b/ChangeLog @@ -26,6 +26,8 @@ ver. 0.9.3 (2015/XX/XXX) - wanna-be-released - Added regex to work with 'userlogins' log * action.d/sendmail*.conf - use LC_ALL (superseeding LC_TIME) to override locale on systems with customized LC_ALL + * performance fix: minimizes connection overhead, close socket only at + communication end (gh-1099) - New Features: * New filters: diff --git a/bin/fail2ban-client b/bin/fail2ban-client index ada4b376..7f3f5639 100755 --- a/bin/fail2ban-client +++ b/bin/fail2ban-client @@ -153,30 +153,36 @@ class Fail2banClient: return self.__processCmd([["ping"]], False) def __processCmd(self, cmd, showRet = True): - beautifier = Beautifier() - streamRet = True - for c in cmd: - beautifier.setInputCmd(c) - try: - client = CSocket(self.__conf["socket"]) - ret = client.send(c) - if ret[0] == 0: - logSys.debug("OK : " + `ret[1]`) + client = None + try: + beautifier = Beautifier() + streamRet = True + for c in cmd: + beautifier.setInputCmd(c) + try: + if not client: + client = CSocket(self.__conf["socket"]) + ret = client.send(c) + if ret[0] == 0: + logSys.debug("OK : " + `ret[1]`) + if showRet: + print beautifier.beautify(ret[1]) + else: + logSys.error("NOK: " + `ret[1].args`) + if showRet: + print beautifier.beautifyError(ret[1]) + streamRet = False + except socket.error: if showRet: - print beautifier.beautify(ret[1]) - else: - logSys.error("NOK: " + `ret[1].args`) + self.__logSocketError() + return False + except Exception, e: if showRet: - print beautifier.beautifyError(ret[1]) - streamRet = False - except socket.error: - if showRet: - self.__logSocketError() - return False - except Exception, e: - if showRet: - logSys.error(e) - return False + logSys.error(e) + return False + finally: + if client: + client.close() return streamRet def __logSocketError(self): diff --git a/fail2ban/client/csocket.py b/fail2ban/client/csocket.py index 9ac0eff1..2e22e5ee 100644 --- a/fail2ban/client/csocket.py +++ b/fail2ban/client/csocket.py @@ -26,43 +26,40 @@ __license__ = "GPL" #from cPickle import dumps, loads, HIGHEST_PROTOCOL from pickle import dumps, loads, HIGHEST_PROTOCOL +from ..protocol import CSPROTO import socket import sys -if sys.version_info >= (3,): - # b"" causes SyntaxError in python <= 2.5, so below implements equivalent - EMPTY_BYTES = bytes("", encoding="ascii") -else: - # python 2.x, string type is equivalent to bytes. - EMPTY_BYTES = "" - - class CSocket: - if sys.version_info >= (3,): - END_STRING = bytes("", encoding='ascii') - else: - END_STRING = "" - - def __init__(self, sock = "/var/run/fail2ban/fail2ban.sock"): + def __init__(self, sock="/var/run/fail2ban/fail2ban.sock"): # Create an INET, STREAMing socket #self.csock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) self.__csock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) #self.csock.connect(("localhost", 2222)) self.__csock.connect(sock) + + def __del__(self): + self.close(False) def send(self, msg): # Convert every list member to string obj = dumps([str(m) for m in msg], HIGHEST_PROTOCOL) - self.__csock.send(obj + CSocket.END_STRING) - ret = self.receive(self.__csock) + self.__csock.send(obj + CSPROTO.END) + return self.receive(self.__csock) + + def close(self, sendEnd=True): + if not self.__csock: + return + if sendEnd: + self.__csock.sendall(CSPROTO.CLOSE + CSPROTO.END) self.__csock.close() - return ret + self.__csock = None @staticmethod def receive(sock): - msg = EMPTY_BYTES - while msg.rfind(CSocket.END_STRING) == -1: + msg = CSPROTO.EMPTY + while msg.rfind(CSPROTO.END) == -1: chunk = sock.recv(6) if chunk == '': raise RuntimeError("socket connection broken") diff --git a/fail2ban/protocol.py b/fail2ban/protocol.py index 9b690067..2cace91f 100644 --- a/fail2ban/protocol.py +++ b/fail2ban/protocol.py @@ -29,6 +29,16 @@ import textwrap ## # Describes the protocol used to communicate with the server. +class dotdict(dict): + def __getattr__(self, name): + return self[name] + +CSPROTO = dotdict({ + "EMPTY": b"", + "END": b"", + "CLOSE": b"" +}) + protocol = [ ['', "BASIC", ""], ["start", "starts the server and the jails"], diff --git a/fail2ban/server/asyncserver.py b/fail2ban/server/asyncserver.py index 4a8bc987..4caa702f 100644 --- a/fail2ban/server/asyncserver.py +++ b/fail2ban/server/asyncserver.py @@ -33,19 +33,12 @@ import socket import sys import traceback +from ..protocol import CSPROTO from ..helpers import getLogger,formatExceptionInfo # Gets the instance of the logger. logSys = getLogger(__name__) -if sys.version_info >= (3,): - # b"" causes SyntaxError in python <= 2.5, so below implements equivalent - EMPTY_BYTES = bytes("", encoding="ascii") -else: - # python 2.x, string type is equivalent to bytes. - EMPTY_BYTES = "" - - ## # Request handler class. # @@ -54,17 +47,12 @@ else: class RequestHandler(asynchat.async_chat): - if sys.version_info >= (3,): - END_STRING = bytes("", encoding="ascii") - else: - END_STRING = "" - def __init__(self, conn, transmitter): asynchat.async_chat.__init__(self, conn) self.__transmitter = transmitter self.__buffer = [] # Sets the terminator. - self.set_terminator(RequestHandler.END_STRING) + self.set_terminator(CSPROTO.END) def collect_incoming_data(self, data): #logSys.debug("Received raw data: " + str(data)) @@ -76,16 +64,21 @@ class RequestHandler(asynchat.async_chat): # This method is called once we have a complete request. def found_terminator(self): + # Pop whole buffer + buf = self.__buffer + self.__buffer = [] # Joins the buffer items. - message = loads(EMPTY_BYTES.join(self.__buffer)) + message = loads(CSPROTO.EMPTY.join(buf)) + # Closes the channel if close was received + if message == CSPROTO.CLOSE: + self.close_when_done() + return # Gives the message to the transmitter. message = self.__transmitter.proceed(message) # Serializes the response. message = dumps(message, HIGHEST_PROTOCOL) # Sends the response to the client. - self.push(message + RequestHandler.END_STRING) - # Closes the channel. - self.close_when_done() + self.push(message + CSPROTO.END) def handle_error(self): e1, e2 = formatExceptionInfo() diff --git a/fail2ban/tests/sockettestcase.py b/fail2ban/tests/sockettestcase.py index 01e72847..8eeb7b51 100644 --- a/fail2ban/tests/sockettestcase.py +++ b/fail2ban/tests/sockettestcase.py @@ -25,11 +25,13 @@ __copyright__ = "Copyright (c) 2013 Steven Hiscocks" __license__ = "GPL" import os +import sys import tempfile import threading import time import unittest +from .. import protocol from ..server.asyncserver import AsyncServer, AsyncServerException from ..client.csocket import CSocket @@ -63,6 +65,11 @@ class Socket(unittest.TestCase): testMessage = ["A", "test", "message"] self.assertEqual(client.send(testMessage), testMessage) + # test close message + client.close() + # 2nd close does nothing + client.close() + self.server.stop() serverThread.join(1) self.assertFalse(os.path.exists(self.sock_name)) @@ -83,3 +90,17 @@ class Socket(unittest.TestCase): self.server.stop() serverThread.join(1) self.assertFalse(os.path.exists(self.sock_name)) + + +class ClientMisc(unittest.TestCase): + + def testPrintFormattedAndWiki(self): + # redirect stdout to devnull + saved_stdout = sys.stdout + sys.stdout = open(os.devnull, 'w') + try: + protocol.printFormatted() + protocol.printWiki() + finally: + # restore stdout + sys.stdout = saved_stdout diff --git a/fail2ban/tests/utils.py b/fail2ban/tests/utils.py index bf992024..89539107 100644 --- a/fail2ban/tests/utils.py +++ b/fail2ban/tests/utils.py @@ -126,6 +126,7 @@ def gatherTests(regexps=None, no_network=False): tests.addTest(unittest.makeSuite(clientreadertestcase.JailsReaderTestCache)) # CSocket and AsyncServer tests.addTest(unittest.makeSuite(sockettestcase.Socket)) + tests.addTest(unittest.makeSuite(sockettestcase.ClientMisc)) # Misc helpers tests.addTest(unittest.makeSuite(misctestcase.HelpersTest)) tests.addTest(unittest.makeSuite(misctestcase.SetupTest))