diff --git a/fail2ban/server/database.py b/fail2ban/server/database.py index 560fbfe5..9f562511 100644 --- a/fail2ban/server/database.py +++ b/fail2ban/server/database.py @@ -293,8 +293,12 @@ class Fail2BanDb(object): Jail to be added to the database. """ cur.execute( - "INSERT OR REPLACE INTO jails(name, enabled) VALUES(?, 1)", + "INSERT OR IGNORE INTO jails(name, enabled) VALUES(?, 1)", (jail.name,)) + if cur.rowcount <= 0: + cur.execute( + "UPDATE jails SET enabled = 1 WHERE name = ? AND enabled != 1", + (jail.name,)) @commitandrollback def delJail(self, cur, jail): @@ -317,7 +321,7 @@ class Fail2BanDb(object): cur.execute("UPDATE jails SET enabled=0") @commitandrollback - def getJailNames(self, cur): + def getJailNames(self, cur, enabled=None): """Get name of jails in database. Currently only used for testing purposes. @@ -327,7 +331,11 @@ class Fail2BanDb(object): set Set of jail names. """ - cur.execute("SELECT name FROM jails") + if enabled is None: + cur.execute("SELECT name FROM jails") + else: + cur.execute("SELECT name FROM jails WHERE enabled=%s" % + (int(enabled),)) return set(row[0] for row in cur.fetchmany()) @commitandrollback diff --git a/fail2ban/tests/utils.py b/fail2ban/tests/utils.py index 7bccf6e2..e091c935 100644 --- a/fail2ban/tests/utils.py +++ b/fail2ban/tests/utils.py @@ -271,6 +271,7 @@ class LogCaptureTestCase(unittest.TestCase): def tearDown(self): """Call after every test case.""" # print "O: >>%s<<" % self._log.getvalue() + self.pruneLog() logSys = getLogger("fail2ban") logSys.handlers = self._old_handlers logSys.level = self._old_level @@ -278,7 +279,7 @@ class LogCaptureTestCase(unittest.TestCase): def _is_logged(self, s): return s in self._log.getvalue() - def assertLogged(self, *s): + def assertLogged(self, *s, **kwargs): """Assert that one of the strings was logged Preferable to assertTrue(self._is_logged(..))) @@ -288,14 +289,23 @@ class LogCaptureTestCase(unittest.TestCase): ---------- s : string or list/set/tuple of strings Test should succeed if string (or any of the listed) is present in the log + all : boolean (default False) if True should fail if any of s not logged """ logged = self._log.getvalue() - for s_ in s: - if s_ in logged: - return - raise AssertionError("None among %r was found in the log: %r" % (s, logged)) + if not kwargs.get('all', False): + # at least one entry should be found: + for s_ in s: + if s_ in logged: + return + if True: # pragma: no cover + self.fail("None among %r was found in the log: ===\n%s===" % (s, logged)) + else: + # each entry should be found: + for s_ in s: + if s_ not in logged: # pragma: no cover + self.fail("%r was not found in the log: ===\n%s===" % (s_, logged)) - def assertNotLogged(self, *s): + def assertNotLogged(self, *s, **kwargs): """Assert that strings were not logged Parameters @@ -303,13 +313,22 @@ class LogCaptureTestCase(unittest.TestCase): s : string or list/set/tuple of strings Test should succeed if the string (or at least one of the listed) is not present in the log + all : boolean (default False) if True should fail if any of s logged """ logged = self._log.getvalue() - for s_ in s: - if s_ not in logged: - return - raise AssertionError("All of the %r were found present in the log: %r" % (s, logged)) + if not kwargs.get('all', False): + for s_ in s: + if s_ not in logged: + return + if True: # pragma: no cover + self.fail("All of the %r were found present in the log: ===\n%s===" % (s, logged)) + else: + for s_ in s: + if s_ in logged: # pragma: no cover + self.fail("%r was found in the log: ===\n%s===" % (s_, logged)) + def pruneLog(self): + self._log.truncate(0) def getLog(self): return self._log.getvalue()