Merge branch '0.9/gh-1492'

pull/1486/merge
sebres 2016-08-01 14:45:05 +02:00
commit 0083036b5f
9 changed files with 204 additions and 47 deletions

View File

@ -293,8 +293,12 @@ class Fail2BanDb(object):
Jail to be added to the database.
"""
cur.execute(
"INSERT OR REPLACE INTO jails(name, enabled) VALUES(?, 1)",
"INSERT OR IGNORE INTO jails(name, enabled) VALUES(?, 1)",
(jail.name,))
if cur.rowcount <= 0:
cur.execute(
"UPDATE jails SET enabled = 1 WHERE name = ? AND enabled != 1",
(jail.name,))
@commitandrollback
def delJail(self, cur, jail):
@ -317,7 +321,7 @@ class Fail2BanDb(object):
cur.execute("UPDATE jails SET enabled=0")
@commitandrollback
def getJailNames(self, cur):
def getJailNames(self, cur, enabled=None):
"""Get name of jails in database.
Currently only used for testing purposes.
@ -327,7 +331,11 @@ class Fail2BanDb(object):
set
Set of jail names.
"""
cur.execute("SELECT name FROM jails")
if enabled is None:
cur.execute("SELECT name FROM jails")
else:
cur.execute("SELECT name FROM jails WHERE enabled=%s" %
(int(enabled),))
return set(row[0] for row in cur.fetchmany())
@commitandrollback

View File

@ -49,7 +49,7 @@ if sys.version_info >= (2,7):
def testCategory(self):
categories = self.action.getCategories()
self.assertTrue("ssh" in categories)
self.assertIn("ssh", categories)
self.assertTrue(len(categories) >= 10)
self.assertRaises(

View File

@ -101,21 +101,21 @@ class SMTPActionTest(unittest.TestCase):
self.assertEqual(self.smtpd.rcpttos, ["root"])
subject = "Subject: [Fail2Ban] %s: banned %s" % (
self.jail.name, aInfo['ip'])
self.assertTrue(subject in self.smtpd.data.replace("\n", ""))
self.assertIn(subject, self.smtpd.data.replace("\n", ""))
self.assertTrue(
"%i attempts" % aInfo['failures'] in self.smtpd.data)
self.action.matches = "matches"
self.action.ban(aInfo)
self.assertTrue(aInfo['matches'] in self.smtpd.data)
self.assertIn(aInfo['matches'], self.smtpd.data)
self.action.matches = "ipjailmatches"
self.action.ban(aInfo)
self.assertTrue(aInfo['ipjailmatches'] in self.smtpd.data)
self.assertIn(aInfo['ipjailmatches'], self.smtpd.data)
self.action.matches = "ipmatches"
self.action.ban(aInfo)
self.assertTrue(aInfo['ipmatches'] in self.smtpd.data)
self.assertIn(aInfo['ipmatches'], self.smtpd.data)
def testOptions(self):
self.action.start()

View File

@ -65,12 +65,12 @@ class ExecuteActions(LogCaptureTestCase):
def testActionsManipulation(self):
self.__actions.add('test')
self.assertTrue(self.__actions['test'])
self.assertTrue('test' in self.__actions)
self.assertFalse('nonexistant action' in self.__actions)
self.assertIn('test', self.__actions)
self.assertNotIn('nonexistant action', self.__actions)
self.__actions.add('test1')
del self.__actions['test']
del self.__actions['test1']
self.assertFalse('test' in self.__actions)
self.assertNotIn('test', self.__actions)
self.assertEqual(len(self.__actions), 0)
self.__actions.setBanTime(127)

View File

@ -525,12 +525,12 @@ class JailsReaderTest(LogCaptureTestCase):
self.assertTrue(actionReader.read())
actionReader.getOptions({}) # populate _opts
if not actionName.endswith('-common'):
self.assertTrue('Definition' in actionReader.sections(),
self.assertIn('Definition', actionReader.sections(),
msg="Action file %r is lacking [Definition] section" % actionConfig)
# all must have some actionban defined
self.assertTrue(actionReader._opts.get('actionban', '').strip(),
msg="Action file %r is lacking actionban" % actionConfig)
self.assertTrue('Init' in actionReader.sections(),
self.assertIn('Init', actionReader.sections(),
msg="Action file %r is lacking [Init] section" % actionConfig)
def testReadStockJailConf(self):
@ -582,7 +582,7 @@ class JailsReaderTest(LogCaptureTestCase):
self.assertTrue(len(actName))
self.assertTrue(isinstance(actOpt, dict))
if actName == 'iptables-multiport':
self.assertTrue('port' in actOpt)
self.assertIn('port', actOpt)
actionReader = ActionReader(
actName, jail, {}, basedir=CONFIG_DIR)
@ -632,11 +632,13 @@ class JailsReaderTest(LogCaptureTestCase):
# and we know even some of them by heart
for j in ['sshd', 'recidive']:
# by default we have 'auto' backend ATM
self.assertTrue(['add', j, 'auto'] in comm_commands)
# by default we have 'auto' backend ATM, but some distributions can overwrite it,
# (e.g. fedora default is 'systemd') therefore let check it without backend...
self.assertIn(['add', j],
(cmd[:2] for cmd in comm_commands if len(cmd) == 3 and cmd[0] == 'add'))
# and warn on useDNS
self.assertTrue(['set', j, 'usedns', 'warn'] in comm_commands)
self.assertTrue(['start', j] in comm_commands)
self.assertIn(['set', j, 'usedns', 'warn'], comm_commands)
self.assertIn(['start', j], comm_commands)
# last commands should be the 'start' commands
self.assertEqual(comm_commands[-1][0], 'start')
@ -655,7 +657,7 @@ class JailsReaderTest(LogCaptureTestCase):
action_name = action.getName()
if '<blocktype>' in str(commands):
# Verify that it is among cInfo
self.assertTrue('blocktype' in action._initOpts)
self.assertIn('blocktype', action._initOpts)
# Verify that we have a call to set it up
blocktype_present = False
target_command = ['set', jail_name, 'action', action_name, 'blocktype']

View File

@ -123,7 +123,7 @@ class DatabaseTest(LogCaptureTestCase):
self.db.addLog(self.jail, self.fileContainer)
self.assertTrue(filename in self.db.getLogPaths(self.jail))
self.assertIn(filename, self.db.getLogPaths(self.jail))
os.remove(filename)
def testUpdateLog(self):
@ -318,6 +318,25 @@ class DatabaseTest(LogCaptureTestCase):
actions._Actions__checkBan()
self.assertLogged("ban ainfo %s, %s, %s, %s" % (True, True, True, True))
def testDelAndAddJail(self):
self.testAddJail() # Add jail
# Delete jail (just disabled it):
self.db.delJail(self.jail)
jails = self.db.getJailNames()
self.assertIn(len(jails) == 1 and self.jail.name, jails)
jails = self.db.getJailNames(enabled=False)
self.assertIn(len(jails) == 1 and self.jail.name, jails)
jails = self.db.getJailNames(enabled=True)
self.assertTrue(len(jails) == 0)
# Add it again - should just enable it:
self.db.addJail(self.jail)
jails = self.db.getJailNames()
self.assertIn(len(jails) == 1 and self.jail.name, jails)
jails = self.db.getJailNames(enabled=True)
self.assertIn(len(jails) == 1 and self.jail.name, jails)
jails = self.db.getJailNames(enabled=False)
self.assertTrue(len(jails) == 0)
def testPurge(self):
if Fail2BanDb is None: # pragma: no cover
return

View File

@ -23,6 +23,7 @@ __license__ = "GPL"
import logging
import os
import re
import sys
import unittest
import tempfile
@ -32,6 +33,8 @@ import datetime
from glob import glob
from StringIO import StringIO
from utils import LogCaptureTestCase, logSys as DefLogSys
from ..helpers import formatExceptionInfo, mbasename, TraceBack, FormatterWithTraceBack, getLogger
from ..helpers import splitwords
from ..server.datetemplate import DatePatternRegex
@ -130,7 +133,7 @@ class SetupTest(unittest.TestCase):
% (sys.executable, self.setup))
class TestsUtilsTest(unittest.TestCase):
class TestsUtilsTest(LogCaptureTestCase):
def testmbasename(self):
self.assertEqual(mbasename("sample.py"), 'sample')
@ -165,12 +168,88 @@ class TestsUtilsTest(unittest.TestCase):
if not ('fail2ban-testcases' in s):
# we must be calling it from setup or nosetests but using at least
# nose's core etc
self.assertTrue('>' in s, msg="no '>' in %r" % s)
self.assertIn('>', s)
elif not ('coverage' in s):
# There is only "fail2ban-testcases" in this case, no true traceback
self.assertFalse('>' in s, msg="'>' present in %r" % s)
self.assertNotIn('>', s)
self.assertTrue(':' in s, msg="no ':' in %r" % s)
self.assertIn(':', s)
def _testAssertionErrorRE(self, regexp, fun, *args, **kwargs):
self.assertRaisesRegexp(AssertionError, regexp, fun, *args, **kwargs)
def testExtendedAssertRaisesRE(self):
## test _testAssertionErrorRE several fail cases:
def _key_err(msg):
raise KeyError(msg)
self.assertRaises(KeyError,
self._testAssertionErrorRE, r"^failed$",
_key_err, 'failed')
self.assertRaises(AssertionError,
self._testAssertionErrorRE, r"^failed$",
self.fail, '__failed__')
self._testAssertionErrorRE(r'failed.* does not match .*__failed__',
lambda: self._testAssertionErrorRE(r"^failed$",
self.fail, '__failed__')
)
## no exception in callable:
self.assertRaises(AssertionError,
self._testAssertionErrorRE, r"", int, 1)
self._testAssertionErrorRE(r'0 AssertionError not raised X.* does not match .*AssertionError not raised',
lambda: self._testAssertionErrorRE(r"^0 AssertionError not raised X$",
lambda: self._testAssertionErrorRE(r"", int, 1))
)
def testExtendedAssertMethods(self):
## assertIn, assertNotIn positive case:
self.assertIn('a', ['a', 'b', 'c', 'd'])
self.assertIn('a', ('a', 'b', 'c', 'd',))
self.assertIn('a', 'cba')
self.assertIn('a', (c for c in 'cba' if c != 'b'))
self.assertNotIn('a', ['b', 'c', 'd'])
self.assertNotIn('a', ('b', 'c', 'd',))
self.assertNotIn('a', 'cbd')
self.assertNotIn('a', (c.upper() for c in 'cba' if c != 'b'))
## assertIn, assertNotIn negative case:
self._testAssertionErrorRE(r"'a' unexpectedly found in 'cba'",
self.assertNotIn, 'a', 'cba')
self._testAssertionErrorRE(r"1 unexpectedly found in \[0, 1, 2\]",
self.assertNotIn, 1, xrange(3))
self._testAssertionErrorRE(r"'A' unexpectedly found in \['C', 'A'\]",
self.assertNotIn, 'A', (c.upper() for c in 'cba' if c != 'b'))
self._testAssertionErrorRE(r"'a' was not found in 'xyz'",
self.assertIn, 'a', 'xyz')
self._testAssertionErrorRE(r"5 was not found in \[0, 1, 2\]",
self.assertIn, 5, xrange(3))
self._testAssertionErrorRE(r"'A' was not found in \['C', 'B'\]",
self.assertIn, 'A', (c.upper() for c in 'cba' if c != 'a'))
## assertLogged, assertNotLogged positive case:
logSys = DefLogSys
self.pruneLog()
logSys.debug('test "xyz"')
self.assertLogged('test "xyz"')
self.assertLogged('test', 'xyz', all=True)
self.assertNotLogged('test', 'zyx', all=False)
self.assertNotLogged('test_zyx', 'zyx', all=True)
self.assertLogged('test', 'zyx', all=False)
self.pruneLog()
logSys.debug('xxxx "xxx"')
self.assertNotLogged('test "xyz"')
self.assertNotLogged('test', 'xyz', all=False)
self.assertNotLogged('test', 'xyz', 'zyx', all=True)
## assertLogged, assertNotLogged negative case:
self.pruneLog()
logSys.debug('test "xyz"')
self._testAssertionErrorRE(r"All of the .* were found present in the log",
self.assertNotLogged, 'test "xyz"')
self._testAssertionErrorRE(r"was found in the log",
self.assertNotLogged, 'test', 'xyz', all=True)
self._testAssertionErrorRE(r"was not found in the log",
self.assertLogged, 'test', 'zyx', all=True)
self._testAssertionErrorRE(r"None among .* was found in the log",
self.assertLogged, 'test_zyx', 'zyx', all=False)
self._testAssertionErrorRE(r"All of the .* were found present in the log",
self.assertNotLogged, 'test', 'xyz', all=False)
def testFormatterWithTraceBack(self):
strout = StringIO()

View File

@ -228,7 +228,7 @@ class Transmitter(TransmitterBase):
time.sleep(1)
self.assertEqual(
self.transm.proceed(["stop", self.jailName]), (0, None))
self.assertTrue(self.jailName not in self.server._Server__jails)
self.assertNotIn(self.jailName, self.server._Server__jails)
def testStartStopAllJail(self):
self.server.addJail("TestJail2", "auto")
@ -242,8 +242,8 @@ class Transmitter(TransmitterBase):
time.sleep(0.1)
self.assertEqual(self.transm.proceed(["stop", "all"]), (0, None))
time.sleep(1)
self.assertTrue(self.jailName not in self.server._Server__jails)
self.assertTrue("TestJail2" not in self.server._Server__jails)
self.assertNotIn(self.jailName, self.server._Server__jails)
self.assertNotIn("TestJail2", self.server._Server__jails)
def testJailIdle(self):
self.assertEqual(

View File

@ -22,6 +22,7 @@ __author__ = "Yaroslav Halchenko"
__copyright__ = "Copyright (c) 2013 Yaroslav Halchenko"
__license__ = "GPL"
import itertools
import logging
import os
import re
@ -208,16 +209,45 @@ def gatherTests(regexps=None, no_network=False):
return tests
# forwards compatibility of unittest.TestCase for some early python versions
if not hasattr(unittest.TestCase, 'assertIn'):
def __assertIn(self, a, b, msg=None):
if a not in b: # pragma: no cover
self.fail(msg or "%r was not found in %r" % (a, b))
unittest.TestCase.assertIn = __assertIn
def __assertNotIn(self, a, b, msg=None):
if a in b: # pragma: no cover
self.fail(msg or "%r was found in %r" % (a, b))
unittest.TestCase.assertNotIn = __assertNotIn
#
# Forwards compatibility of unittest.TestCase for some early python versions
#
if not hasattr(unittest.TestCase, 'assertRaisesRegexp'):
def assertRaisesRegexp(self, exccls, regexp, fun, *args, **kwargs):
try:
fun(*args, **kwargs)
except exccls as e:
if re.search(regexp, e.message) is None:
self.fail('\"%s\" does not match \"%s\"' % (regexp, e.message))
else:
self.fail('%s not raised' % getattr(exccls, '__name__'))
unittest.TestCase.assertRaisesRegexp = assertRaisesRegexp
# always custom following methods, because we use atm better version of both (support generators)
if True: ## if not hasattr(unittest.TestCase, 'assertIn'):
def assertIn(self, a, b, msg=None):
bb = b
wrap = False
if msg is None and hasattr(b, '__iter__') and not isinstance(b, basestring):
b, bb = itertools.tee(b)
wrap = True
if a not in b:
if wrap: bb = list(bb)
msg = msg or "%r was not found in %r" % (a, bb)
self.fail(msg)
unittest.TestCase.assertIn = assertIn
def assertNotIn(self, a, b, msg=None):
bb = b
wrap = False
if msg is None and hasattr(b, '__iter__') and not isinstance(b, basestring):
b, bb = itertools.tee(b)
wrap = True
if a in b:
if wrap: bb = list(bb)
msg = msg or "%r unexpectedly found in %r" % (a, bb)
self.fail(msg)
unittest.TestCase.assertNotIn = assertNotIn
class LogCaptureTestCase(unittest.TestCase):
@ -241,6 +271,7 @@ class LogCaptureTestCase(unittest.TestCase):
def tearDown(self):
"""Call after every test case."""
# print "O: >>%s<<" % self._log.getvalue()
self.pruneLog()
logSys = getLogger("fail2ban")
logSys.handlers = self._old_handlers
logSys.level = self._old_level
@ -248,7 +279,7 @@ class LogCaptureTestCase(unittest.TestCase):
def _is_logged(self, s):
return s in self._log.getvalue()
def assertLogged(self, *s):
def assertLogged(self, *s, **kwargs):
"""Assert that one of the strings was logged
Preferable to assertTrue(self._is_logged(..)))
@ -258,14 +289,23 @@ class LogCaptureTestCase(unittest.TestCase):
----------
s : string or list/set/tuple of strings
Test should succeed if string (or any of the listed) is present in the log
all : boolean (default False) if True should fail if any of s not logged
"""
logged = self._log.getvalue()
for s_ in s:
if s_ in logged:
return
raise AssertionError("None among %r was found in the log: %r" % (s, logged))
if not kwargs.get('all', False):
# at least one entry should be found:
for s_ in s:
if s_ in logged:
return
if True: # pragma: no cover
self.fail("None among %r was found in the log: ===\n%s===" % (s, logged))
else:
# each entry should be found:
for s_ in s:
if s_ not in logged: # pragma: no cover
self.fail("%r was not found in the log: ===\n%s===" % (s_, logged))
def assertNotLogged(self, *s):
def assertNotLogged(self, *s, **kwargs):
"""Assert that strings were not logged
Parameters
@ -273,13 +313,22 @@ class LogCaptureTestCase(unittest.TestCase):
s : string or list/set/tuple of strings
Test should succeed if the string (or at least one of the listed) is not
present in the log
all : boolean (default False) if True should fail if any of s logged
"""
logged = self._log.getvalue()
for s_ in s:
if s_ not in logged:
return
raise AssertionError("All of the %r were found present in the log: %r" % (s, logged))
if not kwargs.get('all', False):
for s_ in s:
if s_ not in logged:
return
if True: # pragma: no cover
self.fail("All of the %r were found present in the log: ===\n%s===" % (s, logged))
else:
for s_ in s:
if s_ in logged: # pragma: no cover
self.fail("%r was found in the log: ===\n%s===" % (s_, logged))
def pruneLog(self):
self._log.truncate(0)
def getLog(self):
return self._log.getvalue()