Merge pull request #154 from fail2ban/_tent/fixup_tests_racing

Multiple ENHs + fixup tests racing. fixes #103
pull/155/merge
Yaroslav Halchenko 2013-03-27 05:42:44 -07:00
commit 01b4870adc
5 changed files with 159 additions and 31 deletions

View File

@ -36,6 +36,8 @@ from testcases import servertestcase
from testcases import datedetectortestcase from testcases import datedetectortestcase
from testcases import actiontestcase from testcases import actiontestcase
from testcases import sockettestcase from testcases import sockettestcase
from testcases.utils import FormatterWithTraceBack
from server.mytime import MyTime from server.mytime import MyTime
from optparse import OptionParser, Option from optparse import OptionParser, Option
@ -52,12 +54,14 @@ def get_opt_parser():
choices=('debug', 'info', 'warn', 'error', 'fatal'), choices=('debug', 'info', 'warn', 'error', 'fatal'),
default=None, default=None,
help="Log level for the logger to use during running tests"), help="Log level for the logger to use during running tests"),
])
p.add_options([
Option('-n', "--no-network", action="store_true", Option('-n', "--no-network", action="store_true",
dest="no_network", dest="no_network",
help="Do not run tests that require the network"), help="Do not run tests that require the network"),
Option("-t", "--log-traceback", action='store_true',
help="Enrich log-messages with compressed tracebacks"),
Option("--full-traceback", action='store_true',
help="Either to make the tracebacks full, not compressed (as by default)"),
]) ])
return p return p
@ -89,12 +93,21 @@ else: # pragma: no cover
# Add the default logging handler # Add the default logging handler
stdout = logging.StreamHandler(sys.stdout) stdout = logging.StreamHandler(sys.stdout)
fmt = ' %(message)s'
if opts.log_traceback:
Formatter = FormatterWithTraceBack
fmt = (opts.full_traceback and ' %(tb)s' or ' %(tbc)s') + fmt
else:
Formatter = logging.Formatter
# Custom log format for the verbose tests runs # Custom log format for the verbose tests runs
if verbosity > 1: # pragma: no cover if verbosity > 1: # pragma: no cover
stdout.setFormatter(logging.Formatter(' %(asctime)-15s %(thread)s %(message)s')) stdout.setFormatter(Formatter(' %(asctime)-15s %(thread)s' + fmt))
else: # pragma: no cover else: # pragma: no cover
# just prefix with the space # just prefix with the space
stdout.setFormatter(logging.Formatter(' %(message)s')) stdout.setFormatter(Formatter(fmt))
logSys.addHandler(stdout) logSys.addHandler(stdout)
# #

View File

@ -105,9 +105,17 @@ class FailManager:
fData.setLastReset(unixTime) fData.setLastReset(unixTime)
fData.setLastTime(unixTime) fData.setLastTime(unixTime)
self.__failList[ip] = fData self.__failList[ip] = fData
logSys.debug("Currently have failures from %d IPs: %s"
% (len(self.__failList), self.__failList.keys()))
self.__failTotal += 1 self.__failTotal += 1
if logSys.getEffectiveLevel() <= logging.DEBUG:
# yoh: Since composing this list might be somewhat time consuming
# in case of having many active failures, it should be ran only
# if debug level is "low" enough
failures_summary = ', '.join(['%s:%d' % (k, v.getRetry())
for k,v in self.__failList.iteritems()])
logSys.debug("Total # of detected failures: %d. Current failures from %d IPs (IP:count): %s"
% (self.__failTotal, len(self.__failList), failures_summary))
finally: finally:
self.__lock.release() self.__lock.release()

View File

