meantime commit: code review, simplification, pythonization, etc. (test cases passed)

unnecessarily code aggravation with explicit converting reverted - implicit converting inside internal functions if not IPAddr object;
pull/1414/head
sebres 2016-05-04 19:44:52 +02:00
parent 85b895178b
commit afe1f73af2
7 changed files with 212 additions and 86 deletions

View File

@ -181,7 +181,7 @@ class Actions(JailThread, Mapping):
def getBanTime(self): def getBanTime(self):
return self.__banManager.getBanTime() return self.__banManager.getBanTime()
def removeBannedIP(self, ipstr): def removeBannedIP(self, ip):
"""Removes banned IP calling actions' unban method """Removes banned IP calling actions' unban method
Remove a banned IP now, rather than waiting for it to expire, Remove a banned IP now, rather than waiting for it to expire,
@ -189,16 +189,14 @@ class Actions(JailThread, Mapping):
Parameters Parameters
---------- ----------
ipstr : str ip : str or IPAddr
The IP address string to unban The IP address to unban
Raises Raises
------ ------
ValueError ValueError
If `ip` is not banned If `ip` is not banned
""" """
# Create new IPAddr object from IP string
ip = IPAddr(ipstr)
# Always delete ip from database (also if currently not banned) # Always delete ip from database (also if currently not banned)
if self._jail.database is not None: if self._jail.database is not None:
self._jail.database.delBan(self._jail, ip) self._jail.database.delBan(self._jail, ip)

View File

@ -32,7 +32,6 @@ from threading import RLock
from .mytime import MyTime from .mytime import MyTime
from .ticket import FailTicket from .ticket import FailTicket
from .filter import IPAddr
from ..helpers import getLogger from ..helpers import getLogger
# Gets the instance of the logger. # Gets the instance of the logger.
@ -412,18 +411,19 @@ class Fail2BanDb(object):
ticket : BanTicket ticket : BanTicket
Ticket of the ban to be added. Ticket of the ban to be added.
""" """
ip = str(ticket.getIP())
try: try:
del self._bansMergedCache[(ticket.getIP(), jail)] del self._bansMergedCache[(ip, jail)]
except KeyError: except KeyError:
pass pass
try: try:
del self._bansMergedCache[(ticket.getIP(), None)] del self._bansMergedCache[(ip, None)]
except KeyError: except KeyError:
pass pass
#TODO: Implement data parts once arbitrary match keys completed #TODO: Implement data parts once arbitrary match keys completed
cur.execute( cur.execute(
"INSERT INTO bans(jail, ip, timeofban, data) VALUES(?, ?, ?, ?)", "INSERT INTO bans(jail, ip, timeofban, data) VALUES(?, ?, ?, ?)",
(jail.name, ticket.getIP().ntoa(), int(round(ticket.getTime())), (jail.name, ip, int(round(ticket.getTime())),
ticket.getData())) ticket.getData()))
@commitandrollback @commitandrollback
@ -437,7 +437,7 @@ class Fail2BanDb(object):
ip : str ip : str
IP to be removed. IP to be removed.
""" """
queryArgs = (jail.name, ip.ntoa()); queryArgs = (jail.name, str(ip));
cur.execute( cur.execute(
"DELETE FROM bans WHERE jail = ? AND ip = ?", "DELETE FROM bans WHERE jail = ? AND ip = ?",
queryArgs); queryArgs);
@ -455,7 +455,7 @@ class Fail2BanDb(object):
queryArgs.append(MyTime.time() - bantime) queryArgs.append(MyTime.time() - bantime)
if ip is not None: if ip is not None:
query += " AND ip=?" query += " AND ip=?"
queryArgs.append(ip.ntoa()) queryArgs.append(ip)
query += " ORDER BY ip, timeofban desc" query += " ORDER BY ip, timeofban desc"
return cur.execute(query, queryArgs) return cur.execute(query, queryArgs)
@ -471,7 +471,7 @@ class Fail2BanDb(object):
Ban time in seconds, such that bans returned would still be Ban time in seconds, such that bans returned would still be
valid now. Negative values are equivalent to `None`. valid now. Negative values are equivalent to `None`.
Default `None`; no limit. Default `None`; no limit.
ip : IPAddr object ip : str
IP Address to filter bans by. Default `None`; all IPs. IP Address to filter bans by. Default `None`; all IPs.
Returns Returns
@ -480,8 +480,7 @@ class Fail2BanDb(object):
List of `Ticket`s for bans stored in database. List of `Ticket`s for bans stored in database.
""" """
tickets = [] tickets = []
for ipstr, timeofban, data in self._getBans(**kwargs): for ip, timeofban, data in self._getBans(**kwargs):
ip = IPAddr(ipstr)
#TODO: Implement data parts once arbitrary match keys completed #TODO: Implement data parts once arbitrary match keys completed
tickets.append(FailTicket(ip, timeofban)) tickets.append(FailTicket(ip, timeofban))
tickets[-1].setData(data) tickets[-1].setData(data)
@ -501,7 +500,7 @@ class Fail2BanDb(object):
Ban time in seconds, such that bans returned would still be Ban time in seconds, such that bans returned would still be
valid now. Negative values are equivalent to `None`. valid now. Negative values are equivalent to `None`.
Default `None`; no limit. Default `None`; no limit.
ip : IPAddr object ip : str
IP Address to filter bans by. Default `None`; all IPs. IP Address to filter bans by. Default `None`; all IPs.
Returns Returns
@ -522,8 +521,6 @@ class Fail2BanDb(object):
ticket = None ticket = None
results = list(self._getBans(ip=ip, jail=jail, bantime=bantime)) results = list(self._getBans(ip=ip, jail=jail, bantime=bantime))
# Convert IP strings to IPAddr objects
results = map(lambda i:(IPAddr(i[0]),)+i[1:], results)
if results: if results:
prev_banip = results[0][0] prev_banip = results[0][0]
matches = [] matches = []

