diff --git a/fail2ban/server/ipdns.py b/fail2ban/server/ipdns.py index aca0f6a6..75e21a31 100644 --- a/fail2ban/server/ipdns.py +++ b/fail2ban/server/ipdns.py @@ -154,17 +154,18 @@ class DNSUtils: # try find cached own hostnames (this tuple-key cannot be used elsewhere): key = ('self','hostname', fqdn) name = DNSUtils.CACHE_ipToName.get(key) + if name is not None: + return name # get it using different ways (hostname, fully-qualified or vice versa): - if name is None: - name = '' - for hostname in ( - (getfqdn, socket.gethostname) if fqdn else (socket.gethostname, getfqdn) - ): - try: - name = hostname() - break - except Exception as e: # pragma: no cover - logSys.warning("Retrieving own hostnames failed: %s", e) + name = '' + for hostname in ( + (getfqdn, socket.gethostname) if fqdn else (socket.gethostname, getfqdn) + ): + try: + name = hostname() + break + except Exception as e: # pragma: no cover + logSys.warning("Retrieving own hostnames failed: %s", e) # cache and return : DNSUtils.CACHE_ipToName.set(key, name) return name @@ -177,11 +178,12 @@ class DNSUtils: """Get own host names of self""" # try find cached own hostnames: names = DNSUtils.CACHE_ipToName.get(DNSUtils._getSelfNames_key) + if names is not None: + return names # get it using different ways (a set with names of localhost, hostname, fully qualified): - if names is None: - names = set([ - 'localhost', DNSUtils.getHostname(False), DNSUtils.getHostname(True) - ]) - set(['']) # getHostname can return '' + names = set([ + 'localhost', DNSUtils.getHostname(False), DNSUtils.getHostname(True) + ]) - set(['']) # getHostname can return '' # cache and return : DNSUtils.CACHE_ipToName.set(DNSUtils._getSelfNames_key, names) return names @@ -194,14 +196,19 @@ class DNSUtils: """Get own IP addresses of self""" # to find cached own IPs: ips = DNSUtils.CACHE_nameToIp.get(DNSUtils._getSelfIPs_key) - # get it using different ways (a set with IPs of localhost, hostname, fully qualified): - if ips is None: - ips = set() - for hostname in DNSUtils.getSelfNames(): - try: - ips |= set(DNSUtils.textToIp(hostname, 'yes')) - except Exception as e: # pragma: no cover - logSys.warning("Retrieving own IPs of %s failed: %s", hostname, e) + if ips is not None: + return ips + # firstly try to obtain from network interfaces if possible (implemented for this platform): + try: + ips = IPAddrSet([a for ni, a in DNSUtils._NetworkInterfacesAddrs()]) + except: + ips = IPAddrSet() + # extend it using different ways (a set with IPs of localhost, hostname, fully qualified): + for hostname in DNSUtils.getSelfNames(): + try: + ips |= IPAddrSet(DNSUtils.textToIp(hostname, 'yes')) + except Exception as e: # pragma: no cover + logSys.warning("Retrieving own IPs of %s failed: %s", hostname, e) # cache and return : DNSUtils.CACHE_nameToIp.set(DNSUtils._getSelfIPs_key, ips) return ips @@ -586,6 +593,9 @@ class IPAddr(object): """ return isinstance(ip, IPAddr) and (ip == self or ip.isInNet(self)) + def __contains__(self, ip): + return self.contains(ip) + # Pre-calculated map: addr to maskplen def __getMaskMap(): m6 = (1 << 128)-1 @@ -635,3 +645,111 @@ class IPAddr(object): # An IPv4 compatible IPv6 to be reused IPAddr.IP6_4COMPAT = IPAddr("::ffff:0:0", 96) + + +class IPAddrSet(set): + + def __contains__(self, ip): + if not isinstance(ip, IPAddr): ip = IPAddr(ip) + # IP can be found directly or IP is in each subnet: + return set.__contains__(self, ip) or any(n.contains(ip) for n in self) + + +def _NetworkInterfacesAddrs(): + + # Closure implementing lazy load modules and libc and define _NetworkInterfacesAddrs on demand: + # Currently tested on Linux only (TODO: implement for MacOS, Solaris, etc) + + from ctypes import ( + Structure, Union, POINTER, + pointer, get_errno, cast, + c_ushort, c_byte, c_void_p, c_char_p, c_uint, c_int, c_uint16, c_uint32 + ) + import ctypes.util + import ctypes + + class struct_sockaddr(Structure): + _fields_ = [ + ('sa_family', c_ushort), + ('sa_data', c_byte * 14),] + + class struct_sockaddr_in(Structure): + _fields_ = [ + ('sin_family', c_ushort), + ('sin_port', c_uint16), + ('sin_addr', c_byte * 4)] + + class struct_sockaddr_in6(Structure): + _fields_ = [ + ('sin6_family', c_ushort), + ('sin6_port', c_uint16), + ('sin6_flowinfo', c_uint32), + ('sin6_addr', c_byte * 16), + ('sin6_scope_id', c_uint32)] + + class union_ifa_ifu(Union): + _fields_ = [ + ('ifu_broadaddr', POINTER(struct_sockaddr)), + ('ifu_dstaddr', POINTER(struct_sockaddr)),] + + class struct_ifaddrs(Structure): + pass + struct_ifaddrs._fields_ = [ + ('ifa_next', POINTER(struct_ifaddrs)), + ('ifa_name', c_char_p), + ('ifa_flags', c_uint), + ('ifa_addr', POINTER(struct_sockaddr)), + ('ifa_netmask', POINTER(struct_sockaddr)), + ('ifa_ifu', union_ifa_ifu), + ('ifa_data', c_void_p),] + + libc = ctypes.CDLL(ctypes.util.find_library('c')) + + def ifap_iter(ifap): + ifa = ifap.contents + while True: + yield ifa + if not ifa.ifa_next: + break + ifa = ifa.ifa_next.contents + + def getfamaddr(ifa): + sa = ifa.ifa_addr.contents + fam = sa.sa_family + if fam == socket.AF_INET: + sa = cast(pointer(sa), POINTER(struct_sockaddr_in)).contents + addr = socket.inet_ntop(fam, sa.sin_addr) + nm = ifa.ifa_netmask.contents + if nm is not None and nm.sa_family == socket.AF_INET: + nm = cast(pointer(nm), POINTER(struct_sockaddr_in)).contents + addr += '/'+socket.inet_ntop(fam, nm.sin_addr) + return IPAddr(addr) + elif fam == socket.AF_INET6: + sa = cast(pointer(sa), POINTER(struct_sockaddr_in6)).contents + addr = socket.inet_ntop(fam, sa.sin6_addr) + nm = ifa.ifa_netmask.contents + if nm is not None and nm.sa_family == socket.AF_INET6: + nm = cast(pointer(nm), POINTER(struct_sockaddr_in6)).contents + addr += '/'+socket.inet_ntop(fam, nm.sin6_addr) + return IPAddr(addr) + return None + + def _NetworkInterfacesAddrs(): + ifap = POINTER(struct_ifaddrs)() + result = libc.getifaddrs(pointer(ifap)) + if result != 0: + raise OSError(get_errno()) + del result + try: + for ifa in ifap_iter(ifap): + name = ifa.ifa_name.decode("UTF-8") + addr = getfamaddr(ifa) + if addr: + yield name, addr + finally: + libc.freeifaddrs(ifap) + + DNSUtils._NetworkInterfacesAddrs = staticmethod(_NetworkInterfacesAddrs); + return _NetworkInterfacesAddrs() + +DNSUtils._NetworkInterfacesAddrs = staticmethod(_NetworkInterfacesAddrs); diff --git a/fail2ban/tests/filtertestcase.py b/fail2ban/tests/filtertestcase.py index 017e54ec..24f5272e 100644 --- a/fail2ban/tests/filtertestcase.py +++ b/fail2ban/tests/filtertestcase.py @@ -40,7 +40,7 @@ from ..server.jail import Jail from ..server.filterpoll import FilterPoll from ..server.filter import FailTicket, Filter, FileFilter, FileContainer from ..server.failmanager import FailManagerEmpty -from ..server.ipdns import asip, getfqdn, DNSUtils, IPAddr +from ..server.ipdns import asip, getfqdn, DNSUtils, IPAddr, IPAddrSet from ..server.mytime import MyTime from ..server.utils import Utils, uni_decode from .databasetestcase import getFail2BanDb @@ -2333,6 +2333,38 @@ class DNSUtilsNetworkTests(unittest.TestCase): ip1 = IPAddr('93.184.216.34'); ip2 = IPAddr('93.184.216.34'); self.assertEqual(id(ip1), id(ip2)) ip1 = IPAddr('2606:2800:220:1:248:1893:25c8:1946'); ip2 = IPAddr('2606:2800:220:1:248:1893:25c8:1946'); self.assertEqual(id(ip1), id(ip2)) + def test_IPAddrSet(self): + ips = IPAddrSet([IPAddr('192.0.2.1/27'), IPAddr('2001:DB8::/32')]) + self.assertTrue(IPAddr('192.0.2.1') in ips) + self.assertTrue(IPAddr('192.0.2.31') in ips) + self.assertFalse(IPAddr('192.0.2.32') in ips) + self.assertTrue(IPAddr('2001:DB8::1') in ips) + self.assertTrue(IPAddr('2001:0DB8:FFFF:FFFF:FFFF:FFFF:FFFF:FFFF') in ips) + self.assertFalse(IPAddr('2001:DB9::') in ips) + # self IPs must be a set too (cover different mechanisms to obtain own IPs): + for cov in ('ni', 'dns', 'last'): + _org_NetworkInterfacesAddrs = None + if cov == 'dns': # mock-up _NetworkInterfacesAddrs like it's not implemented (raises error) + _org_NetworkInterfacesAddrs = DNSUtils._NetworkInterfacesAddrs + def _tmp_NetworkInterfacesAddrs(): + raise NotImplementedError(); + DNSUtils._NetworkInterfacesAddrs = staticmethod(_tmp_NetworkInterfacesAddrs) + try: + ips = DNSUtils.getSelfIPs() + # print('*****', ips) + if ips: + ip = IPAddr('127.0.0.1') + self.assertEqual(ip in ips, any(ip in n for n in ips)) + ip = IPAddr('127.0.0.2') + self.assertEqual(ip in ips, any(ip in n for n in ips)) + ip = IPAddr('::1') + self.assertEqual(ip in ips, any(ip in n for n in ips)) + finally: + if _org_NetworkInterfacesAddrs: + DNSUtils._NetworkInterfacesAddrs = staticmethod(_org_NetworkInterfacesAddrs) + if cov != 'last': + DNSUtils.CACHE_nameToIp.unset(DNSUtils._getSelfIPs_key) + def testFQDN(self): unittest.F2B.SkipIfNoNetwork() sname = DNSUtils.getHostname(fqdn=False)