@ -63,16 +63,17 @@ class FilterPyinotify(FileFilter):
logSys.debug("Created FilterPyinotify") logSys.debug("Created FilterPyinotify")
def callback(self, event): def callback(self, event, origin=''):
logSys.debug("%sCallback for Event: %s", origin, event)
path = event.pathname path = event.pathname
if event.mask & pyinotify.IN_CREATE: if event.mask & pyinotify.IN_CREATE:
# skip directories altogether # skip directories altogether
if event.mask & pyinotify.IN_ISDIR: if event.mask & pyinotify.IN_ISDIR:
logSys.debug("Ignoring creation of directory %s" % path) logSys.debug("Ignoring creation of directory %s", path)
return return
# check if that is a file we care about # check if that is a file we care about
if not path in self.__watches: if not path in self.__watches:
logSys.debug("Ignoring creation of %s we do not monitor" % path) logSys.debug("Ignoring creation of %s we do not monitor", path)
return return
else: else:
# we need to substitute the watcher with a new one, so first # we need to substitute the watcher with a new one, so first
@ -104,8 +105,8 @@ class FilterPyinotify(FileFilter):
def _addFileWatcher(self, path): def _addFileWatcher(self, path):
wd = self.__monitor.add_watch(path, pyinotify.IN_MODIFY) wd = self.__monitor.add_watch(path, pyinotify.IN_MODIFY)
self.__watches.update(wd) self.__watches.update(wd)
logSys.debug("Added file watcher for %s" % path) logSys.debug("Added file watcher for %s", path)
# process the file since we did get even # process the file since we did get even
self._process_file(path) self._process_file(path)
@ -114,7 +115,7 @@ class FilterPyinotify(FileFilter):
wd = self.__monitor.rm_watch(wdInt) wd = self.__monitor.rm_watch(wdInt)
if wd[wdInt]: if wd[wdInt]:
del self.__watches[path] del self.__watches[path]
logSys.debug("Removed file watcher for %s" % path) logSys.debug("Removed file watcher for %s", path)
return True return True
else: else:
return False return False
@ -130,7 +131,7 @@ class FilterPyinotify(FileFilter):
# we need to watch also the directory for IN_CREATE # we need to watch also the directory for IN_CREATE
self.__watches.update( self.__watches.update(
self.__monitor.add_watch(path_dir, pyinotify.IN_CREATE)) self.__monitor.add_watch(path_dir, pyinotify.IN_CREATE))
logSys.debug("Added monitor for the parent directory %s" % path_dir) logSys.debug("Added monitor for the parent directory %s", path_dir)
self._addFileWatcher(path) self._addFileWatcher(path)
@ -151,7 +152,7 @@ class FilterPyinotify(FileFilter):
# since there is no other monitored file under this directory # since there is no other monitored file under this directory
wdInt = self.__watches.pop(path_dir) wdInt = self.__watches.pop(path_dir)
_ = self.__monitor.rm_watch(wdInt) _ = self.__monitor.rm_watch(wdInt)
logSys.debug("Removed monitor for the parent directory %s" % path_dir) logSys.debug("Removed monitor for the parent directory %s", path_dir)
## ##
@ -165,7 +166,7 @@ class FilterPyinotify(FileFilter):
self.__notifier = pyinotify.ThreadedNotifier(self.__monitor, self.__notifier = pyinotify.ThreadedNotifier(self.__monitor,
ProcessPyinotify(self)) ProcessPyinotify(self))
self.__notifier.start() self.__notifier.start()
logSys.debug("pyinotifier started for %s." % self.jail.getName()) logSys.debug("pyinotifier started for %s.", self.jail.getName())
# TODO: verify that there is nothing really to be done for # TODO: verify that there is nothing really to be done for
# idle jails # idle jails
return True return True
@ -201,5 +202,4 @@ class ProcessPyinotify(pyinotify.ProcessEvent):
# just need default, since using mask on watch to limit events # just need default, since using mask on watch to limit events
def process_default(self, event): def process_default(self, event):
logSys.debug("Callback for Event: %s" % event) self.__FileFilter.callback(event, origin='Default ')
self.__FileFilter.callback(event)

View File