View File

@ -54,6 +54,15 @@ class FailManager:
with self.__lock: with self.__lock:
return self.__failTotal return self.__failTotal
def getFailCount(self):
# may be slow on large list of failures, should be used for test purposes only...
with self.__lock:
return len(self.__failList), sum([f.getRetry() for f in self.__failList.values()])
def getFailTotal(self):
with self.__lock:
return self.__failTotal
def setMaxRetry(self, value): def setMaxRetry(self, value):
self.__maxRetry = value self.__maxRetry = value

View File

@ -306,19 +306,15 @@ class Filter(JailThread):
def getIgnoreCommand(self): def getIgnoreCommand(self):
return self.__ignoreCommand return self.__ignoreCommand
##
# create new IPAddr object from IP address string
def newIP(self, ipstr):
return IPAddr(ipstr)
## ##
# Ban an IP - http://blogs.buanzo.com.ar/2009/04/fail2ban-patch-ban-ip-address-manually.html # Ban an IP - http://blogs.buanzo.com.ar/2009/04/fail2ban-patch-ban-ip-address-manually.html
# Arturo 'Buanzo' Busleiman <buanzo@buanzo.com.ar> # Arturo 'Buanzo' Busleiman <buanzo@buanzo.com.ar>
# #
# to enable banip fail2ban-client BAN command # to enable banip fail2ban-client BAN command
def addBannedIP(self, ipstr): def addBannedIP(self, ip):
ip = IPAddr(ipstr) if not isinstance(ip, IPAddr):
ip = IPAddr(ip)
if self.inIgnoreIPList(ip): if self.inIgnoreIPList(ip):
logSys.warning('Requested to manually ban an ignored IP %s. User knows best. Proceeding to ban it.' % ip) logSys.warning('Requested to manually ban an ignored IP %s. User knows best. Proceeding to ban it.' % ip)
@ -358,11 +354,11 @@ class Filter(JailThread):
ip = IPAddr(s[0], s[1]) ip = IPAddr(s[0], s[1])
# log and append to ignore list # log and append to ignore list
logSys.debug("Add " + ip + " to ignore list") logSys.debug("Add %r to ignore list (%r, %r)", ip, s[0], s[1])
self.__ignoreIpList.append(ip) self.__ignoreIpList.append(ip)
def delIgnoreIP(self, ip): def delIgnoreIP(self, ip):
logSys.debug("Remove " + ip + " from ignore list") logSys.debug("Remove %r from ignore list", ip)
self.__ignoreIpList.remove(ip) self.__ignoreIpList.remove(ip)
def logIgnoreIp(self, ip, log_ignore, ignore_source="unknown source"): def logIgnoreIp(self, ip, log_ignore, ignore_source="unknown source"):
@ -384,18 +380,9 @@ class Filter(JailThread):
if not isinstance(ip, IPAddr): if not isinstance(ip, IPAddr):
ip = IPAddr(ip) ip = IPAddr(ip)
for net in self.__ignoreIpList: for net in self.__ignoreIpList:
# if it isn't a valid IP address, try DNS resolution
if not net.isValidIP() and net.getRaw() != "":
# Check if IP in DNS
ips = DNSUtils.dnsToIp(net.getRaw())
if ip in ips:
self.logIgnoreIp(ip, log_ignore, ignore_source="dns")
return True
else:
continue
# check if the IP is covered by ignore IP # check if the IP is covered by ignore IP
if ip.isInNet(net): if ip.isInNet(net):
self.logIgnoreIp(ip, log_ignore, ignore_source="ip") self.logIgnoreIp(ip, log_ignore, ignore_source=("ip" if net.isValidIP() else "dns"))
return True return True
if self.__ignoreCommand: if self.__ignoreCommand:
@ -1006,8 +993,6 @@ from .utils import Utils
class DNSUtils: class DNSUtils:
IP_CRE = re.compile("^(?:\d{1,3}\.){3}\d{1,3}$")
# todo: make configurable the expired time and max count of cache entries: # todo: make configurable the expired time and max count of cache entries:
CACHE_nameToIp = Utils.Cache(maxCount=1000, maxTime=5*60) CACHE_nameToIp = Utils.Cache(maxCount=1000, maxTime=5*60)
CACHE_ipToName = Utils.Cache(maxCount=1000, maxTime=5*60) CACHE_ipToName = Utils.Cache(maxCount=1000, maxTime=5*60)
@ -1053,17 +1038,6 @@ class DNSUtils:
DNSUtils.CACHE_ipToName.set(ip, v) DNSUtils.CACHE_ipToName.set(ip, v)
return v return v
@staticmethod
def searchIP(text):
""" Search if an IP address if directly available and return
it.
"""
match = DNSUtils.IP_CRE.match(text)
if match:
return match
else:
return None
@staticmethod @staticmethod
def isValidIP(string): def isValidIP(string):
""" Return true if str is a valid IP """ Return true if str is a valid IP
@ -1082,7 +1056,7 @@ class DNSUtils:
ipList = list() ipList = list()
# Search for plain IP # Search for plain IP
plainIP = IPAddr.searchIP(text) plainIP = IPAddr.searchIP(text)
if not plainIP is None: if plainIP is not None:
ip = IPAddr(plainIP.group(0)) ip = IPAddr(plainIP.group(0))
if ip.isValidIP(): if ip.isValidIP():
ipList.append(ip) ipList.append(ip)
@ -1122,7 +1096,7 @@ class DNSUtils:
# #
# This class contains methods for handling IPv4 and IPv6 addresses. # This class contains methods for handling IPv4 and IPv6 addresses.
class IPAddr: class IPAddr(object):
""" provide functions to handle IPv4 and IPv6 addresses """ provide functions to handle IPv4 and IPv6 addresses
""" """
@ -1136,8 +1110,22 @@ class IPAddr:
valid = False valid = False
raw = "" raw = ""
# todo: make configurable the expired time and max count of cache entries:
CACHE_OBJ = Utils.Cache(maxCount=1000, maxTime=5*60)
def __new__(cls, ipstring, cidr=-1):
# already correct IPAddr
args = (ipstring, cidr)
ip = IPAddr.CACHE_OBJ.get(args)
if ip is not None:
return ip
ip = super(IPAddr, cls).__new__(cls)
ip.__init(ipstring, cidr)
IPAddr.CACHE_OBJ.set(args, ip)
return ip
# object methods # object methods
def __init__(self, ipstring, cidr=-1): def __init(self, ipstring, cidr=-1):
""" initialize IP object by converting IP address string """ initialize IP object by converting IP address string
to binary to integer to binary to integer
""" """
@ -1193,7 +1181,9 @@ class IPAddr:
return self.ntoa() return self.ntoa()
def __eq__(self, other): def __eq__(self, other):
other = other if isinstance(other, IPAddr) else IPAddr(other) if not isinstance(other, IPAddr):
if other is None: return False
other = IPAddr(other)
if not self.valid and not other.valid: return self.raw == other.raw if not self.valid and not other.valid: return self.raw == other.raw
if not self.valid or not other.valid: return False if not self.valid or not other.valid: return False
if self.addr != other.addr: return False if self.addr != other.addr: return False
@ -1202,7 +1192,9 @@ class IPAddr:
return True return True
def __ne__(self, other): def __ne__(self, other):
other = other if isinstance(other, IPAddr) else IPAddr(other) if not isinstance(other, IPAddr):
if other is None: return True
other = IPAddr(other)
if not self.valid and not other.valid: return self.raw != other.raw if not self.valid and not other.valid: return self.raw != other.raw
if self.addr != other.addr: return True if self.addr != other.addr: return True
if self.family != other.family: return True if self.family != other.family: return True
@ -1210,7 +1202,9 @@ class IPAddr:
return False return False
def __lt__(self, other): def __lt__(self, other):
other = other if isinstance(other, IPAddr) else IPAddr(other) if not isinstance(other, IPAddr):
if other is None: return False
other = IPAddr(other)
return self.family < other.family or self.addr < other.addr return self.family < other.family or self.addr < other.addr
def __add__(self, other): def __add__(self, other):
@ -1220,7 +1214,9 @@ class IPAddr:
return "%s%s" % (other, self) return "%s%s" % (other, self)
def __hash__(self): def __hash__(self):
return hash(self.addr)^hash((self.plen<<16)|self.family) # should be the same as by string (because of possible compare with string):
return hash(self.ntoa())
#return hash(self.addr)^hash((self.plen<<16)|self.family)
def hexdump(self): def hexdump(self):
""" dump the ip address in as a hex sequence in """ dump the ip address in as a hex sequence in
@ -1302,6 +1298,11 @@ class IPAddr:
""" returns true if the IP object is in the provided """ returns true if the IP object is in the provided
network (object) network (object)
""" """
# if it isn't a valid IP address, try DNS resolution
if not net.isValidIP() and net.getRaw() != "":
# Check if IP in DNS
return self in DNSUtils.dnsToIp(net.getRaw())
if self.family != net.family: if self.family != net.family:
return False return False
@ -1318,18 +1319,27 @@ class IPAddr:
return False return False
@property
def maskplen(self):
plen = 0
if (hasattr(self, '_maskplen')):
return self._plen
maddr = self.addr
while maddr:
if not (maddr & 0x80000000):
raise ValueError("invalid mask %r, no plen representation" % (self.ntoa(),))
maddr = (maddr << 1) & 0xFFFFFFFFL
plen += 1
self._maskplen = plen
return plen
@staticmethod @staticmethod
def masktoplen(maskstr): def masktoplen(maskstr):
""" converts mask string to prefix length """ converts mask string to prefix length
only used for IPv4 masks only used for IPv4 masks
""" """
mask = IPAddr(maskstr) return IPAddr(maskstr).maskplen
plen = 0
while mask.addr:
mask.addr = (mask.addr << 1) & 0xFFFFFFFFL
plen += 1
return plen
@staticmethod @staticmethod

View File

@ -72,6 +72,10 @@ class Ticket:
return False return False
def setIP(self, value): def setIP(self, value):
# guarantee using IPAddr instead of unicode, str for the IP
if isinstance(value, basestring):
from .filter import IPAddr
value = IPAddr(value)
self.__ip = value self.__ip = value
def getIP(self): def getIP(self):

View File

@ -28,6 +28,7 @@ import unittest
from ..server import failmanager from ..server import failmanager
from ..server.failmanager import FailManager, FailManagerEmpty from ..server.failmanager import FailManager, FailManagerEmpty
from ..server.filter import IPAddr
from ..server.ticket import FailTicket from ..server.ticket import FailTicket
@ -140,7 +141,7 @@ class AddFailure(unittest.TestCase):
#ticket = FailTicket('193.168.0.128', None) #ticket = FailTicket('193.168.0.128', None)
ticket = self.__failManager.toBan() ticket = self.__failManager.toBan()
self.assertEqual(ticket.getIP(), "193.168.0.128") self.assertEqual(ticket.getIP(), "193.168.0.128")
self.assertTrue(isinstance(ticket.getIP(), str)) self.assertTrue(isinstance(ticket.getIP(), (str, IPAddr)))
# finish with rudimentary tests of the ticket # finish with rudimentary tests of the ticket
# verify consistent str # verify consistent str

