diff --git a/fail2ban/tests/filtertestcase.py b/fail2ban/tests/filtertestcase.py index fe64e55e..b6faf836 100644 --- a/fail2ban/tests/filtertestcase.py +++ b/fail2ban/tests/filtertestcase.py @@ -695,6 +695,32 @@ class LogFileMonitor(LogCaptureTestCase): self.assertEqual(self.filter.failManager.getFailTotal(), 3) +class CommonMonitorTestCase(unittest.TestCase): + + def setUp(self): + """Call before every test case.""" + self._failTotal = 0 + + def waitFailTotal(self, count, delay=1.): + """Wait up to `delay` sec to assure that expected failure `count` reached + """ + ret = Utils.wait_for( + lambda: self.filter.failManager.getFailTotal() >= self._failTotal + count and self.jail.isFilled(), + _maxWaitTime(delay)) + self._failTotal += count + return ret + + def isFilled(self, delay=1.): + """Wait up to `delay` sec to assure that it was modified or not + """ + return Utils.wait_for(self.jail.isFilled, _maxWaitTime(delay)) + + def isEmpty(self, delay=_maxWaitTime(5)): + """Wait up to `delay` sec to assure that it empty again + """ + return Utils.wait_for(self.jail.isEmpty, _maxWaitTime(delay)) + + def get_monitor_failures_testcase(Filter_): """Generator of TestCase's for different filters/backends """ @@ -703,11 +729,12 @@ def get_monitor_failures_testcase(Filter_): testclass_name = tempfile.mktemp( 'fail2ban', 'monitorfailures_%s' % (Filter_.__name__,)) - class MonitorFailures(unittest.TestCase): + class MonitorFailures(CommonMonitorTestCase): count = 0 def setUp(self): """Call before every test case.""" + super(MonitorFailures, self).setUp() setUpMyTime() self.filter = self.name = 'NA' self.name = '%s-%d' % (testclass_name, self.count) @@ -737,11 +764,6 @@ def get_monitor_failures_testcase(Filter_): #time.sleep(0.2) # Give FS time to ack the removal pass - def isFilled(self, delay=1.): - """Wait up to `delay` sec to assure that it was modified or not - """ - return Utils.wait_for(self.jail.isFilled, _maxWaitTime(delay)) - def _sleep_4_poll(self): # Since FilterPoll relies on time stamps and some # actions might be happening too fast in the tests, @@ -749,10 +771,6 @@ def get_monitor_failures_testcase(Filter_): if isinstance(self.filter, FilterPoll): Utils.wait_for(self.filter.isAlive, _maxWaitTime(5)) - def isEmpty(self, delay=_maxWaitTime(5)): - # shorter wait time for not modified status - return Utils.wait_for(self.jail.isEmpty, _maxWaitTime(delay)) - def assert_correct_last_attempt(self, failures, count=None): self.assertTrue(self.isFilled(10)) # give Filter a chance to react _assert_correct_last_attempt(self, self.jail, failures, count=count) @@ -906,13 +924,13 @@ def get_monitor_failures_journal_testcase(Filter_): # pragma: systemd no cover """Generator of TestCase's for journal based filters/backends """ - class MonitorJournalFailures(unittest.TestCase): + class MonitorJournalFailures(CommonMonitorTestCase): def setUp(self): """Call before every test case.""" + super(MonitorJournalFailures, self).setUp() self.test_file = os.path.join(TEST_FILES_DIR, "testcase-journal.log") self.jail = DummyJail() self.filter = Filter_(self.jail) - self._failTotal = 0 # UUID used to ensure that only meeages generated # as part of this test are picked up by the filter self.test_uuid = str(uuid.uuid4()) @@ -940,24 +958,6 @@ def get_monitor_failures_journal_testcase(Filter_): # pragma: systemd no cover return "MonitorJournalFailures%s(%s)" \ % (Filter_, hasattr(self, 'name') and self.name or 'tempfile') - def waitFailTotal(self, count, delay=1.): - """Wait up to `delay` sec to assure that expected failure `count` reached - """ - ret = Utils.wait_for( - lambda: self.filter.failManager.getFailTotal() >= self._failTotal + count and self.jail.isFilled(), - _maxWaitTime(delay)) - self._failTotal += count - return ret - - def isFilled(self, delay=1.): - """Wait up to `delay` sec to assure that it was modified or not - """ - return Utils.wait_for(self.jail.isFilled, _maxWaitTime(delay)) - - def isEmpty(self, delay=_maxWaitTime(5)): - # shorter wait time for not modified status - return Utils.wait_for(self.jail.isEmpty, _maxWaitTime(delay)) - def assert_correct_ban(self, test_ip, test_attempts): self.assertTrue(self.waitFailTotal(test_attempts, 10)) # give Filter a chance to react ticket = self.jail.getFailTicket()