@ -105,20 +105,23 @@ def _copy_lines_between_files(fin, fout, n=None, skip=0, mode='a', terminal_line
time.sleep(1) time.sleep(1)
if isinstance(fin, str): # pragma: no branch - only used with str in test cases if isinstance(fin, str): # pragma: no branch - only used with str in test cases
fin = open(fin, 'r') fin = open(fin, 'r')
if isinstance(fout, str):
fout = open(fout, mode)
# Skip # Skip
for i in xrange(skip): for i in xrange(skip):
_ = fin.readline() _ = fin.readline()
# Read/Write # Read
i = 0 i = 0
lines = []
while n is None or i < n: while n is None or i < n:
l = fin.readline() l = fin.readline()
if terminal_line is not None and l == terminal_line: if terminal_line is not None and l == terminal_line:
break break
fout.write(l) lines.append(l)
fout.flush()
i += 1 i += 1
# Write: all at once and flush
if isinstance(fout, str):
fout = open(fout, mode)
fout.write('\n'.join(lines))
fout.flush()
# to give other threads possibly some time to crunch # to give other threads possibly some time to crunch
time.sleep(0.1) time.sleep(0.1)
return fout return fout
@ -324,11 +327,15 @@ def get_monitor_failures_testcase(Filter_):
"""Generator of TestCase's for different filters/backends """Generator of TestCase's for different filters/backends
""" """
_, testclass_name = tempfile.mkstemp('fail2ban', 'monitorfailures')
class MonitorFailures(unittest.TestCase): class MonitorFailures(unittest.TestCase):
count = 0
def setUp(self): def setUp(self):
"""Call before every test case.""" """Call before every test case."""
self.filter = self.name = 'NA' self.filter = self.name = 'NA'
_, self.name = tempfile.mkstemp('fail2ban', 'monitorfailures') self.name = '%s-%d' % (testclass_name, self.count)
MonitorFailures.count += 1 # so we have unique filenames across tests
self.file = open(self.name, 'a') self.file = open(self.name, 'a')
self.jail = DummyJail() self.jail = DummyJail()
self.filter = Filter_(self.jail) self.filter = Filter_(self.jail)
@ -351,12 +358,9 @@ def get_monitor_failures_testcase(Filter_):
self.filter.join() # wait for the thread to terminate self.filter.join() # wait for the thread to terminate
#print "D: KILLING THE FILE" #print "D: KILLING THE FILE"
_killfile(self.file, self.name) _killfile(self.file, self.name)
#time.sleep(0.2) # Give FS time to ack the removal
pass pass
def __str__(self): # pragma: no cover - will only show up if unexpected exception is thrown
return "MonitorFailures%s(%s)" \
% (Filter_, hasattr(self, 'name') and self.name or 'tempfile')
def isFilled(self, delay=2.): def isFilled(self, delay=2.):
"""Wait up to `delay` sec to assure that it was modified or not """Wait up to `delay` sec to assure that it was modified or not
""" """
@ -431,9 +435,10 @@ def get_monitor_failures_testcase(Filter_):
# if we move file into a new location while it has been open already # if we move file into a new location while it has been open already
self.file = _copy_lines_between_files(GetFailures.FILENAME_01, self.name, self.file = _copy_lines_between_files(GetFailures.FILENAME_01, self.name,
n=14, mode='w') n=14, mode='w')
self.assertTrue(self.isEmpty(2)) # Poll might need more time
self.assertTrue(self.isEmpty(2 + int(isinstance(self.filter, FilterPoll))*4))
self.assertRaises(FailManagerEmpty, self.filter.failManager.toBan) self.assertRaises(FailManagerEmpty, self.filter.failManager.toBan)
self.assertEqual(self.filter.failManager.getFailTotal(), 2) # Fails with Poll from time to time self.assertEqual(self.filter.failManager.getFailTotal(), 2)
# move aside, but leaving the handle still open... # move aside, but leaving the handle still open...
os.rename(self.name, self.name + '.bak') os.rename(self.name, self.name + '.bak')
@ -488,7 +493,8 @@ def get_monitor_failures_testcase(Filter_):
# yoh: not sure why count here is not 9... TODO # yoh: not sure why count here is not 9... TODO
self.assert_correct_last_attempt(GetFailures.FAILURES_01)#, count=9) self.assert_correct_last_attempt(GetFailures.FAILURES_01)#, count=9)
MonitorFailures.__name__ = "MonitorFailures<%s>(%s)" \
% (Filter_.__name__, testclass_name) # 'tempfile')
return MonitorFailures return MonitorFailures

101
testcases/utils.py Normal file
View File

@ -0,0 +1,101 @@
# emacs: -*- mode: python; py-indent-offset: 4; indent-tabs-mode: t -*-
# vi: set ft=python sts=4 ts=4 sw=4 noet :
# This file is part of Fail2Ban.
#
# Fail2Ban is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 2 of the License, or
# (at your option) any later version.
#
# Fail2Ban is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with Fail2Ban; if not, write to the Free Software
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
__author__ = "Yaroslav Halchenko"
__copyright__ = "Copyright (c) 2013 Yaroslav Halchenko"
__license__ = "GPL"
import logging, os, re, traceback
from os.path import basename, dirname
#
# Following "traceback" functions are adopted from PyMVPA distributed
# under MIT/Expat and copyright by PyMVPA developers (i.e. me and
# Michael). Hereby I re-license derivative work on these pieces under GPL
# to stay in line with the main Fail2Ban license
#
def mbasename(s):
"""Custom function to include directory name if filename is too common
Also strip .py at the end
"""
base = basename(s)
if base.endswith('.py'):
base = base[:-3]
if base in set(['base', '__init__']):
base = basename(dirname(s)) + '.' + base
return base
class TraceBack(object):
"""Customized traceback to be included in debug messages
"""
def __init__(self, compress=False):
"""Initialize TrackBack metric
Parameters
----------
compress : bool
if True then prefix common with previous invocation gets
replaced with ...
"""
self.__prev = ""
self.__compress = compress
def __call__(self):
ftb = traceback.extract_stack(limit=100)[:-2]
entries = [[mbasename(x[0]), str(x[1])] for x in ftb]
entries = [ e for e in entries
if not e[0] in ['unittest', 'logging.__init__' ]]
# lets make it more consize
entries_out = [entries[0]]
for entry in entries[1:]:
if entry[0] == entries_out[-1][0]:
entries_out[-1][1] += ',%s' % entry[1]
else:
entries_out.append(entry)
sftb = '>'.join(['%s:%s' % (mbasename(x[0]),
x[1]) for x in entries_out])
if self.__compress:
# lets remove part which is common with previous invocation
prev_next = sftb
common_prefix = os.path.commonprefix((self.__prev, sftb))
common_prefix2 = re.sub('>[^>]*$', '', common_prefix)
if common_prefix2 != "":
sftb = '...' + sftb[len(common_prefix2):]
self.__prev = prev_next
return sftb
class FormatterWithTraceBack(logging.Formatter):
"""Custom formatter which expands %(tb) and %(tbc) with tracebacks
TODO: might need locking in case of compressed tracebacks
"""
def __init__(self, fmt, *args, **kwargs):
logging.Formatter.__init__(self, fmt=fmt, *args, **kwargs)
compress = '%(tbc)s' in fmt
self._tb = TraceBack(compress=compress)
def format(self, record):
record.tbc = record.tb = self._tb()
return logging.Formatter.format(self, record)