From 817852f0831fbaa55098d78a00e832bcadbfc7d7 Mon Sep 17 00:00:00 2001 From: sebres Date: Fri, 17 Jul 2015 18:34:13 +0200 Subject: [PATCH] test cases extended, code review (+ python 3.x compatibility); database test cases extended - enable deleted (disabled) jail in addJail; --- MANIFEST | 1 + fail2ban/server/database.py | 8 +- fail2ban/server/filter.py | 4 +- fail2ban/server/server.py | 2 +- fail2ban/server/ticket.py | 21 ++-- fail2ban/server/utils.py | 5 +- fail2ban/tests/banmanagertestcase.py | 51 +++++++- fail2ban/tests/databasetestcase.py | 21 +++- fail2ban/tests/filtertestcase.py | 56 ++++++++- fail2ban/tests/tickettestcase.py | 176 +++++++++++++++++++++++++++ fail2ban/tests/utils.py | 4 + 11 files changed, 325 insertions(+), 24 deletions(-) create mode 100644 fail2ban/tests/tickettestcase.py diff --git a/MANIFEST b/MANIFEST index 7306cc41..fb70bb4b 100644 --- a/MANIFEST +++ b/MANIFEST @@ -331,6 +331,7 @@ fail2ban/tests/misctestcase.py fail2ban/tests/samplestestcase.py fail2ban/tests/servertestcase.py fail2ban/tests/sockettestcase.py +fail2ban/tests/tickettestcase.py fail2ban/tests/utils.py fail2ban/version.py files/bash-completion diff --git a/fail2ban/server/database.py b/fail2ban/server/database.py index c6e8c95c..3b419ed3 100644 --- a/fail2ban/server/database.py +++ b/fail2ban/server/database.py @@ -302,7 +302,7 @@ class Fail2BanDb(object): cur.execute("UPDATE jails SET enabled=0") @commitandrollback - def getJailNames(self, cur): + def getJailNames(self, cur, enabled=None): """Get name of jails in database. Currently only used for testing purposes. @@ -312,7 +312,11 @@ class Fail2BanDb(object): set Set of jail names. """ - cur.execute("SELECT name FROM jails") + if enabled is None: + cur.execute("SELECT name FROM jails") + else: + cur.execute("SELECT name FROM jails WHERE enabled=%s" % + (int(enabled),)) return set(row[0] for row in cur.fetchmany()) @commitandrollback diff --git a/fail2ban/server/filter.py b/fail2ban/server/filter.py index 7b917676..1360c60d 100644 --- a/fail2ban/server/filter.py +++ b/fail2ban/server/filter.py @@ -951,8 +951,8 @@ class DNSUtils: IP_CRE = re.compile("^(?:\d{1,3}\.){3}\d{1,3}$") # todo: make configurable the expired time and max count of cache entries: - CACHE_dnsToIp = Utils.Cache(maxCount=1000, maxTime=60*60) - CACHE_ipToName = Utils.Cache(maxCount=1000, maxTime=60*60) + CACHE_dnsToIp = Utils.Cache(maxCount=1000, maxTime=5*60) + CACHE_ipToName = Utils.Cache(maxCount=1000, maxTime=5*60) @staticmethod def dnsToIp(dns): diff --git a/fail2ban/server/server.py b/fail2ban/server/server.py index 8cf6ff94..8dc1b761 100644 --- a/fail2ban/server/server.py +++ b/fail2ban/server/server.py @@ -44,7 +44,7 @@ logSys = getLogger(__name__) try: from .database import Fail2BanDb -except ImportError: +except ImportError: # pragma: no cover # Dont print error here, as database may not even be used Fail2BanDb = None diff --git a/fail2ban/server/ticket.py b/fail2ban/server/ticket.py index 49ebf9ea..e856b66d 100644 --- a/fail2ban/server/ticket.py +++ b/fail2ban/server/ticket.py @@ -146,16 +146,17 @@ class Ticket: # return default if not exists: if not self._data: return default - # return filtered by lambda/function: - if callable(key): - # todo: if support >= 2.7 only: - # return {k:v for k,v in self._data.iteritems() if key(k)} - return dict([(k,v) for k,v in self._data.iteritems() if key(k)]) - # return filtered by keys: - if hasattr(key, '__iter__'): - # todo: if support >= 2.7 only: - # return {k:v for k,v in self._data.iteritems() if k in key} - return dict([(k,v) for k,v in self._data.iteritems() if k in key]) + if not isinstance(key,(str,unicode,type(None),int,float,bool,complex)): + # return filtered by lambda/function: + if callable(key): + # todo: if support >= 2.7 only: + # return {k:v for k,v in self._data.iteritems() if key(k)} + return dict([(k,v) for k,v in self._data.iteritems() if key(k)]) + # return filtered by keys: + if hasattr(key, '__iter__'): + # todo: if support >= 2.7 only: + # return {k:v for k,v in self._data.iteritems() if k in key} + return dict([(k,v) for k,v in self._data.iteritems() if k in key]) # return single value of data: return self._data.get(key, default) diff --git a/fail2ban/server/utils.py b/fail2ban/server/utils.py index 7e69ddca..262b303d 100644 --- a/fail2ban/server/utils.py +++ b/fail2ban/server/utils.py @@ -52,7 +52,10 @@ class Utils(): class Cache(dict): - def __init__(self, maxCount=1000, maxTime=60*60): + def __init__(self, *args, **kwargs): + self.setOptions(*args, **kwargs) + + def setOptions(self, maxCount=1000, maxTime=60): self.maxCount = maxCount self.maxTime = maxTime diff --git a/fail2ban/tests/banmanagertestcase.py b/fail2ban/tests/banmanagertestcase.py index a2d399b3..9e865a1b 100644 --- a/fail2ban/tests/banmanagertestcase.py +++ b/fail2ban/tests/banmanagertestcase.py @@ -35,27 +35,74 @@ class AddFailure(unittest.TestCase): """Call before every test case.""" self.__ticket = BanTicket('193.168.0.128', 1167605999.0) self.__banManager = BanManager() - self.assertTrue(self.__banManager.addBanTicket(self.__ticket)) def tearDown(self): """Call after every test case.""" pass def testAdd(self): + self.assertTrue(self.__banManager.addBanTicket(self.__ticket)) self.assertEqual(self.__banManager.size(), 1) - + self.assertEqual(self.__banManager.getBanTotal(), 1) + self.__banManager.setBanTotal(0) + self.assertEqual(self.__banManager.getBanTotal(), 0) + def testAddDuplicate(self): + self.assertTrue(self.__banManager.addBanTicket(self.__ticket)) self.assertFalse(self.__banManager.addBanTicket(self.__ticket)) self.assertEqual(self.__banManager.size(), 1) + def testAddDuplicateWithTime(self): + # add again a duplicate : + # 1) with newer start time and the same ban time + # 2) with same start time and longer ban time + # 3) with permanent ban time (-1) + for tnew, btnew in ( + (1167605999.0 + 100, None), + (1167605999.0, 24*60*60), + (1167605999.0, -1), + ): + ticket1 = BanTicket('193.168.0.128', 1167605999.0) + ticket2 = BanTicket('193.168.0.128', tnew) + if btnew is not None: + ticket2.setBanTime(btnew) + self.assertTrue(self.__banManager.addBanTicket(ticket1)) + self.assertFalse(self.__banManager.addBanTicket(ticket2)) + self.assertEqual(self.__banManager.size(), 1) + # pop ticket and check it was prolonged : + banticket = self.__banManager.getTicketByIP(ticket2.getIP()) + self.assertEqual(banticket.getTime(), ticket2.getTime()) + self.assertEqual(banticket.getTime(), ticket2.getTime()) + self.assertEqual(banticket.getBanTime(), ticket2.getBanTime(self.__banManager.getBanTime())) + def testInListOK(self): + self.assertTrue(self.__banManager.addBanTicket(self.__ticket)) ticket = BanTicket('193.168.0.128', 1167605999.0) self.assertTrue(self.__banManager._inBanList(ticket)) def testInListNOK(self): + self.assertTrue(self.__banManager.addBanTicket(self.__ticket)) ticket = BanTicket('111.111.1.111', 1167605999.0) self.assertFalse(self.__banManager._inBanList(ticket)) + def testUnban(self): + btime = self.__banManager.getBanTime() + self.assertTrue(self.__banManager.addBanTicket(self.__ticket)) + self.assertTrue(self.__banManager._inBanList(self.__ticket)) + self.assertEqual(self.__banManager.unBanList(self.__ticket.getTime() + btime + 1), [self.__ticket]) + self.assertEqual(self.__banManager.size(), 0) + + def testUnbanPermanent(self): + btime = self.__banManager.getBanTime() + self.__banManager.setBanTime(-1) + try: + self.assertTrue(self.__banManager.addBanTicket(self.__ticket)) + self.assertTrue(self.__banManager._inBanList(self.__ticket)) + self.assertEqual(self.__banManager.unBanList(self.__ticket.getTime() + btime + 1), []) + self.assertEqual(self.__banManager.size(), 1) + finally: + self.__banManager.setBanTime(btime) + class StatusExtendedCymruInfo(unittest.TestCase): def setUp(self): diff --git a/fail2ban/tests/databasetestcase.py b/fail2ban/tests/databasetestcase.py index 20baa847..3f0e4c10 100644 --- a/fail2ban/tests/databasetestcase.py +++ b/fail2ban/tests/databasetestcase.py @@ -124,7 +124,7 @@ class DatabaseTest(LogCaptureTestCase): self.jail = DummyJail() self.db.addJail(self.jail) self.assertTrue( - self.jail.name in self.db.getJailNames(), + self.jail.name in self.db.getJailNames(True), "Jail not added to database") def testAddLog(self): @@ -332,6 +332,25 @@ class DatabaseTest(LogCaptureTestCase): actions._Actions__checkBan() self.assertLogged("ban ainfo %s, %s, %s, %s" % (True, True, True, True)) + def testDelAndAddJail(self): + self.testAddJail() # Add jail + # Delete jail (just disabled it): + self.db.delJail(self.jail) + jails = self.db.getJailNames() + self.assertTrue(len(jails) == 1 and self.jail.name in jails) + jails = self.db.getJailNames(enabled=False) + self.assertTrue(len(jails) == 1 and self.jail.name in jails) + jails = self.db.getJailNames(enabled=True) + self.assertTrue(len(jails) == 0) + # Add it again - should just enable it: + self.db.addJail(self.jail) + jails = self.db.getJailNames() + self.assertTrue(len(jails) == 1 and self.jail.name in jails) + jails = self.db.getJailNames(enabled=True) + self.assertTrue(len(jails) == 1 and self.jail.name in jails) + jails = self.db.getJailNames(enabled=False) + self.assertTrue(len(jails) == 0) + def testPurge(self): if Fail2BanDb is None: # pragma: no cover return diff --git a/fail2ban/tests/filtertestcase.py b/fail2ban/tests/filtertestcase.py index c211d642..0174d0ec 100644 --- a/fail2ban/tests/filtertestcase.py +++ b/fail2ban/tests/filtertestcase.py @@ -575,18 +575,18 @@ def get_monitor_failures_testcase(Filter_): def isFilled(self, delay=1.): """Wait up to `delay` sec to assure that it was modified or not """ - return Utils.wait_for(lambda: self.jail.isFilled(), delay) + return Utils.wait_for(self.jail.isFilled, delay) def _sleep_4_poll(self): # Since FilterPoll relies on time stamps and some # actions might be happening too fast in the tests, # sleep a bit to guarantee reliable time stamps if isinstance(self.filter, FilterPoll): - Utils.wait_for(lambda: self.filter.isAlive(), 4*Utils.DEFAULT_SLEEP_TIME) + Utils.wait_for(self.filter.isAlive, 4*Utils.DEFAULT_SLEEP_TIME) def isEmpty(self, delay=4*Utils.DEFAULT_SLEEP_TIME): # shorter wait time for not modified status - return Utils.wait_for(lambda: self.jail.isEmpty(), delay) + return Utils.wait_for(self.jail.isEmpty, delay) def assert_correct_last_attempt(self, failures, count=None): self.assertTrue(self.isFilled(20)) # give Filter a chance to react @@ -777,11 +777,11 @@ def get_monitor_failures_journal_testcase(Filter_): # pragma: systemd no cover def isFilled(self, delay=1.): """Wait up to `delay` sec to assure that it was modified or not """ - return Utils.wait_for(lambda: self.jail.isFilled(), delay) + return Utils.wait_for(self.jail.isFilled, delay) def isEmpty(self, delay=4*Utils.DEFAULT_SLEEP_TIME): # shorter wait time for not modified status - return Utils.wait_for(lambda: self.jail.isEmpty(), delay) + return Utils.wait_for(self.jail.isEmpty, delay) def assert_correct_ban(self, test_ip, test_attempts): self.assertTrue(self.isFilled(10)) # give Filter a chance to react @@ -1058,6 +1058,52 @@ class GetFailures(unittest.TestCase): class DNSUtilsTests(unittest.TestCase): + def testCache(self): + c = Utils.Cache(maxCount=5, maxTime=60) + # not available : + self.assertTrue(c.get('a') is None) + self.assertEqual(c.get('a', 'test'), 'test') + # exact 5 elements : + for i in xrange(5): + c.set(i, i) + for i in xrange(5): + self.assertEqual(c.get(i), i) + + def testCacheMaxSize(self): + c = Utils.Cache(maxCount=5, maxTime=60) + # exact 5 elements : + for i in xrange(5): + c.set(i, i) + self.assertEqual([c.get(i) for i in xrange(5)], [i for i in xrange(5)]) + self.assertFalse(-1 in [c.get(i, -1) for i in xrange(5)]) + # add one - too many: + c.set(10, i) + # one element should be removed : + self.assertTrue(-1 in [c.get(i, -1) for i in xrange(5)]) + # test max size (not expired): + for i in xrange(10): + c.set(i, 1) + self.assertEqual(len(c), 5) + + def testCacheMaxTime(self): + # test max time (expired, timeout reached) : + c = Utils.Cache(maxCount=5, maxTime=0.0005) + for i in xrange(10): + c.set(i, 1) + st = time.time() + self.assertTrue(Utils.wait_for(lambda: time.time() >= st + 0.0005, 1)) + # we have still 5 elements (or fewer if too slow test mashine): + self.assertTrue(len(c) <= 5) + # but all that are expiered also: + for i in xrange(10): + self.assertTrue(c.get(i) is None) + # here the whole cache should be empty: + self.assertEqual(len(c), 0) + + + +class DNSUtilsNetworkTests(unittest.TestCase): + def setUp(self): """Call before every test case.""" unittest.F2B.SkipIfNoNetwork() diff --git a/fail2ban/tests/tickettestcase.py b/fail2ban/tests/tickettestcase.py new file mode 100644 index 00000000..68a44bb5 --- /dev/null +++ b/fail2ban/tests/tickettestcase.py @@ -0,0 +1,176 @@ +# emacs: -*- mode: python; py-indent-offset: 4; indent-tabs-mode: t -*- +# vi: set ft=python sts=4 ts=4 sw=4 noet : + +# This file is part of Fail2Ban. +# +# Fail2Ban is free software; you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation; either version 2 of the License, or +# (at your option) any later version. +# +# Fail2Ban is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with Fail2Ban; if not, write to the Free Software +# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA. + + +__author__ = "Serg G. Brester (sebres)" +__copyright__ = "Copyright (c) 2015 Serg G. Brester, 2015- Fail2Ban Contributors" +__license__ = "GPL" + +from ..server.mytime import MyTime +import unittest + +from ..server.ticket import Ticket, FailTicket, BanTicket + + +class TicketTests(unittest.TestCase): + + def testTicket(self): + + tm = MyTime.time() + matches = ['first', 'second'] + matches2 = ['first', 'second'] + matches3 = ['first', 'second', 'third'] + + # Ticket + t = Ticket('193.168.0.128', tm, matches) + self.assertEqual(t.getIP(), '193.168.0.128') + self.assertEqual(t.getTime(), tm) + self.assertEqual(t.getMatches(), matches2) + t.setAttempt(2) + self.assertEqual(t.getAttempt(), 2) + t.setBanCount(10) + self.assertEqual(t.getBanCount(), 10) + # default ban time (from manager): + self.assertEqual(t.getBanTime(60*60), 60*60) + self.assertFalse(t.isTimedOut(tm + 60 + 1, 60*60)) + self.assertTrue(t.isTimedOut(tm + 60*60 + 1, 60*60)) + t.setBanTime(60) + self.assertEqual(t.getBanTime(60*60), 60) + self.assertEqual(t.getBanTime(), 60) + self.assertFalse(t.isTimedOut(tm)) + self.assertTrue(t.isTimedOut(tm + 60 + 1)) + # permanent : + t.setBanTime(-1) + self.assertFalse(t.isTimedOut(tm + 60 + 1)) + t.setBanTime(60) + + # BanTicket + tm = MyTime.time() + matches = ['first', 'second'] + ft = FailTicket('193.168.0.128', tm, matches) + ft.setBanTime(60*60) + self.assertEqual(ft.getIP(), '193.168.0.128') + self.assertEqual(ft.getTime(), tm) + self.assertEqual(ft.getMatches(), matches2) + ft.setAttempt(2) + self.assertEqual(ft.getAttempt(), 2) + # retry is max of set retry and failures: + self.assertEqual(ft.getRetry(), 2) + ft.setRetry(1) + self.assertEqual(ft.getRetry(), 2) + ft.setRetry(3) + self.assertEqual(ft.getRetry(), 3) + ft.inc() + self.assertEqual(ft.getAttempt(), 3) + self.assertEqual(ft.getRetry(), 4) + self.assertEqual(ft.getMatches(), matches2) + # with 1 match, 1 failure and factor 10 (retry count) : + ft.inc(['third'], 1, 10) + self.assertEqual(ft.getAttempt(), 4) + self.assertEqual(ft.getRetry(), 14) + self.assertEqual(ft.getMatches(), matches3) + # last time (ignore if smaller as time): + self.assertEqual(ft.getLastTime(), tm) + ft.setLastTime(tm-60) + self.assertEqual(ft.getTime(), tm) + self.assertEqual(ft.getLastTime(), tm) + ft.setLastTime(tm+60) + self.assertEqual(ft.getTime(), tm+60) + self.assertEqual(ft.getLastTime(), tm+60) + ft.setData('country', 'DE') + self.assertEqual(ft.getData(), + {'matches': ['first', 'second', 'third'], 'failures': 4, 'country': 'DE'}) + + # copy all from another ticket: + ft2 = FailTicket(ticket=ft) + self.assertEqual(ft, ft2) + self.assertEqual(ft.getData(), ft2.getData()) + self.assertEqual(ft2.getAttempt(), 4) + self.assertEqual(ft2.getRetry(), 14) + self.assertEqual(ft2.getMatches(), matches3) + self.assertEqual(ft2.getTime(), ft.getTime()) + self.assertEqual(ft2.getLastTime(), ft.getLastTime()) + self.assertEqual(ft2.getBanTime(), ft.getBanTime()) + + def testTicketData(self): + t = BanTicket('193.168.0.128', None, ['first', 'second']) + # expand data (no overwrites, matches are available) : + t.setData('region', 'Hamburg', 'country', 'DE', 'city', 'Hamburg') + self.assertEqual( + t.getData(), + {'matches': ['first', 'second'], 'failures':0, 'region': 'Hamburg', 'country': 'DE', 'city': 'Hamburg'}) + # at once as dict (single argument, overwrites it completelly, no more matches/failures) : + t.setData({'region': None, 'country': 'FR', 'city': 'Paris'},) + self.assertEqual( + t.getData(), + {'city': 'Paris', 'country': 'FR'}) + # at once as dict (overwrites it completelly, no more matches/failures) : + t.setData({'region': 'Hamburg', 'country': 'DE', 'city': None}) + self.assertEqual( + t.getData(), + {'region': 'Hamburg', 'country': 'DE'}) + self.assertEqual( + t.getData('region'), + 'Hamburg') + self.assertEqual( + t.getData('country'), + 'DE') + # again, named arguments: + t.setData(region='Bremen', city='Bremen') + self.assertEqual(t.getData(), + {'region': 'Bremen', 'country': 'DE', 'city': 'Bremen'}) + # again, but as args (key value pair): + t.setData('region', 'Brandenburg', 'city', 'Berlin') + self.assertEqual( + t.getData('region'), + 'Brandenburg') + self.assertEqual( + t.getData('city'), + 'Berlin') + self.assertEqual( + t.getData(), + {'city':'Berlin', 'region': 'Brandenburg', 'country': 'DE'}) + # interator filter : + self.assertEqual( + t.getData(('city', 'country')), + {'city':'Berlin', 'country': 'DE'}) + # callable filter : + self.assertEqual( + t.getData(lambda k: k.upper() == 'COUNTRY'), + {'country': 'DE'}) + # remove one data entry: + t.setData('city', None) + self.assertEqual( + t.getData(), + {'region': 'Brandenburg', 'country': 'DE'}) + # default if not available: + self.assertEqual( + t.getData('city', 'Unknown'), + 'Unknown') + # add continent : + t.setData('continent', 'Europe') + # again, but as argument list (overwrite new only, leave continent unchanged) : + t.setData(*['country', 'RU', 'region', 'Moscow']) + self.assertEqual( + t.getData(), + {'continent': 'Europe', 'country': 'RU', 'region': 'Moscow'}) + # clear: + t.setData({}) + self.assertEqual(t.getData(), {}) + self.assertEqual(t.getData('anything', 'default'), 'default') diff --git a/fail2ban/tests/utils.py b/fail2ban/tests/utils.py index f2b6f37b..32bfba50 100644 --- a/fail2ban/tests/utils.py +++ b/fail2ban/tests/utils.py @@ -115,6 +115,7 @@ def gatherTests(regexps=None, opts=None): # avoid circular imports from . import banmanagertestcase from . import clientreadertestcase + from . import tickettestcase from . import failmanagertestcase from . import filtertestcase from . import servertestcase @@ -149,6 +150,8 @@ def gatherTests(regexps=None, opts=None): tests.addTest(unittest.makeSuite(servertestcase.LoggingTests)) tests.addTest(unittest.makeSuite(actiontestcase.CommandActionTest)) tests.addTest(unittest.makeSuite(actionstestcase.ExecuteActions)) + # Ticket, BanTicket, FailTicket + tests.addTest(unittest.makeSuite(tickettestcase.TicketTests)) # FailManager tests.addTest(unittest.makeSuite(failmanagertestcase.AddFailure)) # BanManager @@ -186,6 +189,7 @@ def gatherTests(regexps=None, opts=None): tests.addTest(unittest.makeSuite(filtertestcase.IgnoreIPDNS)) tests.addTest(unittest.makeSuite(filtertestcase.GetFailures)) tests.addTest(unittest.makeSuite(filtertestcase.DNSUtilsTests)) + tests.addTest(unittest.makeSuite(filtertestcase.DNSUtilsNetworkTests)) tests.addTest(unittest.makeSuite(filtertestcase.JailTests)) # DateDetector