mirror of https://github.com/fail2ban/fail2ban
ENH: Make use of functools.wraps for server.database decorators
parent
0bcff771b8
commit
d6cbc05e35
|
@ -26,6 +26,7 @@ import sys
|
|||
import sqlite3
|
||||
import json
|
||||
import locale
|
||||
from functools import wraps
|
||||
|
||||
from fail2ban.server.mytime import MyTime
|
||||
from fail2ban.server.ticket import FailTicket
|
||||
|
@ -46,13 +47,12 @@ else:
|
|||
sqlite3.register_adapter(dict, json.dumps)
|
||||
sqlite3.register_converter("JSON", json.loads)
|
||||
|
||||
def commitandrollback():
|
||||
def wrap(f):
|
||||
def func(self, *args, **kw):
|
||||
with self._db: # Auto commit and rollback on exception
|
||||
return f(self, self._db.cursor(), *args, **kw)
|
||||
return func
|
||||
return wrap
|
||||
def commitandrollback(f):
|
||||
@wraps(f)
|
||||
def wrapper(self, *args, **kwargs):
|
||||
with self._db: # Auto commit and rollback on exception
|
||||
return f(self, self._db.cursor(), *args, **kwargs)
|
||||
return wrapper
|
||||
|
||||
class Fail2BanDb(object):
|
||||
__version__ = 1
|
||||
|
@ -97,7 +97,7 @@ class Fail2BanDb(object):
|
|||
def setPurgeAge(self, value):
|
||||
self._purgeAge = int(value)
|
||||
|
||||
@commitandrollback()
|
||||
@commitandrollback
|
||||
def createDb(self, cur):
|
||||
# Version info
|
||||
cur.execute("CREATE TABLE fail2banDb(version INTEGER)")
|
||||
|
@ -143,14 +143,14 @@ class Fail2BanDb(object):
|
|||
cur.execute("SELECT version FROM fail2banDb LIMIT 1")
|
||||
return cur.fetchone()[0]
|
||||
|
||||
@commitandrollback()
|
||||
@commitandrollback
|
||||
def updateDb(self, cur, version):
|
||||
raise NotImplementedError(
|
||||
"Only single version of database exists...how did you get here??")
|
||||
cur.execute("SELECT version FROM fail2banDb LIMIT 1")
|
||||
return cur.fetchone()[0]
|
||||
|
||||
@commitandrollback()
|
||||
@commitandrollback
|
||||
def addJail(self, cur, jail):
|
||||
cur.execute(
|
||||
"INSERT OR REPLACE INTO jails(name, enabled) VALUES(?, 1)",
|
||||
|
@ -159,23 +159,23 @@ class Fail2BanDb(object):
|
|||
def delJail(self, jail):
|
||||
return self.delJailName(jail.getName())
|
||||
|
||||
@commitandrollback()
|
||||
@commitandrollback
|
||||
def delJailName(self, cur, name):
|
||||
# Will be deleted by purge as appropriate
|
||||
cur.execute(
|
||||
"UPDATE jails SET enabled=0 WHERE name=?", (name, ))
|
||||
|
||||
@commitandrollback()
|
||||
@commitandrollback
|
||||
def delAllJails(self, cur):
|
||||
# Will be deleted by purge as appropriate
|
||||
cur.execute("UPDATE jails SET enabled=0")
|
||||
|
||||
@commitandrollback()
|
||||
@commitandrollback
|
||||
def getJailNames(self, cur):
|
||||
cur.execute("SELECT name FROM jails")
|
||||
return set(row[0] for row in cur.fetchmany())
|
||||
|
||||
@commitandrollback()
|
||||
@commitandrollback
|
||||
def addLog(self, cur, jail, container):
|
||||
lastLinePos = None
|
||||
cur.execute(
|
||||
|
@ -196,7 +196,7 @@ class Fail2BanDb(object):
|
|||
lastLinePos = None
|
||||
return lastLinePos
|
||||
|
||||
@commitandrollback()
|
||||
@commitandrollback
|
||||
def getLogPaths(self, cur, jail=None):
|
||||
query = "SELECT path FROM logs"
|
||||
queryArgs = []
|
||||
|
@ -206,7 +206,7 @@ class Fail2BanDb(object):
|
|||
cur.execute(query, queryArgs)
|
||||
return set(row[0] for row in cur.fetchmany())
|
||||
|
||||
@commitandrollback()
|
||||
@commitandrollback
|
||||
def updateLog(self, cur, *args, **kwargs):
|
||||
self._updateLog(cur, *args, **kwargs)
|
||||
|
||||
|
@ -217,7 +217,7 @@ class Fail2BanDb(object):
|
|||
(container.getHash(), container.getPos(),
|
||||
jail.getName(), container.getFileName()))
|
||||
|
||||
@commitandrollback()
|
||||
@commitandrollback
|
||||
def addBan(self, cur, jail, ticket):
|
||||
#TODO: Implement data parts once arbitrary match keys completed
|
||||
cur.execute(
|
||||
|
@ -226,7 +226,7 @@ class Fail2BanDb(object):
|
|||
{"matches": ticket.getMatches(),
|
||||
"failures": ticket.getAttempt()}))
|
||||
|
||||
@commitandrollback()
|
||||
@commitandrollback
|
||||
def _getBans(self, cur, jail=None, bantime=None, ip=None):
|
||||
query = "SELECT ip, timeofban, data FROM bans WHERE 1"
|
||||
queryArgs = []
|
||||
|
@ -263,7 +263,7 @@ class Fail2BanDb(object):
|
|||
ticket.setAttempt(failures)
|
||||
return ticket
|
||||
|
||||
@commitandrollback()
|
||||
@commitandrollback
|
||||
def purge(self, cur):
|
||||
cur.execute(
|
||||
"DELETE FROM bans WHERE timeofban < ?",
|
||||
|
|
Loading…
Reference in New Issue