Merge branch '0.10' into 0.11

(conflicts resolved)
pull/2023/merge
sebres 2022-09-16 19:11:53 +02:00
commit 94dac78afe
5 changed files with 155 additions and 92 deletions

View File

@ -175,9 +175,13 @@ class Fail2banClient(Fail2banCmdLine, Thread):
return [["server-stream", stream], ['server-status']] return [["server-stream", stream], ['server-status']]
def _set_server(self, s):
self._server = s
## ##
def __startServer(self, background=True): def __startServer(self, background=True):
from .fail2banserver import Fail2banServer from .fail2banserver import Fail2banServer
# read configuration here (in client only, in server we do that in the config-thread):
stream = self.__prepareStartServer() stream = self.__prepareStartServer()
self._alive = True self._alive = True
if not stream: if not stream:
@ -192,16 +196,19 @@ class Fail2banClient(Fail2banCmdLine, Thread):
return False return False
else: else:
# In foreground mode we should make server/client communication in different threads: # In foreground mode we should make server/client communication in different threads:
th = Thread(target=Fail2banClient.__processStartStreamAfterWait, args=(self, stream, False)) phase = dict()
th.daemon = True self.configureServer(phase=phase, stream=stream)
th.start()
# Mark current (main) thread as daemon: # Mark current (main) thread as daemon:
self.daemon = True self.daemon = True
# Start server direct here in main thread (not fork): # Start server direct here in main thread (not fork):
self._server = Fail2banServer.startServerDirect(self._conf, False) self._server = Fail2banServer.startServerDirect(self._conf, False, self._set_server)
if not phase.get('done', False):
if self._server: # pragma: no cover
self._server.quit()
self._server = None
exit(255)
except ExitException: # pragma: no cover except ExitException: # pragma: no cover
pass raise
except Exception as e: # pragma: no cover except Exception as e: # pragma: no cover
output("") output("")
logSys.error("Exception while starting server " + ("background" if background else "foreground")) logSys.error("Exception while starting server " + ("background" if background else "foreground"))
@ -214,23 +221,39 @@ class Fail2banClient(Fail2banCmdLine, Thread):
return True return True
## ##
def configureServer(self, nonsync=True, phase=None): def configureServer(self, nonsync=True, phase=None, stream=None):
# if asynchronous start this operation in the new thread: # if asynchronous start this operation in the new thread:
if nonsync: if nonsync:
th = Thread(target=Fail2banClient.configureServer, args=(self, False, phase)) if phase is not None:
# event for server ready flag:
def _server_ready():
phase['start-ready'] = True
logSys.log(5, ' server phase %s', phase)
# notify waiting thread if server really ready
self._conf['onstart'] = _server_ready
th = Thread(target=Fail2banClient.configureServer, args=(self, False, phase, stream))
th.daemon = True th.daemon = True
return th.start() th.start()
# if we need to read configuration stream:
if stream is None and phase is not None:
# wait, do not continue if configuration is not 100% valid:
Utils.wait_for(lambda: phase.get('ready', None) is not None, self._conf["timeout"], 0.001)
logSys.log(5, ' server phase %s', phase)
if not phase.get('start', False):
raise ServerExecutionException('Async configuration of server failed')
return True
# prepare: read config, check configuration is valid, etc.: # prepare: read config, check configuration is valid, etc.:
if phase is not None: if phase is not None:
phase['start'] = True phase['start'] = True
logSys.log(5, ' client phase %s', phase) logSys.log(5, ' client phase %s', phase)
if stream is None:
stream = self.__prepareStartServer() stream = self.__prepareStartServer()
if phase is not None: if phase is not None:
phase['ready'] = phase['start'] = (True if stream else False) phase['ready'] = phase['start'] = (True if stream else False)
logSys.log(5, ' client phase %s', phase) logSys.log(5, ' client phase %s', phase)
if not stream: if not stream:
return False return False
# wait a litle bit for phase "start-ready" before enter active waiting: # wait a little bit for phase "start-ready" before enter active waiting:
if phase is not None: if phase is not None:
Utils.wait_for(lambda: phase.get('start-ready', None) is not None, 0.5, 0.001) Utils.wait_for(lambda: phase.get('start-ready', None) is not None, 0.5, 0.001)
phase['configure'] = (True if stream else False) phase['configure'] = (True if stream else False)
@ -321,13 +344,14 @@ class Fail2banClient(Fail2banCmdLine, Thread):
def __processStartStreamAfterWait(self, *args): def __processStartStreamAfterWait(self, *args):
ret = False
try: try:
# Wait for the server to start # Wait for the server to start
if not self.__waitOnServer(): # pragma: no cover if not self.__waitOnServer(): # pragma: no cover
logSys.error("Could not find server, waiting failed") logSys.error("Could not find server, waiting failed")
return False return False
# Configure the server # Configure the server
self.__processCmd(*args) ret = self.__processCmd(*args)
except ServerExecutionException as e: # pragma: no cover except ServerExecutionException as e: # pragma: no cover
if self._conf["verbose"] > 1: if self._conf["verbose"] > 1:
logSys.exception(e) logSys.exception(e)
@ -336,10 +360,11 @@ class Fail2banClient(Fail2banCmdLine, Thread):
"remove " + self._conf["socket"] + ". If " "remove " + self._conf["socket"] + ". If "
"you used fail2ban-client to start the " "you used fail2ban-client to start the "
"server, adding the -x option will do it") "server, adding the -x option will do it")
if self._server:
if not ret and self._server: # stop on error (foreground, config read in another thread):
self._server.quit() self._server.quit()
return False self._server = None
return True return ret
def __waitOnServer(self, alive=True, maxtime=None): def __waitOnServer(self, alive=True, maxtime=None):
if maxtime is None: if maxtime is None:

