diff --git a/fail2ban/server/database.py b/fail2ban/server/database.py index 5474311b..f301fdf8 100644 --- a/fail2ban/server/database.py +++ b/fail2ban/server/database.py @@ -129,14 +129,15 @@ class Fail2BanDb(object): purgeage """ __version__ = 2 - # Note all _TABLE_* strings must end in ';' for py26 compatibility - _TABLE_fail2banDb = "CREATE TABLE fail2banDb(version INTEGER);" - _TABLE_jails = "CREATE TABLE jails(" \ + # Note all SCRIPTS strings must end in ';' for py26 compatibility + _CREATE_SCRIPTS = ( + ('fail2banDb', "CREATE TABLE IF NOT EXISTS fail2banDb(version INTEGER);") + ,('jails', "CREATE TABLE IF NOT EXISTS jails(" \ "name TEXT NOT NULL UNIQUE, " \ "enabled INTEGER NOT NULL DEFAULT 1" \ ");" \ - "CREATE INDEX jails_name ON jails(name);" - _TABLE_logs = "CREATE TABLE logs(" \ + "CREATE INDEX IF NOT EXISTS jails_name ON jails(name);") + ,('logs', "CREATE TABLE IF NOT EXISTS logs(" \ "jail TEXT NOT NULL, " \ "path TEXT, " \ "firstlinemd5 TEXT, " \ @@ -145,22 +146,24 @@ class Fail2BanDb(object): "UNIQUE(jail, path)," \ "UNIQUE(jail, path, firstlinemd5)" \ ");" \ - "CREATE INDEX logs_path ON logs(path);" \ - "CREATE INDEX logs_jail_path ON logs(jail, path);" + "CREATE INDEX IF NOT EXISTS logs_path ON logs(path);" \ + "CREATE INDEX IF NOT EXISTS logs_jail_path ON logs(jail, path);") #TODO: systemd journal features \ #"journalmatch TEXT, " \ #"journlcursor TEXT, " \ - #"lastfiletime INTEGER DEFAULT 0, " # is this easily available \ - _TABLE_bans = "CREATE TABLE bans(" \ + #"lastfiletime INTEGER DEFAULT 0, " # is this easily available + ,('bans', "CREATE TABLE IF NOT EXISTS bans(" \ "jail TEXT NOT NULL, " \ "ip TEXT, " \ "timeofban INTEGER NOT NULL, " \ "data JSON, " \ "FOREIGN KEY(jail) REFERENCES jails(name) " \ ");" \ - "CREATE INDEX bans_jail_timeofban_ip ON bans(jail, timeofban);" \ - "CREATE INDEX bans_jail_ip ON bans(jail, ip);" \ - "CREATE INDEX bans_ip ON bans(ip);" \ + "CREATE INDEX IF NOT EXISTS bans_jail_timeofban_ip ON bans(jail, timeofban);" \ + "CREATE INDEX IF NOT EXISTS bans_jail_ip ON bans(jail, ip);" \ + "CREATE INDEX IF NOT EXISTS bans_ip ON bans(ip);") + ) + _CREATE_TABS = dict(_CREATE_SCRIPTS) def __init__(self, filename, purgeAge=24*60*60): @@ -206,16 +209,9 @@ class Fail2BanDb(object): # speedup: temporary tables and indices are kept in memory: cur.execute("PRAGMA temp_store = MEMORY") - if checkIntegrity: - logSys.debug(" Check integrity ...") - cur.execute("PRAGMA integrity_check") - for s in cur.fetchall(): - logSys.debug(" %s", s) - self._db.commit() - cur.execute("SELECT version FROM fail2banDb LIMIT 1") except sqlite3.OperationalError: - logSys.warning("New database created. Version '%i'", + logSys.warning("New database created. Version '%r'", self.createDb()) except sqlite3.Error as e: logSys.error( @@ -233,14 +229,23 @@ class Fail2BanDb(object): if version < Fail2BanDb.__version__: newversion = self.updateDb(version) if newversion == Fail2BanDb.__version__: - logSys.warning( "Database updated from '%i' to '%i'", + logSys.warning( "Database updated from '%r' to '%r'", version, newversion) else: # pragma: no cover - logSys.error( "Database update failed to achieve version '%i'" - ": updated from '%i' to '%i'", + logSys.error( "Database update failed to achieve version '%r'" + ": updated from '%r' to '%r'", Fail2BanDb.__version__, version, newversion) raise RuntimeError('Failed to fully update') finally: + if checkIntegrity: + logSys.debug(" Create missing tables/indices ...") + self._createDb(cur, incremental=True) + logSys.debug(" -> ok") + logSys.debug(" Check integrity ...") + cur.execute("PRAGMA integrity_check") + for s in cur.fetchall(): + logSys.debug(" -> %s", ' '.join(s)) + self._db.commit() if cur: # pypy: set journal mode after possible upgrade db: if pypy: @@ -261,6 +266,8 @@ class Fail2BanDb(object): return self.__dbBackupFilename def repairDB(self): + class RepairException(Exception): + pass # avoid endless recursion if reconnect failed again for some reasons: _repairDB = self.repairDB self.repairDB = None @@ -280,13 +287,14 @@ class Fail2BanDb(object): self._connectDB(checkIntegrity=True) else: logSys.info(" Repair seems to be failed, restored %d byte(s).", dbFileSize) - raise Exception('Recreate ...') + raise RepairException('Recreate ...') except Exception as e: # if still failed, just recreate database as fallback: logSys.error(" Error repairing of fail2ban database '%s': %s", - self._dbFilename, e.args[0]) + self._dbFilename, e.args[0], + exc_info=(not isinstance(e, RepairException) and logSys.getEffectiveLevel() <= 10)) os.remove(self._dbFilename) - self._connectDB() + self._connectDB(checkIntegrity=True) finally: self.repairDB = _repairDB @@ -306,24 +314,23 @@ class Fail2BanDb(object): def purgeage(self, value): self._purgeAge = MyTime.str2seconds(value) - @commitandrollback - def createDb(self, cur): + def _createDb(self, cur, incremental=False): """Creates a new database, called during initialisation. """ - # Version info - cur.executescript(Fail2BanDb._TABLE_fail2banDb) - cur.execute("INSERT INTO fail2banDb(version) VALUES(?)", + # create all (if not exists): + for (n, s) in Fail2BanDb._CREATE_SCRIPTS: + cur.executescript(s) + # save current database version (if not already set): + cur.execute("INSERT INTO fail2banDb(version)" + " SELECT ? WHERE NOT EXISTS (SELECT 1 FROM fail2banDb LIMIT 1)", (Fail2BanDb.__version__, )) - # Jails - cur.executescript(Fail2BanDb._TABLE_jails) - # Logs - cur.executescript(Fail2BanDb._TABLE_logs) - # Bans - cur.executescript(Fail2BanDb._TABLE_bans) - cur.execute("SELECT version FROM fail2banDb LIMIT 1") return cur.fetchone()[0] + @commitandrollback + def createDb(self, cur, incremental=False): + return self._createDb(cur, incremental); + @commitandrollback def updateDb(self, cur, version): """Update an existing database, called during initialisation. @@ -347,7 +354,7 @@ class Fail2BanDb(object): "INSERT INTO logs SELECT * from logs_temp;" "DROP TABLE logs_temp;" "UPDATE fail2banDb SET version = 2;" - "COMMIT;" % Fail2BanDb._TABLE_logs) + "COMMIT;" % Fail2BanDb._CREATE_TABS['logs']) cur.execute("SELECT version FROM fail2banDb LIMIT 1") return cur.fetchone()[0]