meantime commit: code review, simplification, pythonization, etc. (test cases passed)

unnecessarily code aggravation with explicit converting reverted - implicit converting inside internal functions if not IPAddr object;
pull/1414/head
sebres 2016-05-04 19:44:52 +02:00
parent 85b895178b
commit afe1f73af2
7 changed files with 212 additions and 86 deletions

View File

@ -181,7 +181,7 @@ class Actions(JailThread, Mapping):
def getBanTime(self):
return self.__banManager.getBanTime()
def removeBannedIP(self, ipstr):
def removeBannedIP(self, ip):
"""Removes banned IP calling actions' unban method
Remove a banned IP now, rather than waiting for it to expire,
@ -189,16 +189,14 @@ class Actions(JailThread, Mapping):
Parameters
----------
ipstr : str
The IP address string to unban
ip : str or IPAddr
The IP address to unban
Raises
------
ValueError
If `ip` is not banned
"""
# Create new IPAddr object from IP string
ip = IPAddr(ipstr)
# Always delete ip from database (also if currently not banned)
if self._jail.database is not None:
self._jail.database.delBan(self._jail, ip)

View File

@ -32,7 +32,6 @@ from threading import RLock
from .mytime import MyTime
from .ticket import FailTicket
from .filter import IPAddr
from ..helpers import getLogger
# Gets the instance of the logger.
@ -412,18 +411,19 @@ class Fail2BanDb(object):
ticket : BanTicket
Ticket of the ban to be added.
"""
ip = str(ticket.getIP())
try:
del self._bansMergedCache[(ticket.getIP(), jail)]
del self._bansMergedCache[(ip, jail)]
except KeyError:
pass
try:
del self._bansMergedCache[(ticket.getIP(), None)]
del self._bansMergedCache[(ip, None)]
except KeyError:
pass
#TODO: Implement data parts once arbitrary match keys completed
cur.execute(
"INSERT INTO bans(jail, ip, timeofban, data) VALUES(?, ?, ?, ?)",
(jail.name, ticket.getIP().ntoa(), int(round(ticket.getTime())),
(jail.name, ip, int(round(ticket.getTime())),
ticket.getData()))
@commitandrollback
@ -437,7 +437,7 @@ class Fail2BanDb(object):
ip : str
IP to be removed.
"""
queryArgs = (jail.name, ip.ntoa());
queryArgs = (jail.name, str(ip));
cur.execute(
"DELETE FROM bans WHERE jail = ? AND ip = ?",
queryArgs);
@ -455,7 +455,7 @@ class Fail2BanDb(object):
queryArgs.append(MyTime.time() - bantime)
if ip is not None:
query += " AND ip=?"
queryArgs.append(ip.ntoa())
queryArgs.append(ip)
query += " ORDER BY ip, timeofban desc"
return cur.execute(query, queryArgs)
@ -471,7 +471,7 @@ class Fail2BanDb(object):
Ban time in seconds, such that bans returned would still be
valid now. Negative values are equivalent to `None`.
Default `None`; no limit.
ip : IPAddr object
ip : str
IP Address to filter bans by. Default `None`; all IPs.
Returns
@ -480,8 +480,7 @@ class Fail2BanDb(object):
List of `Ticket`s for bans stored in database.
"""
tickets = []
for ipstr, timeofban, data in self._getBans(**kwargs):
ip = IPAddr(ipstr)
for ip, timeofban, data in self._getBans(**kwargs):
#TODO: Implement data parts once arbitrary match keys completed
tickets.append(FailTicket(ip, timeofban))
tickets[-1].setData(data)
@ -501,7 +500,7 @@ class Fail2BanDb(object):
Ban time in seconds, such that bans returned would still be
valid now. Negative values are equivalent to `None`.
Default `None`; no limit.
ip : IPAddr object
ip : str
IP Address to filter bans by. Default `None`; all IPs.
Returns
@ -522,8 +521,6 @@ class Fail2BanDb(object):
ticket = None
results = list(self._getBans(ip=ip, jail=jail, bantime=bantime))
# Convert IP strings to IPAddr objects
results = map(lambda i:(IPAddr(i[0]),)+i[1:], results)
if results:
prev_banip = results[0][0]
matches = []

View File

@ -54,6 +54,15 @@ class FailManager:
with self.__lock:
return self.__failTotal
def getFailCount(self):
# may be slow on large list of failures, should be used for test purposes only...
with self.__lock:
return len(self.__failList), sum([f.getRetry() for f in self.__failList.values()])
def getFailTotal(self):
with self.__lock:
return self.__failTotal
def setMaxRetry(self, value):
self.__maxRetry = value

View File

@ -306,19 +306,15 @@ class Filter(JailThread):
def getIgnoreCommand(self):
return self.__ignoreCommand
##
# create new IPAddr object from IP address string
def newIP(self, ipstr):
return IPAddr(ipstr)
##
# Ban an IP - http://blogs.buanzo.com.ar/2009/04/fail2ban-patch-ban-ip-address-manually.html
# Arturo 'Buanzo' Busleiman <buanzo@buanzo.com.ar>
#
# to enable banip fail2ban-client BAN command
def addBannedIP(self, ipstr):
ip = IPAddr(ipstr)
def addBannedIP(self, ip):
if not isinstance(ip, IPAddr):
ip = IPAddr(ip)
if self.inIgnoreIPList(ip):
logSys.warning('Requested to manually ban an ignored IP %s. User knows best. Proceeding to ban it.' % ip)
@ -358,11 +354,11 @@ class Filter(JailThread):
ip = IPAddr(s[0], s[1])
# log and append to ignore list
logSys.debug("Add " + ip + " to ignore list")
logSys.debug("Add %r to ignore list (%r, %r)", ip, s[0], s[1])
self.__ignoreIpList.append(ip)
def delIgnoreIP(self, ip):
logSys.debug("Remove " + ip + " from ignore list")
logSys.debug("Remove %r from ignore list", ip)
self.__ignoreIpList.remove(ip)
def logIgnoreIp(self, ip, log_ignore, ignore_source="unknown source"):
@ -384,18 +380,9 @@ class Filter(JailThread):
if not isinstance(ip, IPAddr):
ip = IPAddr(ip)
for net in self.__ignoreIpList:
# if it isn't a valid IP address, try DNS resolution
if not net.isValidIP() and net.getRaw() != "":
# Check if IP in DNS
ips = DNSUtils.dnsToIp(net.getRaw())
if ip in ips:
self.logIgnoreIp(ip, log_ignore, ignore_source="dns")
return True
else:
continue
# check if the IP is covered by ignore IP
if ip.isInNet(net):
self.logIgnoreIp(ip, log_ignore, ignore_source="ip")
self.logIgnoreIp(ip, log_ignore, ignore_source=("ip" if net.isValidIP() else "dns"))
return True
if self.__ignoreCommand:
@ -1006,8 +993,6 @@ from .utils import Utils
class DNSUtils:
IP_CRE = re.compile("^(?:\d{1,3}\.){3}\d{1,3}$")
# todo: make configurable the expired time and max count of cache entries:
CACHE_nameToIp = Utils.Cache(maxCount=1000, maxTime=5*60)
CACHE_ipToName = Utils.Cache(maxCount=1000, maxTime=5*60)
@ -1053,17 +1038,6 @@ class DNSUtils:
DNSUtils.CACHE_ipToName.set(ip, v)
return v
@staticmethod
def searchIP(text):
""" Search if an IP address if directly available and return
it.
"""
match = DNSUtils.IP_CRE.match(text)
if match:
return match
else:
return None
@staticmethod
def isValidIP(string):
""" Return true if str is a valid IP
@ -1082,7 +1056,7 @@ class DNSUtils:
ipList = list()
# Search for plain IP
plainIP = IPAddr.searchIP(text)
if not plainIP is None:
if plainIP is not None:
ip = IPAddr(plainIP.group(0))
if ip.isValidIP():
ipList.append(ip)
@ -1122,7 +1096,7 @@ class DNSUtils:
#
# This class contains methods for handling IPv4 and IPv6 addresses.
class IPAddr:
class IPAddr(object):
""" provide functions to handle IPv4 and IPv6 addresses
"""
@ -1136,8 +1110,22 @@ class IPAddr:
valid = False
raw = ""
# todo: make configurable the expired time and max count of cache entries:
CACHE_OBJ = Utils.Cache(maxCount=1000, maxTime=5*60)
def __new__(cls, ipstring, cidr=-1):
# already correct IPAddr
args = (ipstring, cidr)
ip = IPAddr.CACHE_OBJ.get(args)
if ip is not None:
return ip
ip = super(IPAddr, cls).__new__(cls)
ip.__init(ipstring, cidr)
IPAddr.CACHE_OBJ.set(args, ip)
return ip
# object methods
def __init__(self, ipstring, cidr=-1):
def __init(self, ipstring, cidr=-1):
""" initialize IP object by converting IP address string
to binary to integer
"""
@ -1193,7 +1181,9 @@ class IPAddr:
return self.ntoa()
def __eq__(self, other):
other = other if isinstance(other, IPAddr) else IPAddr(other)
if not isinstance(other, IPAddr):
if other is None: return False
other = IPAddr(other)
if not self.valid and not other.valid: return self.raw == other.raw
if not self.valid or not other.valid: return False
if self.addr != other.addr: return False
@ -1202,7 +1192,9 @@ class IPAddr:
return True
def __ne__(self, other):
other = other if isinstance(other, IPAddr) else IPAddr(other)
if not isinstance(other, IPAddr):
if other is None: return True
other = IPAddr(other)
if not self.valid and not other.valid: return self.raw != other.raw
if self.addr != other.addr: return True
if self.family != other.family: return True
@ -1210,7 +1202,9 @@ class IPAddr:
return False
def __lt__(self, other):
other = other if isinstance(other, IPAddr) else IPAddr(other)
if not isinstance(other, IPAddr):
if other is None: return False
other = IPAddr(other)
return self.family < other.family or self.addr < other.addr
def __add__(self, other):
@ -1220,7 +1214,9 @@ class IPAddr:
return "%s%s" % (other, self)
def __hash__(self):
return hash(self.addr)^hash((self.plen<<16)|self.family)
# should be the same as by string (because of possible compare with string):
return hash(self.ntoa())
#return hash(self.addr)^hash((self.plen<<16)|self.family)
def hexdump(self):
""" dump the ip address in as a hex sequence in
@ -1302,6 +1298,11 @@ class IPAddr:
""" returns true if the IP object is in the provided
network (object)
"""
# if it isn't a valid IP address, try DNS resolution
if not net.isValidIP() and net.getRaw() != "":
# Check if IP in DNS
return self in DNSUtils.dnsToIp(net.getRaw())
if self.family != net.family:
return False
@ -1318,18 +1319,27 @@ class IPAddr:
return False
@property
def maskplen(self):
plen = 0
if (hasattr(self, '_maskplen')):
return self._plen
maddr = self.addr
while maddr:
if not (maddr & 0x80000000):
raise ValueError("invalid mask %r, no plen representation" % (self.ntoa(),))
maddr = (maddr << 1) & 0xFFFFFFFFL
plen += 1
self._maskplen = plen
return plen
@staticmethod
def masktoplen(maskstr):
""" converts mask string to prefix length
only used for IPv4 masks
"""
mask = IPAddr(maskstr)
plen = 0
while mask.addr:
mask.addr = (mask.addr << 1) & 0xFFFFFFFFL
plen += 1
return plen
return IPAddr(maskstr).maskplen
@staticmethod

View File

@ -72,6 +72,10 @@ class Ticket:
return False
def setIP(self, value):
# guarantee using IPAddr instead of unicode, str for the IP
if isinstance(value, basestring):
from .filter import IPAddr
value = IPAddr(value)
self.__ip = value
def getIP(self):

View File

@ -28,6 +28,7 @@ import unittest
from ..server import failmanager
from ..server.failmanager import FailManager, FailManagerEmpty
from ..server.filter import IPAddr
from ..server.ticket import FailTicket
@ -140,7 +141,7 @@ class AddFailure(unittest.TestCase):
#ticket = FailTicket('193.168.0.128', None)
ticket = self.__failManager.toBan()
self.assertEqual(ticket.getIP(), "193.168.0.128")
self.assertTrue(isinstance(ticket.getIP(), str))
self.assertTrue(isinstance(ticket.getIP(), (str, IPAddr)))
# finish with rudimentary tests of the ticket
# verify consistent str

View File

@ -38,7 +38,7 @@ except ImportError:
from ..server.jail import Jail
from ..server.filterpoll import FilterPoll
from ..server.filter import Filter, FileFilter, FileContainer, DNSUtils
from ..server.filter import Filter, FileFilter, FileContainer, DNSUtils, IPAddr
from ..server.failmanager import FailManagerEmpty
from ..server.mytime import MyTime
from ..server.utils import Utils
@ -154,19 +154,41 @@ def _assert_correct_last_attempt(utest, filter_, output, count=None):
Test filter to contain target ticket
"""
# one or multiple tickets:
if not isinstance(output[0], (tuple,list)):
tickcount = 1
failcount = (count if count else output[1])
else:
tickcount = len(output)
failcount = (count if count else sum((o[1] for o in output)))
found = []
if isinstance(filter_, DummyJail):
# get fail ticket from jail
found = _ticket_tuple(filter_.getFailTicket())
found.append(_ticket_tuple(filter_.getFailTicket()))
else:
# when we are testing without jails
# wait for failures (up to max time)
Utils.wait_for(
lambda: filter_.failManager.getFailTotal() >= (count if count else output[1]),
lambda: filter_.failManager.getFailCount() >= (tickcount, failcount),
_maxWaitTime(10))
# get fail ticket from filter
found = _ticket_tuple(filter_.failManager.toBan())
# get fail ticket(s) from filter
while tickcount:
try:
found.append(_ticket_tuple(filter_.failManager.toBan()))
except FailManagerEmpty:
break
tickcount -= 1
_assert_equal_entries(utest, found, output, count)
if not isinstance(output[0], (tuple,list)):
utest.assertEqual(len(found), 1)
_assert_equal_entries(utest, found[0], output, count)
else:
# sort by string representation of ip (multiple failures with different ips):
found = sorted(found, key=lambda x: str(x))
output = sorted(output, key=lambda x: str(x))
for f, o in zip(found, output):
_assert_equal_entries(utest, f, o)
def _copy_lines_between_files(in_, fout, n=None, skip=0, mode='a', terminal_line=""):
@ -315,6 +337,10 @@ class IgnoreIP(LogCaptureTestCase):
self.assertFalse(self.filter.inIgnoreIPList('192.168.1.255'))
self.assertFalse(self.filter.inIgnoreIPList('192.168.0.255'))
def testWrongIPMask(self):
self.filter.addIgnoreIP('192.168.1.0/255.255.0.0')
self.assertRaises(ValueError, self.filter.addIgnoreIP, '192.168.1.0/255.255.0.128')
def testIgnoreInProcessLine(self):
setUpMyTime()
self.filter.addIgnoreIP('192.168.1.0/25')
@ -345,16 +371,21 @@ class IgnoreIP(LogCaptureTestCase):
self.assertNotLogged("[%s] Ignore %s by %s" % (self.jail.name, "example.com", "NOT_LOGGED"))
class IgnoreIPDNS(IgnoreIP):
class IgnoreIPDNS(LogCaptureTestCase):
def setUp(self):
"""Call before every test case."""
unittest.F2B.SkipIfNoNetwork()
IgnoreIP.setUp(self)
LogCaptureTestCase.setUp(self)
self.jail = DummyJail()
self.filter = FileFilter(self.jail)
def testIgnoreIPDNSOK(self):
self.filter.addIgnoreIP("www.epfl.ch")
self.assertTrue(self.filter.inIgnoreIPList("128.178.50.12"))
self.filter.addIgnoreIP("example.com")
self.assertTrue(self.filter.inIgnoreIPList("93.184.216.34"))
self.assertTrue(self.filter.inIgnoreIPList("2606:2800:220:1:248:1893:25c8:1946"))
def testIgnoreIPDNSNOK(self):
# Test DNS
@ -1109,18 +1140,18 @@ class GetFailures(LogCaptureTestCase):
_assert_correct_last_attempt(self, self.filter, output)
def testGetFailures04(self):
output = [('212.41.96.186', 4, 1124013600.0),
('212.41.96.185', 4, 1124017198.0)]
# because of not exact time in testcase04.log (no year), we should always use our test time:
self.assertEqual(MyTime.time(), 1124013600)
# should find exact 4 failures for *.186 and 2 failures for *.185
output = (('212.41.96.186', 4, 1124013600.0),
('212.41.96.185', 2, 1124013598.0))
self.filter.setMaxRetry(2)
self.filter.addLogPath(GetFailures.FILENAME_04, autoSeek=0)
self.filter.addFailRegex("Invalid user .* <HOST>")
self.filter.getFailures(GetFailures.FILENAME_04)
try:
for i, out in enumerate(output):
_assert_correct_last_attempt(self, self.filter, out)
except FailManagerEmpty:
pass
_assert_correct_last_attempt(self, self.filter, output)
def testGetFailuresWrongChar(self):
# write wrong utf-8 char:
@ -1160,20 +1191,31 @@ class GetFailures(LogCaptureTestCase):
def testGetFailuresUseDNS(self):
unittest.F2B.SkipIfNoNetwork()
# We should still catch failures with usedns = no ;-)
output_yes = ('93.184.216.34', 2, 1124013539.0,
[u'Aug 14 11:54:59 i60p295 sshd[12365]: Failed publickey for roehl from example.com port 51332 ssh2',
u'Aug 14 11:58:59 i60p295 sshd[12365]: Failed publickey for roehl from ::ffff:93.184.216.34 port 51332 ssh2'])
output_yes = (
('93.184.216.34', 2, 1124013539.0,
[u'Aug 14 11:54:59 i60p295 sshd[12365]: Failed publickey for roehl from example.com port 51332 ssh2',
u'Aug 14 11:58:59 i60p295 sshd[12365]: Failed publickey for roehl from ::ffff:93.184.216.34 port 51332 ssh2']
),
('2606:2800:220:1:248:1893:25c8:1946', 1, 1124013299.0,
[u'Aug 14 11:54:59 i60p295 sshd[12365]: Failed publickey for roehl from example.com port 51332 ssh2']
),
)
output_no = ('93.184.216.34', 1, 1124013539.0,
[u'Aug 14 11:58:59 i60p295 sshd[12365]: Failed publickey for roehl from ::ffff:93.184.216.34 port 51332 ssh2'])
output_no = (
('93.184.216.34', 1, 1124013539.0,
[u'Aug 14 11:58:59 i60p295 sshd[12365]: Failed publickey for roehl from ::ffff:93.184.216.34 port 51332 ssh2']
)
)
# Actually no exception would be raised -- it will be just set to 'no'
#self.assertRaises(ValueError,
# FileFilter, None, useDns='wrong_value_for_useDns')
for useDns, output in (('yes', output_yes),
('no', output_no),
('warn', output_yes)):
for useDns, output in (
('yes', output_yes),
('no', output_no),
('warn', output_yes)
):
jail = DummyJail()
filter_ = FileFilter(jail, useDns=useDns)
filter_.active = True
@ -1356,6 +1398,71 @@ class DNSUtilsNetworkTests(unittest.TestCase):
res = DNSUtils.bin2addr(167772160L)
self.assertEqual(res, '10.0.0.0')
def testIPAddr_Equal6(self):
self.assertEqual(
IPAddr('2606:2800:220:1:248:1893::'),
IPAddr('2606:2800:220:1:248:1893:0:0')
)
def testIPAddr_Compare(self):
ip4 = [
IPAddr('93.184.0.1'),
IPAddr('93.184.216.1'),
IPAddr('93.184.216.34')
]
ip6 = [
IPAddr('2606:2800:220:1:248:1893::'),
IPAddr('2606:2800:220:1:248:1893:25c8:0'),
IPAddr('2606:2800:220:1:248:1893:25c8:1946')
]
# ip4
self.assertNotEqual(ip4[0], None)
self.assertTrue(ip4[0] is not None)
self.assertFalse(ip4[0] is None)
self.assertLess(None, ip4[0])
self.assertLess(ip4[0], ip4[1])
self.assertLess(ip4[1], ip4[2])
self.assertEqual(sorted(reversed(ip4)), ip4)
# ip6
self.assertNotEqual(ip6[0], None)
self.assertTrue(ip6[0] is not None)
self.assertFalse(ip6[0] is None)
self.assertLess(None, ip6[0])
self.assertLess(ip6[0], ip6[1])
self.assertLess(ip6[1], ip6[2])
self.assertEqual(sorted(reversed(ip6)), ip6)
# ip4 vs ip6
self.assertNotEqual(ip4[0], ip6[0])
self.assertLess(ip4[0], ip6[0])
self.assertLess(ip4[2], ip6[2])
self.assertEqual(sorted(reversed(ip4+ip6)), ip4+ip6)
# hashing (with string as key):
d={
'93.184.216.34': 'ip4-test',
'2606:2800:220:1:248:1893:25c8:1946': 'ip6-test'
}
d2 = dict([(IPAddr(k), v) for k, v in d.iteritems()])
self.assertTrue(isinstance(d.keys()[0], basestring))
self.assertTrue(isinstance(d2.keys()[0], IPAddr))
self.assertEqual(d.get(ip4[2], ''), 'ip4-test')
self.assertEqual(d.get(ip6[2], ''), 'ip6-test')
self.assertEqual(d2.get(str(ip4[2]), ''), 'ip4-test')
self.assertEqual(d2.get(str(ip6[2]), ''), 'ip6-test')
# compare with string direct:
self.assertEqual(d, d2)
def testIPAddr_CompareDNS(self):
ips = IPAddr('example.com')
self.assertTrue(IPAddr("93.184.216.34").isInNet(ips))
self.assertTrue(IPAddr("2606:2800:220:1:248:1893:25c8:1946").isInNet(ips))
def testIPAddr_Cached(self):
ips = [DNSUtils.dnsToIp('example.com'), DNSUtils.dnsToIp('example.com')]
for ip1, ip2 in zip(ips, ips):
self.assertEqual(id(ip1), id(ip2))
ip1 = IPAddr('93.184.216.34'); ip2 = IPAddr('93.184.216.34'); self.assertEqual(id(ip1), id(ip2))
ip1 = IPAddr('2606:2800:220:1:248:1893:25c8:1946'); ip2 = IPAddr('2606:2800:220:1:248:1893:25c8:1946'); self.assertEqual(id(ip1), id(ip2))
class JailTests(unittest.TestCase):