mirror of https://github.com/fail2ban/fail2ban
Automatically recover or recreate corrupt persistent database (e. g. if failed to open with 'database disk image is malformed').
Closes #1465pull/2004/head
parent
61109d5c4f
commit
9374de59f3
|
@ -22,6 +22,7 @@ __copyright__ = "Copyright (c) 2013 Steven Hiscocks"
|
||||||
__license__ = "GPL"
|
__license__ = "GPL"
|
||||||
|
|
||||||
import json
|
import json
|
||||||
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
import sqlite3
|
import sqlite3
|
||||||
import sys
|
import sys
|
||||||
|
@ -31,6 +32,7 @@ from threading import RLock
|
||||||
|
|
||||||
from .mytime import MyTime
|
from .mytime import MyTime
|
||||||
from .ticket import FailTicket
|
from .ticket import FailTicket
|
||||||
|
from .utils import Utils
|
||||||
from ..helpers import getLogger, PREFER_ENC
|
from ..helpers import getLogger, PREFER_ENC
|
||||||
|
|
||||||
# Gets the instance of the logger.
|
# Gets the instance of the logger.
|
||||||
|
@ -163,13 +165,17 @@ class Fail2BanDb(object):
|
||||||
|
|
||||||
def __init__(self, filename, purgeAge=24*60*60):
|
def __init__(self, filename, purgeAge=24*60*60):
|
||||||
self.maxEntries = 50
|
self.maxEntries = 50
|
||||||
|
self._lock = RLock()
|
||||||
|
self._dbFilename = filename
|
||||||
|
self._purgeAge = purgeAge
|
||||||
|
self._connectDB()
|
||||||
|
|
||||||
|
def _connectDB(self, checkIntegrity=False):
|
||||||
|
filename = self._dbFilename
|
||||||
try:
|
try:
|
||||||
self._lock = RLock()
|
|
||||||
self._db = sqlite3.connect(
|
self._db = sqlite3.connect(
|
||||||
filename, check_same_thread=False,
|
filename, check_same_thread=False,
|
||||||
detect_types=sqlite3.PARSE_DECLTYPES)
|
detect_types=sqlite3.PARSE_DECLTYPES)
|
||||||
self._dbFilename = filename
|
|
||||||
self._purgeAge = purgeAge
|
|
||||||
|
|
||||||
self._bansMergedCache = {}
|
self._bansMergedCache = {}
|
||||||
|
|
||||||
|
@ -190,20 +196,38 @@ class Fail2BanDb(object):
|
||||||
pypy = False
|
pypy = False
|
||||||
|
|
||||||
cur = self._db.cursor()
|
cur = self._db.cursor()
|
||||||
cur.execute("PRAGMA foreign_keys = ON")
|
|
||||||
# speedup: write data through OS without syncing (no wait):
|
|
||||||
cur.execute("PRAGMA synchronous = OFF")
|
|
||||||
# speedup: transaction log in memory, alternate using OFF (disable, rollback will be impossible):
|
|
||||||
if not pypy:
|
|
||||||
cur.execute("PRAGMA journal_mode = MEMORY")
|
|
||||||
# speedup: temporary tables and indices are kept in memory:
|
|
||||||
cur.execute("PRAGMA temp_store = MEMORY")
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
cur.execute("PRAGMA foreign_keys = ON")
|
||||||
|
# speedup: write data through OS without syncing (no wait):
|
||||||
|
cur.execute("PRAGMA synchronous = OFF")
|
||||||
|
# speedup: transaction log in memory, alternate using OFF (disable, rollback will be impossible):
|
||||||
|
if not pypy:
|
||||||
|
cur.execute("PRAGMA journal_mode = MEMORY")
|
||||||
|
# speedup: temporary tables and indices are kept in memory:
|
||||||
|
cur.execute("PRAGMA temp_store = MEMORY")
|
||||||
|
|
||||||
|
if checkIntegrity:
|
||||||
|
logSys.debug(" Check integrity ...")
|
||||||
|
cur.execute("PRAGMA integrity_check")
|
||||||
|
for s in cur.fetchall():
|
||||||
|
logSys.debug(" %s", s)
|
||||||
|
self._db.commit()
|
||||||
|
|
||||||
cur.execute("SELECT version FROM fail2banDb LIMIT 1")
|
cur.execute("SELECT version FROM fail2banDb LIMIT 1")
|
||||||
except sqlite3.OperationalError:
|
except sqlite3.OperationalError:
|
||||||
logSys.warning("New database created. Version '%i'",
|
logSys.warning("New database created. Version '%i'",
|
||||||
self.createDb())
|
self.createDb())
|
||||||
|
except sqlite3.Error as e:
|
||||||
|
logSys.error(
|
||||||
|
"Error opening fail2ban persistent database '%s': %s",
|
||||||
|
filename, e.args[0])
|
||||||
|
# if not a file - raise an error:
|
||||||
|
if not os.path.isfile(filename):
|
||||||
|
raise
|
||||||
|
# try to repair it:
|
||||||
|
cur.close()
|
||||||
|
cur = None
|
||||||
|
self.repairDB()
|
||||||
else:
|
else:
|
||||||
version = cur.fetchone()[0]
|
version = cur.fetchone()[0]
|
||||||
if version < Fail2BanDb.__version__:
|
if version < Fail2BanDb.__version__:
|
||||||
|
@ -217,16 +241,55 @@ class Fail2BanDb(object):
|
||||||
Fail2BanDb.__version__, version, newversion)
|
Fail2BanDb.__version__, version, newversion)
|
||||||
raise RuntimeError('Failed to fully update')
|
raise RuntimeError('Failed to fully update')
|
||||||
finally:
|
finally:
|
||||||
# pypy: set journal mode after possible upgrade db:
|
if cur:
|
||||||
if pypy:
|
# pypy: set journal mode after possible upgrade db:
|
||||||
cur.execute("PRAGMA journal_mode = MEMORY")
|
if pypy:
|
||||||
cur.close()
|
cur.execute("PRAGMA journal_mode = MEMORY")
|
||||||
|
cur.close()
|
||||||
|
|
||||||
def close(self):
|
def close(self):
|
||||||
logSys.debug("Close connection to database ...")
|
logSys.debug("Close connection to database ...")
|
||||||
self._db.close()
|
self._db.close()
|
||||||
logSys.info("Connection to database closed.")
|
logSys.info("Connection to database closed.")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _dbBackupFilename(self):
|
||||||
|
try:
|
||||||
|
return self.__dbBackupFilename
|
||||||
|
except AttributeError:
|
||||||
|
self.__dbBackupFilename = self._dbFilename + '.' + time.strftime('%Y%m%d-%H%M%S', MyTime.gmtime())
|
||||||
|
return self.__dbBackupFilename
|
||||||
|
|
||||||
|
def repairDB(self):
|
||||||
|
# avoid endless recursion if reconnect failed again for some reasons:
|
||||||
|
_repairDB = self.repairDB
|
||||||
|
self.repairDB = None
|
||||||
|
try:
|
||||||
|
# backup
|
||||||
|
logSys.info("Trying to repair database %s", self._dbFilename)
|
||||||
|
shutil.move(self._dbFilename, self._dbBackupFilename)
|
||||||
|
logSys.info(" Database backup created: %s", self._dbBackupFilename)
|
||||||
|
|
||||||
|
# first try to repair using dump/restore in order
|
||||||
|
Utils.executeCmd((r"""f2b_db=$0; f2b_dbbk=$1; sqlite3 "$f2b_dbbk" ".dump" | sqlite3 "$f2b_db" """,
|
||||||
|
self._dbFilename, self._dbBackupFilename))
|
||||||
|
dbFileSize = os.stat(self._dbFilename).st_size
|
||||||
|
if dbFileSize:
|
||||||
|
logSys.info(" Repair seems to be successful, restored %d byte(s).", dbFileSize)
|
||||||
|
# succeeded - try to reconnect:
|
||||||
|
self._connectDB(checkIntegrity=True)
|
||||||
|
else:
|
||||||
|
logSys.info(" Repair seems to be failed, restored %d byte(s).", dbFileSize)
|
||||||
|
raise Exception('Recreate ...')
|
||||||
|
except Exception as e:
|
||||||
|
# if still failed, just recreate database as fallback:
|
||||||
|
logSys.error(" Error repairing of fail2ban database '%s': %s",
|
||||||
|
self._dbFilename, e.args[0])
|
||||||
|
os.remove(self._dbFilename)
|
||||||
|
self._connectDB()
|
||||||
|
finally:
|
||||||
|
self.repairDB = _repairDB
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def filename(self):
|
def filename(self):
|
||||||
"""File name of SQLite3 database file.
|
"""File name of SQLite3 database file.
|
||||||
|
@ -271,9 +334,10 @@ class Fail2BanDb(object):
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"Attempt to travel to future version of database ...how did you get here??")
|
"Attempt to travel to future version of database ...how did you get here??")
|
||||||
|
|
||||||
self._dbBackupFilename = self.filename + '.' + time.strftime('%Y%m%d-%H%M%S', MyTime.gmtime())
|
logSys.info("Uprade database: %s", self._dbBackupFilename)
|
||||||
shutil.copyfile(self.filename, self._dbBackupFilename)
|
if not os.path.isfile(self._dbBackupFilename):
|
||||||
logSys.info("Database backup created: %s", self._dbBackupFilename)
|
shutil.copyfile(self.filename, self._dbBackupFilename)
|
||||||
|
logSys.info(" Database backup created: %s", self._dbBackupFilename)
|
||||||
|
|
||||||
if version < 2:
|
if version < 2:
|
||||||
cur.executescript("BEGIN TRANSACTION;"
|
cur.executescript("BEGIN TRANSACTION;"
|
||||||
|
|
|
@ -62,7 +62,18 @@ class DatabaseTest(LogCaptureTestCase):
|
||||||
self.dbFilename = None
|
self.dbFilename = None
|
||||||
if not unittest.F2B.memory_db:
|
if not unittest.F2B.memory_db:
|
||||||
_, self.dbFilename = tempfile.mkstemp(".db", "fail2ban_")
|
_, self.dbFilename = tempfile.mkstemp(".db", "fail2ban_")
|
||||||
self.db = getFail2BanDb(self.dbFilename)
|
self._db = ':auto-create-in-memory:'
|
||||||
|
|
||||||
|
@property
|
||||||
|
def db(self):
|
||||||
|
if isinstance(self._db, basestring) and self._db == ':auto-create-in-memory:':
|
||||||
|
self._db = getFail2BanDb(self.dbFilename)
|
||||||
|
return self._db
|
||||||
|
@db.setter
|
||||||
|
def db(self, value):
|
||||||
|
if isinstance(self._db, Fail2BanDb): # pragma: no cover
|
||||||
|
self._db.close()
|
||||||
|
self._db = value
|
||||||
|
|
||||||
def tearDown(self):
|
def tearDown(self):
|
||||||
"""Call after every test case."""
|
"""Call after every test case."""
|
||||||
|
@ -106,23 +117,61 @@ class DatabaseTest(LogCaptureTestCase):
|
||||||
self.jail.name in self.db.getJailNames(),
|
self.jail.name in self.db.getJailNames(),
|
||||||
"Jail not retained in Db after disconnect reconnect.")
|
"Jail not retained in Db after disconnect reconnect.")
|
||||||
|
|
||||||
def testUpdateDb(self):
|
def testRepairDb(self):
|
||||||
if Fail2BanDb is None: # pragma: no cover
|
if Fail2BanDb is None: # pragma: no cover
|
||||||
return
|
return
|
||||||
self.db = None
|
self.db = None
|
||||||
if self.dbFilename is None: # pragma: no cover
|
if self.dbFilename is None: # pragma: no cover
|
||||||
_, self.dbFilename = tempfile.mkstemp(".db", "fail2ban_")
|
_, self.dbFilename = tempfile.mkstemp(".db", "fail2ban_")
|
||||||
shutil.copyfile(
|
# test truncated database with different sizes:
|
||||||
os.path.join(TEST_FILES_DIR, 'database_v1.db'), self.dbFilename)
|
# - 14000 bytes - seems to be reparable,
|
||||||
self.db = Fail2BanDb(self.dbFilename)
|
# - 4000 bytes - is totally broken.
|
||||||
self.assertEqual(self.db.getJailNames(), set(['DummyJail #29162448 with 0 tickets']))
|
for truncSize in (14000, 4000):
|
||||||
self.assertEqual(self.db.getLogPaths(), set(['/tmp/Fail2BanDb_pUlZJh.log']))
|
self.pruneLog("[test-repair], next phase - file-size: %d" % truncSize)
|
||||||
ticket = FailTicket("127.0.0.1", 1388009242.26, [u"abc\n"])
|
shutil.copyfile(
|
||||||
self.assertEqual(self.db.getBans()[0], ticket)
|
os.path.join(TEST_FILES_DIR, 'database_v1.db'), self.dbFilename)
|
||||||
|
# produce currupt database:
|
||||||
|
f = os.open(self.dbFilename, os.O_RDWR)
|
||||||
|
os.ftruncate(f, truncSize)
|
||||||
|
os.close(f)
|
||||||
|
# test repair:
|
||||||
|
try:
|
||||||
|
self.db = Fail2BanDb(self.dbFilename)
|
||||||
|
if truncSize == 14000: # restored:
|
||||||
|
self.assertLogged("Repair seems to be successful",
|
||||||
|
"Check integrity", "Database updated", all=True)
|
||||||
|
self.assertEqual(self.db.getLogPaths(), set(['/tmp/Fail2BanDb_pUlZJh.log']))
|
||||||
|
self.assertEqual(len(self.db.getJailNames()), 1)
|
||||||
|
else: # recreated:
|
||||||
|
self.assertLogged("Repair seems to be failed",
|
||||||
|
"New database created.", all=True)
|
||||||
|
self.assertEqual(len(self.db.getLogPaths()), 0)
|
||||||
|
self.assertEqual(len(self.db.getJailNames()), 0)
|
||||||
|
finally:
|
||||||
|
if self.db and self.db._dbFilename != ":memory:":
|
||||||
|
os.remove(self.db._dbBackupFilename)
|
||||||
|
self.db = None
|
||||||
|
|
||||||
self.assertEqual(self.db.updateDb(Fail2BanDb.__version__), Fail2BanDb.__version__)
|
def testUpdateDb(self):
|
||||||
self.assertRaises(NotImplementedError, self.db.updateDb, Fail2BanDb.__version__ + 1)
|
if Fail2BanDb is None: # pragma: no cover
|
||||||
os.remove(self.db._dbBackupFilename)
|
return
|
||||||
|
self.db = None
|
||||||
|
try:
|
||||||
|
if self.dbFilename is None: # pragma: no cover
|
||||||
|
_, self.dbFilename = tempfile.mkstemp(".db", "fail2ban_")
|
||||||
|
shutil.copyfile(
|
||||||
|
os.path.join(TEST_FILES_DIR, 'database_v1.db'), self.dbFilename)
|
||||||
|
self.db = Fail2BanDb(self.dbFilename)
|
||||||
|
self.assertEqual(self.db.getJailNames(), set(['DummyJail #29162448 with 0 tickets']))
|
||||||
|
self.assertEqual(self.db.getLogPaths(), set(['/tmp/Fail2BanDb_pUlZJh.log']))
|
||||||
|
ticket = FailTicket("127.0.0.1", 1388009242.26, [u"abc\n"])
|
||||||
|
self.assertEqual(self.db.getBans()[0], ticket)
|
||||||
|
|
||||||
|
self.assertEqual(self.db.updateDb(Fail2BanDb.__version__), Fail2BanDb.__version__)
|
||||||
|
self.assertRaises(NotImplementedError, self.db.updateDb, Fail2BanDb.__version__ + 1)
|
||||||
|
finally:
|
||||||
|
if self.db and self.db._dbFilename != ":memory:":
|
||||||
|
os.remove(self.db._dbBackupFilename)
|
||||||
|
|
||||||
def testAddJail(self):
|
def testAddJail(self):
|
||||||
if Fail2BanDb is None: # pragma: no cover
|
if Fail2BanDb is None: # pragma: no cover
|
||||||
|
|
Loading…
Reference in New Issue