ENH+DOC: Update Fail2Ban database doc strings and use properties

pull/628/head
Steven Hiscocks 2014-02-23 18:33:58 +00:00
parent df8d700d17
commit edd0bf7d46
4 changed files with 169 additions and 21 deletions

View File

@ -58,6 +58,11 @@ def commitandrollback(f):
return wrapper return wrapper
class Fail2BanDb(object): class Fail2BanDb(object):
"""Fail2Ban database for storing persistent data.
This allows after Fail2Ban is restarted to reinstated bans and
to continue monitoring logs from the same point.
"""
__version__ = 2 __version__ = 2
# Note all _TABLE_* strings must end in ';' for py26 compatibility # Note all _TABLE_* strings must end in ';' for py26 compatibility
_TABLE_fail2banDb = "CREATE TABLE fail2banDb(version INTEGER);" _TABLE_fail2banDb = "CREATE TABLE fail2banDb(version INTEGER);"
@ -93,6 +98,27 @@ class Fail2BanDb(object):
"CREATE INDEX bans_ip ON bans(ip);" \ "CREATE INDEX bans_ip ON bans(ip);" \
def __init__(self, filename, purgeAge=24*60*60): def __init__(self, filename, purgeAge=24*60*60):
"""Initialise the database by connecting/creating SQLite3 file.
This will either create a new Fail2Ban database, connect to an
existing, and if applicable upgrade the schema in the process.
Parameters
----------
filename : str
File name for SQLite3 database, which will be created if
doesn't already exist.
purgeAge : int
Purge age in seconds, used to remove old bans from
database during purge.
Raises
------
sqlite3.OperationalError
Error connecting/creating a SQLite3 database.
RuntimeError
If exisiting database fails to update to new schema.
"""
try: try:
self._lock = Lock() self._lock = Lock()
self._db = sqlite3.connect( self._db = sqlite3.connect(
@ -130,21 +156,30 @@ class Fail2BanDb(object):
logSys.error( "Database update failed to acheive version '%i'" logSys.error( "Database update failed to acheive version '%i'"
": updated from '%i' to '%i'", ": updated from '%i' to '%i'",
Fail2BanDb.__version__, version, newversion) Fail2BanDb.__version__, version, newversion)
raise Exception('Failed to fully update') raise RuntimeError('Failed to fully update')
finally: finally:
cur.close() cur.close()
def getFilename(self): @property
def filename(self):
"""File name of SQLite3 database file.
"""
return self._dbFilename return self._dbFilename
def getPurgeAge(self): @property
def purgeage(self):
"""Purge age in seconds.
"""
return self._purgeAge return self._purgeAge
def setPurgeAge(self, value): @purgeage.setter
def purgeage(self, value):
self._purgeAge = int(value) self._purgeAge = int(value)
@commitandrollback @commitandrollback
def createDb(self, cur): def createDb(self, cur):
"""Creates a new database, called during initialisation.
"""
# Version info # Version info
cur.executescript(Fail2BanDb._TABLE_fail2banDb) cur.executescript(Fail2BanDb._TABLE_fail2banDb)
cur.execute("INSERT INTO fail2banDb(version) VALUES(?)", cur.execute("INSERT INTO fail2banDb(version) VALUES(?)",
@ -161,8 +196,13 @@ class Fail2BanDb(object):
@commitandrollback @commitandrollback
def updateDb(self, cur, version): def updateDb(self, cur, version):
self.dbBackupFilename = self._dbFilename + '.' + time.strftime('%Y%m%d-%H%M%S', MyTime.gmtime()) """Update an existing database, called during initialisation.
shutil.copyfile(self._dbFilename, self.dbBackupFilename)
A timestamped backup is also created prior to attempting the update.
"""
self._dbBackupFilename = self.filename + '.' + time.strftime('%Y%m%d-%H%M%S', MyTime.gmtime())
shutil.copyfile(self.filename, self._dbBackupFilename)
logSys.info("Database backup created: %s", self._dbBackupFilename)
if version > Fail2BanDb.__version__: if version > Fail2BanDb.__version__:
raise NotImplementedError( raise NotImplementedError(
"Attempt to travel to future version of database ...how did you get here??") "Attempt to travel to future version of database ...how did you get here??")
@ -182,31 +222,68 @@ class Fail2BanDb(object):
@commitandrollback @commitandrollback
def addJail(self, cur, jail): def addJail(self, cur, jail):
"""Adds a jail to the database.
Parameters
----------
jail : Jail
Jail to be added to the database.
"""
cur.execute( cur.execute(
"INSERT OR REPLACE INTO jails(name, enabled) VALUES(?, 1)", "INSERT OR REPLACE INTO jails(name, enabled) VALUES(?, 1)",
(jail.name,)) (jail.name,))
def delJail(self, jail):
return self.delJailName(jail.name)
@commitandrollback @commitandrollback
def delJailName(self, cur, name): def delJail(self, cur, jail):
"""Deletes a jail from the database.
Parameters
----------
jail : Jail
Jail to be removed from the database.
"""
# Will be deleted by purge as appropriate # Will be deleted by purge as appropriate
cur.execute( cur.execute(
"UPDATE jails SET enabled=0 WHERE name=?", (name, )) "UPDATE jails SET enabled=0 WHERE name=?", (jail.name, ))
@commitandrollback @commitandrollback
def delAllJails(self, cur): def delAllJails(self, cur):
"""Deletes all jails from the database.
"""
# Will be deleted by purge as appropriate # Will be deleted by purge as appropriate
cur.execute("UPDATE jails SET enabled=0") cur.execute("UPDATE jails SET enabled=0")
@commitandrollback @commitandrollback
def getJailNames(self, cur): def getJailNames(self, cur):
"""Get name of jails in database.
Currently only used for testing purposes.
Returns
-------
set
Set of jail names.
"""
cur.execute("SELECT name FROM jails") cur.execute("SELECT name FROM jails")
return set(row[0] for row in cur.fetchmany()) return set(row[0] for row in cur.fetchmany())
@commitandrollback @commitandrollback
def addLog(self, cur, jail, container): def addLog(self, cur, jail, container):
"""Adds a log to the database.
Parameters
----------
jail : Jail
Jail that log is being monitored by.
container : FileContainer
File container of the log file being added.
Returns
-------
int
If log was already present in database, value of last position
in the log file; else `None`
"""
lastLinePos = None lastLinePos = None
cur.execute( cur.execute(
"SELECT firstlinemd5, lastfilepos FROM logs " "SELECT firstlinemd5, lastfilepos FROM logs "
@ -228,6 +305,20 @@ class Fail2BanDb(object):
@commitandrollback @commitandrollback
def getLogPaths(self, cur, jail=None): def getLogPaths(self, cur, jail=None):
"""Gets all the log paths from the database.
Currently only for testing purposes.
Parameters
----------
jail : Jail
If specified, will only reutrn logs belonging to the jail.
Returns
-------
set
Set of log paths.
"""
query = "SELECT path FROM logs" query = "SELECT path FROM logs"
queryArgs = [] queryArgs = []
if jail is not None: if jail is not None:
@ -238,6 +329,15 @@ class Fail2BanDb(object):
@commitandrollback @commitandrollback
def updateLog(self, cur, *args, **kwargs): def updateLog(self, cur, *args, **kwargs):
"""Updates hash and last position in log file.
Parameters
----------
jail : Jail
Jail of which the log file belongs to.
container : FileContainer
File container of the log file being updated.
"""
self._updateLog(cur, *args, **kwargs) self._updateLog(cur, *args, **kwargs)
def _updateLog(self, cur, jail, container): def _updateLog(self, cur, jail, container):
@ -249,6 +349,15 @@ class Fail2BanDb(object):
@commitandrollback @commitandrollback
def addBan(self, cur, jail, ticket): def addBan(self, cur, jail, ticket):
"""Add a ban to the database.
Parameters
----------
jail : Jail
Jail in which the ban has occured.
ticket : BanTicket
Ticket of the ban to be added.
"""
self._bansMergedCache = {} self._bansMergedCache = {}
#TODO: Implement data parts once arbitrary match keys completed #TODO: Implement data parts once arbitrary match keys completed
cur.execute( cur.execute(
@ -276,6 +385,23 @@ class Fail2BanDb(object):
return cur.execute(query, queryArgs) return cur.execute(query, queryArgs)
def getBans(self, **kwargs): def getBans(self, **kwargs):
"""Get bans from the database.
Parameters
----------
jail : Jail
Jail that the ban belongs to. Default `None`; all jails.
bantime : int
Ban time in seconds, such that bans returned would still be
valid now. Default `None`; no limit.
ip : str
IP Address to filter bans by. Default `None`; all IPs.
Returns
-------
list
List of `Ticket`s for bans stored in database.
"""
tickets = [] tickets = []
for ip, timeofban, data in self._getBans(**kwargs): for ip, timeofban, data in self._getBans(**kwargs):
#TODO: Implement data parts once arbitrary match keys completed #TODO: Implement data parts once arbitrary match keys completed
@ -284,6 +410,26 @@ class Fail2BanDb(object):
return tickets return tickets
def getBansMerged(self, ip, jail=None, **kwargs): def getBansMerged(self, ip, jail=None, **kwargs):
"""Get bans from the database, merged into single ticket.
This is the same as `getBans`, but bans merged into single
ticket.
Parameters
----------
jail : Jail
Jail that the ban belongs to. Default `None`; all jails.
bantime : int
Ban time in seconds, such that bans returned would still be
valid now. Default `None`; no limit.
ip : str
IP Address to filter bans by. Default `None`; all IPs.
Returns
-------
Ticket
Single ticket representing bans stored in database.
"""
cacheKey = ip if jail is None else "%s|%s" % (ip, jail.name) cacheKey = ip if jail is None else "%s|%s" % (ip, jail.name)
if cacheKey in self._bansMergedCache: if cacheKey in self._bansMergedCache:
return self._bansMergedCache[cacheKey] return self._bansMergedCache[cacheKey]
@ -300,6 +446,8 @@ class Fail2BanDb(object):
@commitandrollback @commitandrollback
def purge(self, cur): def purge(self, cur):
"""Purge old bans, jails and log files from database.
"""
self._bansMergedCache = {} self._bansMergedCache = {}
cur.execute( cur.execute(
"DELETE FROM bans WHERE timeofban < ?", "DELETE FROM bans WHERE timeofban < ?",

View File

@ -125,10 +125,10 @@ class Server:
self.__db.addJail(self.__jails[name]) self.__db.addJail(self.__jails[name])
def delJail(self, name): def delJail(self, name):
del self.__jails[name]
if self.__db is not None: if self.__db is not None:
self.__db.delJailName(name) self.__db.delJail(self.__jails[name])
del self.__jails[name]
def startJail(self, name): def startJail(self, name):
try: try:
self.__lock.acquire() self.__lock.acquire()

View File

@ -123,14 +123,14 @@ class Transmitter:
if db is None: if db is None:
return None return None
else: else:
return db.getFilename() return db.filename
elif name == "dbpurgeage": elif name == "dbpurgeage":
db = self.__server.getDatabase() db = self.__server.getDatabase()
if db is None: if db is None:
return None return None
else: else:
db.setPurgeAge(command[1]) db.purgeage = command[1]
return db.getPurgeAge() return db.purgeage
# Jail # Jail
elif command[1] == "idle": elif command[1] == "idle":
if command[2] == "on": if command[2] == "on":
@ -265,13 +265,13 @@ class Transmitter:
if db is None: if db is None:
return None return None
else: else:
return db.getFilename() return db.filename
elif name == "dbpurgeage": elif name == "dbpurgeage":
db = self.__server.getDatabase() db = self.__server.getDatabase()
if db is None: if db is None:
return None return None
else: else:
return db.getPurgeAge() return db.purgeage
# Filter # Filter
elif command[1] == "logpath": elif command[1] == "logpath":
return self.__server.getLogPath(name) return self.__server.getLogPath(name)

View File

@ -47,7 +47,7 @@ class DatabaseTest(unittest.TestCase):
os.remove(self.dbFilename) os.remove(self.dbFilename)
def testGetFilename(self): def testGetFilename(self):
self.assertEqual(self.dbFilename, self.db.getFilename()) self.assertEqual(self.dbFilename, self.db.filename)
def testCreateInvalidPath(self): def testCreateInvalidPath(self):
self.assertRaises( self.assertRaises(
@ -74,7 +74,7 @@ class DatabaseTest(unittest.TestCase):
self.assertEqual(self.db.updateDb(Fail2BanDb.__version__), Fail2BanDb.__version__) self.assertEqual(self.db.updateDb(Fail2BanDb.__version__), Fail2BanDb.__version__)
self.assertRaises(NotImplementedError, self.db.updateDb, Fail2BanDb.__version__ + 1) self.assertRaises(NotImplementedError, self.db.updateDb, Fail2BanDb.__version__ + 1)
os.remove(self.db.dbBackupFilename) os.remove(self.db._dbBackupFilename)
def testAddJail(self): def testAddJail(self):
self.jail = DummyJail() self.jail = DummyJail()