diff --git a/fail2ban/client/fail2banclient.py b/fail2ban/client/fail2banclient.py index c72208cd3..f3b0f7b27 100755 --- a/fail2ban/client/fail2banclient.py +++ b/fail2ban/client/fail2banclient.py @@ -175,9 +175,13 @@ class Fail2banClient(Fail2banCmdLine, Thread): return [["server-stream", stream], ['server-status']] + def _set_server(self, s): + self._server = s + ## def __startServer(self, background=True): from .fail2banserver import Fail2banServer + # read configuration here (in client only, in server we do that in the config-thread): stream = self.__prepareStartServer() self._alive = True if not stream: @@ -192,16 +196,19 @@ class Fail2banClient(Fail2banCmdLine, Thread): return False else: # In foreground mode we should make server/client communication in different threads: - th = Thread(target=Fail2banClient.__processStartStreamAfterWait, args=(self, stream, False)) - th.daemon = True - th.start() + phase = dict() + self.configureServer(phase=phase, stream=stream) # Mark current (main) thread as daemon: self.daemon = True # Start server direct here in main thread (not fork): - self._server = Fail2banServer.startServerDirect(self._conf, False) - + self._server = Fail2banServer.startServerDirect(self._conf, False, self._set_server) + if not phase.get('done', False): + if self._server: # pragma: no cover + self._server.quit() + self._server = None + exit(255) except ExitException: # pragma: no cover - pass + raise except Exception as e: # pragma: no cover output("") logSys.error("Exception while starting server " + ("background" if background else "foreground")) @@ -214,23 +221,39 @@ class Fail2banClient(Fail2banCmdLine, Thread): return True ## - def configureServer(self, nonsync=True, phase=None): + def configureServer(self, nonsync=True, phase=None, stream=None): # if asynchronous start this operation in the new thread: if nonsync: - th = Thread(target=Fail2banClient.configureServer, args=(self, False, phase)) + if phase is not None: + # event for server ready flag: + def _server_ready(): + phase['start-ready'] = True + logSys.log(5, ' server phase %s', phase) + # notify waiting thread if server really ready + self._conf['onstart'] = _server_ready + th = Thread(target=Fail2banClient.configureServer, args=(self, False, phase, stream)) th.daemon = True - return th.start() + th.start() + # if we need to read configuration stream: + if stream is None and phase is not None: + # wait, do not continue if configuration is not 100% valid: + Utils.wait_for(lambda: phase.get('ready', None) is not None, self._conf["timeout"], 0.001) + logSys.log(5, ' server phase %s', phase) + if not phase.get('start', False): + raise ServerExecutionException('Async configuration of server failed') + return True # prepare: read config, check configuration is valid, etc.: if phase is not None: phase['start'] = True logSys.log(5, ' client phase %s', phase) - stream = self.__prepareStartServer() + if stream is None: + stream = self.__prepareStartServer() if phase is not None: phase['ready'] = phase['start'] = (True if stream else False) logSys.log(5, ' client phase %s', phase) if not stream: return False - # wait a litle bit for phase "start-ready" before enter active waiting: + # wait a little bit for phase "start-ready" before enter active waiting: if phase is not None: Utils.wait_for(lambda: phase.get('start-ready', None) is not None, 0.5, 0.001) phase['configure'] = (True if stream else False) @@ -321,13 +344,14 @@ class Fail2banClient(Fail2banCmdLine, Thread): def __processStartStreamAfterWait(self, *args): + ret = False try: # Wait for the server to start if not self.__waitOnServer(): # pragma: no cover logSys.error("Could not find server, waiting failed") return False # Configure the server - self.__processCmd(*args) + ret = self.__processCmd(*args) except ServerExecutionException as e: # pragma: no cover if self._conf["verbose"] > 1: logSys.exception(e) @@ -336,10 +360,11 @@ class Fail2banClient(Fail2banCmdLine, Thread): "remove " + self._conf["socket"] + ". If " "you used fail2ban-client to start the " "server, adding the -x option will do it") - if self._server: - self._server.quit() - return False - return True + + if not ret and self._server: # stop on error (foreground, config read in another thread): + self._server.quit() + self._server = None + return ret def __waitOnServer(self, alive=True, maxtime=None): if maxtime is None: diff --git a/fail2ban/client/fail2banserver.py b/fail2ban/client/fail2banserver.py index d94d13ff7..eee78d5f6 100644 --- a/fail2ban/client/fail2banserver.py +++ b/fail2ban/client/fail2banserver.py @@ -44,7 +44,7 @@ class Fail2banServer(Fail2banCmdLine): # Start the Fail2ban server in background/foreground (daemon mode or not). @staticmethod - def startServerDirect(conf, daemon=True): + def startServerDirect(conf, daemon=True, setServer=None): logSys.debug(" direct starting of server in %s, deamon: %s", os.getpid(), daemon) from ..server.server import Server server = None @@ -52,6 +52,10 @@ class Fail2banServer(Fail2banCmdLine): # Start it in foreground (current thread, not new process), # server object will internally fork self if daemon is True server = Server(daemon) + # notify caller - set server handle: + if setServer: + setServer(server) + # run: server.start(conf["socket"], conf["pidfile"], conf["force"], conf=conf) @@ -63,6 +67,10 @@ class Fail2banServer(Fail2banCmdLine): if conf["verbose"] > 1: logSys.exception(e2) raise + finally: + # notify waiting thread server ready resp. done (background execution, error case, etc): + if conf.get('onstart'): + conf['onstart']() return server @@ -179,27 +187,15 @@ class Fail2banServer(Fail2banCmdLine): # Start new thread with client to read configuration and # transfer it to the server: cli = self._Fail2banClient() + cli._conf = self._conf phase = dict() logSys.debug('Configure via async client thread') cli.configureServer(phase=phase) - # wait, do not continue if configuration is not 100% valid: - Utils.wait_for(lambda: phase.get('ready', None) is not None, self._conf["timeout"], 0.001) - logSys.log(5, ' server phase %s', phase) - if not phase.get('start', False): - raise ServerExecutionException('Async configuration of server failed') - # event for server ready flag: - def _server_ready(): - phase['start-ready'] = True - logSys.log(5, ' server phase %s', phase) - # notify waiting thread if server really ready - self._conf['onstart'] = _server_ready # Start server, daemonize it, etc. pid = os.getpid() - server = Fail2banServer.startServerDirect(self._conf, background) - # notify waiting thread server ready resp. done (background execution, error case, etc): - if not nonsync: - _server_ready() + server = Fail2banServer.startServerDirect(self._conf, background, + cli._set_server if cli else None) # If forked - just exit other processes if pid != os.getpid(): # pragma: no cover os._exit(0) diff --git a/fail2ban/server/database.py b/fail2ban/server/database.py index ed736a7a1..59eeb8fd1 100644 --- a/fail2ban/server/database.py +++ b/fail2ban/server/database.py @@ -104,7 +104,11 @@ def commitandrollback(f): def wrapper(self, *args, **kwargs): with self._lock: # Threading lock with self._db: # Auto commit and rollback on exception - return f(self, self._db.cursor(), *args, **kwargs) + cur = self._db.cursor() + try: + return f(self, cur, *args, **kwargs) + finally: + cur.close() return wrapper @@ -253,7 +257,7 @@ class Fail2BanDb(object): self.repairDB() else: version = cur.fetchone()[0] - if version < Fail2BanDb.__version__: + if version != Fail2BanDb.__version__: newversion = self.updateDb(version) if newversion == Fail2BanDb.__version__: logSys.warning( "Database updated from '%r' to '%r'", @@ -301,9 +305,11 @@ class Fail2BanDb(object): try: # backup logSys.info("Trying to repair database %s", self._dbFilename) - shutil.move(self._dbFilename, self._dbBackupFilename) - logSys.info(" Database backup created: %s", self._dbBackupFilename) - + if not os.path.isfile(self._dbBackupFilename): + shutil.move(self._dbFilename, self._dbBackupFilename) + logSys.info(" Database backup created: %s", self._dbBackupFilename) + elif os.path.isfile(self._dbFilename): + os.remove(self._dbFilename) # first try to repair using dump/restore in order Utils.executeCmd((r"""f2b_db=$0; f2b_dbbk=$1; sqlite3 "$f2b_dbbk" ".dump" | sqlite3 "$f2b_db" """, self._dbFilename, self._dbBackupFilename)) @@ -415,7 +421,7 @@ class Fail2BanDb(object): logSys.error("Failed to upgrade database '%s': %s", self._dbFilename, e.args[0], exc_info=logSys.getEffectiveLevel() <= 10) - raise + self.repairDB() @commitandrollback def addJail(self, cur, jail): @@ -789,7 +795,6 @@ class Fail2BanDb(object): queryArgs.append(fromtime) if overalljails or jail is None: query += " GROUP BY ip ORDER BY timeofban DESC LIMIT 1" - cur = self._db.cursor() # repack iterator as long as in lock: return list(cur.execute(query, queryArgs)) @@ -812,11 +817,9 @@ class Fail2BanDb(object): query += " GROUP BY ip ORDER BY ip, timeofban DESC" else: query += " ORDER BY timeofban DESC LIMIT 1" - cur = self._db.cursor() return cur.execute(query, queryArgs) - @commitandrollback - def getCurrentBans(self, cur, jail=None, ip=None, forbantime=None, fromtime=None, + def getCurrentBans(self, jail=None, ip=None, forbantime=None, fromtime=None, correctBanTime=True, maxmatches=None ): """Reads tickets (with merged info) currently affected from ban from the database. @@ -828,57 +831,63 @@ class Fail2BanDb(object): (and therefore endOfBan) of the ticket (normally it is ban-time of jail as maximum) for all tickets with ban-time greater (or persistent). """ - if fromtime is None: - fromtime = MyTime.time() - tickets = [] - ticket = None - if correctBanTime is True: - correctBanTime = jail.getMaxBanTime() if jail is not None else None - # don't change if persistent allowed: - if correctBanTime == -1: correctBanTime = None - - for ticket in self._getCurrentBans(cur, jail=jail, ip=ip, - forbantime=forbantime, fromtime=fromtime - ): - # can produce unpack error (database may return sporadical wrong-empty row): - try: - banip, timeofban, bantime, bancount, data = ticket - # additionally check for empty values: - if banip is None or banip == "": # pragma: no cover - raise ValueError('unexpected value %r' % (banip,)) - # if bantime unknown (after upgrade-db from earlier version), just use min known ban-time: - if bantime == -2: # todo: remove it in future version - bantime = jail.actions.getBanTime() if jail is not None else ( - correctBanTime if correctBanTime else 600) - elif correctBanTime and correctBanTime >= 0: - # if persistent ban (or greater as max), use current max-bantime of the jail: - if bantime == -1 or bantime > correctBanTime: - bantime = correctBanTime - # after correction check the end of ban again: - if bantime != -1 and timeofban + bantime <= fromtime: - # not persistent and too old - ignore it: - logSys.debug("ignore ticket (with new max ban-time %r): too old %r <= %r, ticket: %r", - bantime, timeofban + bantime, fromtime, ticket) + cur = self._db.cursor() + try: + if fromtime is None: + fromtime = MyTime.time() + tickets = [] + ticket = None + if correctBanTime is True: + correctBanTime = jail.getMaxBanTime() if jail is not None else None + # don't change if persistent allowed: + if correctBanTime == -1: correctBanTime = None + + with self._lock: + bans = self._getCurrentBans(cur, jail=jail, ip=ip, + forbantime=forbantime, fromtime=fromtime + ) + for ticket in bans: + # can produce unpack error (database may return sporadical wrong-empty row): + try: + banip, timeofban, bantime, bancount, data = ticket + # additionally check for empty values: + if banip is None or banip == "": # pragma: no cover + raise ValueError('unexpected value %r' % (banip,)) + # if bantime unknown (after upgrade-db from earlier version), just use min known ban-time: + if bantime == -2: # todo: remove it in future version + bantime = jail.actions.getBanTime() if jail is not None else ( + correctBanTime if correctBanTime else 600) + elif correctBanTime and correctBanTime >= 0: + # if persistent ban (or greater as max), use current max-bantime of the jail: + if bantime == -1 or bantime > correctBanTime: + bantime = correctBanTime + # after correction check the end of ban again: + if bantime != -1 and timeofban + bantime <= fromtime: + # not persistent and too old - ignore it: + logSys.debug("ignore ticket (with new max ban-time %r): too old %r <= %r, ticket: %r", + bantime, timeofban + bantime, fromtime, ticket) + continue + except ValueError as e: # pragma: no cover + logSys.debug("get current bans: ignore row %r - %s", ticket, e) continue - except ValueError as e: # pragma: no cover - logSys.debug("get current bans: ignore row %r - %s", ticket, e) - continue - # logSys.debug('restore ticket %r, %r, %r', banip, timeofban, data) - ticket = FailTicket(banip, timeofban, data=data) - # filter matches if expected (current count > as maxmatches specified): - if maxmatches is None: - maxmatches = self.maxMatches - if maxmatches: - matches = ticket.getMatches() - if matches and len(matches) > maxmatches: - ticket.setMatches(matches[-maxmatches:]) - else: - ticket.setMatches(None) - # logSys.debug('restored ticket: %r', ticket) - ticket.setBanTime(bantime) - ticket.setBanCount(bancount) - if ip is not None: return ticket - tickets.append(ticket) + # logSys.debug('restore ticket %r, %r, %r', banip, timeofban, data) + ticket = FailTicket(banip, timeofban, data=data) + # filter matches if expected (current count > as maxmatches specified): + if maxmatches is None: + maxmatches = self.maxMatches + if maxmatches: + matches = ticket.getMatches() + if matches and len(matches) > maxmatches: + ticket.setMatches(matches[-maxmatches:]) + else: + ticket.setMatches(None) + # logSys.debug('restored ticket: %r', ticket) + ticket.setBanTime(bantime) + ticket.setBanCount(bancount) + if ip is not None: return ticket + tickets.append(ticket) + finally: + cur.close() return tickets diff --git a/fail2ban/server/transmitter.py b/fail2ban/server/transmitter.py index 8e17d8629..6de60f94e 100644 --- a/fail2ban/server/transmitter.py +++ b/fail2ban/server/transmitter.py @@ -58,7 +58,7 @@ class Transmitter: ret = self.__commandHandler(command) ack = 0, ret except Exception as e: - logSys.warning("Command %r has failed. Received %r", + logSys.error("Command %r has failed. Received %r", command, e, exc_info=logSys.getEffectiveLevel()<=logging.DEBUG) ack = 1, e diff --git a/fail2ban/tests/fail2banclienttestcase.py b/fail2ban/tests/fail2banclienttestcase.py index 86480f02e..d72130106 100644 --- a/fail2ban/tests/fail2banclienttestcase.py +++ b/fail2ban/tests/fail2banclienttestcase.py @@ -491,6 +491,39 @@ class Fail2banClientServerBase(LogCaptureTestCase): self.execCmd(FAILED, startparams, "~~unknown~cmd~failed~~") self.execCmd(SUCCESS, startparams, "echo", "TEST-ECHO") + @with_tmpdir + @with_kill_srv + def testStartFailsInForeground(self, tmp): + if not server.Fail2BanDb: # pragma: no cover + raise unittest.SkipTest('Skip test because no database') + dbname = pjoin(tmp,"tmp.db") + db = server.Fail2BanDb(dbname) + # set inappropriate DB version to simulate an irreparable error by start: + cur = db._db.cursor() + cur.executescript("UPDATE fail2banDb SET version = 555") + cur.close() + # timeout (thread will stop foreground server): + startparams = _start_params(tmp, db=dbname, logtarget='INHERITED') + phase = {'stop': True} + def _stopTimeout(startparams, phase): + if not Utils.wait_for(lambda: not phase['stop'], MAX_WAITTIME): + # print('==== STOP ====') + self.execCmdDirect(startparams, 'stop') + th = Thread( + name="_TestCaseWorker", + target=_stopTimeout, + args=(startparams, phase) + ) + th.start() + # test: + try: + self.execCmd(FAILED, ("-f",) + startparams, "start") + finally: + phase['stop'] = False + th.join() + self.assertLogged("Attempt to travel to future version of database", + "Exit with code 255", all=True) + class Fail2banClientTest(Fail2banClientServerBase):