diff --git a/ChangeLog b/ChangeLog index 956ffc33..50e8e643 100644 --- a/ChangeLog +++ b/ChangeLog @@ -108,9 +108,14 @@ TODO: implementing of options resp. other tasks from PR #1346 - `` - failure identifier (if raw resp. failures without IP address) - `` - PTR reversed representation of IP address - `` - host name of the IP address + - `` - ban count of this offender if known as bad (started by 1 for unknown) + - `` - current ban-time of the ticket (prolongation can be retarded up to 10 sec.) - `` - interpolates to the corresponding filter group capture `...` - `` - fully-qualified name of host (the same as `$(hostname -f)`) - `` - short hostname (the same as `$(uname -n)`) +* Introduced new action command `actionprolong` to prolong ban-time (e. g. set new timeout if expected); + Several actions (like ipset, etc.) rewritten using net logic with `actionprolong`. + Note: because ban-time is dynamic, it was removed from jail.conf as timeout argument (check jail.local). * Allow to use filter options by `fail2ban-regex`, example: fail2ban-regex text.log "sshd[mode=aggressive]" * Samples test case factory extended with filter options - dict in JSON to control @@ -184,6 +189,9 @@ ver. 0.10.0-alpha-1 (2016/07/14) - ipv6-support-etc * testSocket: sporadical bug repaired - wait for server thread starts a socket (listener) * testExecuteTimeoutWithNastyChildren: sporadical bug repaired - wait for pid file inside bash, kill tree in any case (gh-1155) +* purge database will be executed now (within observer). +* restoring currently banned ip after service restart fixed + (now < timeofban + bantime), ignore old log failures (already banned) * Fixed high-load of pyinotify-backend, see https://github.com/fail2ban/fail2ban/issues/885#issuecomment-248964591 * Database: stability fix - repack cursor iterator as long as locked @@ -221,6 +229,9 @@ ver. 0.10.0-alpha-1 (2016/07/14) - ipv6-support-etc - new conditional section functionality used in config resp. includes: - [Init?family=inet4] - IPv4 qualified hosts only - [Init?family=inet6] - IPv6 qualified hosts only +* Increment ban time (+ observer) functionality introduced. + Thanks Serg G. Brester (sebres) +* Database functionality extended with bad ips. * New reload functionality (now totally without restart, unbanning/rebanning, etc.), see gh-1557 * Several commands extended and new commands introduced: diff --git a/MANIFEST b/MANIFEST index a6fdae5a..9aceb461 100644 --- a/MANIFEST +++ b/MANIFEST @@ -200,6 +200,7 @@ fail2ban/server/jail.py fail2ban/server/jails.py fail2ban/server/jailthread.py fail2ban/server/mytime.py +fail2ban/server/observer.py fail2ban/server/server.py fail2ban/server/strptime.py fail2ban/server/ticket.py @@ -253,6 +254,7 @@ fail2ban/tests/files/config/apache-auth/digest_wrongrelm/.htpasswd fail2ban/tests/files/config/apache-auth/noentry/.htaccess fail2ban/tests/files/config/apache-auth/README fail2ban/tests/files/database_v1.db +fail2ban/tests/files/database_v2.db fail2ban/tests/files/filter.d/substition.conf fail2ban/tests/files/filter.d/testcase01.conf fail2ban/tests/files/filter.d/testcase-common.conf @@ -351,6 +353,7 @@ fail2ban/tests/files/zzz-sshd-obsolete-multiline.log fail2ban/tests/filtertestcase.py fail2ban/tests/__init__.py fail2ban/tests/misctestcase.py +fail2ban/tests/observertestcase.py fail2ban/tests/samplestestcase.py fail2ban/tests/servertestcase.py fail2ban/tests/sockettestcase.py diff --git a/THANKS b/THANKS index 7861ceb5..c363c76c 100644 --- a/THANKS +++ b/THANKS @@ -111,7 +111,7 @@ Russell Odom SATO Kentaro Sean DuBois Sebastian Arcus -Serg G. Brester +Serg G. Brester (sebres) Sergey Safarov Shaun C. Sireyessire diff --git a/config/action.d/firewallcmd-ipset.conf b/config/action.d/firewallcmd-ipset.conf index 69447627..ecbb3bef 100644 --- a/config/action.d/firewallcmd-ipset.conf +++ b/config/action.d/firewallcmd-ipset.conf @@ -18,7 +18,7 @@ before = firewallcmd-common.conf [Definition] -actionstart = ipset create hash:ip timeout +actionstart = ipset create hash:ip firewall-cmd --direct --add-rule filter 0 -p -m multiport --dports -m set --match-set src -j actionstop = firewall-cmd --direct --remove-rule filter 0 -p -m multiport --dports -m set --match-set src -j @@ -27,6 +27,8 @@ actionstop = firewall-cmd --direct --remove-rule filter 0 -p

