mirror of https://github.com/fail2ban/fail2ban
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
135 lines
4.1 KiB
135 lines
4.1 KiB
# emacs: -*- mode: python; py-indent-offset: 4; indent-tabs-mode: t -*-
|
|
# vi: set ft=python sts=4 ts=4 sw=4 noet :
|
|
|
|
# This file is part of Fail2Ban.
|
|
#
|
|
# Fail2Ban is free software; you can redistribute it and/or modify
|
|
# it under the terms of the GNU General Public License as published by
|
|
# the Free Software Foundation; either version 2 of the License, or
|
|
# (at your option) any later version.
|
|
#
|
|
# Fail2Ban is distributed in the hope that it will be useful,
|
|
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
|
# GNU General Public License for more details.
|
|
#
|
|
# You should have received a copy of the GNU General Public License
|
|
# along with Fail2Ban; if not, write to the Free Software
|
|
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
|
|
|
|
|
|
__author__ = "Yaroslav Halchenko"
|
|
__copyright__ = "Copyright (c) 2013 Yaroslav Halchenko"
|
|
__license__ = "GPL"
|
|
|
|
import unittest, logging, os, re, tempfile, sys, time, traceback
|
|
from os.path import basename, dirname
|
|
from StringIO import StringIO
|
|
|
|
#
|
|
# Following "traceback" functions are adopted from PyMVPA distributed
|
|
# under MIT/Expat and copyright by PyMVPA developers (i.e. me and
|
|
# Michael). Hereby I re-license derivative work on these pieces under GPL
|
|
# to stay in line with the main Fail2Ban license
|
|
#
|
|
def mbasename(s):
|
|
"""Custom function to include directory name if filename is too common
|
|
|
|
Also strip .py at the end
|
|
"""
|
|
base = basename(s)
|
|
if base.endswith('.py'):
|
|
base = base[:-3]
|
|
if base in set(['base', '__init__']):
|
|
base = basename(dirname(s)) + '.' + base
|
|
return base
|
|
|
|
class TraceBack(object):
|
|
"""Customized traceback to be included in debug messages
|
|
"""
|
|
|
|
def __init__(self, compress=False):
|
|
"""Initialize TrackBack metric
|
|
|
|
Parameters
|
|
----------
|
|
compress : bool
|
|
if True then prefix common with previous invocation gets
|
|
replaced with ...
|
|
"""
|
|
self.__prev = ""
|
|
self.__compress = compress
|
|
|
|
def __call__(self):
|
|
ftb = traceback.extract_stack(limit=100)[:-2]
|
|
entries = [[mbasename(x[0]), dirname(x[0]), str(x[1])] for x in ftb]
|
|
entries = [ [e[0], e[2]] for e in entries
|
|
if not (e[0] in ['unittest', 'logging.__init__']
|
|
or e[1].endswith('/unittest'))]
|
|
|
|
# lets make it more concise
|
|
entries_out = [entries[0]]
|
|
for entry in entries[1:]:
|
|
if entry[0] == entries_out[-1][0]:
|
|
entries_out[-1][1] += ',%s' % entry[1]
|
|
else:
|
|
entries_out.append(entry)
|
|
sftb = '>'.join(['%s:%s' % (mbasename(x[0]),
|
|
x[1]) for x in entries_out])
|
|
if self.__compress:
|
|
# lets remove part which is common with previous invocation
|
|
prev_next = sftb
|
|
common_prefix = os.path.commonprefix((self.__prev, sftb))
|
|
common_prefix2 = re.sub('>[^>]*$', '', common_prefix)
|
|
|
|
if common_prefix2 != "":
|
|
sftb = '...' + sftb[len(common_prefix2):]
|
|
self.__prev = prev_next
|
|
|
|
return sftb
|
|
|
|
class FormatterWithTraceBack(logging.Formatter):
|
|
"""Custom formatter which expands %(tb) and %(tbc) with tracebacks
|
|
|
|
TODO: might need locking in case of compressed tracebacks
|
|
"""
|
|
def __init__(self, fmt, *args, **kwargs):
|
|
logging.Formatter.__init__(self, fmt=fmt, *args, **kwargs)
|
|
compress = '%(tbc)s' in fmt
|
|
self._tb = TraceBack(compress=compress)
|
|
|
|
def format(self, record):
|
|
record.tbc = record.tb = self._tb()
|
|
return logging.Formatter.format(self, record)
|
|
|
|
def mtimesleep():
|
|
# no sleep now should be necessary since polling tracks now not only
|
|
# mtime but also ino and size
|
|
pass
|
|
|
|
class LogCaptureTestCase(unittest.TestCase):
|
|
|
|
def setUp(self):
|
|
|
|
# For extended testing of what gets output into logging
|
|
# system, we will redirect it to a string
|
|
logSys = logging.getLogger("fail2ban")
|
|
|
|
# Keep old settings
|
|
self._old_level = logSys.level
|
|
self._old_handlers = logSys.handlers
|
|
# Let's log everything into a string
|
|
self._log = StringIO()
|
|
logSys.handlers = [logging.StreamHandler(self._log)]
|
|
logSys.setLevel(getattr(logging, 'DEBUG'))
|
|
|
|
def tearDown(self):
|
|
"""Call after every test case."""
|
|
# print "O: >>%s<<" % self._log.getvalue()
|
|
logSys = logging.getLogger("fail2ban")
|
|
logSys.handlers = self._old_handlers
|
|
logSys.level = self._old_level
|
|
|
|
def _is_logged(self, s):
|
|
return s in self._log.getvalue()
|