ENH: Make use of functools.wraps for server.database decorators

pull/491/head
Steven Hiscocks 2013-12-15 21:10:11 +00:00
parent 0bcff771b8
commit d6cbc05e35
1 changed files with 19 additions and 19 deletions

View File

@ -26,6 +26,7 @@ import sys
import sqlite3 import sqlite3
import json import json
import locale import locale
from functools import wraps
from fail2ban.server.mytime import MyTime from fail2ban.server.mytime import MyTime
from fail2ban.server.ticket import FailTicket from fail2ban.server.ticket import FailTicket
@ -46,13 +47,12 @@ else:
sqlite3.register_adapter(dict, json.dumps) sqlite3.register_adapter(dict, json.dumps)
sqlite3.register_converter("JSON", json.loads) sqlite3.register_converter("JSON", json.loads)
def commitandrollback(): def commitandrollback(f):
def wrap(f): @wraps(f)
def func(self, *args, **kw): def wrapper(self, *args, **kwargs):
with self._db: # Auto commit and rollback on exception with self._db: # Auto commit and rollback on exception
return f(self, self._db.cursor(), *args, **kw) return f(self, self._db.cursor(), *args, **kwargs)
return func return wrapper
return wrap
class Fail2BanDb(object): class Fail2BanDb(object):
__version__ = 1 __version__ = 1
@ -97,7 +97,7 @@ class Fail2BanDb(object):
def setPurgeAge(self, value): def setPurgeAge(self, value):
self._purgeAge = int(value) self._purgeAge = int(value)
@commitandrollback() @commitandrollback
def createDb(self, cur): def createDb(self, cur):
# Version info # Version info
cur.execute("CREATE TABLE fail2banDb(version INTEGER)") cur.execute("CREATE TABLE fail2banDb(version INTEGER)")
@ -143,14 +143,14 @@ class Fail2BanDb(object):
cur.execute("SELECT version FROM fail2banDb LIMIT 1") cur.execute("SELECT version FROM fail2banDb LIMIT 1")
return cur.fetchone()[0] return cur.fetchone()[0]
@commitandrollback() @commitandrollback
def updateDb(self, cur, version): def updateDb(self, cur, version):
raise NotImplementedError( raise NotImplementedError(
"Only single version of database exists...how did you get here??") "Only single version of database exists...how did you get here??")
cur.execute("SELECT version FROM fail2banDb LIMIT 1") cur.execute("SELECT version FROM fail2banDb LIMIT 1")
return cur.fetchone()[0] return cur.fetchone()[0]
@commitandrollback() @commitandrollback
def addJail(self, cur, jail): def addJail(self, cur, jail):
cur.execute( cur.execute(
"INSERT OR REPLACE INTO jails(name, enabled) VALUES(?, 1)", "INSERT OR REPLACE INTO jails(name, enabled) VALUES(?, 1)",
@ -159,23 +159,23 @@ class Fail2BanDb(object):
def delJail(self, jail): def delJail(self, jail):
return self.delJailName(jail.getName()) return self.delJailName(jail.getName())
@commitandrollback() @commitandrollback
def delJailName(self, cur, name): def delJailName(self, cur, name):
# Will be deleted by purge as appropriate # Will be deleted by purge as appropriate
cur.execute( cur.execute(
"UPDATE jails SET enabled=0 WHERE name=?", (name, )) "UPDATE jails SET enabled=0 WHERE name=?", (name, ))
@commitandrollback() @commitandrollback
def delAllJails(self, cur): def delAllJails(self, cur):
# Will be deleted by purge as appropriate # Will be deleted by purge as appropriate
cur.execute("UPDATE jails SET enabled=0") cur.execute("UPDATE jails SET enabled=0")
@commitandrollback() @commitandrollback
def getJailNames(self, cur): def getJailNames(self, cur):
cur.execute("SELECT name FROM jails") cur.execute("SELECT name FROM jails")
return set(row[0] for row in cur.fetchmany()) return set(row[0] for row in cur.fetchmany())
@commitandrollback() @commitandrollback
def addLog(self, cur, jail, container): def addLog(self, cur, jail, container):
lastLinePos = None lastLinePos = None
cur.execute( cur.execute(
@ -196,7 +196,7 @@ class Fail2BanDb(object):
lastLinePos = None lastLinePos = None
return lastLinePos return lastLinePos
@commitandrollback() @commitandrollback
def getLogPaths(self, cur, jail=None): def getLogPaths(self, cur, jail=None):
query = "SELECT path FROM logs" query = "SELECT path FROM logs"
queryArgs = [] queryArgs = []
@ -206,7 +206,7 @@ class Fail2BanDb(object):
cur.execute(query, queryArgs) cur.execute(query, queryArgs)
return set(row[0] for row in cur.fetchmany()) return set(row[0] for row in cur.fetchmany())
@commitandrollback() @commitandrollback
def updateLog(self, cur, *args, **kwargs): def updateLog(self, cur, *args, **kwargs):
self._updateLog(cur, *args, **kwargs) self._updateLog(cur, *args, **kwargs)
@ -217,7 +217,7 @@ class Fail2BanDb(object):
(container.getHash(), container.getPos(), (container.getHash(), container.getPos(),
jail.getName(), container.getFileName())) jail.getName(), container.getFileName()))
@commitandrollback() @commitandrollback
def addBan(self, cur, jail, ticket): def addBan(self, cur, jail, ticket):
#TODO: Implement data parts once arbitrary match keys completed #TODO: Implement data parts once arbitrary match keys completed
cur.execute( cur.execute(
@ -226,7 +226,7 @@ class Fail2BanDb(object):
{"matches": ticket.getMatches(), {"matches": ticket.getMatches(),
"failures": ticket.getAttempt()})) "failures": ticket.getAttempt()}))
@commitandrollback() @commitandrollback
def _getBans(self, cur, jail=None, bantime=None, ip=None): def _getBans(self, cur, jail=None, bantime=None, ip=None):
query = "SELECT ip, timeofban, data FROM bans WHERE 1" query = "SELECT ip, timeofban, data FROM bans WHERE 1"
queryArgs = [] queryArgs = []
@ -263,7 +263,7 @@ class Fail2BanDb(object):
ticket.setAttempt(failures) ticket.setAttempt(failures)
return ticket return ticket
@commitandrollback() @commitandrollback
def purge(self, cur): def purge(self, cur):
cur.execute( cur.execute(
"DELETE FROM bans WHERE timeofban < ?", "DELETE FROM bans WHERE timeofban < ?",