timeout -exist +actionprolong = %(actionban)s + actionunban = ipset del -exist [Init] @@ -38,12 +40,6 @@ actionunban = ipset del -exist # chain = INPUT_direct -# Option: bantime -# Notes: specifies the bantime in seconds (handled internally rather than by fail2ban) -# Values: [ NUM ] Default: 600 - -bantime = 600 - ipmset = f2b- [Init?family=inet6] diff --git a/config/action.d/iptables-ipset-proto6-allports.conf b/config/action.d/iptables-ipset-proto6-allports.conf index b761ad8c..a0ede56e 100644 --- a/config/action.d/iptables-ipset-proto6-allports.conf +++ b/config/action.d/iptables-ipset-proto6-allports.conf @@ -26,7 +26,7 @@ before = iptables-common.conf # Notes.: command executed once at the start of Fail2Ban. # Values: CMD # -actionstart = ipset create hash:ip timeout +actionstart = ipset create hash:ip -I -m set --match-set src -j # Option: actionflush @@ -51,6 +51,8 @@ actionstop = -D -m set --match-set src -j timeout -exist +actionprolong = %(actionban)s + # Option: actionunban # Notes.: command executed when unbanning an IP. Take care that the # command is executed with Fail2Ban user rights. @@ -61,12 +63,6 @@ actionunban = ipset del -exist [Init] -# Option: bantime -# Notes: specifies the bantime in seconds (handled internally rather than by fail2ban) -# Values: [ NUM ] Default: 600 -# -bantime = 600 - ipmset = f2b- familyopt = diff --git a/config/action.d/iptables-ipset-proto6.conf b/config/action.d/iptables-ipset-proto6.conf index e337eedf..b13eb711 100644 --- a/config/action.d/iptables-ipset-proto6.conf +++ b/config/action.d/iptables-ipset-proto6.conf @@ -26,7 +26,7 @@ before = iptables-common.conf # Notes.: command executed once at the start of Fail2Ban. # Values: CMD # -actionstart = ipset create hash:ip timeout +actionstart = ipset create hash:ip -I -p -m multiport --dports -m set --match-set src -j # Option: actionflush @@ -51,6 +51,8 @@ actionstop = -D -p -m multiport --dports -m # actionban = ipset add timeout -exist +actionprolong = %(actionban)s + # Option: actionunban # Notes.: command executed when unbanning an IP. Take care that the # command is executed with Fail2Ban user rights. @@ -61,12 +63,6 @@ actionunban = ipset del -exist [Init] -# Option: bantime -# Notes: specifies the bantime in seconds (handled internally rather than by fail2ban) -# Values: [ NUM ] Default: 600 -# -bantime = 600 - ipmset = f2b- familyopt = diff --git a/config/action.d/osx-afctl.conf b/config/action.d/osx-afctl.conf index a319fc6b..a75e5723 100644 --- a/config/action.d/osx-afctl.conf +++ b/config/action.d/osx-afctl.conf @@ -12,5 +12,5 @@ actioncheck = actionban = /usr/libexec/afctl -a -t actionunban = /usr/libexec/afctl -r -[Init] -bantime = 2880 +actionprolong = %(actionunban)s && %(actionban)s + diff --git a/config/action.d/shorewall-ipset-proto6.conf b/config/action.d/shorewall-ipset-proto6.conf index 1ebcfb01..8d80460f 100644 --- a/config/action.d/shorewall-ipset-proto6.conf +++ b/config/action.d/shorewall-ipset-proto6.conf @@ -51,7 +51,7 @@ # Values: CMD # actionstart = if ! ipset -quiet -name list f2b- >/dev/null; - then ipset -quiet -exist create f2b- hash:ip timeout ; + then ipset -quiet -exist create f2b- hash:ip; fi # Option: actionstop @@ -68,6 +68,8 @@ actionstop = ipset flush f2b- # actionban = ipset add f2b- timeout -exist +actionprolong = %(actionban)s + # Option: actionunban # Notes.: command executed when unbanning an IP. Take care that the # command is executed with Fail2Ban user rights. @@ -76,10 +78,3 @@ actionban = ipset add f2b- timeout -exist # actionunban = ipset del f2b- -exist -[Init] - -# Option: bantime -# Notes: specifies the bantime in seconds (handled internally rather than by fail2ban) -# Values: [ NUM ] Default: 600 -# -bantime = 600 diff --git a/config/jail.conf b/config/jail.conf index e3e89ff0..cc887fc5 100644 --- a/config/jail.conf +++ b/config/jail.conf @@ -44,10 +44,47 @@ before = paths-debian.conf # MISCELLANEOUS OPTIONS # +# "bantime.increment" allows to use database for searching of previously banned ip's to increase a +# default ban time using special formula, default it is banTime * 1, 2, 4, 8, 16, 32... +#bantime.increment = true + +# "bantime.rndtime" is the max number of seconds using for mixing with random time +# to prevent "clever" botnets calculate exact time IP can be unbanned again: +#bantime.rndtime = + +# "bantime.maxtime" is the max number of seconds using the ban time can reach (don't grows further) +#bantime.maxtime = + +# "bantime.factor" is a coefficient to calculate exponent growing of the formula or common multiplier, +# default value of factor is 1 and with default value of formula, the ban time +# grows by 1, 2, 4, 8, 16 ... +#bantime.factor = 1 + +# "bantime.formula" used by default to calculate next value of ban time, default value bellow, +# the same ban time growing will be reached by multipliers 1, 2, 4, 8, 16, 32... +#bantime.formula = ban.Time * (1<<(ban.Count if ban.Count<20 else 20)) * banFactor +# +# more aggressive example of formula has the same values only for factor "2.0 / 2.885385" : +#bantime.formula = ban.Time * math.exp(float(ban.Count+1)*banFactor)/math.exp(1*banFactor) + +# "bantime.multipliers" used to calculate next value of ban time instead of formula, coresponding +# previously ban count and given "bantime.factor" (for multipliers default is 1); +# following example grows ban time by 1, 2, 4, 8, 16 ... and if last ban count greater as multipliers count, +# always used last multiplier (64 in example), for factor '1' and original ban time 600 - 10.6 hours +#bantime.multipliers = 1 2 4 8 16 32 64 +# following example can be used for small initial ban time (bantime=60) - it grows more aggressive at begin, +# for bantime=60 the multipliers are minutes and equal: 1 min, 5 min, 30 min, 1 hour, 5 hour, 12 hour, 1 day, 2 day +#bantime.multipliers = 1 5 30 60 300 720 1440 2880 + +# "bantime.overalljails" (if true) specifies the search of IP in the database will be executed +# cross over all jails, if false (dafault), only current jail of the ban IP will be searched +#bantime.overalljails = false + +# -------------------- + # "ignorself" specifies whether the local resp. own IP addresses should be ignored # (default is true). Fail2ban will not ban a host which matches such addresses. #ignorself = true - # "ignoreip" can be a list of IP addresses, CIDR masks or DNS hosts. Fail2ban # will not ban a host which matches an address in this list. Several addresses # can be defined using space (and/or comma) separator. @@ -165,22 +202,22 @@ banaction = iptables-multiport banaction_allports = iptables-allports # The simplest action to take: ban only -action_ = %(banaction)s[name=%(__name__)s, bantime="%(bantime)s", port="%(port)s", protocol="%(protocol)s", chain="%(chain)s"] +action_ = %(banaction)s[name=%(__name__)s, port="%(port)s", protocol="%(protocol)s", chain="%(chain)s"] # ban & send an e-mail with whois report to the destemail. -action_mw = %(banaction)s[name=%(__name__)s, bantime="%(bantime)s", port="%(port)s", protocol="%(protocol)s", chain="%(chain)s"] +action_mw = %(banaction)s[name=%(__name__)s, port="%(port)s", protocol="%(protocol)s", chain="%(chain)s"] %(mta)s-whois[name=%(__name__)s, sender="%(sender)s", dest="%(destemail)s", protocol="%(protocol)s", chain="%(chain)s"] # ban & send an e-mail with whois report and relevant log lines # to the destemail. -action_mwl = %(banaction)s[name=%(__name__)s, bantime="%(bantime)s", port="%(port)s", protocol="%(protocol)s", chain="%(chain)s"] +action_mwl = %(banaction)s[name=%(__name__)s, port="%(port)s", protocol="%(protocol)s", chain="%(chain)s"] %(mta)s-whois-lines[name=%(__name__)s, sender="%(sender)s", dest="%(destemail)s", logpath=%(logpath)s, chain="%(chain)s"] # See the IMPORTANT note in action.d/xarf-login-attack for when to use this action # # ban & send a xarf e-mail to abuse contact of IP address and include relevant log lines # to the destemail. -action_xarf = %(banaction)s[name=%(__name__)s, bantime="%(bantime)s", port="%(port)s", protocol="%(protocol)s", chain="%(chain)s"] +action_xarf = %(banaction)s[name=%(__name__)s, port="%(port)s", protocol="%(protocol)s", chain="%(chain)s"] xarf-login-attack[service=%(__name__)s, sender="%(sender)s", logpath=%(logpath)s, port="%(port)s"] # ban IP on CloudFlare & send an e-mail with whois report and relevant log lines diff --git a/fail2ban/client/actionreader.py b/fail2ban/client/actionreader.py index ace0b898..7ba54d31 100644 --- a/fail2ban/client/actionreader.py +++ b/fail2ban/client/actionreader.py @@ -45,6 +45,7 @@ class ActionReader(DefinitionInitConfigReader): "actioncheck": ["string", None], "actionrepair": ["string", None], "actionban": ["string", None], + "actionprolong": ["string", None], "actionunban": ["string", None], "norestored": ["string", None], } diff --git a/fail2ban/client/csocket.py b/fail2ban/client/csocket.py index 6b478460..e53ca1fd 100644 --- a/fail2ban/client/csocket.py +++ b/fail2ban/client/csocket.py @@ -45,13 +45,13 @@ class CSocket: def __del__(self): self.close(False) - def send(self, msg): + def send(self, msg, nonblocking=False, timeout=None): # Convert every list member to string obj = dumps(map( lambda m: str(m) if not isinstance(m, (list, dict, set)) else m, msg), HIGHEST_PROTOCOL) self.__csock.send(obj + CSPROTO.END) - return self.receive(self.__csock) + return self.receive(self.__csock, nonblocking, timeout) def settimeout(self, timeout): self.__csock.settimeout(timeout if timeout != -1 else self.__deftout) @@ -65,11 +65,13 @@ class CSocket: self.__csock = None @staticmethod - def receive(sock): + def receive(sock, nonblocking=False, timeout=None): msg = CSPROTO.EMPTY + if nonblocking: sock.setblocking(0) + if timeout: sock.settimeout(timeout) while msg.rfind(CSPROTO.END) == -1: - chunk = sock.recv(6) - if chunk == '': + chunk = sock.recv(512) + if chunk in ('', b''): # python 3.x may return b'' instead of '' raise RuntimeError("socket connection broken") msg = msg + chunk return loads(msg) diff --git a/fail2ban/client/jailreader.py b/fail2ban/client/jailreader.py index ce0ed3b6..85a68ce4 100644 --- a/fail2ban/client/jailreader.py +++ b/fail2ban/client/jailreader.py @@ -107,6 +107,13 @@ class JailReader(ConfigReader): ["int", "maxretry", None], ["string", "findtime", None], ["string", "bantime", None], + ["bool", "bantime.increment", None], + ["string", "bantime.factor", None], + ["string", "bantime.formula", None], + ["string", "bantime.multipliers", None], + ["string", "bantime.maxtime", None], + ["string", "bantime.rndtime", None], + ["bool", "bantime.overalljails", None], ["string", "usedns", None], # be sure usedns is before all regex(s) in stream ["string", "failregex", None], ["string", "ignoreregex", None], diff --git a/fail2ban/server/action.py b/fail2ban/server/action.py index d00458ba..f1983d1e 100644 --- a/fail2ban/server/action.py +++ b/fail2ban/server/action.py @@ -149,7 +149,7 @@ class CallingMap(MutableMapping, object): def __len__(self): return len(self.data) - def copy(self): # pargma: no cover + def copy(self): # pragma: no cover return self.__class__(_merge_copy_dicts(self.data, self.storage)) @@ -224,6 +224,10 @@ class ActionBase(object): """ pass + @property + def _prolongable(self): # pragma: no cover - abstract + return False + def unban(self, aInfo): # pragma: no cover - abstract """Executed when a ban expires. @@ -236,6 +240,11 @@ class ActionBase(object): pass +WRAP_CMD_PARAMS = { + 'timeout': 'str2seconds', + 'bantime': 'ignore', +} + class CommandAction(ActionBase): """A action which executes OS shell commands. @@ -306,7 +315,10 @@ class CommandAction(ActionBase): def __setattr__(self, name, value): if not name.startswith('_') and not self.__init and not callable(value): # special case for some pasrameters: - if name in ('timeout', 'bantime'): + wrp = WRAP_CMD_PARAMS.get(name) + if wrp == 'ignore': # ignore (filter) dynamic parameters + return + elif wrp == 'str2seconds': value = str(MyTime.str2seconds(value)) # parameters changed - clear properties and substitution cache: self.__properties = None @@ -434,6 +446,26 @@ class CommandAction(ActionBase): if not self._processCmd('', aInfo): raise RuntimeError("Error banning %(ip)s" % aInfo) + @property + def _prolongable(self): + return (hasattr(self, 'actionprolong') and self.actionprolong + and not str(self.actionprolong).isspace()) + + def prolong(self, aInfo): + """Executes the "actionprolong" command. + + Replaces the tags in the action command with actions properties + and ban information, and executes the resulting command. + + Parameters + ---------- + aInfo : dict + Dictionary which includes information in relation to + the ban. + """ + if not self._processCmd('', aInfo): + raise RuntimeError("Error prolonging %(ip)s" % aInfo) + def unban(self, aInfo): """Executes the "actionunban" command. @@ -498,8 +530,10 @@ class CommandAction(ActionBase): """ return self._executeOperation('', 'reloading') - @staticmethod - def escapeTag(value): + ESCAPE_CRE = re.compile(r"""[\\#&;`|*?~<>^()\[\]{}$'"\n\r]""") + + @classmethod + def escapeTag(cls, value): """Escape characters which may be used for command injection. Parameters @@ -516,12 +550,15 @@ class CommandAction(ActionBase): ----- The following characters are escaped:: - \\#&;`|*?~<>^()[]{}$'" + \\#&;`|*?~<>^()[]{}$'"\n\r """ - for c in '\\#&;`|*?~<>^()[]{}$\'"': - if c in value: - value = value.replace(c, '\\' + c) + _map2c = {'\n': 'n', '\r': 'r'} + def substChar(m): + c = m.group() + return '\\' + _map2c.get(c, c) + + value = cls.ESCAPE_CRE.sub(substChar, value) return value @classmethod @@ -780,7 +817,8 @@ class CommandAction(ActionBase): RuntimeError If command execution times out. """ - logSys.debug(realCmd) + if logSys.getEffectiveLevel() < logging.DEBUG: # pragma: no cover + logSys.log(9, realCmd) if not realCmd: logSys.debug("Nothing to do") return True diff --git a/fail2ban/server/actions.py b/fail2ban/server/actions.py index f8f8d4e1..f940bb45 100644 --- a/fail2ban/server/actions.py +++ b/fail2ban/server/actions.py @@ -34,11 +34,12 @@ try: except ImportError: OrderedDict = dict -from .banmanager import BanManager +from .banmanager import BanManager, BanTicket from .ipdns import DNSUtils from .jailthread import JailThread from .action import ActionBase, CommandAction, CallingMap from .mytime import MyTime +from .observer import Observers from .utils import Utils from ..helpers import getLogger @@ -297,6 +298,8 @@ class Actions(JailThread, Mapping): "fid": lambda self: self.__ticket.getID(), "failures": lambda self: self.__ticket.getAttempt(), "time": lambda self: self.__ticket.getTime(), + "bantime": lambda self: self._getBanTime(), + "bancount": lambda self: self.__ticket.getBanCount(), "matches": lambda self: "\n".join(self.__ticket.getMatches()), # to bypass actions, that should not be executed for restored tickets "restored": lambda self: (1 if self.__ticket.restored else 0), @@ -321,9 +324,14 @@ class Actions(JailThread, Mapping): self.immutable = immutable self.data = data - def copy(self): # pargma: no cover + def copy(self): # pragma: no cover return self.__class__(self.__ticket, self.__jail, self.immutable, self.data.copy()) + def _getBanTime(self): + btime = self.__ticket.getBanTime() + if btime is None: btime = self.__jail.actions.getBanTime() + return btime + def _mi4ip(self, overalljails=False): """Gets bans merged once, a helper for lambda(s), prevents stop of executing action by any exception inside. @@ -389,13 +397,19 @@ class Actions(JailThread, Mapping): ticket = self._jail.getFailTicket() if not ticket: break - bTicket = BanManager.createBanTicket(ticket) + + bTicket = BanTicket.wrap(ticket) + btime = ticket.getBanTime(self.__banManager.getBanTime()) ip = bTicket.getIP() aInfo = self.__getActionInfo(bTicket) reason = {} if self.__banManager.addBanTicket(bTicket, reason=reason): cnt += 1 + # report ticket to observer, to check time should be increased and hereafter observer writes ban to database (asynchronous) + if Observers.Main is not None and not bTicket.restored: + Observers.Main.add('banFound', bTicket, self._jail, btime) logSys.notice("[%s] %sBan %s", self._jail.name, ('' if not bTicket.restored else 'Restore '), ip) + # do actions : for name, action in self._actions.iteritems(): try: if ticket.restored and getattr(action, 'norestored', False): @@ -411,7 +425,10 @@ class Actions(JailThread, Mapping): # after all actions are processed set banned flag: bTicket.banned = True else: - bTicket = reason['ticket'] + if reason.get('expired', 0): + logSys.info('[%s] Ignore %s, expired bantime', self._jail.name, ip) + continue + bTicket = reason.get('ticket', bTicket) # if already banned (otherwise still process some action) if bTicket.banned: # compare time of failure occurrence with time ticket was really banned: @@ -429,6 +446,29 @@ class Actions(JailThread, Mapping): self.__banManager.getBanTotal(), self.__banManager.size(), self._jail.name) return cnt + def _prolongBan(self, ticket): + # prevent to prolong ticket that was removed in-between, + # if it in ban list - ban time already prolonged (and it stays there): + if not self.__banManager._inBanList(ticket): return + # do actions : + aInfo = None + for name, action in self._actions.iteritems(): + try: + if ticket.restored and getattr(action, 'norestored', False): + continue + if not action._prolongable: + continue + if aInfo is None: + aInfo = self.__getActionInfo(ticket) + if not aInfo.immutable: aInfo.reset() + action.prolong(aInfo) + except Exception as e: + logSys.error( + "Failed to execute ban jail '%s' action '%s' " + "info '%r': %s", + self._jail.name, name, aInfo, e, + exc_info=logSys.getEffectiveLevel()<=logging.DEBUG) + def __checkUnBan(self): """Check for IP address to unban. diff --git a/fail2ban/server/asyncserver.py b/fail2ban/server/asyncserver.py index 9cc74658..11c81649 100644 --- a/fail2ban/server/asyncserver.py +++ b/fail2ban/server/asyncserver.py @@ -37,7 +37,7 @@ import traceback from .utils import Utils from ..protocol import CSPROTO -from ..helpers import getLogger,formatExceptionInfo +from ..helpers import logging, getLogger, formatExceptionInfo # Gets the instance of the logger. logSys = getLogger(__name__) @@ -80,22 +80,36 @@ class RequestHandler(asynchat.async_chat): # Deserialize message = loads(message) # Gives the message to the transmitter. - message = self.__transmitter.proceed(message) + if self.__transmitter: + message = self.__transmitter.proceed(message) + else: + message = ['SHUTDOWN'] # Serializes the response. message = dumps(message, HIGHEST_PROTOCOL) # Sends the response to the client. self.push(message + CSPROTO.END) - except Exception as e: # pragma: no cover + except Exception as e: logSys.error("Caught unhandled exception: %r", e, exc_info=logSys.getEffectiveLevel()<=logging.DEBUG) + # Sends the response to the client. + message = dumps("ERROR: %s" % e, HIGHEST_PROTOCOL) + self.push(message + CSPROTO.END) - + ## + # Handles an communication errors in request. + # def handle_error(self): - e1, e2 = formatExceptionInfo() - logSys.error("Unexpected communication error: %s" % str(e2)) - logSys.error(traceback.format_exc().splitlines()) - self.close() - + try: + e1, e2 = formatExceptionInfo() + logSys.error("Unexpected communication error: %s" % str(e2)) + logSys.error(traceback.format_exc().splitlines()) + # Sends the response to the client. + message = dumps("ERROR: %s" % e2, HIGHEST_PROTOCOL) + self.push(message + CSPROTO.END) + except Exception as e: # pragma: no cover - normally unreachable + pass + self.close_when_done() + def loop(active, timeout=None, use_poll=False): """Custom event loop implementation @@ -119,18 +133,20 @@ def loop(active, timeout=None, use_poll=False): poll(timeout) if errCount: errCount -= 1 - except Exception as e: # pragma: no cover + except Exception as e: if not active(): break errCount += 1 if errCount < 20: - if e.args[0] in (errno.ENOTCONN, errno.EBADF): # (errno.EBADF, 'Bad file descriptor') + # errno.ENOTCONN - 'Socket is not connected' + # errno.EBADF - 'Bad file descriptor' + if e.args[0] in (errno.ENOTCONN, errno.EBADF): # pragma: no cover (too sporadic) logSys.info('Server connection was closed: %s', str(e)) else: logSys.error('Server connection was closed: %s', str(e)) elif errCount == 20: - logSys.info('Too many errors - stop logging connection errors') logSys.exception(e) + logSys.error('Too many errors - stop logging connection errors') ## @@ -158,10 +174,10 @@ class AsyncServer(asyncore.dispatcher): def handle_accept(self): try: conn, addr = self.accept() - except socket.error: + except socket.error: # pragma: no cover logSys.warning("Socket error") return - except TypeError: + except TypeError: # pragma: no cover logSys.warning("Type error") return AsyncServer.__markCloseOnExec(conn) @@ -175,7 +191,7 @@ class AsyncServer(asyncore.dispatcher): # @param sock: socket file. # @param force: remove the socket file if exists. - def start(self, sock, force, use_poll=False): + def start(self, sock, force, timeout=None, use_poll=False): self.__worker = threading.current_thread() self.__sock = sock # Remove socket @@ -191,7 +207,7 @@ class AsyncServer(asyncore.dispatcher): self.set_reuse_addr() try: self.bind(sock) - except Exception: + except Exception: # pragma: no cover raise AsyncServerException("Unable to bind socket %s" % self.__sock) AsyncServer.__markCloseOnExec(self.socket) self.listen(1) @@ -201,12 +217,11 @@ class AsyncServer(asyncore.dispatcher): if self.onstart: self.onstart() # Event loop as long as active: - loop(lambda: self.__loop, use_poll=use_poll) + loop(lambda: self.__loop, timeout=timeout, use_poll=use_poll) self.__active = False # Cleanup all self.stop() - def close(self): stopflg = False if self.__active: @@ -228,6 +243,13 @@ class AsyncServer(asyncore.dispatcher): ## # Stops the communication server. + def stop_communication(self): + logSys.debug("Stop communication") + self.__transmitter = None + + ## + # Stops the server. + def stop(self): self.close() diff --git a/fail2ban/server/banmanager.py b/fail2ban/server/banmanager.py index 1275d3a4..0425db51 100644 --- a/fail2ban/server/banmanager.py +++ b/fail2ban/server/banmanager.py @@ -243,21 +243,6 @@ class BanManager: logSys.exception(e) return [] - ## - # Create a ban ticket. - # - # Create a BanTicket from a FailTicket. The timestamp of the BanTicket - # is the current time. This is a static method. - # @param ticket the FailTicket - # @return a BanTicket - - @staticmethod - def createBanTicket(ticket): - # we should always use correct time to calculate correct end time (ban time is variable now, - # + possible double banning by restore from database and from log file) - # so use as lastTime always time from ticket. - return BanTicket(ticket=ticket) - ## # Add a ban ticket. # @@ -267,6 +252,9 @@ class BanManager: def addBanTicket(self, ticket, reason={}): eob = ticket.getEndOfBanTime(self.__banTime) + if eob < MyTime.time(): + reason['expired'] = 1 + return False with self.__lock: # check already banned fid = ticket.getID() @@ -288,6 +276,7 @@ class BanManager: # not yet banned - add new one: self.__banList[fid] = ticket self.__banTotal += 1 + ticket.incrBanCount() # correct next unban time: if self.__nextUnbanTime > eob: self.__nextUnbanTime = eob diff --git a/fail2ban/server/database.py b/fail2ban/server/database.py index f4f9b6c2..ba71210e 100644 --- a/fail2ban/server/database.py +++ b/fail2ban/server/database.py @@ -126,7 +126,7 @@ class Fail2BanDb(object): filename purgeage """ - __version__ = 2 + __version__ = 4 # Note all _TABLE_* strings must end in ';' for py26 compatibility _TABLE_fail2banDb = "CREATE TABLE fail2banDb(version INTEGER);" _TABLE_jails = "CREATE TABLE jails(" \ @@ -153,6 +153,8 @@ class Fail2BanDb(object): "jail TEXT NOT NULL, " \ "ip TEXT, " \ "timeofban INTEGER NOT NULL, " \ + "bantime INTEGER NOT NULL, " \ + "bancount INTEGER NOT NULL default 1, " \ "data JSON, " \ "FOREIGN KEY(jail) REFERENCES jails(name) " \ ");" \ @@ -160,8 +162,21 @@ class Fail2BanDb(object): "CREATE INDEX bans_jail_ip ON bans(jail, ip);" \ "CREATE INDEX bans_ip ON bans(ip);" \ + _TABLE_bips = "CREATE TABLE bips(" \ + "ip TEXT NOT NULL, " \ + "jail TEXT NOT NULL, " \ + "timeofban INTEGER NOT NULL, " \ + "bantime INTEGER NOT NULL, " \ + "bancount INTEGER NOT NULL default 1, " \ + "data JSON, " \ + "PRIMARY KEY(ip, jail), " \ + "FOREIGN KEY(jail) REFERENCES jails(name) " \ + ");" \ + "CREATE INDEX bips_timeofban ON bips(timeofban);" \ + "CREATE INDEX bips_ip ON bips(ip);" \ - def __init__(self, filename, purgeAge=24*60*60): + + def __init__(self, filename, purgeAge=24*60*60, outDatedFactor=3): self.maxEntries = 50 try: self._lock = RLock() @@ -170,6 +185,7 @@ class Fail2BanDb(object): detect_types=sqlite3.PARSE_DECLTYPES) self._dbFilename = filename self._purgeAge = purgeAge + self._outDatedFactor = outDatedFactor; self._bansMergedCache = {} @@ -257,6 +273,8 @@ class Fail2BanDb(object): cur.executescript(Fail2BanDb._TABLE_logs) # Bans cur.executescript(Fail2BanDb._TABLE_bans) + # BIPs (bad ips) + cur.executescript(Fail2BanDb._TABLE_bips) cur.execute("SELECT version FROM fail2banDb LIMIT 1") return cur.fetchone()[0] @@ -285,6 +303,20 @@ class Fail2BanDb(object): "UPDATE fail2banDb SET version = 2;" "COMMIT;" % Fail2BanDb._TABLE_logs) + if version < 3: + cur.executescript("BEGIN TRANSACTION;" + "CREATE TEMPORARY TABLE bans_temp AS SELECT jail, ip, timeofban, 600 as bantime, 1 as bancount, data FROM bans;" + "DROP TABLE bans;" + "%s;" + "INSERT INTO bans SELECT * from bans_temp;" + "DROP TABLE bans_temp;" + "COMMIT;" % Fail2BanDb._TABLE_bans) + if version < 4: + cur.executescript("BEGIN TRANSACTION;" + "%s;" + "UPDATE fail2banDb SET version = 4;" + "COMMIT;" % Fail2BanDb._TABLE_bips) + cur.execute("SELECT version FROM fail2banDb LIMIT 1") return cur.fetchone()[0] @@ -445,8 +477,12 @@ class Fail2BanDb(object): pass #TODO: Implement data parts once arbitrary match keys completed cur.execute( - "INSERT INTO bans(jail, ip, timeofban, data) VALUES(?, ?, ?, ?)", - (jail.name, ip, int(round(ticket.getTime())), + "INSERT INTO bans(jail, ip, timeofban, bantime, bancount, data) VALUES(?, ?, ?, ?, ?, ?)", + (jail.name, ip, int(round(ticket.getTime())), ticket.getBanTime(jail.actions.getBanTime()), ticket.getBanCount(), + ticket.getData())) + cur.execute( + "INSERT OR REPLACE INTO bips(ip, jail, timeofban, bantime, bancount, data) VALUES(?, ?, ?, ?, ?, ?)", + (ip, jail.name, int(round(ticket.getTime())), ticket.getBanTime(jail.actions.getBanTime()), ticket.getBanCount(), ticket.getData())) @commitandrollback @@ -461,6 +497,9 @@ class Fail2BanDb(object): IP to be removed. """ queryArgs = (jail.name, str(ip)); + cur.execute( + "DELETE FROM bips WHERE jail = ? AND ip = ?", + queryArgs) cur.execute( "DELETE FROM bans WHERE jail = ? AND ip = ?", queryArgs); @@ -581,18 +620,43 @@ class Fail2BanDb(object): self._bansMergedCache[cacheKey] = tickets if ip is None else ticket return tickets if ip is None else ticket + @commitandrollback + def getBan(self, cur, ip, jail=None, forbantime=None, overalljails=None, fromtime=None): + ip = str(ip) + if not overalljails: + query = "SELECT bancount, timeofban, bantime FROM bips" + else: + query = "SELECT sum(bancount), max(timeofban), sum(bantime) FROM bips" + query += " WHERE ip = ?" + queryArgs = [ip] + if not overalljails and jail is not None: + query += " AND jail=?" + queryArgs.append(jail.name) + if forbantime is not None: + query += " AND timeofban > ?" + queryArgs.append(MyTime.time() - forbantime) + if fromtime is not None: + query += " AND timeofban > ?" + queryArgs.append(fromtime) + if overalljails or jail is None: + query += " GROUP BY ip ORDER BY timeofban DESC LIMIT 1" + cur = self._db.cursor() + return cur.execute(query, queryArgs) + def _getCurrentBans(self, cur, jail = None, ip = None, forbantime=None, fromtime=None): if fromtime is None: fromtime = MyTime.time() queryArgs = [] if jail is not None: - query = "SELECT ip, timeofban, data FROM bans WHERE jail=?" + query = "SELECT ip, timeofban, bantime, bancount, data FROM bips WHERE jail=?" queryArgs.append(jail.name) else: - query = "SELECT ip, max(timeofban), data FROM bans WHERE 1" + query = "SELECT ip, max(timeofban), bantime, bancount, data FROM bips WHERE 1" if ip is not None: query += " AND ip=?" queryArgs.append(ip) + query += " AND (timeofban + bantime > ? OR bantime = -1)" + queryArgs.append(fromtime) if forbantime not in (None, -1): # not specified or persistent (all) query += " AND timeofban > ?" queryArgs.append(fromtime - forbantime) @@ -601,23 +665,49 @@ class Fail2BanDb(object): cur = self._db.cursor() return cur.execute(query, queryArgs) - def getCurrentBans(self, jail = None, ip = None, forbantime=None, fromtime=None): + @commitandrollback + def getCurrentBans(self, cur, jail = None, ip = None, forbantime=None, fromtime=None): tickets = [] ticket = None - with self._lock: - results = list(self._getCurrentBans(self._db.cursor(), - jail=jail, ip=ip, forbantime=forbantime, fromtime=fromtime)) - - if results: - for banip, timeofban, data in results: - # logSys.debug('restore ticket %r, %r, %r', banip, timeofban, data) - ticket = FailTicket(banip, timeofban, data=data) - # logSys.debug('restored ticket: %r', ticket) - tickets.append(ticket) + for ticket in self._getCurrentBans(cur, jail=jail, ip=ip, + forbantime=forbantime, fromtime=fromtime + ): + # can produce unpack error (database may return sporadical wrong-empty row): + try: + banip, timeofban, bantime, bancount, data = ticket + # additionally check for empty values: + if banip is None or banip == "": # pragma: no cover + raise ValueError('unexpected value %r' % (banip,)) + except ValueError as e: # pragma: no cover + logSys.debug("get current bans: ignore row %r - %s", ticket, e) + continue + # logSys.debug('restore ticket %r, %r, %r', banip, timeofban, data) + ticket = FailTicket(banip, timeofban, data=data) + # logSys.debug('restored ticket: %r', ticket) + ticket.setBanTime(bantime) + ticket.setBanCount(bancount) + tickets.append(ticket) return tickets if ip is None else ticket + def _cleanjails(self, cur): + """Remove empty jails jails and log files from database. + """ + cur.execute( + "DELETE FROM jails WHERE enabled = 0 " + "AND NOT EXISTS(SELECT * FROM bans WHERE jail = jails.name) " + "AND NOT EXISTS(SELECT * FROM bips WHERE jail = jails.name)") + + def _purge_bips(self, cur): + """Purge old bad ips (jails and log files from database). + Currently it is timed out IP, whose time since last ban is several times out-dated (outDatedFactor is default 3). + Permanent banned ips will be never removed. + """ + cur.execute( + "DELETE FROM bips WHERE timeofban < ? and bantime != -1 and (timeofban + (bantime * ?)) < ?", + (int(MyTime.time()) - self._purgeAge, self._outDatedFactor, int(MyTime.time()) - self._purgeAge)) + @commitandrollback def purge(self, cur): """Purge old bans, jails and log files from database. @@ -626,7 +716,6 @@ class Fail2BanDb(object): cur.execute( "DELETE FROM bans WHERE timeofban < ?", (MyTime.time() - self._purgeAge, )) - cur.execute( - "DELETE FROM jails WHERE enabled = 0 " - "AND NOT EXISTS(SELECT * FROM bans WHERE jail = jails.name)") + self._purge_bips(cur) + self._cleanjails(cur) diff --git a/fail2ban/server/failmanager.py b/fail2ban/server/failmanager.py index ee4b049d..6ce9b74e 100644 --- a/fail2ban/server/failmanager.py +++ b/fail2ban/server/failmanager.py @@ -27,7 +27,7 @@ __license__ = "GPL" from threading import Lock import logging -from .ticket import FailTicket +from .ticket import FailTicket, BanTicket from ..helpers import getLogger, BgService # Gets the instance of the logger. @@ -75,7 +75,7 @@ class FailManager: def getMaxTime(self): return self.__maxTime - def addFailure(self, ticket, count=1): + def addFailure(self, ticket, count=1, observed=False): attempts = 1 with self.__lock: fid = ticket.getID() @@ -102,11 +102,14 @@ class FailManager: if len(matches) > self.maxEntries: fData.setMatches(matches[-self.maxEntries:]) except KeyError: + # not found - already banned - prevent to add failure if comes from observer: + if observed or isinstance(ticket, BanTicket): + return # if already FailTicket - add it direct, otherwise create (using copy all ticket data): if isinstance(ticket, FailTicket): fData = ticket; else: - fData = FailTicket(ticket=ticket) + fData = FailTicket.wrap(ticket) if count > ticket.getAttempt(): fData.setRetry(count) self.__failList[fid] = fData diff --git a/fail2ban/server/filter.py b/fail2ban/server/filter.py index c4f29878..be0c8917 100644 --- a/fail2ban/server/filter.py +++ b/fail2ban/server/filter.py @@ -32,6 +32,7 @@ import time from .failmanager import FailManagerEmpty, FailManager from .ipdns import DNSUtils, IPAddr +from .observer import Observers from .ticket import FailTicket from .jailthread import JailThread from .datedetector import DateDetector, validateTimeZone @@ -552,6 +553,9 @@ class Filter(JailThread): ) tick = FailTicket(ip, unixTime, data=fail) self.failManager.addFailure(tick) + # report to observer - failure was found, for possibly increasing of it retry counter (asynchronous) + if Observers.Main is not None: + Observers.Main.add('failureFound', self.failManager, self.jail, tick) # reset (halve) error counter (successfully processed line): if self._errors: self._errors //= 2 diff --git a/fail2ban/server/jail.py b/fail2ban/server/jail.py index 39fdd959..0bb9f6fb 100644 --- a/fail2ban/server/jail.py +++ b/fail2ban/server/jail.py @@ -24,11 +24,14 @@ __copyright__ = "Copyright (c) 2004 Cyril Jaquier, 2011-2012 Lee Clemens, 2012 Y __license__ = "GPL" import logging +import math +import random import Queue from .actions import Actions from ..client.jailreader import JailReader from ..helpers import getLogger, MyTime +from .mytime import MyTime # Gets the instance of the logger. logSys = getLogger(__name__) @@ -76,6 +79,8 @@ class Jail(object): self.__name = name self.__queue = Queue.Queue() self.__filter = None + # Extra parameters for increase ban time + self._banExtra = {}; logSys.info("Creating new jail '%s'" % self.name) if backend is not None: self._setBackend(backend) @@ -194,8 +199,8 @@ class Jail(object): Used by filter to add a failure for banning. """ self.__queue.put(ticket) - if not ticket.restored and self.database is not None: - self.database.addBan(self, ticket) + # add ban to database moved to observer (should previously check not already banned + # and increase ticket time if "bantime.increment" set) def getFailTicket(self): """Get a fail ticket from the jail. @@ -208,15 +213,70 @@ class Jail(object): except Queue.Empty: return False + def setBanTimeExtra(self, opt, value): + # merge previous extra with new option: + be = self._banExtra; + if value == '': + value = None + if value is not None: + be[opt] = value; + elif opt in be: + del be[opt] + logSys.info('Set banTime.%s = %s', opt, value) + if opt == 'increment': + if isinstance(value, str): + be[opt] = value.lower() in ("yes", "true", "ok", "1") + if be.get(opt) and self.database is None: + logSys.warning("ban time increment is not available as long jail database is not set") + if opt in ['maxtime', 'rndtime']: + if not value is None: + be[opt] = MyTime.str2seconds(value) + # prepare formula lambda: + if opt in ['formula', 'factor', 'maxtime', 'rndtime', 'multipliers'] or be.get('evformula', None) is None: + # split multifiers to an array begins with 0 (or empty if not set): + if opt == 'multipliers': + be['evmultipliers'] = [int(i) for i in (value.split(' ') if value is not None and value != '' else [])] + # if we have multifiers - use it in lambda, otherwise compile and use formula within lambda + multipliers = be.get('evmultipliers', []) + banFactor = eval(be.get('factor', "1")) + if len(multipliers): + evformula = lambda ban, banFactor=banFactor: ( + ban.Time * banFactor * multipliers[ban.Count if ban.Count < len(multipliers) else -1] + ) + else: + formula = be.get('formula', 'ban.Time * (1<<(ban.Count if ban.Count<20 else 20)) * banFactor') + formula = compile(formula, '~inline-conf-expr~', 'eval') + evformula = lambda ban, banFactor=banFactor, formula=formula: max(ban.Time, eval(formula)) + # extend lambda with max time : + if not be.get('maxtime', None) is None: + maxtime = be['maxtime'] + evformula = lambda ban, evformula=evformula: min(evformula(ban), maxtime) + # mix lambda with random time (to prevent bot-nets to calculate exact time IP can be unbanned): + if not be.get('rndtime', None) is None: + rndtime = be['rndtime'] + evformula = lambda ban, evformula=evformula: (evformula(ban) + random.random() * rndtime) + # set to extra dict: + be['evformula'] = evformula + #logSys.info('banTimeExtra : %s' % json.dumps(be)) + + def getBanTimeExtra(self, opt=None): + if opt is not None: + return self._banExtra.get(opt, None) + return self._banExtra + def restoreCurrentBans(self): """Restore any previous valid bans from the database. """ try: if self.database is not None: - forbantime = self.actions.getBanTime() + forbantime = None; + # use ban time as search time if we have not enabled a increasing: + if not self.getBanTimeExtra('increment'): + forbantime = self.actions.getBanTime() for ticket in self.database.getCurrentBans(jail=self, forbantime=forbantime): - #logSys.debug('restored ticket: %s', ticket) - if not self.filter.inIgnoreIPList(ticket.getIP(), log_ignore=True): + try: + #logSys.debug('restored ticket: %s', ticket) + if self.filter.inIgnoreIPList(ticket.getIP(), log_ignore=True): continue # mark ticked was restored from database - does not put it again into db: ticket.restored = True # correct start time / ban time (by the same end of ban): @@ -227,11 +287,13 @@ class Jail(object): # ignore obsolete tickets: if btm != -1 and btm <= 0: continue - ticket.setTime(MyTime.time()) - ticket.setBanTime(btm) self.putFailTicket(ticket) + except Exception as e: # pragma: no cover + logSys.error('Restore ticket failed: %s', e, + exc_info=logSys.getEffectiveLevel()<=logging.DEBUG) except Exception as e: # pragma: no cover - logSys.error('%s', e, exc_info=logSys.getEffectiveLevel()<=logging.DEBUG) + logSys.error('Restore bans failed: %s', e, + exc_info=logSys.getEffectiveLevel()<=logging.DEBUG) def start(self): """Start the jail, by starting filter and actions threads. diff --git a/fail2ban/server/observer.py b/fail2ban/server/observer.py new file mode 100644 index 00000000..92ff8bc6 --- /dev/null +++ b/fail2ban/server/observer.py @@ -0,0 +1,529 @@ +# emacs: -*- mode: python; py-indent-offset: 4; indent-tabs-mode: t -*- +# vi: set ft=python sts=4 ts=4 sw=4 noet : + +# This file is part of Fail2Ban. +# +# Fail2Ban is free software; you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation; either version 2 of the License, or +# (at your option) any later version. +# +# Fail2Ban is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with Fail2Ban; if not, write to the Free Software +# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA. + +# Author: Serg G. Brester (sebres) +# +# This module was written as part of ban time increment feature. + +__author__ = "Serg G. Brester (sebres)" +__copyright__ = "Copyright (c) 2014 Serg G. Brester" +__license__ = "GPL" + +import threading +from .jailthread import JailThread +from .failmanager import FailManagerEmpty +import os, logging, time, datetime, math, json, random +import sys +from ..helpers import getLogger +from .mytime import MyTime +from .utils import Utils + +# Gets the instance of the logger. +logSys = getLogger(__name__) + +class ObserverThread(JailThread): + """Handles observing a database, managing bad ips and ban increment. + + Parameters + ---------- + + Attributes + ---------- + daemon + ident + name + status + active : bool + Control the state of the thread. + idle : bool + Control the idle state of the thread. + sleeptime : int + The time the thread sleeps for in the loop. + """ + + # observer is event driven and it sleep organized incremental, so sleep intervals can be shortly: + DEFAULT_SLEEP_INTERVAL = Utils.DEFAULT_SLEEP_INTERVAL / 10 + + def __init__(self): + # init thread + super(ObserverThread, self).__init__(name='Observer') + # before started - idle: + self.idle = True + ## Event queue + self._queue_lock = threading.RLock() + self._queue = [] + ## Event, be notified if anything added to event queue + self._notify = threading.Event() + ## Sleep for max 60 seconds, it possible to specify infinite to always sleep up to notifying via event, + ## but so we can later do some service "events" occurred infrequently directly in main loop of observer (not using queue) + self.sleeptime = 60 + # + self._timers = {} + self._paused = False + self.__db = None + self.__db_purge_interval = 60*60 + # observer is a not main thread: + self.daemon = True + + def __getitem__(self, i): + try: + return self._queue[i] + except KeyError: + raise KeyError("Invalid event index : %s" % i) + + def __delitem__(self, name): + try: + del self._queue[i] + except KeyError: + raise KeyError("Invalid event index: %s" % i) + + def __iter__(self): + return iter(self._queue) + + def __len__(self): + return len(self._queue) + + def __eq__(self, other): # Required for Threading + return False + + def __hash__(self): # Required for Threading + return id(self) + + def add_named_timer(self, name, starttime, *event): + """Add a named timer event to queue will start (and wake) in 'starttime' seconds + + Previous timer event with same name will be canceled and trigger self into + queue after new 'starttime' value + """ + t = self._timers.get(name, None) + if t is not None: + t.cancel() + t = threading.Timer(starttime, self.add, event) + self._timers[name] = t + t.start() + + def add_timer(self, starttime, *event): + """Add a timer event to queue will start (and wake) in 'starttime' seconds + """ + # in testing we should wait (looping) for the possible time drifts: + if MyTime.myTime is not None and starttime: + # test time after short sleep: + t = threading.Timer(Utils.DEFAULT_SLEEP_INTERVAL, self._delayedEvent, + (MyTime.time() + starttime, time.time() + starttime, event) + ) + t.start() + return + # add timer event: + t = threading.Timer(starttime, self.add, event) + t.start() + + def _delayedEvent(self, endMyTime, endTime, event): + if MyTime.time() >= endMyTime or time.time() >= endTime: + self.add_timer(0, *event) + return + # repeat after short sleep: + t = threading.Timer(Utils.DEFAULT_SLEEP_INTERVAL, self._delayedEvent, + (endMyTime, endTime, event) + ) + t.start() + + def pulse_notify(self): + """Notify wakeup (sets /and resets/ notify event) + """ + if not self._paused and self._notify: + self._notify.set() + #self._notify.clear() + + def add(self, *event): + """Add a event to queue and notify thread to wake up. + """ + ## lock and add new event to queue: + with self._queue_lock: + self._queue.append(event) + self.pulse_notify() + + def add_wn(self, *event): + """Add a event to queue withouth notifying thread to wake up. + """ + ## lock and add new event to queue: + with self._queue_lock: + self._queue.append(event) + + def call_lambda(self, l, *args): + l(*args) + + def run(self): + """Main loop for Threading. + + This function is the main loop of the thread. + + Returns + ------- + bool + True when the thread exits nicely. + """ + logSys.info("Observer start...") + ## first time create named timer to purge database each hour (clean old entries) ... + self.add_named_timer('DB_PURGE', self.__db_purge_interval, 'db_purge') + ## Mapping of all possible event types of observer: + __meth = { + # universal lambda: + 'call': self.call_lambda, + # system and service events: + 'db_set': self.db_set, + 'db_purge': self.db_purge, + # service events of observer self: + 'is_alive' : self.isAlive, + 'is_active': self.isActive, + 'start': self.start, + 'stop': self.stop, + 'nop': lambda:(), + 'shutdown': lambda:() + } + try: + ## check it self with sending is_alive event + self.add('is_alive') + ## if we should stop - break a main loop + while self.active: + self.idle = False + ## check events available and execute all events from queue + while not self._paused: + ## lock, check and pop one from begin of queue: + try: + ev = None + with self._queue_lock: + if len(self._queue): + ev = self._queue.pop(0) + if ev is None: + break + ## retrieve method by name + meth = ev[0] + if not callable(ev[0]): meth = __meth.get(meth) or getattr(self, meth) + ## execute it with rest of event as variable arguments + meth(*ev[1:]) + except Exception as e: + #logSys.error('%s', e, exc_info=logSys.getEffectiveLevel()<=logging.DEBUG) + logSys.error('%s', e, exc_info=True) + ## going sleep, wait for events (in queue) + n = self._notify + if n: + self.idle = True + n.wait(self.sleeptime) + ## wake up - reset signal now (we don't need it so long as we reed from queue) + n.clear() + if self._paused: + continue + else: + ## notify event deleted (shutdown) - just sleep a litle bit (waiting for shutdown events, prevent high cpu usage) + time.sleep(ObserverThread.DEFAULT_SLEEP_INTERVAL) + ## stop by shutdown and empty queue : + if not self.is_full: + break + ## end of main loop - exit + logSys.info("Observer stopped, %s events remaining.", len(self._queue)) + #print("Observer stopped, %s events remaining." % len(self._queue)) + except Exception as e: + logSys.error('Observer stopped after error: %s', e, exc_info=True) + #print("Observer stopped with error: %s" % str(e)) + # clear all events - exit, for possible calls of wait_empty: + with self._queue_lock: + self._queue = [] + self.idle = True + return True + + def isAlive(self): + #logSys.debug("Observer alive...") + return True + + def isActive(self, fromStr=None): + # logSys.info("Observer alive, %s%s", + # 'active' if self.active else 'inactive', + # '' if fromStr is None else (", called from '%s'" % fromStr)) + return self.active + + def start(self): + with self._queue_lock: + if not self.active: + super(ObserverThread, self).start() + + def stop(self): + if self.active and self._notify: + wtime = 5 + logSys.info("Observer stop ... try to end queue %s seconds", wtime) + #print("Observer stop ....") + # just add shutdown job to make possible wait later until full (events remaining) + with self._queue_lock: + self.add_wn('shutdown') + #don't pulse - just set, because we will delete it hereafter (sometimes not wakeup) + n = self._notify + self._notify.set() + #self.pulse_notify() + self._notify = None + # wait max wtime seconds until full (events remaining) + self.wait_empty(wtime) + n.clear() + self.active = False + self.wait_idle(0.5) + + @property + def is_full(self): + with self._queue_lock: + return True if len(self._queue) else False + + def wait_empty(self, sleeptime=None): + """Wait observer is running and returns if observer has no more events (queue is empty) + """ + time.sleep(ObserverThread.DEFAULT_SLEEP_INTERVAL) + if sleeptime is not None: + e = MyTime.time() + sleeptime + # block queue with not operation to be sure all really jobs are executed if nop goes from queue : + if self._notify is not None: + self.add_wn('nop') + if self.is_full and self.idle: + self.pulse_notify() + while self.is_full: + if sleeptime is not None and MyTime.time() > e: + break + time.sleep(ObserverThread.DEFAULT_SLEEP_INTERVAL) + # wait idle to be sure the last queue element is processed (because pop event before processing it) : + self.wait_idle(0.001) + return not self.is_full + + + def wait_idle(self, sleeptime=None): + """Wait observer is running and returns if observer idle (observer sleeps) + """ + time.sleep(ObserverThread.DEFAULT_SLEEP_INTERVAL) + if self.idle: + return True + if sleeptime is not None: + e = MyTime.time() + sleeptime + while not self.idle: + if sleeptime is not None and MyTime.time() > e: + break + time.sleep(ObserverThread.DEFAULT_SLEEP_INTERVAL) + return self.idle + + @property + def paused(self): + return self._paused; + + @paused.setter + def paused(self, pause): + if self._paused == pause: + return + self._paused = pause + # wake after pause ended + self.pulse_notify() + + + @property + def status(self): + """Status of observer to be implemented. [TODO] + """ + return ('', '') + + ## ----------------------------------------- + ## [Async] database service functionality ... + ## ----------------------------------------- + + def db_set(self, db): + self.__db = db + + def db_purge(self): + logSys.info("Purge database event occurred") + if self.__db is not None: + self.__db.purge() + # trigger timer again ... + self.add_named_timer('DB_PURGE', self.__db_purge_interval, 'db_purge') + + ## ----------------------------------------- + ## [Async] ban time increment functionality ... + ## ----------------------------------------- + + def failureFound(self, failManager, jail, ticket): + """ Notify observer a failure for ip was found + + Observer will check ip was known (bad) and possibly increase an retry count + """ + # check jail active : + if not jail.isAlive(): + return + ip = ticket.getIP() + unixTime = ticket.getTime() + logSys.debug("[%s] Observer: failure found %s", jail.name, ip) + # increase retry count for known (bad) ip, corresponding banCount of it (one try will count than 2, 3, 5, 9 ...) : + banCount = 0 + retryCount = 1 + timeOfBan = None + try: + maxRetry = failManager.getMaxRetry() + db = jail.database + if db is not None: + for banCount, timeOfBan, lastBanTime in db.getBan(ip, jail): + banCount = max(banCount, ticket.getBanCount()) + retryCount = ((1 << (banCount if banCount < 20 else 20))/2 + 1) + # if lastBanTime == -1 or timeOfBan + lastBanTime * 2 > MyTime.time(): + # retryCount = maxRetry + break + retryCount = min(retryCount, maxRetry) + # check this ticket already known (line was already processed and in the database and will be restored from there): + if timeOfBan is not None and unixTime <= timeOfBan: + logSys.debug("[%s] Ignore failure %s before last ban %s < %s, restored", + jail.name, ip, unixTime, timeOfBan) + return + # for not increased failures observer should not add it to fail manager, because was already added by filter self + if retryCount <= 1: + return + # retry counter was increased - add it again: + logSys.info("[%s] Found %s, bad - %s, %s # -> %s%s", jail.name, ip, + datetime.datetime.fromtimestamp(unixTime).strftime("%Y-%m-%d %H:%M:%S"), banCount, retryCount, + (', Ban' if retryCount >= maxRetry else '')) + # retryCount-1, because a ticket was already once incremented by filter self + retryCount = failManager.addFailure(ticket, retryCount - 1, True) + ticket.setBanCount(banCount) + # after observe we have increased attempt count, compare it >= maxretry ... + if retryCount >= maxRetry: + # perform the banning of the IP now (again) + # [todo]: this code part will be used multiple times - optimize it later. + try: # pragma: no branch - exception is the only way out + while True: + ticket = failManager.toBan(ip) + jail.putFailTicket(ticket) + except FailManagerEmpty: + failManager.cleanup(MyTime.time()) + + except Exception as e: + logSys.error('%s', e, exc_info=logSys.getEffectiveLevel()<=logging.DEBUG) + + + class BanTimeIncr: + def __init__(self, banTime, banCount): + self.Time = banTime + self.Count = banCount + + def calcBanTime(self, jail, banTime, banCount): + be = jail.getBanTimeExtra() + return be['evformula'](self.BanTimeIncr(banTime, banCount)) + + def incrBanTime(self, jail, banTime, ticket): + """Check for IP address to increment ban time (if was already banned). + + Returns + ------- + float + new ban time. + """ + # check jail active : + if not jail.isAlive() or not jail.database: + return banTime + be = jail.getBanTimeExtra() + ip = ticket.getIP() + orgBanTime = banTime + # check ip was already banned (increment time of ban): + try: + if banTime > 0 and be.get('increment', False): + # search IP in database and increase time if found: + for banCount, timeOfBan, lastBanTime in \ + jail.database.getBan(ip, jail, overalljails=be.get('overalljails', False)) \ + : + # increment count in ticket (if still not increased from banmanager, test-cases?): + if banCount >= ticket.getBanCount(): + ticket.setBanCount(banCount+1) + logSys.debug('IP %s was already banned: %s #, %s', ip, banCount, timeOfBan); + # calculate new ban time + if banCount > 0: + banTime = be['evformula'](self.BanTimeIncr(banTime, banCount)) + ticket.setBanTime(banTime) + # check current ticket time to prevent increasing for twice read tickets (restored from log file besides database after restart) + if ticket.getTime() > timeOfBan: + logSys.info('[%s] IP %s is bad: %s # last %s - incr %s to %s' % (jail.name, ip, banCount, + datetime.datetime.fromtimestamp(timeOfBan).strftime("%Y-%m-%d %H:%M:%S"), + datetime.timedelta(seconds=int(orgBanTime)), datetime.timedelta(seconds=int(banTime)))); + else: + ticket.restored = True + break + except Exception as e: + logSys.error('%s', e, exc_info=logSys.getEffectiveLevel()<=logging.DEBUG) + return banTime + + def banFound(self, ticket, jail, btime): + """ Notify observer a ban occured for ip + + Observer will check ip was known (bad) and possibly increase/prolong a ban time + Secondary we will actualize the bans and bips (bad ip) in database + """ + if ticket.restored: # pragma: no cover (normally not resored tickets only) + return + try: + oldbtime = btime + ip = ticket.getIP() + logSys.debug("[%s] Observer: ban found %s, %s", jail.name, ip, btime) + # if not permanent and ban time was not set - check time should be increased: + if btime != -1 and ticket.getBanTime() is None: + btime = self.incrBanTime(jail, btime, ticket) + # if we should prolong ban time: + if btime == -1 or btime > oldbtime: + ticket.setBanTime(btime) + # if not permanent + if btime != -1: + bendtime = ticket.getTime() + btime + logtime = (datetime.timedelta(seconds=int(btime)), + datetime.datetime.fromtimestamp(bendtime).strftime("%Y-%m-%d %H:%M:%S")) + # check ban is not too old : + if bendtime < MyTime.time(): + logSys.debug('Ignore old bantime %s', logtime[1]) + return False + else: + logtime = ('permanent', 'infinite') + # if ban time was prolonged - log again with new ban time: + if btime != oldbtime: + logSys.notice("[%s] Increase Ban %s (%d # %s -> %s)", jail.name, + ip, ticket.getBanCount(), *logtime) + # delayed prolonging ticket via actions that expected this (not later than 10 sec): + logSys.log(5, "[%s] Observer: prolong %s in %s", jail.name, ip, (btime, oldbtime)) + self.add_timer(min(10, max(0, btime - oldbtime - 5)), self.prolongBan, ticket, jail) + # add ticket to database, but only if was not restored (not already read from database): + if jail.database is not None and not ticket.restored: + # add to database always only after ban time was calculated an not yet already banned: + jail.database.addBan(jail, ticket) + except Exception as e: + logSys.error('%s', e, exc_info=logSys.getEffectiveLevel()<=logging.DEBUG) + + def prolongBan(self, ticket, jail): + """ Notify observer a ban occured for ip + + Observer will check ip was known (bad) and possibly increase/prolong a ban time + Secondary we will actualize the bans and bips (bad ip) in database + """ + try: + btime = ticket.getBanTime() + ip = ticket.getIP() + logSys.debug("[%s] Observer: prolong %s, %s", jail.name, ip, btime) + # prolong ticket via actions that expected this: + jail.actions._prolongBan(ticket) + except Exception as e: + logSys.error('%s', e, exc_info=logSys.getEffectiveLevel()<=logging.DEBUG) + +# Global observer initial created in server (could be later rewriten via singleton) +class _Observers: + def __init__(self): + self.Main = None + +Observers = _Observers() diff --git a/fail2ban/server/server.py b/fail2ban/server/server.py index e3b22c44..abe6ed61 100644 --- a/fail2ban/server/server.py +++ b/fail2ban/server/server.py @@ -33,6 +33,7 @@ import signal import stat import sys +from .observer import Observers, ObserverThread from .jails import Jails from .filter import FileFilter, JournalFilter from .transmitter import Transmitter @@ -94,7 +95,7 @@ class Server: self.__prev_signals[s] = signal.getsignal(s) signal.signal(s, new) - def start(self, sock, pidfile, force=False, conf={}): + def start(self, sock, pidfile, force=False, observer=True, conf={}): # First set the mask to only allow access to owner os.umask(0077) # Second daemonize before logging etc, because it will close all handles: @@ -144,6 +145,12 @@ class Server: except (OSError, IOError) as e: # pragma: no cover logSys.error("Unable to create PID file: %s", e) + # Create observers and start it: + if observer: + if Observers.Main is None: + Observers.Main = ObserverThread() + Observers.Main.start() + # Start the communication logSys.debug("Starting communication") try: @@ -152,15 +159,22 @@ class Server: self.__asyncServer.start(sock, force) except AsyncServerException as e: logSys.error("Could not start server: %s", e) + # Removes the PID file. try: logSys.debug("Remove PID file %s", pidfile) os.remove(pidfile) except (OSError, IOError) as e: # pragma: no cover logSys.error("Unable to remove PID file: %s", e) - logSys.info("Exiting Fail2ban") + + # Stop (if not yet already executed): + self.quit() def quit(self): + # Give observer a small chance to complete its work before exit + if Observers.Main is not None: + Observers.Main.stop() + # Stop communication first because if jail's unban action # tries to communicate via fail2ban-client we get a lockup # among threads. So the simplest resolution is to stop all @@ -168,8 +182,7 @@ class Server: # are exiting) # See https://github.com/fail2ban/fail2ban/issues/7 if self.__asyncServer is not None: - self.__asyncServer.stop() - self.__asyncServer = None + self.__asyncServer.stop_communication() # Now stop all the jails self.stopAllJail() @@ -190,6 +203,16 @@ class Server: for s, sh in self.__prev_signals.iteritems(): signal.signal(s, sh) + # Stop observer and exit + if Observers.Main is not None: + Observers.Main.stop() + Observers.Main = None + # Stop async + if self.__asyncServer is not None: + self.__asyncServer.stop() + self.__asyncServer = None + logSys.info("Exiting Fail2ban") + # Prevent to call quit twice: self.quit = lambda: False @@ -481,6 +504,12 @@ class Server: def getBanTime(self, name): return self.__jails[name].actions.getBanTime() + + def setBanTimeExtra(self, name, opt, value): + self.__jails[name].setBanTimeExtra(opt, value) + + def getBanTimeExtra(self, name, opt): + return self.__jails[name].getBanTimeExtra(opt) def isStarted(self): return self.__asyncServer is not None and self.__asyncServer.isActive() @@ -604,7 +633,7 @@ class Server: try: handler.flush() handler.close() - except (ValueError, KeyError): # pragma: no cover + except (ValueError, KeyError): # pragma: no cover # Is known to be thrown after logging was shutdown once # with older Pythons -- seems to be safe to ignore there # At least it was still failing on 2.6.2-0ubuntu1 (jaunty) @@ -691,6 +720,8 @@ class Server: logSys.error( "Unable to import fail2ban database module as sqlite " "is not available.") + if Observers.Main is not None: + Observers.Main.db_set(self.__db) def getDatabase(self): return self.__db diff --git a/fail2ban/server/ticket.py b/fail2ban/server/ticket.py index be205303..c1a14cb2 100644 --- a/fail2ban/server/ticket.py +++ b/fail2ban/server/ticket.py @@ -24,8 +24,6 @@ __author__ = "Cyril Jaquier" __copyright__ = "Copyright (c) 2004 Cyril Jaquier" __license__ = "GPL" -import sys - from ..helpers import getLogger from .ipdns import IPAddr from .mytime import MyTime @@ -35,6 +33,7 @@ logSys = getLogger(__name__) class Ticket(object): + __slots__ = ('_ip', '_flags', '_banCount', '_banTime', '_time', '_data', '_retry', '_lastReset') MAX_TIME = 0X7FFFFFFFFFFF ;# 4461763-th year @@ -61,35 +60,44 @@ class Ticket(object): self._data[k] = v if ticket: # ticket available - copy whole information from ticket: - self.__dict__.update(i for i in ticket.__dict__.iteritems() if i[0] in self.__dict__) + self.update(ticket) + #self.__dict__.update(i for i in ticket.__dict__.iteritems() if i[0] in self.__dict__) def __str__(self): - return "%s: ip=%s time=%s #attempts=%d matches=%r" % \ - (self.__class__.__name__.split('.')[-1], self.__ip, self._time, - self._data['failures'], self._data.get('matches', [])) + return "%s: ip=%s time=%s bantime=%s bancount=%s #attempts=%d matches=%r" % \ + (self.__class__.__name__.split('.')[-1], self._ip, self._time, + self._banTime, self._banCount, + self._data['failures'], self._data.get('matches', [])) def __repr__(self): return str(self) def __eq__(self, other): try: - return self.__ip == other.__ip and \ + return self._ip == other._ip and \ round(self._time, 2) == round(other._time, 2) and \ self._data == other._data except AttributeError: return False + def update(self, ticket): + for n in ticket.__slots__: + v = getattr(ticket, n, None) + if v is not None: + setattr(self, n, v) + + def setIP(self, value): # guarantee using IPAddr instead of unicode, str for the IP if isinstance(value, basestring): value = IPAddr(value) - self.__ip = value + self._ip = value def getID(self): - return self._data.get('fid', self.__ip) + return self._data.get('fid', self._ip) def getIP(self): - return self.__ip + return self._ip def setTime(self, value): self._time = value @@ -98,16 +106,17 @@ class Ticket(object): return self._time def setBanTime(self, value): - self._banTime = value; + self._banTime = value def getBanTime(self, defaultBT=None): return (self._banTime if self._banTime is not None else defaultBT) - def setBanCount(self, value): - self._banCount = value; + def setBanCount(self, value, always=False): + if always or value > self._banCount: + self._banCount = value - def incrBanCount(self, value = 1): - self._banCount += value; + def incrBanCount(self, value=1): + self._banCount += value def getBanCount(self): return self._banCount; @@ -205,21 +214,21 @@ class FailTicket(Ticket): def __init__(self, ip=None, time=None, matches=None, data={}, ticket=None): # this class variables: - self.__retry = 0 - self.__lastReset = None + self._retry = 0 + self._lastReset = None # create/copy using default ticket constructor: Ticket.__init__(self, ip, time, matches, data, ticket) # init: if ticket is None: - self.__lastReset = time if time is not None else self.getTime() - if not self.__retry: - self.__retry = self._data['failures']; + self._lastReset = time if time is not None else self.getTime() + if not self._retry: + self._retry = self._data['failures']; def setRetry(self, value): """ Set artificial retry count, normally equal failures / attempt, used in incremental features (BanTimeIncr) to increase retry count for bad IPs """ - self.__retry = value + self._retry = value if not self._data['failures']: self._data['failures'] = 1 if not value: @@ -230,10 +239,10 @@ class FailTicket(Ticket): """ Returns failures / attempt count or artificial retry count increased for bad IPs """ - return max(self.__retry, self._data['failures']) + return max(self._retry, self._data['failures']) def inc(self, matches=None, attempt=1, count=1): - self.__retry += count + self._retry += count self._data['failures'] += attempt if matches: # we should duplicate "matches", because possibly referenced to multiple tickets: @@ -250,15 +259,24 @@ class FailTicket(Ticket): return self._time def getLastReset(self): - return self.__lastReset + return self._lastReset def setLastReset(self, value): - self.__lastReset = value + self._lastReset = value + + @staticmethod + def wrap(o): + o.__class__ = FailTicket + return o ## # Ban Ticket. # # This class extends the Ticket class. It is mainly used by the BanManager. -class BanTicket(Ticket): - pass +class BanTicket(FailTicket): + + @staticmethod + def wrap(o): + o.__class__ = BanTicket + return o diff --git a/fail2ban/server/transmitter.py b/fail2ban/server/transmitter.py index ecc2a138..ecdc6fad 100644 --- a/fail2ban/server/transmitter.py +++ b/fail2ban/server/transmitter.py @@ -278,6 +278,11 @@ class Transmitter: value = command[2] self.__server.setBanTime(name, value) return self.__server.getBanTime(name) + elif command[1].startswith("bantime."): + value = command[2] + opt = command[1][len("bantime."):] + self.__server.setBanTimeExtra(name, opt, value) + return self.__server.getBanTimeExtra(name, opt) elif command[1] == "banip": value = command[2] return self.__server.setBanIP(name,value) @@ -376,6 +381,9 @@ class Transmitter: # Action elif command[1] == "bantime": return self.__server.getBanTime(name) + elif command[1].startswith("bantime."): + opt = command[1][len("bantime."):] + return self.__server.getBanTimeExtra(name, opt) elif command[1] == "actions": return self.__server.getActions(name).keys() elif command[1] == "action": diff --git a/fail2ban/server/utils.py b/fail2ban/server/utils.py index 58363ff0..613f623d 100644 --- a/fail2ban/server/utils.py +++ b/fail2ban/server/utils.py @@ -102,7 +102,7 @@ class Utils(): def unset(self, k): try: del self._cache[k] - except KeyError: # pragme: no cover + except KeyError: # pragma: no cover pass @@ -330,7 +330,7 @@ class Utils(): return e.errno == errno.EPERM else: return True - else: # pragma : no cover (no windows currently supported) + else: # pragma: no cover (no windows currently supported) @staticmethod def pid_exists(pid): import ctypes diff --git a/fail2ban/tests/actionstestcase.py b/fail2ban/tests/actionstestcase.py index 8969db36..279290d1 100644 --- a/fail2ban/tests/actionstestcase.py +++ b/fail2ban/tests/actionstestcase.py @@ -149,7 +149,7 @@ class ExecuteActions(LogCaptureTestCase): "action2", os.path.join(TEST_FILES_DIR, "action.d/action_modifyainfo.py"), {}) - self.__jail.putFailTicket(FailTicket("1.2.3.4", 0)) + self.__jail.putFailTicket(FailTicket("1.2.3.4")) self.__actions._Actions__checkBan() # Will fail if modification of aInfo from first action propagates # to second action, as both delete same key diff --git a/fail2ban/tests/actiontestcase.py b/fail2ban/tests/actiontestcase.py index cbd0aaca..47b266fd 100644 --- a/fail2ban/tests/actiontestcase.py +++ b/fail2ban/tests/actiontestcase.py @@ -206,15 +206,15 @@ class CommandActionTest(LogCaptureTestCase): self.assertEqual( self.__action.replaceTag("", {'matches': "some >char< should \< be[ escap}ed&\n"}), - "some \\>char\\< should \\\\\\< be\\[ escap\\}ed\\&\n") + "some \\>char\\< should \\\\\\< be\\[ escap\\}ed\\&\\n") self.assertEqual( self.__action.replaceTag("", {'ipmatches': "some >char< should \< be[ escap}ed&\n"}), - "some \\>char\\< should \\\\\\< be\\[ escap\\}ed\\&\n") + "some \\>char\\< should \\\\\\< be\\[ escap\\}ed\\&\\n") self.assertEqual( self.__action.replaceTag("", - {'ipjailmatches': "some >char< should \< be[ escap}ed&\n"}), - "some \\>char\\< should \\\\\\< be\\[ escap\\}ed\\&\n") + {'ipjailmatches': "some >char< should \< be[ escap}ed&\r\n"}), + "some \\>char\\< should \\\\\\< be\\[ escap\\}ed\\&\\r\\n") # Recursive aInfo["ABC"] = "" diff --git a/fail2ban/tests/banmanagertestcase.py b/fail2ban/tests/banmanagertestcase.py index 33cd2dac..4ddb9ce6 100644 --- a/fail2ban/tests/banmanagertestcase.py +++ b/fail2ban/tests/banmanagertestcase.py @@ -26,6 +26,8 @@ __license__ = "GPL" import unittest +from .utils import setUpMyTime, tearDownMyTime + from ..server.banmanager import BanManager from ..server.ticket import BanTicket @@ -33,12 +35,14 @@ class AddFailure(unittest.TestCase): def setUp(self): """Call before every test case.""" super(AddFailure, self).setUp() + setUpMyTime() self.__ticket = BanTicket('193.168.0.128', 1167605999.0) self.__banManager = BanManager() def tearDown(self): """Call after every test case.""" super(AddFailure, self).tearDown() + tearDownMyTime() def testAdd(self): self.assertTrue(self.__banManager.addBanTicket(self.__ticket)) @@ -93,6 +97,25 @@ class AddFailure(unittest.TestCase): self.assertTrue(self.__banManager.addBanTicket(self.__ticket)) ticket = BanTicket('111.111.1.111', 1167605999.0) self.assertFalse(self.__banManager._inBanList(ticket)) + + def testBanTimeIncr(self): + ticket = BanTicket(self.__ticket.getIP(), self.__ticket.getTime()) + ## increase twice and at end permanent, check time/count increase: + c = 0 + for i in (1000, 2000, -1): + self.__banManager.addBanTicket(self.__ticket); c += 1 + ticket.setBanTime(i) + self.assertFalse(self.__banManager.addBanTicket(ticket)); # no incr of c (already banned) + self.assertEqual(str(self.__banManager.getTicketByID(ticket.getIP())), + "BanTicket: ip=%s time=%s bantime=%s bancount=%s #attempts=0 matches=[]" % (ticket.getIP(), ticket.getTime(), i, c)) + ## after permanent, it should remain permanent ban time (-1): + self.__banManager.addBanTicket(self.__ticket); c += 1 + ticket.setBanTime(-1) + self.assertFalse(self.__banManager.addBanTicket(ticket)); # no incr of c (already banned) + ticket.setBanTime(1000) + self.assertFalse(self.__banManager.addBanTicket(ticket)); # no incr of c (already banned) + self.assertEqual(str(self.__banManager.getTicketByID(ticket.getIP())), + "BanTicket: ip=%s time=%s bantime=%s bancount=%s #attempts=0 matches=[]" % (ticket.getIP(), ticket.getTime(), -1, c)) def testUnban(self): btime = self.__banManager.getBanTime() @@ -137,6 +160,7 @@ class StatusExtendedCymruInfo(unittest.TestCase): """Call before every test case.""" super(StatusExtendedCymruInfo, self).setUp() unittest.F2B.SkipIfNoNetwork() + setUpMyTime() self.__ban_ip = "93.184.216.34" self.__asn = "15133" self.__country = "EU" @@ -148,6 +172,7 @@ class StatusExtendedCymruInfo(unittest.TestCase): def tearDown(self): """Call after every test case.""" super(StatusExtendedCymruInfo, self).tearDown() + tearDownMyTime() available = True, None diff --git a/fail2ban/tests/clientreadertestcase.py b/fail2ban/tests/clientreadertestcase.py index c3a10c36..bb8f1415 100644 --- a/fail2ban/tests/clientreadertestcase.py +++ b/fail2ban/tests/clientreadertestcase.py @@ -612,8 +612,6 @@ class JailsReaderTest(LogCaptureTestCase): # all must have some actionban defined self.assertTrue(actionReader._opts.get('actionban', '').strip(), msg="Action file %r is lacking actionban" % actionConfig) - self.assertIn('Init', actionReader.sections(), - msg="Action file %r is lacking [Init] section" % actionConfig) def testReadStockJailConf(self): jails = JailsReader(basedir=CONFIG_DIR, share_config=CONFIG_DIR_SHARE_CFG) # we are running tests from root project dir atm diff --git a/fail2ban/tests/databasetestcase.py b/fail2ban/tests/databasetestcase.py index cbfc1517..11ee661e 100644 --- a/fail2ban/tests/databasetestcase.py +++ b/fail2ban/tests/databasetestcase.py @@ -124,6 +124,33 @@ class DatabaseTest(LogCaptureTestCase): self.assertRaises(NotImplementedError, self.db.updateDb, Fail2BanDb.__version__ + 1) os.remove(self.db._dbBackupFilename) + def testUpdateDb2(self): + if Fail2BanDb is None or self.db.filename == ':memory:': # pragma: no cover + return + shutil.copyfile( + os.path.join(TEST_FILES_DIR, 'database_v2.db'), self.dbFilename) + self.db = Fail2BanDb(self.dbFilename) + self.assertEqual(self.db.getJailNames(), set(['pam-generic'])) + self.assertEqual(self.db.getLogPaths(), set(['/var/log/auth.log'])) + bans = self.db.getBans() + self.assertEqual(len(bans), 2) + # compare first ticket completely: + ticket = FailTicket("1.2.3.7", 1417595494, [ + u'Dec 3 09:31:08 f2btest test:auth[27658]: pam_unix(test:auth): authentication failure; logname= uid=0 euid=0 tty=test ruser= rhost=1.2.3.7', + u'Dec 3 09:31:32 f2btest test:auth[27671]: pam_unix(test:auth): authentication failure; logname= uid=0 euid=0 tty=test ruser= rhost=1.2.3.7', + u'Dec 3 09:31:34 f2btest test:auth[27673]: pam_unix(test:auth): authentication failure; logname= uid=0 euid=0 tty=test ruser= rhost=1.2.3.7' + ]) + ticket.setAttempt(3) + self.assertEqual(bans[0], ticket) + # second ban found also: + self.assertEqual(bans[1].getIP(), "1.2.3.8") + # updated ? + self.assertEqual(self.db.updateDb(Fail2BanDb.__version__), Fail2BanDb.__version__) + # further update should fail: + self.assertRaises(NotImplementedError, self.db.updateDb, Fail2BanDb.__version__ + 1) + # clean: + os.remove(self.db._dbBackupFilename) + def testAddJail(self): if Fail2BanDb is None: # pragma: no cover return @@ -367,10 +394,17 @@ class DatabaseTest(LogCaptureTestCase): tickets = self.db.getCurrentBans(jail=self.jail, forbantime=15, fromtime=MyTime.time() + MyTime.str2seconds("1year")) self.assertEqual(len(tickets), 0) - # persistent bantime (-1), so never expired: + # persistent bantime (-1), so never expired (but no persistent tickets): tickets = self.db.getCurrentBans(jail=self.jail, forbantime=-1, fromtime=MyTime.time() + MyTime.str2seconds("1year")) - self.assertEqual(len(tickets), 2) + self.assertEqual(len(tickets), 0) + # add persistent one: + ticket.setBanTime(-1) + self.db.addBan(self.jail, ticket) + # persistent bantime (-1), so never expired (1 persistent ticket): + tickets = self.db.getCurrentBans(jail=self.jail, forbantime=-1, + fromtime=MyTime.time() + MyTime.str2seconds("1year")) + self.assertEqual(len(tickets), 1) def testActionWithDB(self): # test action together with database functionality @@ -381,8 +415,9 @@ class DatabaseTest(LogCaptureTestCase): "action_checkainfo", os.path.join(TEST_FILES_DIR, "action.d/action_checkainfo.py"), {}) - ticket = FailTicket("1.2.3.4", MyTime.time(), ['test', 'test']) + ticket = FailTicket("1.2.3.4") ticket.setAttempt(5) + ticket.setMatches(['test', 'test']) self.jail.putFailTicket(ticket) actions._Actions__checkBan() self.assertLogged("ban ainfo %s, %s, %s, %s" % (True, True, True, True)) diff --git a/fail2ban/tests/dummyjail.py b/fail2ban/tests/dummyjail.py index 19f97f4e..c7c139e3 100644 --- a/fail2ban/tests/dummyjail.py +++ b/fail2ban/tests/dummyjail.py @@ -28,6 +28,11 @@ from ..server.jail import Jail from ..server.actions import Actions +class DummyActions(Actions): + def checkBan(self): + return self._Actions__checkBan() + + class DummyJail(Jail): """A simple 'jail' to suck in all the tickets generated by Filter's """ @@ -36,7 +41,7 @@ class DummyJail(Jail): self.queue = [] super(DummyJail, self).__init__(name='DummyJail', backend=backend) self.__db = None - self.__actions = Actions(self) + self.__actions = DummyActions(self) def __len__(self): with self.lock: diff --git a/fail2ban/tests/fail2banclienttestcase.py b/fail2ban/tests/fail2banclienttestcase.py index caacf63c..260fd3a3 100644 --- a/fail2ban/tests/fail2banclienttestcase.py +++ b/fail2ban/tests/fail2banclienttestcase.py @@ -43,7 +43,8 @@ from .. import protocol from ..server import server from ..server.mytime import MyTime from ..server.utils import Utils -from .utils import LogCaptureTestCase, logSys as DefLogSys, with_tmpdir, shutil, logging +from .utils import LogCaptureTestCase, logSys as DefLogSys, with_tmpdir, shutil, logging, \ + TEST_NOW, tearDownMyTime from ..helpers import getLogger @@ -80,6 +81,35 @@ fail2banclient.output = \ fail2banserver.output = \ protocol.output = _test_output +def _time_shift(shift): + # jump to the future (+shift minutes): + logSys.debug("===>>> time shift + %s min", shift) + MyTime.setTime(MyTime.time() + shift*60) + + +Observers = server.Observers + +def _observer_wait_idle(): + """Helper to wait observer becomes idle""" + if Observers.Main is not None: + Observers.Main.wait_empty(MID_WAITTIME) + Observers.Main.wait_idle(MID_WAITTIME / 5) + +def _observer_wait_before_incrban(cond, timeout=MID_WAITTIME): + """Helper to block observer before increase bantime until some condition gets true""" + if Observers.Main is not None: + # switch ban handler: + _obs_banFound = Observers.Main.banFound + def _banFound(*args, **kwargs): + # restore original handler: + Observers.Main.banFound = _obs_banFound + # wait for: + logSys.debug(' [Observer::banFound] *** observer blocked for test') + Utils.wait_for(cond, timeout) + logSys.debug(' [Observer::banFound] +++ observer runs again') + # original banFound: + _obs_banFound(*args, **kwargs) + Observers.Main.banFound = _banFound # # Mocking .exit so we could test its correct operation. @@ -309,6 +339,7 @@ def with_foreground_server_thread(startextra={}): # so don't kill (same process) - if success, just wait for end of worker: if phase.get('end', None): th.join() + tearDownMyTime() return wrapper return _deco_wrapper @@ -335,6 +366,7 @@ class Fail2banClientServerBase(LogCaptureTestCase): server.DEF_LOGTARGET = SRV_DEF_LOGTARGET server.DEF_LOGLEVEL = SRV_DEF_LOGLEVEL LogCaptureTestCase.tearDown(self) + tearDownMyTime() @staticmethod def _test_exit(code=0): @@ -948,6 +980,8 @@ class Fail2banServerTest(Fail2banClientServerBase): "[test-jail2] Found 192.0.2.3", "[test-jail2] Ban 192.0.2.3", all=True) + # if observer available wait for it becomes idle (write all tickets to db): + _observer_wait_idle() # rotate logs: _write_file(test1log, "w+") @@ -1151,3 +1185,106 @@ class Fail2banServerTest(Fail2banClientServerBase): self.assertLogged( "Jail 'test-jail1' stopped", "Jail 'test-jail1' started", all=True) + + @with_foreground_server_thread() + def testServerObserver(self, tmp, startparams): + cfg = pjoin(tmp, "config") + test1log = pjoin(tmp, "test1.log") + + os.mkdir(pjoin(cfg, "action.d")) + def _write_action_cfg(actname="test-action1", prolong=True): + fn = pjoin(cfg, "action.d", "%s.conf" % actname) + _write_file(fn, "w", + "[DEFAULT]", + "", + "[Definition]", + "actionban = printf %%s \"[%(name)s] %(actname)s: ++ ban -c -t : \"", \ + "actionprolong = printf %%s \"[%(name)s] %(actname)s: ++ prolong -c -t : \"" \ + if prolong else "", + "actionunban = printf %%b '[%(name)s] %(actname)s: -- unban '", + ) + if unittest.F2B.log_level <= logging.DEBUG: # pragma: no cover + _out_file(fn) + + def _write_jail_cfg(backend="polling"): + _write_file(pjoin(cfg, "jail.conf"), "w", + "[INCLUDES]", "", + "[DEFAULT]", "", + "usedns = no", + "maxretry = 3", + "findtime = 1m", + "bantime = 5m", + "bantime.increment = true", + "datepattern = {^LN-BEG}EPOCH", + "", + "[test-jail1]", "backend = " + backend, "filter =", + "action = test-action1[name='%(__name__)s']", + " test-action2[name='%(__name__)s']", + "logpath = " + test1log, + "failregex = ^\s*failure 401|403 from :\s*.*$", + "enabled = true", + "", + ) + if unittest.F2B.log_level <= logging.DEBUG: # pragma: no cover + _out_file(pjoin(cfg, "jail.conf")) + + # create test config: + _write_action_cfg(actname="test-action1", prolong=False) + _write_action_cfg(actname="test-action2", prolong=True) + _write_jail_cfg() + + _write_file(test1log, "w") + # initial start: + self.pruneLog("[test-phase 0) time-0]") + self.execSuccess(startparams, "reload") + # generate bad ip: + _write_file(test1log, "w+", *( + (str(int(MyTime.time())) + " failure 401 from 192.0.2.11: I'm bad \"hacker\" `` $(echo test)",) * 3 + )) + # wait for ban: + _observer_wait_idle() + self.assertLogged( + "stdout: '[test-jail1] test-action1: ++ ban 192.0.2.11 -c 1 -t 300 : ", + "stdout: '[test-jail1] test-action2: ++ ban 192.0.2.11 -c 1 -t 300 : ", + all=True, wait=MID_WAITTIME) + # wait for observer idle (write all tickets to db): + _observer_wait_idle() + + self.pruneLog("[test-phase 1) time+10m]") + # jump to the future (+10 minutes): + _time_shift(10) + _observer_wait_idle() + self.assertLogged( + "stdout: '[test-jail1] test-action1: -- unban 192.0.2.11", + "stdout: '[test-jail1] test-action2: -- unban 192.0.2.11", + "0 ticket(s) in 'test-jail1'", + all=True, wait=MID_WAITTIME) + _observer_wait_idle() + + self.pruneLog("[test-phase 2) time+10m]") + # following tests are time-related - observer can prolong ticket (increase ban-time) + # before banning, so block it here before banFound called, prolong case later: + wakeObs = False + _observer_wait_before_incrban(lambda: wakeObs) + # write again (IP already bad): + _write_file(test1log, "w+", *( + (str(int(MyTime.time())) + " failure 401 from 192.0.2.11: I'm very bad \"hacker\" `` $(echo test)",) * 2 + )) + # wait for ban: + self.assertLogged( + "stdout: '[test-jail1] test-action1: ++ ban 192.0.2.11 -c 2 -t 300 : ", + "stdout: '[test-jail1] test-action2: ++ ban 192.0.2.11 -c 2 -t 300 : ", + all=True, wait=MID_WAITTIME) + # unblock observer here and wait it is done: + wakeObs = True + _observer_wait_idle() + + self.pruneLog("[test-phase 2) time+11m]") + # jump to the future (+1 minute): + _time_shift(1) + # wait for observer idle (write all tickets to db): + _observer_wait_idle() + # wait for prolong: + self.assertLogged( + "stdout: '[test-jail1] test-action2: ++ prolong 192.0.2.11 -c 2 -t 600 : ", + all=True, wait=MID_WAITTIME) diff --git a/fail2ban/tests/failmanagertestcase.py b/fail2ban/tests/failmanagertestcase.py index 18f2c545..ad89ec76 100644 --- a/fail2ban/tests/failmanagertestcase.py +++ b/fail2ban/tests/failmanagertestcase.py @@ -151,10 +151,10 @@ class AddFailure(unittest.TestCase): ticket_repr = repr(ticket) self.assertEqual( ticket_str, - 'FailTicket: ip=193.168.0.128 time=1167605999.0 #attempts=5 matches=[]') + 'FailTicket: ip=193.168.0.128 time=1167605999.0 bantime=None bancount=0 #attempts=5 matches=[]') self.assertEqual( ticket_repr, - 'FailTicket: ip=193.168.0.128 time=1167605999.0 #attempts=5 matches=[]') + 'FailTicket: ip=193.168.0.128 time=1167605999.0 bantime=None bancount=0 #attempts=5 matches=[]') self.assertFalse(not ticket) # and some get/set-ers otherwise not tested ticket.setTime(1000002000.0) @@ -162,7 +162,7 @@ class AddFailure(unittest.TestCase): # and str() adjusted correspondingly self.assertEqual( str(ticket), - 'FailTicket: ip=193.168.0.128 time=1000002000.0 #attempts=5 matches=[]') + 'FailTicket: ip=193.168.0.128 time=1000002000.0 bantime=None bancount=0 #attempts=5 matches=[]') def testbanNOK(self): self._addDefItems() diff --git a/fail2ban/tests/files/database_v1.db b/fail2ban/tests/files/database_v1.db index 20822671..fa2d7bb2 100644 Binary files a/fail2ban/tests/files/database_v1.db and b/fail2ban/tests/files/database_v1.db differ diff --git a/fail2ban/tests/files/database_v2.db b/fail2ban/tests/files/database_v2.db new file mode 100644 index 00000000..8954c8b5 Binary files /dev/null and b/fail2ban/tests/files/database_v2.db differ diff --git a/fail2ban/tests/observertestcase.py b/fail2ban/tests/observertestcase.py new file mode 100644 index 00000000..80e2e2b7 --- /dev/null +++ b/fail2ban/tests/observertestcase.py @@ -0,0 +1,613 @@ +# emacs: -*- mode: python; py-indent-offset: 4; indent-tabs-mode: t -*- +# vi: set ft=python sts=4 ts=4 sw=4 noet : + +# This file is part of Fail2Ban. +# +# Fail2Ban is free software; you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation; either version 2 of the License, or +# (at your option) any later version. +# +# Fail2Ban is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with Fail2Ban; if not, write to the Free Software +# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA. + +# Author: Serg G. Brester (sebres) +# + +__author__ = "Serg G. Brester (sebres)" +__copyright__ = "Copyright (c) 2014 Serg G. Brester" +__license__ = "GPL" + +import os +import sys +import unittest +import tempfile +import time + +from ..server.mytime import MyTime +from ..server.ticket import FailTicket, BanTicket +from ..server.failmanager import FailManager +from ..server.observer import Observers, ObserverThread +from ..server.utils import Utils +from .utils import LogCaptureTestCase +from ..server.filter import Filter +from .dummyjail import DummyJail + +from .databasetestcase import getFail2BanDb, Fail2BanDb + + +class BanTimeIncr(LogCaptureTestCase): + + def setUp(self): + """Call before every test case.""" + super(BanTimeIncr, self).setUp() + self.__jail = DummyJail() + self.__jail.calcBanTime = self.calcBanTime + self.Observer = ObserverThread() + + def tearDown(self): + super(BanTimeIncr, self).tearDown() + + def calcBanTime(self, banTime, banCount): + return self.Observer.calcBanTime(self.__jail, banTime, banCount) + + def testDefault(self, multipliers = None): + a = self.__jail; + a.setBanTimeExtra('increment', 'true') + self.assertEqual(a.getBanTimeExtra('increment'), True) + a.setBanTimeExtra('maxtime', '1d') + self.assertEqual(a.getBanTimeExtra('maxtime'), 24*60*60) + a.setBanTimeExtra('rndtime', None) + a.setBanTimeExtra('factor', None) + # tests formulat or multipliers: + a.setBanTimeExtra('multipliers', multipliers) + # test algorithm and max time 24 hours : + self.assertEqual( + [a.calcBanTime(600, i) for i in xrange(1, 11)], + [1200, 2400, 4800, 9600, 19200, 38400, 76800, 86400, 86400, 86400] + ) + # with extra large max time (30 days): + a.setBanTimeExtra('maxtime', '30d') + # using formula the ban time grows always, but using multipliers the growing will stops with last one: + arr = [1200, 2400, 4800, 9600, 19200, 38400, 76800, 153600, 307200, 614400] + if multipliers is not None: + multcnt = len(multipliers.split(' ')) + if multcnt < 11: + arr = arr[0:multcnt-1] + ([arr[multcnt-2]] * (11-multcnt)) + self.assertEqual( + [a.calcBanTime(600, i) for i in xrange(1, 11)], + arr + ) + a.setBanTimeExtra('maxtime', '1d') + # change factor : + a.setBanTimeExtra('factor', '2'); + self.assertEqual( + [a.calcBanTime(600, i) for i in xrange(1, 11)], + [2400, 4800, 9600, 19200, 38400, 76800, 86400, 86400, 86400, 86400] + ) + # factor is float : + a.setBanTimeExtra('factor', '1.33'); + self.assertEqual( + [int(a.calcBanTime(600, i)) for i in xrange(1, 11)], + [1596, 3192, 6384, 12768, 25536, 51072, 86400, 86400, 86400, 86400] + ) + a.setBanTimeExtra('factor', None); + # change max time : + a.setBanTimeExtra('maxtime', '12h') + self.assertEqual( + [a.calcBanTime(600, i) for i in xrange(1, 11)], + [1200, 2400, 4800, 9600, 19200, 38400, 43200, 43200, 43200, 43200] + ) + a.setBanTimeExtra('maxtime', '24h') + ## test randomization - not possibe all 10 times we have random = 0: + a.setBanTimeExtra('rndtime', '5m') + self.assertTrue( + False in [1200 in [a.calcBanTime(600, 1) for i in xrange(10)] for c in xrange(10)] + ) + a.setBanTimeExtra('rndtime', None) + self.assertFalse( + False in [1200 in [a.calcBanTime(600, 1) for i in xrange(10)] for c in xrange(10)] + ) + # restore default: + a.setBanTimeExtra('multipliers', None) + a.setBanTimeExtra('factor', None); + a.setBanTimeExtra('maxtime', '24h') + a.setBanTimeExtra('rndtime', None) + + def testMultipliers(self): + # this multipliers has the same values as default formula, we test stop growing after count 9: + self.testDefault('1 2 4 8 16 32 64 128 256') + # this multipliers has exactly the same values as default formula, test endless growing (stops by count 31 only): + self.testDefault(' '.join([str(1<= (2,7): # pragma: no cover + raise unittest.SkipTest( + "Unable to import fail2ban database module as sqlite is not " + "available.") + elif Fail2BanDb is None: + return + _, self.dbFilename = tempfile.mkstemp(".db", "fail2ban_") + self.db = getFail2BanDb(self.dbFilename) + self.jail = DummyJail() + self.jail.database = self.db + self.Observer = ObserverThread() + Observers.Main = self.Observer + + def tearDown(self): + """Call after every test case.""" + if Fail2BanDb is None: # pragma: no cover + return + # Cleanup + self.Observer.stop() + Observers.Main = None + os.remove(self.dbFilename) + super(BanTimeIncrDB, self).tearDown() + + def incrBanTime(self, ticket, banTime=None): + jail = self.jail; + if banTime is None: + banTime = ticket.getBanTime(jail.actions.getBanTime()) + ticket.setBanTime(None) + incrTime = self.Observer.incrBanTime(jail, banTime, ticket) + #print("!!!!!!!!! banTime: %s, %s, incr: %s " % (banTime, ticket.getBanCount(), incrTime)) + return incrTime + + + def testBanTimeIncr(self): + if Fail2BanDb is None: # pragma: no cover + return + jail = self.jail + self.db.addJail(jail) + # we tests with initial ban time = 10 seconds: + jail.actions.setBanTime(10) + jail.setBanTimeExtra('increment', 'true') + jail.setBanTimeExtra('multipliers', '1 2 4 8 16 32 64 128 256 512 1024 2048') + ip = "127.0.0.2" + # used as start and fromtime (like now but time independence, cause test case can run slow): + stime = int(MyTime.time()) + ticket = FailTicket(ip, stime, []) + # test ticket not yet found + self.assertEqual( + [self.incrBanTime(ticket, 10) for i in xrange(3)], + [10, 10, 10] + ) + # add a ticket banned + ticket.incrBanCount() + self.db.addBan(jail, ticket) + # get a ticket already banned in this jail: + self.assertEqual( + [(banCount, timeOfBan, lastBanTime) for banCount, timeOfBan, lastBanTime in self.db.getBan(ip, jail, None, False)], + [(1, stime, 10)] + ) + # incr time and ban a ticket again : + ticket.setTime(stime + 15) + self.assertEqual(self.incrBanTime(ticket, 10), 20) + self.db.addBan(jail, ticket) + # get a ticket already banned in this jail: + self.assertEqual( + [(banCount, timeOfBan, lastBanTime) for banCount, timeOfBan, lastBanTime in self.db.getBan(ip, jail, None, False)], + [(2, stime + 15, 20)] + ) + # get a ticket already banned in all jails: + self.assertEqual( + [(banCount, timeOfBan, lastBanTime) for banCount, timeOfBan, lastBanTime in self.db.getBan(ip, '', None, True)], + [(2, stime + 15, 20)] + ) + # check other optional parameters of getBan: + self.assertEqual( + [(banCount, timeOfBan, lastBanTime) for banCount, timeOfBan, lastBanTime in self.db.getBan(ip, forbantime=stime, fromtime=stime)], + [(2, stime + 15, 20)] + ) + # search currently banned and 1 day later (nothing should be found): + self.assertEqual( + self.db.getCurrentBans(forbantime=-24*60*60, fromtime=stime), + [] + ) + # search currently banned one ticket for ip: + restored_tickets = self.db.getCurrentBans(ip=ip) + self.assertEqual( + str(restored_tickets), + ('FailTicket: ip=%s time=%s bantime=20 bancount=2 #attempts=0 matches=[]' % (ip, stime + 15)) + ) + # search currently banned anywhere: + restored_tickets = self.db.getCurrentBans(fromtime=stime) + self.assertEqual( + str(restored_tickets), + ('[FailTicket: ip=%s time=%s bantime=20 bancount=2 #attempts=0 matches=[]]' % (ip, stime + 15)) + ) + # search currently banned: + restored_tickets = self.db.getCurrentBans(jail=jail, fromtime=stime) + self.assertEqual( + str(restored_tickets), + ('[FailTicket: ip=%s time=%s bantime=20 bancount=2 #attempts=0 matches=[]]' % (ip, stime + 15)) + ) + # increase ban multiple times: + lastBanTime = 20 + for i in xrange(10): + ticket.setTime(stime + lastBanTime + 5) + banTime = self.incrBanTime(ticket, 10) + self.assertEqual(banTime, lastBanTime * 2) + self.db.addBan(jail, ticket) + lastBanTime = banTime + # increase again, but the last multiplier reached (time not increased): + ticket.setTime(stime + lastBanTime + 5) + banTime = self.incrBanTime(ticket, 10) + self.assertNotEqual(banTime, lastBanTime * 2) + self.assertEqual(banTime, lastBanTime) + self.db.addBan(jail, ticket) + lastBanTime = banTime + # add two tickets from yesterday: one unbanned (bantime already out-dated): + ticket2 = FailTicket(ip+'2', stime-24*60*60, []) + ticket2.setBanTime(12*60*60) + ticket2.incrBanCount() + self.db.addBan(jail, ticket2) + # and one from yesterday also, but still currently banned : + ticket2 = FailTicket(ip+'1', stime-24*60*60, []) + ticket2.setBanTime(36*60*60) + ticket2.incrBanCount() + self.db.addBan(jail, ticket2) + # search currently banned: + restored_tickets = self.db.getCurrentBans(fromtime=stime) + self.assertEqual(len(restored_tickets), 2) + self.assertEqual( + str(restored_tickets[0]), + 'FailTicket: ip=%s time=%s bantime=%s bancount=13 #attempts=0 matches=[]' % (ip, stime + lastBanTime + 5, lastBanTime) + ) + self.assertEqual( + str(restored_tickets[1]), + 'FailTicket: ip=%s time=%s bantime=%s bancount=1 #attempts=0 matches=[]' % (ip+'1', stime-24*60*60, 36*60*60) + ) + # search out-dated (give another fromtime now is -18 hours): + restored_tickets = self.db.getCurrentBans(fromtime=stime-18*60*60) + self.assertEqual(len(restored_tickets), 3) + self.assertEqual( + str(restored_tickets[2]), + 'FailTicket: ip=%s time=%s bantime=%s bancount=1 #attempts=0 matches=[]' % (ip+'2', stime-24*60*60, 12*60*60) + ) + # should be still banned + self.assertFalse(restored_tickets[1].isTimedOut(stime)) + self.assertFalse(restored_tickets[1].isTimedOut(stime)) + # the last should be timed out now + self.assertTrue(restored_tickets[2].isTimedOut(stime)) + self.assertFalse(restored_tickets[2].isTimedOut(stime-18*60*60)) + + # test permanent, create timed out: + ticket=FailTicket(ip+'3', stime-36*60*60, []) + self.assertTrue(ticket.isTimedOut(stime, 600)) + # not timed out - permanent jail: + self.assertFalse(ticket.isTimedOut(stime, -1)) + # not timed out - permanent ticket: + ticket.setBanTime(-1) + self.assertFalse(ticket.isTimedOut(stime, 600)) + self.assertFalse(ticket.isTimedOut(stime, -1)) + # timed out - permanent jail but ticket time (not really used behavior) + ticket.setBanTime(600) + self.assertTrue(ticket.isTimedOut(stime, -1)) + + # get currently banned pis with permanent one: + ticket.setBanTime(-1) + ticket.incrBanCount() + self.db.addBan(jail, ticket) + restored_tickets = self.db.getCurrentBans(fromtime=stime) + self.assertEqual(len(restored_tickets), 3) + self.assertEqual( + str(restored_tickets[2]), + 'FailTicket: ip=%s time=%s bantime=%s bancount=1 #attempts=0 matches=[]' % (ip+'3', stime-36*60*60, -1) + ) + # purge (nothing should be changed): + self.db.purge() + restored_tickets = self.db.getCurrentBans(fromtime=stime) + self.assertEqual(len(restored_tickets), 3) + # set short time and purge again: + ticket.setBanTime(600) + ticket.incrBanCount() + self.db.addBan(jail, ticket) + self.db.purge() + # this old ticket should be removed now: + restored_tickets = self.db.getCurrentBans(fromtime=stime) + self.assertEqual(len(restored_tickets), 2) + self.assertEqual(restored_tickets[0].getIP(), ip) + + # purge remove 1st ip + self.db._purgeAge = -48*60*60 + self.db.purge() + restored_tickets = self.db.getCurrentBans(fromtime=stime) + self.assertEqual(len(restored_tickets), 1) + self.assertEqual(restored_tickets[0].getIP(), ip+'1') + + # this should purge all bans, bips and logs - nothing should be found now + self.db._purgeAge = -240*60*60 + self.db.purge() + restored_tickets = self.db.getCurrentBans(fromtime=stime) + self.assertEqual(restored_tickets, []) + + # two separate jails : + jail1 = DummyJail(backend='polling') + jail1.database = self.db + self.db.addJail(jail1) + jail2 = DummyJail(backend='polling') + jail2.database = self.db + self.db.addJail(jail2) + ticket1 = FailTicket(ip, stime, []) + ticket1.setBanTime(6000) + ticket1.incrBanCount() + self.db.addBan(jail1, ticket1) + ticket2 = FailTicket(ip, stime-6000, []) + ticket2.setBanTime(12000) + ticket2.setBanCount(1) + ticket2.incrBanCount() + self.db.addBan(jail2, ticket2) + restored_tickets = self.db.getCurrentBans(jail=jail1, fromtime=stime) + self.assertEqual(len(restored_tickets), 1) + self.assertEqual( + str(restored_tickets[0]), + 'FailTicket: ip=%s time=%s bantime=%s bancount=1 #attempts=0 matches=[]' % (ip, stime, 6000) + ) + restored_tickets = self.db.getCurrentBans(jail=jail2, fromtime=stime) + self.assertEqual(len(restored_tickets), 1) + self.assertEqual( + str(restored_tickets[0]), + 'FailTicket: ip=%s time=%s bantime=%s bancount=2 #attempts=0 matches=[]' % (ip, stime-6000, 12000) + ) + # get last ban values for this ip separately for each jail: + for row in self.db.getBan(ip, jail1): + self.assertEqual(row, (1, stime, 6000)) + break + for row in self.db.getBan(ip, jail2): + self.assertEqual(row, (2, stime-6000, 12000)) + break + # get max values for this ip (over all jails): + for row in self.db.getBan(ip, overalljails=True): + self.assertEqual(row, (3, stime, 18000)) + break + # test restoring bans from database: + jail1.restoreCurrentBans() + ticket = jail1.getFailTicket() + self.assertTrue(ticket.restored) + self.assertEqual(str(ticket), + 'FailTicket: ip=%s time=%s bantime=%s bancount=1 #attempts=0 matches=[]' % (ip, stime, 6000) + ) + # jail2 does not restore any bans (because all ban tickets should be already expired: stime-6000): + jail2.restoreCurrentBans() + self.assertEqual(jail2.getFailTicket(), False) + + def testObserver(self): + if Fail2BanDb is None: # pragma: no cover + return + jail = self.jail + self.db.addJail(jail) + # we tests with initial ban time = 10 seconds: + jail.actions.setBanTime(10) + jail.setBanTimeExtra('increment', 'true') + # observer / database features: + obs = Observers.Main + obs.start() + obs.db_set(self.db) + # wait for start ready + obs.add('nop') + obs.wait_empty(5) + # purge database right now, but using timer, to test it also: + self.db._purgeAge = -240*60*60 + obs.add_named_timer('DB_PURGE', 0.001, 'db_purge') + # wait for timer ready + obs.wait_idle(0.025) + # wait for ready + obs.add('nop') + obs.wait_empty(5) + + stime = int(MyTime.time()) + # completelly empty ? + tickets = self.db.getBans() + self.assertEqual(tickets, []) + + # add failure: + ip = "127.0.0.2" + ticket = FailTicket(ip, stime-120, []) + failManager = FailManager() + failManager.setMaxRetry(3) + for i in xrange(3): + failManager.addFailure(ticket) + obs.add('failureFound', failManager, jail, ticket) + obs.wait_empty(5) + self.assertEqual(ticket.getBanCount(), 0) + # check still not ban : + self.assertTrue(not jail.getFailTicket()) + # add manually 4th times banned (added to bips - make ip bad): + ticket.setBanCount(4) + self.db.addBan(self.jail, ticket) + restored_tickets = self.db.getCurrentBans(jail=jail, fromtime=stime-120) + self.assertEqual(len(restored_tickets), 1) + # check again, new ticket, new failmanager: + ticket = FailTicket(ip, stime, []) + failManager = FailManager() + failManager.setMaxRetry(3) + # add once only - but bad - should be banned: + failManager.addFailure(ticket) + obs.add('failureFound', failManager, self.jail, ticket) + obs.wait_empty(5) + # wait until ticket transfered from failmanager into jail: + ticket2 = Utils.wait_for(jail.getFailTicket, 10) + # check ticket and failure count: + self.assertTrue(ticket2) + self.assertEqual(ticket2.getRetry(), failManager.getMaxRetry()) + + # wrap FailTicket to BanTicket: + failticket2 = ticket2 + ticket2 = BanTicket.wrap(failticket2) + self.assertEqual(ticket2, failticket2) + # add this ticket to ban (use observer only without ban manager): + obs.add('banFound', ticket2, jail, 10) + obs.wait_empty(5) + # increased? + self.assertEqual(ticket2.getBanTime(), 160) + self.assertEqual(ticket2.getBanCount(), 5) + + # check prolonged in database also : + restored_tickets = self.db.getCurrentBans(jail=jail, fromtime=stime) + self.assertEqual(len(restored_tickets), 1) + self.assertEqual(restored_tickets[0].getBanTime(), 160) + self.assertEqual(restored_tickets[0].getBanCount(), 5) + + # now using jail/actions: + ticket = FailTicket(ip, stime-60, ['test-expired-ban-time']) + jail.putFailTicket(ticket) + self.assertFalse(jail.actions.checkBan()) + + ticket = FailTicket(ip, MyTime.time(), ['test-actions']) + jail.putFailTicket(ticket) + self.assertTrue(jail.actions.checkBan()) + + obs.wait_empty(5) + restored_tickets = self.db.getCurrentBans(jail=jail, fromtime=stime) + self.assertEqual(len(restored_tickets), 1) + self.assertEqual(restored_tickets[0].getBanTime(), 320) + self.assertEqual(restored_tickets[0].getBanCount(), 6) + + # and permanent: + ticket = FailTicket(ip+'1', MyTime.time(), ['test-permanent']) + ticket.setBanTime(-1) + jail.putFailTicket(ticket) + self.assertTrue(jail.actions.checkBan()) + + obs.wait_empty(5) + ticket = FailTicket(ip+'1', MyTime.time(), ['test-permanent']) + ticket.setBanTime(600) + jail.putFailTicket(ticket) + self.assertFalse(jail.actions.checkBan()) + + obs.wait_empty(5) + restored_tickets = self.db.getCurrentBans(jail=jail, fromtime=stime) + self.assertEqual(len(restored_tickets), 2) + self.assertEqual(restored_tickets[1].getBanTime(), -1) + self.assertEqual(restored_tickets[1].getBanCount(), 1) + + # stop observer + obs.stop() + +class ObserverTest(LogCaptureTestCase): + + def setUp(self): + """Call before every test case.""" + super(ObserverTest, self).setUp() + + def tearDown(self): + """Call after every test case.""" + super(ObserverTest, self).tearDown() + + def testObserverBanTimeIncr(self): + obs = ObserverThread() + obs.start() + # wait for idle + obs.wait_idle(1) + # observer will replace test set: + o = set(['test']) + obs.add('call', o.clear) + obs.add('call', o.add, 'test2') + # wait for observer ready: + obs.wait_empty(1) + self.assertFalse(obs.is_full) + self.assertEqual(o, set(['test2'])) + # observer makes pause + obs.paused = True + # observer will replace test set, but first after pause ends: + obs.add('call', o.clear) + obs.add('call', o.add, 'test3') + obs.wait_empty(10 * Utils.DEFAULT_SLEEP_TIME) + self.assertTrue(obs.is_full) + self.assertEqual(o, set(['test2'])) + obs.paused = False + # wait running: + obs.wait_empty(1) + self.assertEqual(o, set(['test3'])) + + self.assertTrue(obs.isActive()) + self.assertTrue(obs.isAlive()) + obs.stop() + obs = None + + class _BadObserver(ObserverThread): + def run(self): + raise RuntimeError('run bad thread exception') + + def testObserverBadRun(self): + obs = ObserverTest._BadObserver() + # don't wait for empty by stop + obs.wait_empty = lambda v:() + # save previous hook, prevent write stderr and check hereafter __excepthook__ was executed + prev_exchook = sys.__excepthook__ + x = [] + sys.__excepthook__ = lambda *args: x.append(args) + try: + obs.start() + obs.stop() + obs.join() + self.assertTrue( Utils.wait_for( lambda: len(x) and self._is_logged("Unhandled exception"), 3) ) + finally: + sys.__excepthook__ = prev_exchook + self.assertLogged("Unhandled exception") + self.assertEqual(len(x), 1) + self.assertEqual(x[0][0], RuntimeError) + self.assertEqual(str(x[0][1]), 'run bad thread exception') diff --git a/fail2ban/tests/servertestcase.py b/fail2ban/tests/servertestcase.py index 68b9951c..26096797 100644 --- a/fail2ban/tests/servertestcase.py +++ b/fail2ban/tests/servertestcase.py @@ -921,6 +921,15 @@ class TransmitterLogging(TransmitterBase): self.assertEqual(self.transm.proceed(["set", "logtarget", "STDERR"]), (0, "STDERR")) self.assertEqual(self.transm.proceed(["flushlogs"]), (0, "flushed")) + def testBanTimeIncr(self): + self.setGetTest("bantime.increment", "true", True, jail=self.jailName) + self.setGetTest("bantime.rndtime", "30min", 30*60, jail=self.jailName) + self.setGetTest("bantime.maxtime", "1000 days", 1000*24*60*60, jail=self.jailName) + self.setGetTest("bantime.factor", "2", "2", jail=self.jailName) + self.setGetTest("bantime.formula", "ban.Time * math.exp(float(ban.Count+1)*banFactor)/math.exp(1*banFactor)", jail=self.jailName) + self.setGetTest("bantime.multipliers", "1 5 30 60 300 720 1440 2880", "1 5 30 60 300 720 1440 2880", jail=self.jailName) + self.setGetTest("bantime.overalljails", "true", "true", jail=self.jailName) + class JailTests(unittest.TestCase): @@ -1060,8 +1069,20 @@ class ServerConfigReaderTests(LogCaptureTestCase): logSys.debug(l) return True + def _testActionInfos(self): + if not hasattr(self, '__aInfos'): + dmyjail = DummyJail() + self.__aInfos = {} + for t, ip in (('ipv4', '192.0.2.1'), ('ipv6', '2001:DB8::')): + ticket = BanTicket(ip) + ticket.setBanTime(600) + self.__aInfos[t] = _actions.Actions.ActionInfo(ticket, dmyjail) + return self.__aInfos + def _testExecActions(self, server): jails = server._Server__jails + + aInfos = self._testActionInfos() for jail in jails: # print(jail, jails[jail]) for a in jails[jail].actions: @@ -1078,16 +1099,16 @@ class ServerConfigReaderTests(LogCaptureTestCase): action.start() # test ban ip4 : logSys.debug('# === ban-ipv4 ==='); self.pruneLog() - action.ban({'ip': IPAddr('192.0.2.1'), 'family': 'inet4'}) + action.ban(aInfos['ipv4']) # test unban ip4 : logSys.debug('# === unban ipv4 ==='); self.pruneLog() - action.unban({'ip': IPAddr('192.0.2.1'), 'family': 'inet4'}) + action.unban(aInfos['ipv4']) # test ban ip6 : logSys.debug('# === ban ipv6 ==='); self.pruneLog() - action.ban({'ip': IPAddr('2001:DB8::'), 'family': 'inet6'}) + action.ban(aInfos['ipv6']) # test unban ip6 : logSys.debug('# === unban ipv6 ==='); self.pruneLog() - action.unban({'ip': IPAddr('2001:DB8::'), 'family': 'inet6'}) + action.unban(aInfos['ipv6']) # test stop : logSys.debug('# === stop ==='); self.pruneLog() action.stop() @@ -1305,11 +1326,11 @@ class ServerConfigReaderTests(LogCaptureTestCase): ('j-w-iptables-ipset', 'iptables-ipset-proto6[name=%(__name__)s, bantime="10m", port="http", protocol="tcp", chain="INPUT"]', { 'ip4': (' f2b-j-w-iptables-ipset ',), 'ip6': (' f2b-j-w-iptables-ipset6 ',), 'ip4-start': ( - "`ipset create f2b-j-w-iptables-ipset hash:ip timeout 600`", + "`ipset create f2b-j-w-iptables-ipset hash:ip`", "`iptables -w -I INPUT -p tcp -m multiport --dports http -m set --match-set f2b-j-w-iptables-ipset src -j REJECT --reject-with icmp-port-unreachable`", ), 'ip6-start': ( - "`ipset create f2b-j-w-iptables-ipset6 hash:ip timeout 600 family inet6`", + "`ipset create f2b-j-w-iptables-ipset6 hash:ip family inet6`", "`ip6tables -w -I INPUT -p tcp -m multiport --dports http -m set --match-set f2b-j-w-iptables-ipset6 src -j REJECT --reject-with icmp6-port-unreachable`", ), 'flush': ( @@ -1343,11 +1364,11 @@ class ServerConfigReaderTests(LogCaptureTestCase): ('j-w-iptables-ipset-ap', 'iptables-ipset-proto6-allports[name=%(__name__)s, bantime="10m", chain="INPUT"]', { 'ip4': (' f2b-j-w-iptables-ipset-ap ',), 'ip6': (' f2b-j-w-iptables-ipset-ap6 ',), 'ip4-start': ( - "`ipset create f2b-j-w-iptables-ipset-ap hash:ip timeout 600`", + "`ipset create f2b-j-w-iptables-ipset-ap hash:ip`", "`iptables -w -I INPUT -m set --match-set f2b-j-w-iptables-ipset-ap src -j REJECT --reject-with icmp-port-unreachable`", ), 'ip6-start': ( - "`ipset create f2b-j-w-iptables-ipset-ap6 hash:ip timeout 600 family inet6`", + "`ipset create f2b-j-w-iptables-ipset-ap6 hash:ip family inet6`", "`ip6tables -w -I INPUT -m set --match-set f2b-j-w-iptables-ipset-ap6 src -j REJECT --reject-with icmp6-port-unreachable`", ), 'flush': ( @@ -1641,11 +1662,11 @@ class ServerConfigReaderTests(LogCaptureTestCase): ('j-w-fwcmd-ipset', 'firewallcmd-ipset[name=%(__name__)s, bantime="10m", port="http", protocol="tcp", chain="INPUT"]', { 'ip4': (' f2b-j-w-fwcmd-ipset ',), 'ip6': (' f2b-j-w-fwcmd-ipset6 ',), 'ip4-start': ( - "`ipset create f2b-j-w-fwcmd-ipset hash:ip timeout 600`", + "`ipset create f2b-j-w-fwcmd-ipset hash:ip`", "`firewall-cmd --direct --add-rule ipv4 filter INPUT 0 -p tcp -m multiport --dports http -m set --match-set f2b-j-w-fwcmd-ipset src -j REJECT --reject-with icmp-port-unreachable`", ), 'ip6-start': ( - "`ipset create f2b-j-w-fwcmd-ipset6 hash:ip timeout 600`", + "`ipset create f2b-j-w-fwcmd-ipset6 hash:ip`", "`firewall-cmd --direct --add-rule ipv6 filter INPUT 0 -p tcp -m multiport --dports http -m set --match-set f2b-j-w-fwcmd-ipset6 src -j REJECT --reject-with icmp6-port-unreachable`", ), 'stop': ( @@ -1690,10 +1711,7 @@ class ServerConfigReaderTests(LogCaptureTestCase): jails = server._Server__jails - tickets = { - 'ip4': BanTicket('192.0.2.1'), - 'ip6': BanTicket('2001:DB8::'), - } + aInfos = self._testActionInfos() for jail, act, tests in testJailsActions: # print(jail, jails[jail]) for a in jails[jail].actions: @@ -1711,32 +1729,28 @@ class ServerConfigReaderTests(LogCaptureTestCase): self.assertLogged(*tests['start'], all=True) else: self.assertNotLogged(*tests['ip4-start']+tests['ip6-start'], all=True) - ainfo = { - 'ip4': _actions.Actions.ActionInfo(tickets['ip4'], jails[jail]), - 'ip6': _actions.Actions.ActionInfo(tickets['ip6'], jails[jail]), - } # test ban ip4 : self.pruneLog('# === ban-ipv4 ===') - action.ban(ainfo['ip4']) + action.ban(aInfos['ipv4']) if tests.get('ip4-start'): self.assertLogged(*tests['ip4-start'], all=True) if tests.get('ip6-start'): self.assertNotLogged(*tests['ip6-start'], all=True) self.assertLogged(*tests['ip4-check']+tests['ip4-ban'], all=True) self.assertNotLogged(*tests['ip6'], all=True) # test unban ip4 : self.pruneLog('# === unban ipv4 ===') - action.unban(ainfo['ip4']) + action.unban(aInfos['ipv4']) self.assertLogged(*tests['ip4-check']+tests['ip4-unban'], all=True) self.assertNotLogged(*tests['ip6'], all=True) # test ban ip6 : self.pruneLog('# === ban ipv6 ===') - action.ban(ainfo['ip6']) + action.ban(aInfos['ipv6']) if tests.get('ip6-start'): self.assertLogged(*tests['ip6-start'], all=True) if tests.get('ip4-start'): self.assertNotLogged(*tests['ip4-start'], all=True) self.assertLogged(*tests['ip6-check']+tests['ip6-ban'], all=True) self.assertNotLogged(*tests['ip4'], all=True) # test unban ip6 : self.pruneLog('# === unban ipv6 ===') - action.unban(ainfo['ip6']) + action.unban(aInfos['ipv6']) self.assertLogged(*tests['ip6-check']+tests['ip6-unban'], all=True) self.assertNotLogged(*tests['ip4'], all=True) # test flush for actions should supported this: diff --git a/fail2ban/tests/sockettestcase.py b/fail2ban/tests/sockettestcase.py index 1a94a952..9d9cba15 100644 --- a/fail2ban/tests/sockettestcase.py +++ b/fail2ban/tests/sockettestcase.py @@ -32,42 +32,70 @@ import time import unittest from .. import protocol -from ..server.asyncserver import AsyncServer, AsyncServerException +from ..server.asyncserver import asyncore, RequestHandler, AsyncServer, AsyncServerException from ..server.utils import Utils from ..client.csocket import CSocket +from .utils import LogCaptureTestCase -class Socket(unittest.TestCase): + +def TestMsgError(*args): + raise Exception('test unpickle error') +class TestMsg(object): + def __init__(self, unpickle=(TestMsgError, ())): + self.unpickle = unpickle + def __reduce__(self): + return self.unpickle + + +class Socket(LogCaptureTestCase): def setUp(self): """Call before every test case.""" + LogCaptureTestCase.setUp(self) super(Socket, self).setUp() self.server = AsyncServer(self) sock_fd, sock_name = tempfile.mkstemp('fail2ban.sock', 'socket') os.close(sock_fd) os.remove(sock_name) self.sock_name = sock_name + self.serverThread = None def tearDown(self): """Call after every test case.""" + if self.serverThread: + self.server.stop(); # stop if not already stopped + self._stopServerThread() + LogCaptureTestCase.tearDown(self) @staticmethod def proceed(message): """Test transmitter proceed method which just returns first arg""" return message - def testStopPerCloseUnexpected(self): + def _createServerThread(self, force=False): # start in separate thread : - serverThread = threading.Thread( - target=self.server.start, args=(self.sock_name, False)) + self.serverThread = serverThread = threading.Thread( + target=self.server.start, args=(self.sock_name, force)) serverThread.daemon = True serverThread.start() self.assertTrue(Utils.wait_for(self.server.isActive, unittest.F2B.maxWaitTime(10))) + return serverThread + + def _stopServerThread(self): + serverThread = self.serverThread + # wait for end of thread : + Utils.wait_for(lambda: not serverThread.isAlive() + or serverThread.join(Utils.DEFAULT_SLEEP_TIME), unittest.F2B.maxWaitTime(10)) + self.serverThread = None + + def testStopPerCloseUnexpected(self): + # start in separate thread : + serverThread = self._createServerThread() # unexpected stop directly after start: self.server.close() # wait for end of thread : - Utils.wait_for(lambda: not serverThread.isAlive() - or serverThread.join(Utils.DEFAULT_SLEEP_INTERVAL), unittest.F2B.maxWaitTime(10)) + self._stopServerThread() self.assertFalse(serverThread.isAlive()) # clean : self.server.stop() @@ -81,30 +109,99 @@ class Socket(unittest.TestCase): return None def testSocket(self): - serverThread = threading.Thread( - target=self.server.start, args=(self.sock_name, False)) - serverThread.daemon = True - serverThread.start() - self.assertTrue(Utils.wait_for(self.server.isActive, unittest.F2B.maxWaitTime(10))) - time.sleep(Utils.DEFAULT_SLEEP_TIME) - + # start in separate thread : + serverThread = self._createServerThread() client = Utils.wait_for(self._serverSocket, 2) + testMessage = ["A", "test", "message"] self.assertEqual(client.send(testMessage), testMessage) + # test wrong message: + self.assertEqual(client.send([[TestMsg()]]), 'ERROR: test unpickle error') + self.assertLogged("Caught unhandled exception", "test unpickle error", all=True) + + # test good message again: + self.assertEqual(client.send(testMessage), testMessage) + # test close message client.close() # 2nd close does nothing client.close() + # force shutdown: + self.server.stop_communication() + # test send again (should get in shutdown message): + client = Utils.wait_for(self._serverSocket, 2) + self.assertEqual(client.send(testMessage), ['SHUTDOWN']) + self.server.stop() # wait for end of thread : - Utils.wait_for(lambda: not serverThread.isAlive() - or serverThread.join(Utils.DEFAULT_SLEEP_INTERVAL), unittest.F2B.maxWaitTime(10)) + self._stopServerThread() self.assertFalse(serverThread.isAlive()) self.assertFalse(self.server.isActive()) self.assertFalse(os.path.exists(self.sock_name)) + def testSocketConnectBroken(self): + # start in separate thread : + serverThread = self._createServerThread() + client = Utils.wait_for(self._serverSocket, 2) + # unexpected stop during message body: + testMessage = ["A", "test", "message", [protocol.CSPROTO.END]] + + org_handler = RequestHandler.found_terminator + try: + RequestHandler.found_terminator = lambda self: self.close() + self.assertRaisesRegexp(RuntimeError, r"socket connection broken", + lambda: client.send(testMessage, timeout=unittest.F2B.maxWaitTime(10))) + finally: + RequestHandler.found_terminator = org_handler + + def testStopByCommunicate(self): + # start in separate thread : + serverThread = self._createServerThread() + client = Utils.wait_for(self._serverSocket, 2) + + testMessage = ["A", "test", "message"] + self.assertEqual(client.send(testMessage), testMessage) + + org_handler = RequestHandler.found_terminator + try: + RequestHandler.found_terminator = lambda self: TestMsgError() + #self.assertRaisesRegexp(RuntimeError, r"socket connection broken", client.send, testMessage) + self.assertEqual(client.send(testMessage), 'ERROR: test unpickle error') + finally: + RequestHandler.found_terminator = org_handler + + # check errors were logged: + self.assertLogged("Unexpected communication error", "test unpickle error", all=True) + + self.server.stop() + # wait for end of thread : + self._stopServerThread() + self.assertFalse(serverThread.isAlive()) + + def testLoopErrors(self): + # replace poll handler to produce error in loop-cycle: + org_poll = asyncore.poll + err = {'cntr': 0} + def _produce_error(*args): + err['cntr'] += 1 + if err['cntr'] < 50: + raise RuntimeError('test errors in poll') + return org_poll(*args) + + try: + asyncore.poll = _produce_error + serverThread = self._createServerThread() + # wait all-cases processed: + self.assertTrue(Utils.wait_for(lambda: err['cntr'] > 50, unittest.F2B.maxWaitTime(10))) + finally: + # restore: + asyncore.poll = org_poll + # check errors were logged: + self.assertLogged("Server connection was closed: test errors in poll", + "Too many errors - stop logging connection errors", all=True) + def testSocketForce(self): open(self.sock_name, 'w').close() # Create sock file # Try to start without force @@ -112,16 +209,12 @@ class Socket(unittest.TestCase): AsyncServerException, self.server.start, self.sock_name, False) # Try again with force set - serverThread = threading.Thread( - target=self.server.start, args=(self.sock_name, True)) - serverThread.daemon = True - serverThread.start() - self.assertTrue(Utils.wait_for(self.server.isActive, unittest.F2B.maxWaitTime(10))) + serverThread = self._createServerThread(True) self.server.stop() # wait for end of thread : - Utils.wait_for(lambda: not serverThread.isAlive() - or serverThread.join(Utils.DEFAULT_SLEEP_INTERVAL), unittest.F2B.maxWaitTime(10)) + self._stopServerThread() + self.assertFalse(serverThread.isAlive()) self.assertFalse(self.server.isActive()) self.assertFalse(os.path.exists(self.sock_name)) diff --git a/fail2ban/tests/utils.py b/fail2ban/tests/utils.py index 9f92c26b..7a5f0b14 100644 --- a/fail2ban/tests/utils.py +++ b/fail2ban/tests/utils.py @@ -328,6 +328,7 @@ def gatherTests(regexps=None, opts=None): from . import sockettestcase from . import misctestcase from . import databasetestcase + from . import observertestcase from . import samplestestcase from . import fail2banclienttestcase from . import fail2banregextestcase @@ -358,7 +359,6 @@ def gatherTests(regexps=None, opts=None): tests = FilteredTestSuite() # Server - #tests.addTest(unittest.makeSuite(servertestcase.StartStop)) tests.addTest(unittest.makeSuite(servertestcase.Transmitter)) tests.addTest(unittest.makeSuite(servertestcase.JailTests)) tests.addTest(unittest.makeSuite(servertestcase.RegexTests)) @@ -398,6 +398,10 @@ def gatherTests(regexps=None, opts=None): tests.addTest(unittest.makeSuite(misctestcase.MyTimeTest)) # Database tests.addTest(unittest.makeSuite(databasetestcase.DatabaseTest)) + # Observer + tests.addTest(unittest.makeSuite(observertestcase.ObserverTest)) + tests.addTest(unittest.makeSuite(observertestcase.BanTimeIncr)) + tests.addTest(unittest.makeSuite(observertestcase.BanTimeIncrDB)) # Filter tests.addTest(unittest.makeSuite(filtertestcase.IgnoreIP))