start observer together with the server (parametrized to prevent constantly start/stop of observer by addJail in test cases)

pull/1460/head
sebres 2015-12-29 18:59:33 +01:00
parent 9d4f163e88
commit b3d4ce291e
3 changed files with 26 additions and 9 deletions

View File

@ -52,7 +52,7 @@ except ImportError: # pragma: no cover
class Server:
def __init__(self, daemon = False):
def __init__(self, daemon=False):
self.__loggingLock = Lock()
self.__lock = RLock()
self.__jails = Jails()
@ -81,7 +81,7 @@ class Server:
logSys.debug("Caught signal %d. Flushing logs" % signum)
self.flushLogs()
def start(self, sock, pidfile, force = False):
def start(self, sock, pidfile, force=False, observer=True):
logSys.info("Starting Fail2ban v%s", version.version)
# Install signal handlers
@ -112,6 +112,12 @@ class Server:
except IOError, e:
logSys.error("Unable to create PID file: %s" % e)
# Create observers and start it:
if observer:
if Observers.Main is None:
Observers.Main = ObserverThread()
Observers.Main.start()
# Start the communication
logSys.debug("Starting communication")
try:
@ -150,15 +156,10 @@ class Server:
self.__loggingLock.release()
def addJail(self, name, backend):
# Create an observer if not yet created and start it:
if Observers.Main is None:
Observers.Main = ObserverThread()
Observers.Main.start()
# Add jail hereafter:
self.__jails.add(name, backend, self.__db)
if self.__db is not None:
self.__db.addJail(self.__jails[name])
Observers.Main.db_set(self.__db)
def delJail(self, name):
if self.__db is not None:
@ -541,6 +542,8 @@ class Server:
logSys.error(
"Unable to import fail2ban database module as sqlite "
"is not available.")
if Observers.Main is not None:
Observers.Main.db_set(self.__db)
def getDatabase(self):
return self.__db

View File

@ -71,7 +71,7 @@ class TransmitterBase(unittest.TestCase):
'fail2ban.pid', 'transmitter')
os.close(pidfile_fd)
self.tmp_files.append(pidfile_name)
self.server.start(sock_name, pidfile_name, force=False)
self.server.start(sock_name, pidfile_name, **self.server_start_args)
self.jailName = "TestJail1"
self.server.addJail(self.jailName, FAST_BACKEND)
@ -160,6 +160,7 @@ class Transmitter(TransmitterBase):
def setUp(self):
self.server = TestServer()
self.server_start_args = {'force':False, 'observer':False}
super(Transmitter, self).setUp()
def testStopServer(self):
@ -795,6 +796,7 @@ class TransmitterLogging(TransmitterBase):
self.server.setLogTarget("/dev/null")
self.server.setLogLevel("CRITICAL")
self.server.setSyslogSocket("auto")
self.server_start_args = {'force':False, 'observer':False}
super(TransmitterLogging, self).setUp()
def testLogTarget(self):
@ -912,6 +914,18 @@ class TransmitterLogging(TransmitterBase):
self.setGetTest("bantime.multipliers", "1 5 30 60 300 720 1440 2880", "1 5 30 60 300 720 1440 2880", jail=self.jailName)
self.setGetTest("bantime.overalljails", "true", "true", jail=self.jailName)
class TransmitterWithObserver(TransmitterBase):
def setUp(self):
self.server = TestServer()
self.server_start_args = {'force':False, 'observer':True}
super(TransmitterWithObserver, self).setUp()
def testObserver(self):
pass
class JailTests(unittest.TestCase):
def testLongName(self):

View File

@ -167,8 +167,8 @@ def gatherTests(regexps=None, opts=None):
tests = FilteredTestSuite()
# Server
#tests.addTest(unittest.makeSuite(servertestcase.StartStop))
tests.addTest(unittest.makeSuite(servertestcase.Transmitter))
tests.addTest(unittest.makeSuite(servertestcase.TransmitterWithObserver))
tests.addTest(unittest.makeSuite(servertestcase.JailTests))
tests.addTest(unittest.makeSuite(servertestcase.RegexTests))
tests.addTest(unittest.makeSuite(servertestcase.LoggingTests))