diff --git a/fail2ban/server/actions.py b/fail2ban/server/actions.py index 428d1ccf..1b72af91 100644 --- a/fail2ban/server/actions.py +++ b/fail2ban/server/actions.py @@ -181,7 +181,7 @@ class Actions(JailThread, Mapping): def getBanTime(self): return self.__banManager.getBanTime() - def removeBannedIP(self, ipstr): + def removeBannedIP(self, ip): """Removes banned IP calling actions' unban method Remove a banned IP now, rather than waiting for it to expire, @@ -189,16 +189,14 @@ class Actions(JailThread, Mapping): Parameters ---------- - ipstr : str - The IP address string to unban + ip : str or IPAddr + The IP address to unban Raises ------ ValueError If `ip` is not banned """ - # Create new IPAddr object from IP string - ip = IPAddr(ipstr) # Always delete ip from database (also if currently not banned) if self._jail.database is not None: self._jail.database.delBan(self._jail, ip) diff --git a/fail2ban/server/database.py b/fail2ban/server/database.py index e5e51992..6a3d87c3 100644 --- a/fail2ban/server/database.py +++ b/fail2ban/server/database.py @@ -32,7 +32,6 @@ from threading import RLock from .mytime import MyTime from .ticket import FailTicket -from .filter import IPAddr from ..helpers import getLogger # Gets the instance of the logger. @@ -412,18 +411,19 @@ class Fail2BanDb(object): ticket : BanTicket Ticket of the ban to be added. """ + ip = str(ticket.getIP()) try: - del self._bansMergedCache[(ticket.getIP(), jail)] + del self._bansMergedCache[(ip, jail)] except KeyError: pass try: - del self._bansMergedCache[(ticket.getIP(), None)] + del self._bansMergedCache[(ip, None)] except KeyError: pass #TODO: Implement data parts once arbitrary match keys completed cur.execute( "INSERT INTO bans(jail, ip, timeofban, data) VALUES(?, ?, ?, ?)", - (jail.name, ticket.getIP().ntoa(), int(round(ticket.getTime())), + (jail.name, ip, int(round(ticket.getTime())), ticket.getData())) @commitandrollback @@ -437,7 +437,7 @@ class Fail2BanDb(object): ip : str IP to be removed. """ - queryArgs = (jail.name, ip.ntoa()); + queryArgs = (jail.name, str(ip)); cur.execute( "DELETE FROM bans WHERE jail = ? AND ip = ?", queryArgs); @@ -455,7 +455,7 @@ class Fail2BanDb(object): queryArgs.append(MyTime.time() - bantime) if ip is not None: query += " AND ip=?" - queryArgs.append(ip.ntoa()) + queryArgs.append(ip) query += " ORDER BY ip, timeofban desc" return cur.execute(query, queryArgs) @@ -471,7 +471,7 @@ class Fail2BanDb(object): Ban time in seconds, such that bans returned would still be valid now. Negative values are equivalent to `None`. Default `None`; no limit. - ip : IPAddr object + ip : str IP Address to filter bans by. Default `None`; all IPs. Returns @@ -480,8 +480,7 @@ class Fail2BanDb(object): List of `Ticket`s for bans stored in database. """ tickets = [] - for ipstr, timeofban, data in self._getBans(**kwargs): - ip = IPAddr(ipstr) + for ip, timeofban, data in self._getBans(**kwargs): #TODO: Implement data parts once arbitrary match keys completed tickets.append(FailTicket(ip, timeofban)) tickets[-1].setData(data) @@ -501,7 +500,7 @@ class Fail2BanDb(object): Ban time in seconds, such that bans returned would still be valid now. Negative values are equivalent to `None`. Default `None`; no limit. - ip : IPAddr object + ip : str IP Address to filter bans by. Default `None`; all IPs. Returns @@ -522,8 +521,6 @@ class Fail2BanDb(object): ticket = None results = list(self._getBans(ip=ip, jail=jail, bantime=bantime)) - # Convert IP strings to IPAddr objects - results = map(lambda i:(IPAddr(i[0]),)+i[1:], results) if results: prev_banip = results[0][0] matches = [] diff --git a/fail2ban/server/failmanager.py b/fail2ban/server/failmanager.py index ae97b36a..b342b280 100644 --- a/fail2ban/server/failmanager.py +++ b/fail2ban/server/failmanager.py @@ -54,6 +54,15 @@ class FailManager: with self.__lock: return self.__failTotal + def getFailCount(self): + # may be slow on large list of failures, should be used for test purposes only... + with self.__lock: + return len(self.__failList), sum([f.getRetry() for f in self.__failList.values()]) + + def getFailTotal(self): + with self.__lock: + return self.__failTotal + def setMaxRetry(self, value): self.__maxRetry = value diff --git a/fail2ban/server/filter.py b/fail2ban/server/filter.py index 371a13d7..04f34821 100644 --- a/fail2ban/server/filter.py +++ b/fail2ban/server/filter.py @@ -306,19 +306,15 @@ class Filter(JailThread): def getIgnoreCommand(self): return self.__ignoreCommand - ## - # create new IPAddr object from IP address string - def newIP(self, ipstr): - return IPAddr(ipstr) - ## # Ban an IP - http://blogs.buanzo.com.ar/2009/04/fail2ban-patch-ban-ip-address-manually.html # Arturo 'Buanzo' Busleiman # # to enable banip fail2ban-client BAN command - def addBannedIP(self, ipstr): - ip = IPAddr(ipstr) + def addBannedIP(self, ip): + if not isinstance(ip, IPAddr): + ip = IPAddr(ip) if self.inIgnoreIPList(ip): logSys.warning('Requested to manually ban an ignored IP %s. User knows best. Proceeding to ban it.' % ip) @@ -358,11 +354,11 @@ class Filter(JailThread): ip = IPAddr(s[0], s[1]) # log and append to ignore list - logSys.debug("Add " + ip + " to ignore list") + logSys.debug("Add %r to ignore list (%r, %r)", ip, s[0], s[1]) self.__ignoreIpList.append(ip) def delIgnoreIP(self, ip): - logSys.debug("Remove " + ip + " from ignore list") + logSys.debug("Remove %r from ignore list", ip) self.__ignoreIpList.remove(ip) def logIgnoreIp(self, ip, log_ignore, ignore_source="unknown source"): @@ -384,18 +380,9 @@ class Filter(JailThread): if not isinstance(ip, IPAddr): ip = IPAddr(ip) for net in self.__ignoreIpList: - # if it isn't a valid IP address, try DNS resolution - if not net.isValidIP() and net.getRaw() != "": - # Check if IP in DNS - ips = DNSUtils.dnsToIp(net.getRaw()) - if ip in ips: - self.logIgnoreIp(ip, log_ignore, ignore_source="dns") - return True - else: - continue # check if the IP is covered by ignore IP if ip.isInNet(net): - self.logIgnoreIp(ip, log_ignore, ignore_source="ip") + self.logIgnoreIp(ip, log_ignore, ignore_source=("ip" if net.isValidIP() else "dns")) return True if self.__ignoreCommand: @@ -1006,8 +993,6 @@ from .utils import Utils class DNSUtils: - IP_CRE = re.compile("^(?:\d{1,3}\.){3}\d{1,3}$") - # todo: make configurable the expired time and max count of cache entries: CACHE_nameToIp = Utils.Cache(maxCount=1000, maxTime=5*60) CACHE_ipToName = Utils.Cache(maxCount=1000, maxTime=5*60) @@ -1053,17 +1038,6 @@ class DNSUtils: DNSUtils.CACHE_ipToName.set(ip, v) return v - @staticmethod - def searchIP(text): - """ Search if an IP address if directly available and return - it. - """ - match = DNSUtils.IP_CRE.match(text) - if match: - return match - else: - return None - @staticmethod def isValidIP(string): """ Return true if str is a valid IP @@ -1082,7 +1056,7 @@ class DNSUtils: ipList = list() # Search for plain IP plainIP = IPAddr.searchIP(text) - if not plainIP is None: + if plainIP is not None: ip = IPAddr(plainIP.group(0)) if ip.isValidIP(): ipList.append(ip) @@ -1122,7 +1096,7 @@ class DNSUtils: # # This class contains methods for handling IPv4 and IPv6 addresses. -class IPAddr: +class IPAddr(object): """ provide functions to handle IPv4 and IPv6 addresses """ @@ -1136,8 +1110,22 @@ class IPAddr: valid = False raw = "" + # todo: make configurable the expired time and max count of cache entries: + CACHE_OBJ = Utils.Cache(maxCount=1000, maxTime=5*60) + + def __new__(cls, ipstring, cidr=-1): + # already correct IPAddr + args = (ipstring, cidr) + ip = IPAddr.CACHE_OBJ.get(args) + if ip is not None: + return ip + ip = super(IPAddr, cls).__new__(cls) + ip.__init(ipstring, cidr) + IPAddr.CACHE_OBJ.set(args, ip) + return ip + # object methods - def __init__(self, ipstring, cidr=-1): + def __init(self, ipstring, cidr=-1): """ initialize IP object by converting IP address string to binary to integer """ @@ -1193,7 +1181,9 @@ class IPAddr: return self.ntoa() def __eq__(self, other): - other = other if isinstance(other, IPAddr) else IPAddr(other) + if not isinstance(other, IPAddr): + if other is None: return False + other = IPAddr(other) if not self.valid and not other.valid: return self.raw == other.raw if not self.valid or not other.valid: return False if self.addr != other.addr: return False @@ -1202,7 +1192,9 @@ class IPAddr: return True def __ne__(self, other): - other = other if isinstance(other, IPAddr) else IPAddr(other) + if not isinstance(other, IPAddr): + if other is None: return True + other = IPAddr(other) if not self.valid and not other.valid: return self.raw != other.raw if self.addr != other.addr: return True if self.family != other.family: return True @@ -1210,7 +1202,9 @@ class IPAddr: return False def __lt__(self, other): - other = other if isinstance(other, IPAddr) else IPAddr(other) + if not isinstance(other, IPAddr): + if other is None: return False + other = IPAddr(other) return self.family < other.family or self.addr < other.addr def __add__(self, other): @@ -1220,7 +1214,9 @@ class IPAddr: return "%s%s" % (other, self) def __hash__(self): - return hash(self.addr)^hash((self.plen<<16)|self.family) + # should be the same as by string (because of possible compare with string): + return hash(self.ntoa()) + #return hash(self.addr)^hash((self.plen<<16)|self.family) def hexdump(self): """ dump the ip address in as a hex sequence in @@ -1302,6 +1298,11 @@ class IPAddr: """ returns true if the IP object is in the provided network (object) """ + # if it isn't a valid IP address, try DNS resolution + if not net.isValidIP() and net.getRaw() != "": + # Check if IP in DNS + return self in DNSUtils.dnsToIp(net.getRaw()) + if self.family != net.family: return False @@ -1318,18 +1319,27 @@ class IPAddr: return False + @property + def maskplen(self): + plen = 0 + if (hasattr(self, '_maskplen')): + return self._plen + maddr = self.addr + while maddr: + if not (maddr & 0x80000000): + raise ValueError("invalid mask %r, no plen representation" % (self.ntoa(),)) + maddr = (maddr << 1) & 0xFFFFFFFFL + plen += 1 + self._maskplen = plen + return plen + @staticmethod def masktoplen(maskstr): """ converts mask string to prefix length only used for IPv4 masks """ - mask = IPAddr(maskstr) - plen = 0 - while mask.addr: - mask.addr = (mask.addr << 1) & 0xFFFFFFFFL - plen += 1 - return plen + return IPAddr(maskstr).maskplen @staticmethod diff --git a/fail2ban/server/ticket.py b/fail2ban/server/ticket.py index 7b7eb908..3307a6c9 100644 --- a/fail2ban/server/ticket.py +++ b/fail2ban/server/ticket.py @@ -72,6 +72,10 @@ class Ticket: return False def setIP(self, value): + # guarantee using IPAddr instead of unicode, str for the IP + if isinstance(value, basestring): + from .filter import IPAddr + value = IPAddr(value) self.__ip = value def getIP(self): diff --git a/fail2ban/tests/failmanagertestcase.py b/fail2ban/tests/failmanagertestcase.py index a8a71723..36bc87a3 100644 --- a/fail2ban/tests/failmanagertestcase.py +++ b/fail2ban/tests/failmanagertestcase.py @@ -28,6 +28,7 @@ import unittest from ..server import failmanager from ..server.failmanager import FailManager, FailManagerEmpty +from ..server.filter import IPAddr from ..server.ticket import FailTicket @@ -140,7 +141,7 @@ class AddFailure(unittest.TestCase): #ticket = FailTicket('193.168.0.128', None) ticket = self.__failManager.toBan() self.assertEqual(ticket.getIP(), "193.168.0.128") - self.assertTrue(isinstance(ticket.getIP(), str)) + self.assertTrue(isinstance(ticket.getIP(), (str, IPAddr))) # finish with rudimentary tests of the ticket # verify consistent str diff --git a/fail2ban/tests/filtertestcase.py b/fail2ban/tests/filtertestcase.py index b5a772d9..7198185a 100644 --- a/fail2ban/tests/filtertestcase.py +++ b/fail2ban/tests/filtertestcase.py @@ -38,7 +38,7 @@ except ImportError: from ..server.jail import Jail from ..server.filterpoll import FilterPoll -from ..server.filter import Filter, FileFilter, FileContainer, DNSUtils +from ..server.filter import Filter, FileFilter, FileContainer, DNSUtils, IPAddr from ..server.failmanager import FailManagerEmpty from ..server.mytime import MyTime from ..server.utils import Utils @@ -154,19 +154,41 @@ def _assert_correct_last_attempt(utest, filter_, output, count=None): Test filter to contain target ticket """ + # one or multiple tickets: + if not isinstance(output[0], (tuple,list)): + tickcount = 1 + failcount = (count if count else output[1]) + else: + tickcount = len(output) + failcount = (count if count else sum((o[1] for o in output))) + + found = [] if isinstance(filter_, DummyJail): # get fail ticket from jail - found = _ticket_tuple(filter_.getFailTicket()) + found.append(_ticket_tuple(filter_.getFailTicket())) else: # when we are testing without jails # wait for failures (up to max time) Utils.wait_for( - lambda: filter_.failManager.getFailTotal() >= (count if count else output[1]), + lambda: filter_.failManager.getFailCount() >= (tickcount, failcount), _maxWaitTime(10)) - # get fail ticket from filter - found = _ticket_tuple(filter_.failManager.toBan()) + # get fail ticket(s) from filter + while tickcount: + try: + found.append(_ticket_tuple(filter_.failManager.toBan())) + except FailManagerEmpty: + break + tickcount -= 1 - _assert_equal_entries(utest, found, output, count) + if not isinstance(output[0], (tuple,list)): + utest.assertEqual(len(found), 1) + _assert_equal_entries(utest, found[0], output, count) + else: + # sort by string representation of ip (multiple failures with different ips): + found = sorted(found, key=lambda x: str(x)) + output = sorted(output, key=lambda x: str(x)) + for f, o in zip(found, output): + _assert_equal_entries(utest, f, o) def _copy_lines_between_files(in_, fout, n=None, skip=0, mode='a', terminal_line=""): @@ -315,6 +337,10 @@ class IgnoreIP(LogCaptureTestCase): self.assertFalse(self.filter.inIgnoreIPList('192.168.1.255')) self.assertFalse(self.filter.inIgnoreIPList('192.168.0.255')) + def testWrongIPMask(self): + self.filter.addIgnoreIP('192.168.1.0/255.255.0.0') + self.assertRaises(ValueError, self.filter.addIgnoreIP, '192.168.1.0/255.255.0.128') + def testIgnoreInProcessLine(self): setUpMyTime() self.filter.addIgnoreIP('192.168.1.0/25') @@ -345,16 +371,21 @@ class IgnoreIP(LogCaptureTestCase): self.assertNotLogged("[%s] Ignore %s by %s" % (self.jail.name, "example.com", "NOT_LOGGED")) -class IgnoreIPDNS(IgnoreIP): +class IgnoreIPDNS(LogCaptureTestCase): def setUp(self): """Call before every test case.""" unittest.F2B.SkipIfNoNetwork() - IgnoreIP.setUp(self) + LogCaptureTestCase.setUp(self) + self.jail = DummyJail() + self.filter = FileFilter(self.jail) def testIgnoreIPDNSOK(self): self.filter.addIgnoreIP("www.epfl.ch") self.assertTrue(self.filter.inIgnoreIPList("128.178.50.12")) + self.filter.addIgnoreIP("example.com") + self.assertTrue(self.filter.inIgnoreIPList("93.184.216.34")) + self.assertTrue(self.filter.inIgnoreIPList("2606:2800:220:1:248:1893:25c8:1946")) def testIgnoreIPDNSNOK(self): # Test DNS @@ -1109,18 +1140,18 @@ class GetFailures(LogCaptureTestCase): _assert_correct_last_attempt(self, self.filter, output) def testGetFailures04(self): - output = [('212.41.96.186', 4, 1124013600.0), - ('212.41.96.185', 4, 1124017198.0)] + # because of not exact time in testcase04.log (no year), we should always use our test time: + self.assertEqual(MyTime.time(), 1124013600) + # should find exact 4 failures for *.186 and 2 failures for *.185 + output = (('212.41.96.186', 4, 1124013600.0), + ('212.41.96.185', 2, 1124013598.0)) + self.filter.setMaxRetry(2) self.filter.addLogPath(GetFailures.FILENAME_04, autoSeek=0) self.filter.addFailRegex("Invalid user .* ") self.filter.getFailures(GetFailures.FILENAME_04) - try: - for i, out in enumerate(output): - _assert_correct_last_attempt(self, self.filter, out) - except FailManagerEmpty: - pass + _assert_correct_last_attempt(self, self.filter, output) def testGetFailuresWrongChar(self): # write wrong utf-8 char: @@ -1160,20 +1191,31 @@ class GetFailures(LogCaptureTestCase): def testGetFailuresUseDNS(self): unittest.F2B.SkipIfNoNetwork() # We should still catch failures with usedns = no ;-) - output_yes = ('93.184.216.34', 2, 1124013539.0, - [u'Aug 14 11:54:59 i60p295 sshd[12365]: Failed publickey for roehl from example.com port 51332 ssh2', - u'Aug 14 11:58:59 i60p295 sshd[12365]: Failed publickey for roehl from ::ffff:93.184.216.34 port 51332 ssh2']) + output_yes = ( + ('93.184.216.34', 2, 1124013539.0, + [u'Aug 14 11:54:59 i60p295 sshd[12365]: Failed publickey for roehl from example.com port 51332 ssh2', + u'Aug 14 11:58:59 i60p295 sshd[12365]: Failed publickey for roehl from ::ffff:93.184.216.34 port 51332 ssh2'] + ), + ('2606:2800:220:1:248:1893:25c8:1946', 1, 1124013299.0, + [u'Aug 14 11:54:59 i60p295 sshd[12365]: Failed publickey for roehl from example.com port 51332 ssh2'] + ), + ) - output_no = ('93.184.216.34', 1, 1124013539.0, - [u'Aug 14 11:58:59 i60p295 sshd[12365]: Failed publickey for roehl from ::ffff:93.184.216.34 port 51332 ssh2']) + output_no = ( + ('93.184.216.34', 1, 1124013539.0, + [u'Aug 14 11:58:59 i60p295 sshd[12365]: Failed publickey for roehl from ::ffff:93.184.216.34 port 51332 ssh2'] + ) + ) # Actually no exception would be raised -- it will be just set to 'no' #self.assertRaises(ValueError, # FileFilter, None, useDns='wrong_value_for_useDns') - for useDns, output in (('yes', output_yes), - ('no', output_no), - ('warn', output_yes)): + for useDns, output in ( + ('yes', output_yes), + ('no', output_no), + ('warn', output_yes) + ): jail = DummyJail() filter_ = FileFilter(jail, useDns=useDns) filter_.active = True @@ -1356,6 +1398,71 @@ class DNSUtilsNetworkTests(unittest.TestCase): res = DNSUtils.bin2addr(167772160L) self.assertEqual(res, '10.0.0.0') + def testIPAddr_Equal6(self): + self.assertEqual( + IPAddr('2606:2800:220:1:248:1893::'), + IPAddr('2606:2800:220:1:248:1893:0:0') + ) + + def testIPAddr_Compare(self): + ip4 = [ + IPAddr('93.184.0.1'), + IPAddr('93.184.216.1'), + IPAddr('93.184.216.34') + ] + ip6 = [ + IPAddr('2606:2800:220:1:248:1893::'), + IPAddr('2606:2800:220:1:248:1893:25c8:0'), + IPAddr('2606:2800:220:1:248:1893:25c8:1946') + ] + # ip4 + self.assertNotEqual(ip4[0], None) + self.assertTrue(ip4[0] is not None) + self.assertFalse(ip4[0] is None) + self.assertLess(None, ip4[0]) + self.assertLess(ip4[0], ip4[1]) + self.assertLess(ip4[1], ip4[2]) + self.assertEqual(sorted(reversed(ip4)), ip4) + # ip6 + self.assertNotEqual(ip6[0], None) + self.assertTrue(ip6[0] is not None) + self.assertFalse(ip6[0] is None) + self.assertLess(None, ip6[0]) + self.assertLess(ip6[0], ip6[1]) + self.assertLess(ip6[1], ip6[2]) + self.assertEqual(sorted(reversed(ip6)), ip6) + # ip4 vs ip6 + self.assertNotEqual(ip4[0], ip6[0]) + self.assertLess(ip4[0], ip6[0]) + self.assertLess(ip4[2], ip6[2]) + self.assertEqual(sorted(reversed(ip4+ip6)), ip4+ip6) + # hashing (with string as key): + d={ + '93.184.216.34': 'ip4-test', + '2606:2800:220:1:248:1893:25c8:1946': 'ip6-test' + } + d2 = dict([(IPAddr(k), v) for k, v in d.iteritems()]) + self.assertTrue(isinstance(d.keys()[0], basestring)) + self.assertTrue(isinstance(d2.keys()[0], IPAddr)) + self.assertEqual(d.get(ip4[2], ''), 'ip4-test') + self.assertEqual(d.get(ip6[2], ''), 'ip6-test') + self.assertEqual(d2.get(str(ip4[2]), ''), 'ip4-test') + self.assertEqual(d2.get(str(ip6[2]), ''), 'ip6-test') + # compare with string direct: + self.assertEqual(d, d2) + + def testIPAddr_CompareDNS(self): + ips = IPAddr('example.com') + self.assertTrue(IPAddr("93.184.216.34").isInNet(ips)) + self.assertTrue(IPAddr("2606:2800:220:1:248:1893:25c8:1946").isInNet(ips)) + + def testIPAddr_Cached(self): + ips = [DNSUtils.dnsToIp('example.com'), DNSUtils.dnsToIp('example.com')] + for ip1, ip2 in zip(ips, ips): + self.assertEqual(id(ip1), id(ip2)) + ip1 = IPAddr('93.184.216.34'); ip2 = IPAddr('93.184.216.34'); self.assertEqual(id(ip1), id(ip2)) + ip1 = IPAddr('2606:2800:220:1:248:1893:25c8:1946'); ip2 = IPAddr('2606:2800:220:1:248:1893:25c8:1946'); self.assertEqual(id(ip1), id(ip2)) + class JailTests(unittest.TestCase):