View File

@ -38,7 +38,7 @@ except ImportError:
from ..server.jail import Jail from ..server.jail import Jail
from ..server.filterpoll import FilterPoll from ..server.filterpoll import FilterPoll
from ..server.filter import Filter, FileFilter, FileContainer, DNSUtils from ..server.filter import Filter, FileFilter, FileContainer, DNSUtils, IPAddr
from ..server.failmanager import FailManagerEmpty from ..server.failmanager import FailManagerEmpty
from ..server.mytime import MyTime from ..server.mytime import MyTime
from ..server.utils import Utils from ..server.utils import Utils
@ -154,19 +154,41 @@ def _assert_correct_last_attempt(utest, filter_, output, count=None):
Test filter to contain target ticket Test filter to contain target ticket
""" """
# one or multiple tickets:
if not isinstance(output[0], (tuple,list)):
tickcount = 1
failcount = (count if count else output[1])
else:
tickcount = len(output)
failcount = (count if count else sum((o[1] for o in output)))
found = []
if isinstance(filter_, DummyJail): if isinstance(filter_, DummyJail):
# get fail ticket from jail # get fail ticket from jail
found = _ticket_tuple(filter_.getFailTicket()) found.append(_ticket_tuple(filter_.getFailTicket()))
else: else:
# when we are testing without jails # when we are testing without jails
# wait for failures (up to max time) # wait for failures (up to max time)
Utils.wait_for( Utils.wait_for(
lambda: filter_.failManager.getFailTotal() >= (count if count else output[1]), lambda: filter_.failManager.getFailCount() >= (tickcount, failcount),
_maxWaitTime(10)) _maxWaitTime(10))
# get fail ticket from filter # get fail ticket(s) from filter
found = _ticket_tuple(filter_.failManager.toBan()) while tickcount:
try:
found.append(_ticket_tuple(filter_.failManager.toBan()))
except FailManagerEmpty:
break
tickcount -= 1
_assert_equal_entries(utest, found, output, count) if not isinstance(output[0], (tuple,list)):
utest.assertEqual(len(found), 1)
_assert_equal_entries(utest, found[0], output, count)
else:
# sort by string representation of ip (multiple failures with different ips):
found = sorted(found, key=lambda x: str(x))
output = sorted(output, key=lambda x: str(x))
for f, o in zip(found, output):
_assert_equal_entries(utest, f, o)
def _copy_lines_between_files(in_, fout, n=None, skip=0, mode='a', terminal_line=""): def _copy_lines_between_files(in_, fout, n=None, skip=0, mode='a', terminal_line=""):
@ -315,6 +337,10 @@ class IgnoreIP(LogCaptureTestCase):
self.assertFalse(self.filter.inIgnoreIPList('192.168.1.255')) self.assertFalse(self.filter.inIgnoreIPList('192.168.1.255'))
self.assertFalse(self.filter.inIgnoreIPList('192.168.0.255')) self.assertFalse(self.filter.inIgnoreIPList('192.168.0.255'))
def testWrongIPMask(self):
self.filter.addIgnoreIP('192.168.1.0/255.255.0.0')
self.assertRaises(ValueError, self.filter.addIgnoreIP, '192.168.1.0/255.255.0.128')
def testIgnoreInProcessLine(self): def testIgnoreInProcessLine(self):
setUpMyTime() setUpMyTime()
self.filter.addIgnoreIP('192.168.1.0/25') self.filter.addIgnoreIP('192.168.1.0/25')
@ -345,16 +371,21 @@ class IgnoreIP(LogCaptureTestCase):
self.assertNotLogged("[%s] Ignore %s by %s" % (self.jail.name, "example.com", "NOT_LOGGED")) self.assertNotLogged("[%s] Ignore %s by %s" % (self.jail.name, "example.com", "NOT_LOGGED"))
class IgnoreIPDNS(IgnoreIP): class IgnoreIPDNS(LogCaptureTestCase):
def setUp(self): def setUp(self):
"""Call before every test case.""" """Call before every test case."""
unittest.F2B.SkipIfNoNetwork() unittest.F2B.SkipIfNoNetwork()
IgnoreIP.setUp(self) LogCaptureTestCase.setUp(self)
self.jail = DummyJail()
self.filter = FileFilter(self.jail)
def testIgnoreIPDNSOK(self): def testIgnoreIPDNSOK(self):
self.filter.addIgnoreIP("www.epfl.ch") self.filter.addIgnoreIP("www.epfl.ch")
self.assertTrue(self.filter.inIgnoreIPList("128.178.50.12")) self.assertTrue(self.filter.inIgnoreIPList("128.178.50.12"))
self.filter.addIgnoreIP("example.com")
self.assertTrue(self.filter.inIgnoreIPList("93.184.216.34"))
self.assertTrue(self.filter.inIgnoreIPList("2606:2800:220:1:248:1893:25c8:1946"))
def testIgnoreIPDNSNOK(self): def testIgnoreIPDNSNOK(self):
# Test DNS # Test DNS
@ -1109,18 +1140,18 @@ class GetFailures(LogCaptureTestCase):
_assert_correct_last_attempt(self, self.filter, output) _assert_correct_last_attempt(self, self.filter, output)
def testGetFailures04(self): def testGetFailures04(self):
output = [('212.41.96.186', 4, 1124013600.0), # because of not exact time in testcase04.log (no year), we should always use our test time:
('212.41.96.185', 4, 1124017198.0)] self.assertEqual(MyTime.time(), 1124013600)
# should find exact 4 failures for *.186 and 2 failures for *.185
output = (('212.41.96.186', 4, 1124013600.0),
('212.41.96.185', 2, 1124013598.0))
self.filter.setMaxRetry(2)
self.filter.addLogPath(GetFailures.FILENAME_04, autoSeek=0) self.filter.addLogPath(GetFailures.FILENAME_04, autoSeek=0)
self.filter.addFailRegex("Invalid user .* <HOST>") self.filter.addFailRegex("Invalid user .* <HOST>")
self.filter.getFailures(GetFailures.FILENAME_04) self.filter.getFailures(GetFailures.FILENAME_04)
try: _assert_correct_last_attempt(self, self.filter, output)
for i, out in enumerate(output):
_assert_correct_last_attempt(self, self.filter, out)
except FailManagerEmpty:
pass
def testGetFailuresWrongChar(self): def testGetFailuresWrongChar(self):
# write wrong utf-8 char: # write wrong utf-8 char:
@ -1160,20 +1191,31 @@ class GetFailures(LogCaptureTestCase):
def testGetFailuresUseDNS(self): def testGetFailuresUseDNS(self):
unittest.F2B.SkipIfNoNetwork() unittest.F2B.SkipIfNoNetwork()
# We should still catch failures with usedns = no ;-) # We should still catch failures with usedns = no ;-)
output_yes = ('93.184.216.34', 2, 1124013539.0, output_yes = (
('93.184.216.34', 2, 1124013539.0,
[u'Aug 14 11:54:59 i60p295 sshd[12365]: Failed publickey for roehl from example.com port 51332 ssh2', [u'Aug 14 11:54:59 i60p295 sshd[12365]: Failed publickey for roehl from example.com port 51332 ssh2',
u'Aug 14 11:58:59 i60p295 sshd[12365]: Failed publickey for roehl from ::ffff:93.184.216.34 port 51332 ssh2']) u'Aug 14 11:58:59 i60p295 sshd[12365]: Failed publickey for roehl from ::ffff:93.184.216.34 port 51332 ssh2']
),
('2606:2800:220:1:248:1893:25c8:1946', 1, 1124013299.0,
[u'Aug 14 11:54:59 i60p295 sshd[12365]: Failed publickey for roehl from example.com port 51332 ssh2']
),
)
output_no = ('93.184.216.34', 1, 1124013539.0, output_no = (
[u'Aug 14 11:58:59 i60p295 sshd[12365]: Failed publickey for roehl from ::ffff:93.184.216.34 port 51332 ssh2']) ('93.184.216.34', 1, 1124013539.0,
[u'Aug 14 11:58:59 i60p295 sshd[12365]: Failed publickey for roehl from ::ffff:93.184.216.34 port 51332 ssh2']
)
)
# Actually no exception would be raised -- it will be just set to 'no' # Actually no exception would be raised -- it will be just set to 'no'
#self.assertRaises(ValueError, #self.assertRaises(ValueError,
# FileFilter, None, useDns='wrong_value_for_useDns') # FileFilter, None, useDns='wrong_value_for_useDns')
for useDns, output in (('yes', output_yes), for useDns, output in (
('yes', output_yes),
('no', output_no), ('no', output_no),
('warn', output_yes)): ('warn', output_yes)
):
jail = DummyJail() jail = DummyJail()
filter_ = FileFilter(jail, useDns=useDns) filter_ = FileFilter(jail, useDns=useDns)
filter_.active = True filter_.active = True
@ -1356,6 +1398,71 @@ class DNSUtilsNetworkTests(unittest.TestCase):
res = DNSUtils.bin2addr(167772160L) res = DNSUtils.bin2addr(167772160L)
self.assertEqual(res, '10.0.0.0') self.assertEqual(res, '10.0.0.0')
def testIPAddr_Equal6(self):
self.assertEqual(
IPAddr('2606:2800:220:1:248:1893::'),
IPAddr('2606:2800:220:1:248:1893:0:0')
)
def testIPAddr_Compare(self):
ip4 = [
IPAddr('93.184.0.1'),
IPAddr('93.184.216.1'),
IPAddr('93.184.216.34')
]
ip6 = [
IPAddr('2606:2800:220:1:248:1893::'),
IPAddr('2606:2800:220:1:248:1893:25c8:0'),
IPAddr('2606:2800:220:1:248:1893:25c8:1946')
]
# ip4
self.assertNotEqual(ip4[0], None)
self.assertTrue(ip4[0] is not None)
self.assertFalse(ip4[0] is None)
self.assertLess(None, ip4[0])
self.assertLess(ip4[0], ip4[1])
self.assertLess(ip4[1], ip4[2])
self.assertEqual(sorted(reversed(ip4)), ip4)
# ip6
self.assertNotEqual(ip6[0], None)
self.assertTrue(ip6[0] is not None)
self.assertFalse(ip6[0] is None)
self.assertLess(None, ip6[0])
self.assertLess(ip6[0], ip6[1])
self.assertLess(ip6[1], ip6[2])
self.assertEqual(sorted(reversed(ip6)), ip6)
# ip4 vs ip6
self.assertNotEqual(ip4[0], ip6[0])
self.assertLess(ip4[0], ip6[0])
self.assertLess(ip4[2], ip6[2])
self.assertEqual(sorted(reversed(ip4+ip6)), ip4+ip6)
# hashing (with string as key):
d={
'93.184.216.34': 'ip4-test',
'2606:2800:220:1:248:1893:25c8:1946': 'ip6-test'
}
d2 = dict([(IPAddr(k), v) for k, v in d.iteritems()])
self.assertTrue(isinstance(d.keys()[0], basestring))
self.assertTrue(isinstance(d2.keys()[0], IPAddr))
self.assertEqual(d.get(ip4[2], ''), 'ip4-test')
self.assertEqual(d.get(ip6[2], ''), 'ip6-test')
self.assertEqual(d2.get(str(ip4[2]), ''), 'ip4-test')
self.assertEqual(d2.get(str(ip6[2]), ''), 'ip6-test')
# compare with string direct:
self.assertEqual(d, d2)
def testIPAddr_CompareDNS(self):
ips = IPAddr('example.com')
self.assertTrue(IPAddr("93.184.216.34").isInNet(ips))
self.assertTrue(IPAddr("2606:2800:220:1:248:1893:25c8:1946").isInNet(ips))
def testIPAddr_Cached(self):
ips = [DNSUtils.dnsToIp('example.com'), DNSUtils.dnsToIp('example.com')]
for ip1, ip2 in zip(ips, ips):
self.assertEqual(id(ip1), id(ip2))
ip1 = IPAddr('93.184.216.34'); ip2 = IPAddr('93.184.216.34'); self.assertEqual(id(ip1), id(ip2))
ip1 = IPAddr('2606:2800:220:1:248:1893:25c8:1946'); ip2 = IPAddr('2606:2800:220:1:248:1893:25c8:1946'); self.assertEqual(id(ip1), id(ip2))
class JailTests(unittest.TestCase): class JailTests(unittest.TestCase):