cherry-pick from 0.11: changes in updateDb because it can be executed after repair, and some tables can be missing.

pull/2014/head
sebres 2017-12-22 17:21:11 +01:00
parent 277edd5fe5
commit 1e39c2600c
2 changed files with 22 additions and 10 deletions

View File

@ -331,6 +331,12 @@ class Fail2BanDb(object):
def createDb(self, cur, incremental=False): def createDb(self, cur, incremental=False):
return self._createDb(cur, incremental); return self._createDb(cur, incremental);
def _tableExists(self, cur, table):
cur.execute("select 1 where exists ("
"select 1 from sqlite_master WHERE type='table' AND name=?)", (table,))
res = cur.fetchone()
return res is not None and res[0]
@commitandrollback @commitandrollback
def updateDb(self, cur, version): def updateDb(self, cur, version):
"""Update an existing database, called during initialisation. """Update an existing database, called during initialisation.
@ -340,13 +346,13 @@ class Fail2BanDb(object):
if version > Fail2BanDb.__version__: if version > Fail2BanDb.__version__:
raise NotImplementedError( raise NotImplementedError(
"Attempt to travel to future version of database ...how did you get here??") "Attempt to travel to future version of database ...how did you get here??")
try:
logSys.info("Uprade database: %s", self._dbBackupFilename) logSys.info("Upgrade database: %s from version '%r'", self._dbBackupFilename, version)
if not os.path.isfile(self._dbBackupFilename): if not os.path.isfile(self._dbBackupFilename):
shutil.copyfile(self.filename, self._dbBackupFilename) shutil.copyfile(self.filename, self._dbBackupFilename)
logSys.info(" Database backup created: %s", self._dbBackupFilename) logSys.info(" Database backup created: %s", self._dbBackupFilename)
if version < 2: if version < 2 and self._tableExists(cur, "logs"):
cur.executescript("BEGIN TRANSACTION;" cur.executescript("BEGIN TRANSACTION;"
"CREATE TEMPORARY TABLE logs_temp AS SELECT * FROM logs;" "CREATE TEMPORARY TABLE logs_temp AS SELECT * FROM logs;"
"DROP TABLE logs;" "DROP TABLE logs;"
@ -358,6 +364,12 @@ class Fail2BanDb(object):
cur.execute("SELECT version FROM fail2banDb LIMIT 1") cur.execute("SELECT version FROM fail2banDb LIMIT 1")
return cur.fetchone()[0] return cur.fetchone()[0]
except Exception as e:
# if still failed, just recreate database as fallback:
logSys.error("Failed to upgrade database '%s': %s",
self._dbFilename, e.args[0],
exc_info=logSys.getEffectiveLevel() <= 10)
raise
@commitandrollback @commitandrollback
def addJail(self, cur, jail): def addJail(self, cur, jail):

View File

@ -144,7 +144,7 @@ class DatabaseTest(LogCaptureTestCase):
self.assertEqual(len(self.db.getJailNames()), 1) self.assertEqual(len(self.db.getJailNames()), 1)
else: # recreated: else: # recreated:
self.assertLogged("Repair seems to be failed", self.assertLogged("Repair seems to be failed",
"New database created.", all=True) "Check integrity", "New database created.", all=True)
self.assertEqual(len(self.db.getLogPaths()), 0) self.assertEqual(len(self.db.getLogPaths()), 0)
self.assertEqual(len(self.db.getJailNames()), 0) self.assertEqual(len(self.db.getJailNames()), 0)
finally: finally: