diff --git a/fail2ban/server/database.py b/fail2ban/server/database.py index 560fbfe5..9f562511 100644 --- a/fail2ban/server/database.py +++ b/fail2ban/server/database.py @@ -293,8 +293,12 @@ class Fail2BanDb(object): Jail to be added to the database. """ cur.execute( - "INSERT OR REPLACE INTO jails(name, enabled) VALUES(?, 1)", + "INSERT OR IGNORE INTO jails(name, enabled) VALUES(?, 1)", (jail.name,)) + if cur.rowcount <= 0: + cur.execute( + "UPDATE jails SET enabled = 1 WHERE name = ? AND enabled != 1", + (jail.name,)) @commitandrollback def delJail(self, cur, jail): @@ -317,7 +321,7 @@ class Fail2BanDb(object): cur.execute("UPDATE jails SET enabled=0") @commitandrollback - def getJailNames(self, cur): + def getJailNames(self, cur, enabled=None): """Get name of jails in database. Currently only used for testing purposes. @@ -327,7 +331,11 @@ class Fail2BanDb(object): set Set of jail names. """ - cur.execute("SELECT name FROM jails") + if enabled is None: + cur.execute("SELECT name FROM jails") + else: + cur.execute("SELECT name FROM jails WHERE enabled=%s" % + (int(enabled),)) return set(row[0] for row in cur.fetchmany()) @commitandrollback diff --git a/fail2ban/tests/action_d/test_badips.py b/fail2ban/tests/action_d/test_badips.py index 3f71b7a3..b0f8b3c3 100644 --- a/fail2ban/tests/action_d/test_badips.py +++ b/fail2ban/tests/action_d/test_badips.py @@ -49,7 +49,7 @@ if sys.version_info >= (2,7): def testCategory(self): categories = self.action.getCategories() - self.assertTrue("ssh" in categories) + self.assertIn("ssh", categories) self.assertTrue(len(categories) >= 10) self.assertRaises( diff --git a/fail2ban/tests/action_d/test_smtp.py b/fail2ban/tests/action_d/test_smtp.py index 35ac2393..b8328743 100644 --- a/fail2ban/tests/action_d/test_smtp.py +++ b/fail2ban/tests/action_d/test_smtp.py @@ -101,21 +101,21 @@ class SMTPActionTest(unittest.TestCase): self.assertEqual(self.smtpd.rcpttos, ["root"]) subject = "Subject: [Fail2Ban] %s: banned %s" % ( self.jail.name, aInfo['ip']) - self.assertTrue(subject in self.smtpd.data.replace("\n", "")) + self.assertIn(subject, self.smtpd.data.replace("\n", "")) self.assertTrue( "%i attempts" % aInfo['failures'] in self.smtpd.data) self.action.matches = "matches" self.action.ban(aInfo) - self.assertTrue(aInfo['matches'] in self.smtpd.data) + self.assertIn(aInfo['matches'], self.smtpd.data) self.action.matches = "ipjailmatches" self.action.ban(aInfo) - self.assertTrue(aInfo['ipjailmatches'] in self.smtpd.data) + self.assertIn(aInfo['ipjailmatches'], self.smtpd.data) self.action.matches = "ipmatches" self.action.ban(aInfo) - self.assertTrue(aInfo['ipmatches'] in self.smtpd.data) + self.assertIn(aInfo['ipmatches'], self.smtpd.data) def testOptions(self): self.action.start() diff --git a/fail2ban/tests/actionstestcase.py b/fail2ban/tests/actionstestcase.py index 0ceb35d5..3b9d2d01 100644 --- a/fail2ban/tests/actionstestcase.py +++ b/fail2ban/tests/actionstestcase.py @@ -65,12 +65,12 @@ class ExecuteActions(LogCaptureTestCase): def testActionsManipulation(self): self.__actions.add('test') self.assertTrue(self.__actions['test']) - self.assertTrue('test' in self.__actions) - self.assertFalse('nonexistant action' in self.__actions) + self.assertIn('test', self.__actions) + self.assertNotIn('nonexistant action', self.__actions) self.__actions.add('test1') del self.__actions['test'] del self.__actions['test1'] - self.assertFalse('test' in self.__actions) + self.assertNotIn('test', self.__actions) self.assertEqual(len(self.__actions), 0) self.__actions.setBanTime(127) diff --git a/fail2ban/tests/clientreadertestcase.py b/fail2ban/tests/clientreadertestcase.py index 0a3734e5..ee362e3d 100644 --- a/fail2ban/tests/clientreadertestcase.py +++ b/fail2ban/tests/clientreadertestcase.py @@ -525,12 +525,12 @@ class JailsReaderTest(LogCaptureTestCase): self.assertTrue(actionReader.read()) actionReader.getOptions({}) # populate _opts if not actionName.endswith('-common'): - self.assertTrue('Definition' in actionReader.sections(), + self.assertIn('Definition', actionReader.sections(), msg="Action file %r is lacking [Definition] section" % actionConfig) # all must have some actionban defined self.assertTrue(actionReader._opts.get('actionban', '').strip(), msg="Action file %r is lacking actionban" % actionConfig) - self.assertTrue('Init' in actionReader.sections(), + self.assertIn('Init', actionReader.sections(), msg="Action file %r is lacking [Init] section" % actionConfig) def testReadStockJailConf(self): @@ -582,7 +582,7 @@ class JailsReaderTest(LogCaptureTestCase): self.assertTrue(len(actName)) self.assertTrue(isinstance(actOpt, dict)) if actName == 'iptables-multiport': - self.assertTrue('port' in actOpt) + self.assertIn('port', actOpt) actionReader = ActionReader( actName, jail, {}, basedir=CONFIG_DIR) @@ -632,11 +632,13 @@ class JailsReaderTest(LogCaptureTestCase): # and we know even some of them by heart for j in ['sshd', 'recidive']: - # by default we have 'auto' backend ATM - self.assertTrue(['add', j, 'auto'] in comm_commands) + # by default we have 'auto' backend ATM, but some distributions can overwrite it, + # (e.g. fedora default is 'systemd') therefore let check it without backend... + self.assertIn(['add', j], + (cmd[:2] for cmd in comm_commands if len(cmd) == 3 and cmd[0] == 'add')) # and warn on useDNS - self.assertTrue(['set', j, 'usedns', 'warn'] in comm_commands) - self.assertTrue(['start', j] in comm_commands) + self.assertIn(['set', j, 'usedns', 'warn'], comm_commands) + self.assertIn(['start', j], comm_commands) # last commands should be the 'start' commands self.assertEqual(comm_commands[-1][0], 'start') @@ -655,7 +657,7 @@ class JailsReaderTest(LogCaptureTestCase): action_name = action.getName() if '' in str(commands): # Verify that it is among cInfo - self.assertTrue('blocktype' in action._initOpts) + self.assertIn('blocktype', action._initOpts) # Verify that we have a call to set it up blocktype_present = False target_command = ['set', jail_name, 'action', action_name, 'blocktype'] diff --git a/fail2ban/tests/databasetestcase.py b/fail2ban/tests/databasetestcase.py index 3d156eda..e934ba45 100644 --- a/fail2ban/tests/databasetestcase.py +++ b/fail2ban/tests/databasetestcase.py @@ -123,7 +123,7 @@ class DatabaseTest(LogCaptureTestCase): self.db.addLog(self.jail, self.fileContainer) - self.assertTrue(filename in self.db.getLogPaths(self.jail)) + self.assertIn(filename, self.db.getLogPaths(self.jail)) os.remove(filename) def testUpdateLog(self): @@ -318,6 +318,25 @@ class DatabaseTest(LogCaptureTestCase): actions._Actions__checkBan() self.assertLogged("ban ainfo %s, %s, %s, %s" % (True, True, True, True)) + def testDelAndAddJail(self): + self.testAddJail() # Add jail + # Delete jail (just disabled it): + self.db.delJail(self.jail) + jails = self.db.getJailNames() + self.assertIn(len(jails) == 1 and self.jail.name, jails) + jails = self.db.getJailNames(enabled=False) + self.assertIn(len(jails) == 1 and self.jail.name, jails) + jails = self.db.getJailNames(enabled=True) + self.assertTrue(len(jails) == 0) + # Add it again - should just enable it: + self.db.addJail(self.jail) + jails = self.db.getJailNames() + self.assertIn(len(jails) == 1 and self.jail.name, jails) + jails = self.db.getJailNames(enabled=True) + self.assertIn(len(jails) == 1 and self.jail.name, jails) + jails = self.db.getJailNames(enabled=False) + self.assertTrue(len(jails) == 0) + def testPurge(self): if Fail2BanDb is None: # pragma: no cover return diff --git a/fail2ban/tests/misctestcase.py b/fail2ban/tests/misctestcase.py index 48074d53..5f7d853d 100644 --- a/fail2ban/tests/misctestcase.py +++ b/fail2ban/tests/misctestcase.py @@ -23,6 +23,7 @@ __license__ = "GPL" import logging import os +import re import sys import unittest import tempfile @@ -32,6 +33,8 @@ import datetime from glob import glob from StringIO import StringIO +from utils import LogCaptureTestCase, logSys as DefLogSys + from ..helpers import formatExceptionInfo, mbasename, TraceBack, FormatterWithTraceBack, getLogger from ..helpers import splitwords from ..server.datetemplate import DatePatternRegex @@ -130,7 +133,7 @@ class SetupTest(unittest.TestCase): % (sys.executable, self.setup)) -class TestsUtilsTest(unittest.TestCase): +class TestsUtilsTest(LogCaptureTestCase): def testmbasename(self): self.assertEqual(mbasename("sample.py"), 'sample') @@ -165,12 +168,88 @@ class TestsUtilsTest(unittest.TestCase): if not ('fail2ban-testcases' in s): # we must be calling it from setup or nosetests but using at least # nose's core etc - self.assertTrue('>' in s, msg="no '>' in %r" % s) + self.assertIn('>', s) elif not ('coverage' in s): # There is only "fail2ban-testcases" in this case, no true traceback - self.assertFalse('>' in s, msg="'>' present in %r" % s) + self.assertNotIn('>', s) - self.assertTrue(':' in s, msg="no ':' in %r" % s) + self.assertIn(':', s) + + def _testAssertionErrorRE(self, regexp, fun, *args, **kwargs): + self.assertRaisesRegexp(AssertionError, regexp, fun, *args, **kwargs) + + def testExtendedAssertRaisesRE(self): + ## test _testAssertionErrorRE several fail cases: + def _key_err(msg): + raise KeyError(msg) + self.assertRaises(KeyError, + self._testAssertionErrorRE, r"^failed$", + _key_err, 'failed') + self.assertRaises(AssertionError, + self._testAssertionErrorRE, r"^failed$", + self.fail, '__failed__') + self._testAssertionErrorRE(r'failed.* does not match .*__failed__', + lambda: self._testAssertionErrorRE(r"^failed$", + self.fail, '__failed__') + ) + ## no exception in callable: + self.assertRaises(AssertionError, + self._testAssertionErrorRE, r"", int, 1) + self._testAssertionErrorRE(r'0 AssertionError not raised X.* does not match .*AssertionError not raised', + lambda: self._testAssertionErrorRE(r"^0 AssertionError not raised X$", + lambda: self._testAssertionErrorRE(r"", int, 1)) + ) + + def testExtendedAssertMethods(self): + ## assertIn, assertNotIn positive case: + self.assertIn('a', ['a', 'b', 'c', 'd']) + self.assertIn('a', ('a', 'b', 'c', 'd',)) + self.assertIn('a', 'cba') + self.assertIn('a', (c for c in 'cba' if c != 'b')) + self.assertNotIn('a', ['b', 'c', 'd']) + self.assertNotIn('a', ('b', 'c', 'd',)) + self.assertNotIn('a', 'cbd') + self.assertNotIn('a', (c.upper() for c in 'cba' if c != 'b')) + ## assertIn, assertNotIn negative case: + self._testAssertionErrorRE(r"'a' unexpectedly found in 'cba'", + self.assertNotIn, 'a', 'cba') + self._testAssertionErrorRE(r"1 unexpectedly found in \[0, 1, 2\]", + self.assertNotIn, 1, xrange(3)) + self._testAssertionErrorRE(r"'A' unexpectedly found in \['C', 'A'\]", + self.assertNotIn, 'A', (c.upper() for c in 'cba' if c != 'b')) + self._testAssertionErrorRE(r"'a' was not found in 'xyz'", + self.assertIn, 'a', 'xyz') + self._testAssertionErrorRE(r"5 was not found in \[0, 1, 2\]", + self.assertIn, 5, xrange(3)) + self._testAssertionErrorRE(r"'A' was not found in \['C', 'B'\]", + self.assertIn, 'A', (c.upper() for c in 'cba' if c != 'a')) + ## assertLogged, assertNotLogged positive case: + logSys = DefLogSys + self.pruneLog() + logSys.debug('test "xyz"') + self.assertLogged('test "xyz"') + self.assertLogged('test', 'xyz', all=True) + self.assertNotLogged('test', 'zyx', all=False) + self.assertNotLogged('test_zyx', 'zyx', all=True) + self.assertLogged('test', 'zyx', all=False) + self.pruneLog() + logSys.debug('xxxx "xxx"') + self.assertNotLogged('test "xyz"') + self.assertNotLogged('test', 'xyz', all=False) + self.assertNotLogged('test', 'xyz', 'zyx', all=True) + ## assertLogged, assertNotLogged negative case: + self.pruneLog() + logSys.debug('test "xyz"') + self._testAssertionErrorRE(r"All of the .* were found present in the log", + self.assertNotLogged, 'test "xyz"') + self._testAssertionErrorRE(r"was found in the log", + self.assertNotLogged, 'test', 'xyz', all=True) + self._testAssertionErrorRE(r"was not found in the log", + self.assertLogged, 'test', 'zyx', all=True) + self._testAssertionErrorRE(r"None among .* was found in the log", + self.assertLogged, 'test_zyx', 'zyx', all=False) + self._testAssertionErrorRE(r"All of the .* were found present in the log", + self.assertNotLogged, 'test', 'xyz', all=False) def testFormatterWithTraceBack(self): strout = StringIO() diff --git a/fail2ban/tests/servertestcase.py b/fail2ban/tests/servertestcase.py index 07e10c7d..21e2d784 100644 --- a/fail2ban/tests/servertestcase.py +++ b/fail2ban/tests/servertestcase.py @@ -228,7 +228,7 @@ class Transmitter(TransmitterBase): time.sleep(1) self.assertEqual( self.transm.proceed(["stop", self.jailName]), (0, None)) - self.assertTrue(self.jailName not in self.server._Server__jails) + self.assertNotIn(self.jailName, self.server._Server__jails) def testStartStopAllJail(self): self.server.addJail("TestJail2", "auto") @@ -242,8 +242,8 @@ class Transmitter(TransmitterBase): time.sleep(0.1) self.assertEqual(self.transm.proceed(["stop", "all"]), (0, None)) time.sleep(1) - self.assertTrue(self.jailName not in self.server._Server__jails) - self.assertTrue("TestJail2" not in self.server._Server__jails) + self.assertNotIn(self.jailName, self.server._Server__jails) + self.assertNotIn("TestJail2", self.server._Server__jails) def testJailIdle(self): self.assertEqual( diff --git a/fail2ban/tests/utils.py b/fail2ban/tests/utils.py index 8fc78683..e091c935 100644 --- a/fail2ban/tests/utils.py +++ b/fail2ban/tests/utils.py @@ -22,6 +22,7 @@ __author__ = "Yaroslav Halchenko" __copyright__ = "Copyright (c) 2013 Yaroslav Halchenko" __license__ = "GPL" +import itertools import logging import os import re @@ -208,16 +209,45 @@ def gatherTests(regexps=None, no_network=False): return tests -# forwards compatibility of unittest.TestCase for some early python versions -if not hasattr(unittest.TestCase, 'assertIn'): - def __assertIn(self, a, b, msg=None): - if a not in b: # pragma: no cover - self.fail(msg or "%r was not found in %r" % (a, b)) - unittest.TestCase.assertIn = __assertIn - def __assertNotIn(self, a, b, msg=None): - if a in b: # pragma: no cover - self.fail(msg or "%r was found in %r" % (a, b)) - unittest.TestCase.assertNotIn = __assertNotIn +# +# Forwards compatibility of unittest.TestCase for some early python versions +# + +if not hasattr(unittest.TestCase, 'assertRaisesRegexp'): + def assertRaisesRegexp(self, exccls, regexp, fun, *args, **kwargs): + try: + fun(*args, **kwargs) + except exccls as e: + if re.search(regexp, e.message) is None: + self.fail('\"%s\" does not match \"%s\"' % (regexp, e.message)) + else: + self.fail('%s not raised' % getattr(exccls, '__name__')) + unittest.TestCase.assertRaisesRegexp = assertRaisesRegexp + +# always custom following methods, because we use atm better version of both (support generators) +if True: ## if not hasattr(unittest.TestCase, 'assertIn'): + def assertIn(self, a, b, msg=None): + bb = b + wrap = False + if msg is None and hasattr(b, '__iter__') and not isinstance(b, basestring): + b, bb = itertools.tee(b) + wrap = True + if a not in b: + if wrap: bb = list(bb) + msg = msg or "%r was not found in %r" % (a, bb) + self.fail(msg) + unittest.TestCase.assertIn = assertIn + def assertNotIn(self, a, b, msg=None): + bb = b + wrap = False + if msg is None and hasattr(b, '__iter__') and not isinstance(b, basestring): + b, bb = itertools.tee(b) + wrap = True + if a in b: + if wrap: bb = list(bb) + msg = msg or "%r unexpectedly found in %r" % (a, bb) + self.fail(msg) + unittest.TestCase.assertNotIn = assertNotIn class LogCaptureTestCase(unittest.TestCase): @@ -241,6 +271,7 @@ class LogCaptureTestCase(unittest.TestCase): def tearDown(self): """Call after every test case.""" # print "O: >>%s<<" % self._log.getvalue() + self.pruneLog() logSys = getLogger("fail2ban") logSys.handlers = self._old_handlers logSys.level = self._old_level @@ -248,7 +279,7 @@ class LogCaptureTestCase(unittest.TestCase): def _is_logged(self, s): return s in self._log.getvalue() - def assertLogged(self, *s): + def assertLogged(self, *s, **kwargs): """Assert that one of the strings was logged Preferable to assertTrue(self._is_logged(..))) @@ -258,14 +289,23 @@ class LogCaptureTestCase(unittest.TestCase): ---------- s : string or list/set/tuple of strings Test should succeed if string (or any of the listed) is present in the log + all : boolean (default False) if True should fail if any of s not logged """ logged = self._log.getvalue() - for s_ in s: - if s_ in logged: - return - raise AssertionError("None among %r was found in the log: %r" % (s, logged)) + if not kwargs.get('all', False): + # at least one entry should be found: + for s_ in s: + if s_ in logged: + return + if True: # pragma: no cover + self.fail("None among %r was found in the log: ===\n%s===" % (s, logged)) + else: + # each entry should be found: + for s_ in s: + if s_ not in logged: # pragma: no cover + self.fail("%r was not found in the log: ===\n%s===" % (s_, logged)) - def assertNotLogged(self, *s): + def assertNotLogged(self, *s, **kwargs): """Assert that strings were not logged Parameters @@ -273,13 +313,22 @@ class LogCaptureTestCase(unittest.TestCase): s : string or list/set/tuple of strings Test should succeed if the string (or at least one of the listed) is not present in the log + all : boolean (default False) if True should fail if any of s logged """ logged = self._log.getvalue() - for s_ in s: - if s_ not in logged: - return - raise AssertionError("All of the %r were found present in the log: %r" % (s, logged)) + if not kwargs.get('all', False): + for s_ in s: + if s_ not in logged: + return + if True: # pragma: no cover + self.fail("All of the %r were found present in the log: ===\n%s===" % (s, logged)) + else: + for s_ in s: + if s_ in logged: # pragma: no cover + self.fail("%r was found in the log: ===\n%s===" % (s_, logged)) + def pruneLog(self): + self._log.truncate(0) def getLog(self): return self._log.getvalue()