View File

@ -44,7 +44,7 @@ class Fail2banServer(Fail2banCmdLine):
# Start the Fail2ban server in background/foreground (daemon mode or not). # Start the Fail2ban server in background/foreground (daemon mode or not).
@staticmethod @staticmethod
def startServerDirect(conf, daemon=True): def startServerDirect(conf, daemon=True, setServer=None):
logSys.debug(" direct starting of server in %s, deamon: %s", os.getpid(), daemon) logSys.debug(" direct starting of server in %s, deamon: %s", os.getpid(), daemon)
from ..server.server import Server from ..server.server import Server
server = None server = None
@ -52,6 +52,10 @@ class Fail2banServer(Fail2banCmdLine):
# Start it in foreground (current thread, not new process), # Start it in foreground (current thread, not new process),
# server object will internally fork self if daemon is True # server object will internally fork self if daemon is True
server = Server(daemon) server = Server(daemon)
# notify caller - set server handle:
if setServer:
setServer(server)
# run:
server.start(conf["socket"], server.start(conf["socket"],
conf["pidfile"], conf["force"], conf["pidfile"], conf["force"],
conf=conf) conf=conf)
@ -63,6 +67,10 @@ class Fail2banServer(Fail2banCmdLine):
if conf["verbose"] > 1: if conf["verbose"] > 1:
logSys.exception(e2) logSys.exception(e2)
raise raise
finally:
# notify waiting thread server ready resp. done (background execution, error case, etc):
if conf.get('onstart'):
conf['onstart']()
return server return server
@ -179,27 +187,15 @@ class Fail2banServer(Fail2banCmdLine):
# Start new thread with client to read configuration and # Start new thread with client to read configuration and
# transfer it to the server: # transfer it to the server:
cli = self._Fail2banClient() cli = self._Fail2banClient()
cli._conf = self._conf
phase = dict() phase = dict()
logSys.debug('Configure via async client thread') logSys.debug('Configure via async client thread')
cli.configureServer(phase=phase) cli.configureServer(phase=phase)
# wait, do not continue if configuration is not 100% valid:
Utils.wait_for(lambda: phase.get('ready', None) is not None, self._conf["timeout"], 0.001)
logSys.log(5, ' server phase %s', phase)
if not phase.get('start', False):
raise ServerExecutionException('Async configuration of server failed')
# event for server ready flag:
def _server_ready():
phase['start-ready'] = True
logSys.log(5, ' server phase %s', phase)
# notify waiting thread if server really ready
self._conf['onstart'] = _server_ready
# Start server, daemonize it, etc. # Start server, daemonize it, etc.
pid = os.getpid() pid = os.getpid()
server = Fail2banServer.startServerDirect(self._conf, background) server = Fail2banServer.startServerDirect(self._conf, background,
# notify waiting thread server ready resp. done (background execution, error case, etc): cli._set_server if cli else None)
if not nonsync:
_server_ready()
# If forked - just exit other processes # If forked - just exit other processes
if pid != os.getpid(): # pragma: no cover if pid != os.getpid(): # pragma: no cover
os._exit(0) os._exit(0)

View File

@ -104,7 +104,11 @@ def commitandrollback(f):
def wrapper(self, *args, **kwargs): def wrapper(self, *args, **kwargs):
with self._lock: # Threading lock with self._lock: # Threading lock
with self._db: # Auto commit and rollback on exception with self._db: # Auto commit and rollback on exception
return f(self, self._db.cursor(), *args, **kwargs) cur = self._db.cursor()
try:
return f(self, cur, *args, **kwargs)
finally:
cur.close()
return wrapper return wrapper
@ -253,7 +257,7 @@ class Fail2BanDb(object):
self.repairDB() self.repairDB()
else: else:
version = cur.fetchone()[0] version = cur.fetchone()[0]
if version < Fail2BanDb.__version__: if version != Fail2BanDb.__version__:
newversion = self.updateDb(version) newversion = self.updateDb(version)
if newversion == Fail2BanDb.__version__: if newversion == Fail2BanDb.__version__:
logSys.warning( "Database updated from '%r' to '%r'", logSys.warning( "Database updated from '%r' to '%r'",
@ -301,9 +305,11 @@ class Fail2BanDb(object):
try: try:
# backup # backup
logSys.info("Trying to repair database %s", self._dbFilename) logSys.info("Trying to repair database %s", self._dbFilename)
if not os.path.isfile(self._dbBackupFilename):
shutil.move(self._dbFilename, self._dbBackupFilename) shutil.move(self._dbFilename, self._dbBackupFilename)
logSys.info(" Database backup created: %s", self._dbBackupFilename) logSys.info(" Database backup created: %s", self._dbBackupFilename)
elif os.path.isfile(self._dbFilename):
os.remove(self._dbFilename)
# first try to repair using dump/restore in order # first try to repair using dump/restore in order
Utils.executeCmd((r"""f2b_db=$0; f2b_dbbk=$1; sqlite3 "$f2b_dbbk" ".dump" | sqlite3 "$f2b_db" """, Utils.executeCmd((r"""f2b_db=$0; f2b_dbbk=$1; sqlite3 "$f2b_dbbk" ".dump" | sqlite3 "$f2b_db" """,
self._dbFilename, self._dbBackupFilename)) self._dbFilename, self._dbBackupFilename))
@ -415,7 +421,7 @@ class Fail2BanDb(object):
logSys.error("Failed to upgrade database '%s': %s", logSys.error("Failed to upgrade database '%s': %s",
self._dbFilename, e.args[0], self._dbFilename, e.args[0],
exc_info=logSys.getEffectiveLevel() <= 10) exc_info=logSys.getEffectiveLevel() <= 10)
raise self.repairDB()
@commitandrollback @commitandrollback
def addJail(self, cur, jail): def addJail(self, cur, jail):
@ -789,7 +795,6 @@ class Fail2BanDb(object):
queryArgs.append(fromtime) queryArgs.append(fromtime)
if overalljails or jail is None: if overalljails or jail is None:
query += " GROUP BY ip ORDER BY timeofban DESC LIMIT 1" query += " GROUP BY ip ORDER BY timeofban DESC LIMIT 1"
cur = self._db.cursor()
# repack iterator as long as in lock: # repack iterator as long as in lock:
return list(cur.execute(query, queryArgs)) return list(cur.execute(query, queryArgs))
@ -812,11 +817,9 @@ class Fail2BanDb(object):
query += " GROUP BY ip ORDER BY ip, timeofban DESC" query += " GROUP BY ip ORDER BY ip, timeofban DESC"
else: else:
query += " ORDER BY timeofban DESC LIMIT 1" query += " ORDER BY timeofban DESC LIMIT 1"
cur = self._db.cursor()
return cur.execute(query, queryArgs) return cur.execute(query, queryArgs)
@commitandrollback def getCurrentBans(self, jail=None, ip=None, forbantime=None, fromtime=None,
def getCurrentBans(self, cur, jail=None, ip=None, forbantime=None, fromtime=None,
correctBanTime=True, maxmatches=None correctBanTime=True, maxmatches=None
): ):
"""Reads tickets (with merged info) currently affected from ban from the database. """Reads tickets (with merged info) currently affected from ban from the database.
@ -828,6 +831,8 @@ class Fail2BanDb(object):
(and therefore endOfBan) of the ticket (normally it is ban-time of jail as maximum) (and therefore endOfBan) of the ticket (normally it is ban-time of jail as maximum)
for all tickets with ban-time greater (or persistent). for all tickets with ban-time greater (or persistent).
""" """
cur = self._db.cursor()
try:
if fromtime is None: if fromtime is None:
fromtime = MyTime.time() fromtime = MyTime.time()
tickets = [] tickets = []
@ -837,9 +842,11 @@ class Fail2BanDb(object):
# don't change if persistent allowed: # don't change if persistent allowed:
if correctBanTime == -1: correctBanTime = None if correctBanTime == -1: correctBanTime = None
for ticket in self._getCurrentBans(cur, jail=jail, ip=ip, with self._lock:
bans = self._getCurrentBans(cur, jail=jail, ip=ip,
forbantime=forbantime, fromtime=fromtime forbantime=forbantime, fromtime=fromtime
): )
for ticket in bans:
# can produce unpack error (database may return sporadical wrong-empty row): # can produce unpack error (database may return sporadical wrong-empty row):
try: try:
banip, timeofban, bantime, bancount, data = ticket banip, timeofban, bantime, bancount, data = ticket
@ -879,6 +886,8 @@ class Fail2BanDb(object):
ticket.setBanCount(bancount) ticket.setBanCount(bancount)
if ip is not None: return ticket if ip is not None: return ticket
tickets.append(ticket) tickets.append(ticket)
finally:
cur.close()
return tickets return tickets

View File

@ -58,7 +58,7 @@ class Transmitter:
ret = self.__commandHandler(command) ret = self.__commandHandler(command)
ack = 0, ret ack = 0, ret
except Exception as e: except Exception as e:
logSys.warning("Command %r has failed. Received %r", logSys.error("Command %r has failed. Received %r",
command, e, command, e,
exc_info=logSys.getEffectiveLevel()<=logging.DEBUG) exc_info=logSys.getEffectiveLevel()<=logging.DEBUG)
ack = 1, e ack = 1, e

View File

@ -491,6 +491,39 @@ class Fail2banClientServerBase(LogCaptureTestCase):
self.execCmd(FAILED, startparams, "~~unknown~cmd~failed~~") self.execCmd(FAILED, startparams, "~~unknown~cmd~failed~~")
self.execCmd(SUCCESS, startparams, "echo", "TEST-ECHO") self.execCmd(SUCCESS, startparams, "echo", "TEST-ECHO")
@with_tmpdir
@with_kill_srv
def testStartFailsInForeground(self, tmp):
if not server.Fail2BanDb: # pragma: no cover
raise unittest.SkipTest('Skip test because no database')
dbname = pjoin(tmp,"tmp.db")
db = server.Fail2BanDb(dbname)
# set inappropriate DB version to simulate an irreparable error by start:
cur = db._db.cursor()
cur.executescript("UPDATE fail2banDb SET version = 555")
cur.close()
# timeout (thread will stop foreground server):
startparams = _start_params(tmp, db=dbname, logtarget='INHERITED')
phase = {'stop': True}
def _stopTimeout(startparams, phase):
if not Utils.wait_for(lambda: not phase['stop'], MAX_WAITTIME):
# print('==== STOP ====')
self.execCmdDirect(startparams, 'stop')
th = Thread(
name="_TestCaseWorker",
target=_stopTimeout,
args=(startparams, phase)
)
th.start()
# test:
try:
self.execCmd(FAILED, ("-f",) + startparams, "start")
finally:
phase['stop'] = False
th.join()
self.assertLogged("Attempt to travel to future version of database",
"Exit with code 255", all=True)
class Fail2banClientTest(Fail2banClientServerBase): class Fail2banClientTest(Fail2banClientServerBase):