mirror of https://github.com/fail2ban/fail2ban
commit
94dac78afe
|
@ -175,9 +175,13 @@ class Fail2banClient(Fail2banCmdLine, Thread):
|
||||||
|
|
||||||
return [["server-stream", stream], ['server-status']]
|
return [["server-stream", stream], ['server-status']]
|
||||||
|
|
||||||
|
def _set_server(self, s):
|
||||||
|
self._server = s
|
||||||
|
|
||||||
##
|
##
|
||||||
def __startServer(self, background=True):
|
def __startServer(self, background=True):
|
||||||
from .fail2banserver import Fail2banServer
|
from .fail2banserver import Fail2banServer
|
||||||
|
# read configuration here (in client only, in server we do that in the config-thread):
|
||||||
stream = self.__prepareStartServer()
|
stream = self.__prepareStartServer()
|
||||||
self._alive = True
|
self._alive = True
|
||||||
if not stream:
|
if not stream:
|
||||||
|
@ -192,16 +196,19 @@ class Fail2banClient(Fail2banCmdLine, Thread):
|
||||||
return False
|
return False
|
||||||
else:
|
else:
|
||||||
# In foreground mode we should make server/client communication in different threads:
|
# In foreground mode we should make server/client communication in different threads:
|
||||||
th = Thread(target=Fail2banClient.__processStartStreamAfterWait, args=(self, stream, False))
|
phase = dict()
|
||||||
th.daemon = True
|
self.configureServer(phase=phase, stream=stream)
|
||||||
th.start()
|
|
||||||
# Mark current (main) thread as daemon:
|
# Mark current (main) thread as daemon:
|
||||||
self.daemon = True
|
self.daemon = True
|
||||||
# Start server direct here in main thread (not fork):
|
# Start server direct here in main thread (not fork):
|
||||||
self._server = Fail2banServer.startServerDirect(self._conf, False)
|
self._server = Fail2banServer.startServerDirect(self._conf, False, self._set_server)
|
||||||
|
if not phase.get('done', False):
|
||||||
|
if self._server: # pragma: no cover
|
||||||
|
self._server.quit()
|
||||||
|
self._server = None
|
||||||
|
exit(255)
|
||||||
except ExitException: # pragma: no cover
|
except ExitException: # pragma: no cover
|
||||||
pass
|
raise
|
||||||
except Exception as e: # pragma: no cover
|
except Exception as e: # pragma: no cover
|
||||||
output("")
|
output("")
|
||||||
logSys.error("Exception while starting server " + ("background" if background else "foreground"))
|
logSys.error("Exception while starting server " + ("background" if background else "foreground"))
|
||||||
|
@ -214,23 +221,39 @@ class Fail2banClient(Fail2banCmdLine, Thread):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
##
|
##
|
||||||
def configureServer(self, nonsync=True, phase=None):
|
def configureServer(self, nonsync=True, phase=None, stream=None):
|
||||||
# if asynchronous start this operation in the new thread:
|
# if asynchronous start this operation in the new thread:
|
||||||
if nonsync:
|
if nonsync:
|
||||||
th = Thread(target=Fail2banClient.configureServer, args=(self, False, phase))
|
if phase is not None:
|
||||||
|
# event for server ready flag:
|
||||||
|
def _server_ready():
|
||||||
|
phase['start-ready'] = True
|
||||||
|
logSys.log(5, ' server phase %s', phase)
|
||||||
|
# notify waiting thread if server really ready
|
||||||
|
self._conf['onstart'] = _server_ready
|
||||||
|
th = Thread(target=Fail2banClient.configureServer, args=(self, False, phase, stream))
|
||||||
th.daemon = True
|
th.daemon = True
|
||||||
return th.start()
|
th.start()
|
||||||
|
# if we need to read configuration stream:
|
||||||
|
if stream is None and phase is not None:
|
||||||
|
# wait, do not continue if configuration is not 100% valid:
|
||||||
|
Utils.wait_for(lambda: phase.get('ready', None) is not None, self._conf["timeout"], 0.001)
|
||||||
|
logSys.log(5, ' server phase %s', phase)
|
||||||
|
if not phase.get('start', False):
|
||||||
|
raise ServerExecutionException('Async configuration of server failed')
|
||||||
|
return True
|
||||||
# prepare: read config, check configuration is valid, etc.:
|
# prepare: read config, check configuration is valid, etc.:
|
||||||
if phase is not None:
|
if phase is not None:
|
||||||
phase['start'] = True
|
phase['start'] = True
|
||||||
logSys.log(5, ' client phase %s', phase)
|
logSys.log(5, ' client phase %s', phase)
|
||||||
|
if stream is None:
|
||||||
stream = self.__prepareStartServer()
|
stream = self.__prepareStartServer()
|
||||||
if phase is not None:
|
if phase is not None:
|
||||||
phase['ready'] = phase['start'] = (True if stream else False)
|
phase['ready'] = phase['start'] = (True if stream else False)
|
||||||
logSys.log(5, ' client phase %s', phase)
|
logSys.log(5, ' client phase %s', phase)
|
||||||
if not stream:
|
if not stream:
|
||||||
return False
|
return False
|
||||||
# wait a litle bit for phase "start-ready" before enter active waiting:
|
# wait a little bit for phase "start-ready" before enter active waiting:
|
||||||
if phase is not None:
|
if phase is not None:
|
||||||
Utils.wait_for(lambda: phase.get('start-ready', None) is not None, 0.5, 0.001)
|
Utils.wait_for(lambda: phase.get('start-ready', None) is not None, 0.5, 0.001)
|
||||||
phase['configure'] = (True if stream else False)
|
phase['configure'] = (True if stream else False)
|
||||||
|
@ -321,13 +344,14 @@ class Fail2banClient(Fail2banCmdLine, Thread):
|
||||||
|
|
||||||
|
|
||||||
def __processStartStreamAfterWait(self, *args):
|
def __processStartStreamAfterWait(self, *args):
|
||||||
|
ret = False
|
||||||
try:
|
try:
|
||||||
# Wait for the server to start
|
# Wait for the server to start
|
||||||
if not self.__waitOnServer(): # pragma: no cover
|
if not self.__waitOnServer(): # pragma: no cover
|
||||||
logSys.error("Could not find server, waiting failed")
|
logSys.error("Could not find server, waiting failed")
|
||||||
return False
|
return False
|
||||||
# Configure the server
|
# Configure the server
|
||||||
self.__processCmd(*args)
|
ret = self.__processCmd(*args)
|
||||||
except ServerExecutionException as e: # pragma: no cover
|
except ServerExecutionException as e: # pragma: no cover
|
||||||
if self._conf["verbose"] > 1:
|
if self._conf["verbose"] > 1:
|
||||||
logSys.exception(e)
|
logSys.exception(e)
|
||||||
|
@ -336,10 +360,11 @@ class Fail2banClient(Fail2banCmdLine, Thread):
|
||||||
"remove " + self._conf["socket"] + ". If "
|
"remove " + self._conf["socket"] + ". If "
|
||||||
"you used fail2ban-client to start the "
|
"you used fail2ban-client to start the "
|
||||||
"server, adding the -x option will do it")
|
"server, adding the -x option will do it")
|
||||||
if self._server:
|
|
||||||
|
if not ret and self._server: # stop on error (foreground, config read in another thread):
|
||||||
self._server.quit()
|
self._server.quit()
|
||||||
return False
|
self._server = None
|
||||||
return True
|
return ret
|
||||||
|
|
||||||
def __waitOnServer(self, alive=True, maxtime=None):
|
def __waitOnServer(self, alive=True, maxtime=None):
|
||||||
if maxtime is None:
|
if maxtime is None:
|
||||||
|
|
|
@ -44,7 +44,7 @@ class Fail2banServer(Fail2banCmdLine):
|
||||||
# Start the Fail2ban server in background/foreground (daemon mode or not).
|
# Start the Fail2ban server in background/foreground (daemon mode or not).
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def startServerDirect(conf, daemon=True):
|
def startServerDirect(conf, daemon=True, setServer=None):
|
||||||
logSys.debug(" direct starting of server in %s, deamon: %s", os.getpid(), daemon)
|
logSys.debug(" direct starting of server in %s, deamon: %s", os.getpid(), daemon)
|
||||||
from ..server.server import Server
|
from ..server.server import Server
|
||||||
server = None
|
server = None
|
||||||
|
@ -52,6 +52,10 @@ class Fail2banServer(Fail2banCmdLine):
|
||||||
# Start it in foreground (current thread, not new process),
|
# Start it in foreground (current thread, not new process),
|
||||||
# server object will internally fork self if daemon is True
|
# server object will internally fork self if daemon is True
|
||||||
server = Server(daemon)
|
server = Server(daemon)
|
||||||
|
# notify caller - set server handle:
|
||||||
|
if setServer:
|
||||||
|
setServer(server)
|
||||||
|
# run:
|
||||||
server.start(conf["socket"],
|
server.start(conf["socket"],
|
||||||
conf["pidfile"], conf["force"],
|
conf["pidfile"], conf["force"],
|
||||||
conf=conf)
|
conf=conf)
|
||||||
|
@ -63,6 +67,10 @@ class Fail2banServer(Fail2banCmdLine):
|
||||||
if conf["verbose"] > 1:
|
if conf["verbose"] > 1:
|
||||||
logSys.exception(e2)
|
logSys.exception(e2)
|
||||||
raise
|
raise
|
||||||
|
finally:
|
||||||
|
# notify waiting thread server ready resp. done (background execution, error case, etc):
|
||||||
|
if conf.get('onstart'):
|
||||||
|
conf['onstart']()
|
||||||
|
|
||||||
return server
|
return server
|
||||||
|
|
||||||
|
@ -179,27 +187,15 @@ class Fail2banServer(Fail2banCmdLine):
|
||||||
# Start new thread with client to read configuration and
|
# Start new thread with client to read configuration and
|
||||||
# transfer it to the server:
|
# transfer it to the server:
|
||||||
cli = self._Fail2banClient()
|
cli = self._Fail2banClient()
|
||||||
|
cli._conf = self._conf
|
||||||
phase = dict()
|
phase = dict()
|
||||||
logSys.debug('Configure via async client thread')
|
logSys.debug('Configure via async client thread')
|
||||||
cli.configureServer(phase=phase)
|
cli.configureServer(phase=phase)
|
||||||
# wait, do not continue if configuration is not 100% valid:
|
|
||||||
Utils.wait_for(lambda: phase.get('ready', None) is not None, self._conf["timeout"], 0.001)
|
|
||||||
logSys.log(5, ' server phase %s', phase)
|
|
||||||
if not phase.get('start', False):
|
|
||||||
raise ServerExecutionException('Async configuration of server failed')
|
|
||||||
# event for server ready flag:
|
|
||||||
def _server_ready():
|
|
||||||
phase['start-ready'] = True
|
|
||||||
logSys.log(5, ' server phase %s', phase)
|
|
||||||
# notify waiting thread if server really ready
|
|
||||||
self._conf['onstart'] = _server_ready
|
|
||||||
|
|
||||||
# Start server, daemonize it, etc.
|
# Start server, daemonize it, etc.
|
||||||
pid = os.getpid()
|
pid = os.getpid()
|
||||||
server = Fail2banServer.startServerDirect(self._conf, background)
|
server = Fail2banServer.startServerDirect(self._conf, background,
|
||||||
# notify waiting thread server ready resp. done (background execution, error case, etc):
|
cli._set_server if cli else None)
|
||||||
if not nonsync:
|
|
||||||
_server_ready()
|
|
||||||
# If forked - just exit other processes
|
# If forked - just exit other processes
|
||||||
if pid != os.getpid(): # pragma: no cover
|
if pid != os.getpid(): # pragma: no cover
|
||||||
os._exit(0)
|
os._exit(0)
|
||||||
|
|
|
@ -104,7 +104,11 @@ def commitandrollback(f):
|
||||||
def wrapper(self, *args, **kwargs):
|
def wrapper(self, *args, **kwargs):
|
||||||
with self._lock: # Threading lock
|
with self._lock: # Threading lock
|
||||||
with self._db: # Auto commit and rollback on exception
|
with self._db: # Auto commit and rollback on exception
|
||||||
return f(self, self._db.cursor(), *args, **kwargs)
|
cur = self._db.cursor()
|
||||||
|
try:
|
||||||
|
return f(self, cur, *args, **kwargs)
|
||||||
|
finally:
|
||||||
|
cur.close()
|
||||||
return wrapper
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
|
@ -253,7 +257,7 @@ class Fail2BanDb(object):
|
||||||
self.repairDB()
|
self.repairDB()
|
||||||
else:
|
else:
|
||||||
version = cur.fetchone()[0]
|
version = cur.fetchone()[0]
|
||||||
if version < Fail2BanDb.__version__:
|
if version != Fail2BanDb.__version__:
|
||||||
newversion = self.updateDb(version)
|
newversion = self.updateDb(version)
|
||||||
if newversion == Fail2BanDb.__version__:
|
if newversion == Fail2BanDb.__version__:
|
||||||
logSys.warning( "Database updated from '%r' to '%r'",
|
logSys.warning( "Database updated from '%r' to '%r'",
|
||||||
|
@ -301,9 +305,11 @@ class Fail2BanDb(object):
|
||||||
try:
|
try:
|
||||||
# backup
|
# backup
|
||||||
logSys.info("Trying to repair database %s", self._dbFilename)
|
logSys.info("Trying to repair database %s", self._dbFilename)
|
||||||
|
if not os.path.isfile(self._dbBackupFilename):
|
||||||
shutil.move(self._dbFilename, self._dbBackupFilename)
|
shutil.move(self._dbFilename, self._dbBackupFilename)
|
||||||
logSys.info(" Database backup created: %s", self._dbBackupFilename)
|
logSys.info(" Database backup created: %s", self._dbBackupFilename)
|
||||||
|
elif os.path.isfile(self._dbFilename):
|
||||||
|
os.remove(self._dbFilename)
|
||||||
# first try to repair using dump/restore in order
|
# first try to repair using dump/restore in order
|
||||||
Utils.executeCmd((r"""f2b_db=$0; f2b_dbbk=$1; sqlite3 "$f2b_dbbk" ".dump" | sqlite3 "$f2b_db" """,
|
Utils.executeCmd((r"""f2b_db=$0; f2b_dbbk=$1; sqlite3 "$f2b_dbbk" ".dump" | sqlite3 "$f2b_db" """,
|
||||||
self._dbFilename, self._dbBackupFilename))
|
self._dbFilename, self._dbBackupFilename))
|
||||||
|
@ -415,7 +421,7 @@ class Fail2BanDb(object):
|
||||||
logSys.error("Failed to upgrade database '%s': %s",
|
logSys.error("Failed to upgrade database '%s': %s",
|
||||||
self._dbFilename, e.args[0],
|
self._dbFilename, e.args[0],
|
||||||
exc_info=logSys.getEffectiveLevel() <= 10)
|
exc_info=logSys.getEffectiveLevel() <= 10)
|
||||||
raise
|
self.repairDB()
|
||||||
|
|
||||||
@commitandrollback
|
@commitandrollback
|
||||||
def addJail(self, cur, jail):
|
def addJail(self, cur, jail):
|
||||||
|
@ -789,7 +795,6 @@ class Fail2BanDb(object):
|
||||||
queryArgs.append(fromtime)
|
queryArgs.append(fromtime)
|
||||||
if overalljails or jail is None:
|
if overalljails or jail is None:
|
||||||
query += " GROUP BY ip ORDER BY timeofban DESC LIMIT 1"
|
query += " GROUP BY ip ORDER BY timeofban DESC LIMIT 1"
|
||||||
cur = self._db.cursor()
|
|
||||||
# repack iterator as long as in lock:
|
# repack iterator as long as in lock:
|
||||||
return list(cur.execute(query, queryArgs))
|
return list(cur.execute(query, queryArgs))
|
||||||
|
|
||||||
|
@ -812,11 +817,9 @@ class Fail2BanDb(object):
|
||||||
query += " GROUP BY ip ORDER BY ip, timeofban DESC"
|
query += " GROUP BY ip ORDER BY ip, timeofban DESC"
|
||||||
else:
|
else:
|
||||||
query += " ORDER BY timeofban DESC LIMIT 1"
|
query += " ORDER BY timeofban DESC LIMIT 1"
|
||||||
cur = self._db.cursor()
|
|
||||||
return cur.execute(query, queryArgs)
|
return cur.execute(query, queryArgs)
|
||||||
|
|
||||||
@commitandrollback
|
def getCurrentBans(self, jail=None, ip=None, forbantime=None, fromtime=None,
|
||||||
def getCurrentBans(self, cur, jail=None, ip=None, forbantime=None, fromtime=None,
|
|
||||||
correctBanTime=True, maxmatches=None
|
correctBanTime=True, maxmatches=None
|
||||||
):
|
):
|
||||||
"""Reads tickets (with merged info) currently affected from ban from the database.
|
"""Reads tickets (with merged info) currently affected from ban from the database.
|
||||||
|
@ -828,6 +831,8 @@ class Fail2BanDb(object):
|
||||||
(and therefore endOfBan) of the ticket (normally it is ban-time of jail as maximum)
|
(and therefore endOfBan) of the ticket (normally it is ban-time of jail as maximum)
|
||||||
for all tickets with ban-time greater (or persistent).
|
for all tickets with ban-time greater (or persistent).
|
||||||
"""
|
"""
|
||||||
|
cur = self._db.cursor()
|
||||||
|
try:
|
||||||
if fromtime is None:
|
if fromtime is None:
|
||||||
fromtime = MyTime.time()
|
fromtime = MyTime.time()
|
||||||
tickets = []
|
tickets = []
|
||||||
|
@ -837,9 +842,11 @@ class Fail2BanDb(object):
|
||||||
# don't change if persistent allowed:
|
# don't change if persistent allowed:
|
||||||
if correctBanTime == -1: correctBanTime = None
|
if correctBanTime == -1: correctBanTime = None
|
||||||
|
|
||||||
for ticket in self._getCurrentBans(cur, jail=jail, ip=ip,
|
with self._lock:
|
||||||
|
bans = self._getCurrentBans(cur, jail=jail, ip=ip,
|
||||||
forbantime=forbantime, fromtime=fromtime
|
forbantime=forbantime, fromtime=fromtime
|
||||||
):
|
)
|
||||||
|
for ticket in bans:
|
||||||
# can produce unpack error (database may return sporadical wrong-empty row):
|
# can produce unpack error (database may return sporadical wrong-empty row):
|
||||||
try:
|
try:
|
||||||
banip, timeofban, bantime, bancount, data = ticket
|
banip, timeofban, bantime, bancount, data = ticket
|
||||||
|
@ -879,6 +886,8 @@ class Fail2BanDb(object):
|
||||||
ticket.setBanCount(bancount)
|
ticket.setBanCount(bancount)
|
||||||
if ip is not None: return ticket
|
if ip is not None: return ticket
|
||||||
tickets.append(ticket)
|
tickets.append(ticket)
|
||||||
|
finally:
|
||||||
|
cur.close()
|
||||||
|
|
||||||
return tickets
|
return tickets
|
||||||
|
|
||||||
|
|
|
@ -58,7 +58,7 @@ class Transmitter:
|
||||||
ret = self.__commandHandler(command)
|
ret = self.__commandHandler(command)
|
||||||
ack = 0, ret
|
ack = 0, ret
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logSys.warning("Command %r has failed. Received %r",
|
logSys.error("Command %r has failed. Received %r",
|
||||||
command, e,
|
command, e,
|
||||||
exc_info=logSys.getEffectiveLevel()<=logging.DEBUG)
|
exc_info=logSys.getEffectiveLevel()<=logging.DEBUG)
|
||||||
ack = 1, e
|
ack = 1, e
|
||||||
|
|
|
@ -491,6 +491,39 @@ class Fail2banClientServerBase(LogCaptureTestCase):
|
||||||
self.execCmd(FAILED, startparams, "~~unknown~cmd~failed~~")
|
self.execCmd(FAILED, startparams, "~~unknown~cmd~failed~~")
|
||||||
self.execCmd(SUCCESS, startparams, "echo", "TEST-ECHO")
|
self.execCmd(SUCCESS, startparams, "echo", "TEST-ECHO")
|
||||||
|
|
||||||
|
@with_tmpdir
|
||||||
|
@with_kill_srv
|
||||||
|
def testStartFailsInForeground(self, tmp):
|
||||||
|
if not server.Fail2BanDb: # pragma: no cover
|
||||||
|
raise unittest.SkipTest('Skip test because no database')
|
||||||
|
dbname = pjoin(tmp,"tmp.db")
|
||||||
|
db = server.Fail2BanDb(dbname)
|
||||||
|
# set inappropriate DB version to simulate an irreparable error by start:
|
||||||
|
cur = db._db.cursor()
|
||||||
|
cur.executescript("UPDATE fail2banDb SET version = 555")
|
||||||
|
cur.close()
|
||||||
|
# timeout (thread will stop foreground server):
|
||||||
|
startparams = _start_params(tmp, db=dbname, logtarget='INHERITED')
|
||||||
|
phase = {'stop': True}
|
||||||
|
def _stopTimeout(startparams, phase):
|
||||||
|
if not Utils.wait_for(lambda: not phase['stop'], MAX_WAITTIME):
|
||||||
|
# print('==== STOP ====')
|
||||||
|
self.execCmdDirect(startparams, 'stop')
|
||||||
|
th = Thread(
|
||||||
|
name="_TestCaseWorker",
|
||||||
|
target=_stopTimeout,
|
||||||
|
args=(startparams, phase)
|
||||||
|
)
|
||||||
|
th.start()
|
||||||
|
# test:
|
||||||
|
try:
|
||||||
|
self.execCmd(FAILED, ("-f",) + startparams, "start")
|
||||||
|
finally:
|
||||||
|
phase['stop'] = False
|
||||||
|
th.join()
|
||||||
|
self.assertLogged("Attempt to travel to future version of database",
|
||||||
|
"Exit with code 255", all=True)
|
||||||
|
|
||||||
|
|
||||||
class Fail2banClientTest(Fail2banClientServerBase):
|
class Fail2banClientTest(Fail2banClientServerBase):
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue