ENH: Remove thread locks from Fail2BanDb

pull/480/head
Steven Hiscocks 2013-12-08 22:03:57 +00:00
parent 7f063b46f9
commit 174f9a243a
1 changed files with 15 additions and 18 deletions

View File

@ -23,7 +23,6 @@ __license__ = "GPL"
import logging import logging
import sys import sys
from threading import Lock
import sqlite3 import sqlite3
import json import json
import locale import locale
@ -47,19 +46,17 @@ else:
sqlite3.register_adapter(dict, json.dumps) sqlite3.register_adapter(dict, json.dumps)
sqlite3.register_converter("JSON", json.loads) sqlite3.register_converter("JSON", json.loads)
def lockandcommit(): def commitandrollback():
def wrap(f): def wrap(f):
def func(self, *args, **kw): def func(self, *args, **kw):
with self._lock: # Threading lock with self._db: # Auto commit and rollback on exception
with self._db: # Auto commit and rollback on exception return f(self, self._db.cursor(), *args, **kw)
return f(self, self._db.cursor(), *args, **kw)
return func return func
return wrap return wrap
class Fail2BanDb(object): class Fail2BanDb(object):
__version__ = 1 __version__ = 1
def __init__(self, filename, purgeAge=24*60*60): def __init__(self, filename, purgeAge=24*60*60):
self._lock = Lock()
try: try:
self._db = sqlite3.connect( self._db = sqlite3.connect(
filename, check_same_thread=False, filename, check_same_thread=False,
@ -100,7 +97,7 @@ class Fail2BanDb(object):
def setPurgeAge(self, value): def setPurgeAge(self, value):
self._purgeAge = int(value) self._purgeAge = int(value)
@lockandcommit() @commitandrollback()
def createDb(self, cur): def createDb(self, cur):
# Version info # Version info
cur.execute("CREATE TABLE fail2banDb(version INTEGER)") cur.execute("CREATE TABLE fail2banDb(version INTEGER)")
@ -146,14 +143,14 @@ class Fail2BanDb(object):
cur.execute("SELECT version FROM fail2banDb LIMIT 1") cur.execute("SELECT version FROM fail2banDb LIMIT 1")
return cur.fetchone()[0] return cur.fetchone()[0]
@lockandcommit() @commitandrollback()
def updateDb(self, cur, version): def updateDb(self, cur, version):
raise NotImplementedError( raise NotImplementedError(
"Only single version of database exists...how did you get here??") "Only single version of database exists...how did you get here??")
cur.execute("SELECT version FROM fail2banDb LIMIT 1") cur.execute("SELECT version FROM fail2banDb LIMIT 1")
return cur.fetchone()[0] return cur.fetchone()[0]
@lockandcommit() @commitandrollback()
def addJail(self, cur, jail): def addJail(self, cur, jail):
cur.execute( cur.execute(
"INSERT OR REPLACE INTO jails(name, enabled) VALUES(?, 1)", "INSERT OR REPLACE INTO jails(name, enabled) VALUES(?, 1)",
@ -162,23 +159,23 @@ class Fail2BanDb(object):
def delJail(self, jail): def delJail(self, jail):
return self.delJailName(jail.getName()) return self.delJailName(jail.getName())
@lockandcommit() @commitandrollback()
def delJailName(self, cur, name): def delJailName(self, cur, name):
# Will be deleted by purge as appropriate # Will be deleted by purge as appropriate
cur.execute( cur.execute(
"UPDATE jails SET enabled=0 WHERE name=?", (name, )) "UPDATE jails SET enabled=0 WHERE name=?", (name, ))
@lockandcommit() @commitandrollback()
def delAllJails(self, cur): def delAllJails(self, cur):
# Will be deleted by purge as appropriate # Will be deleted by purge as appropriate
cur.execute("UPDATE jails SET enabled=0") cur.execute("UPDATE jails SET enabled=0")
@lockandcommit() @commitandrollback()
def getJailNames(self, cur): def getJailNames(self, cur):
cur.execute("SELECT name FROM jails") cur.execute("SELECT name FROM jails")
return set(row[0] for row in cur.fetchmany()) return set(row[0] for row in cur.fetchmany())
@lockandcommit() @commitandrollback()
def addLog(self, cur, jail, container): def addLog(self, cur, jail, container):
lastLinePos = None lastLinePos = None
cur.execute( cur.execute(
@ -199,7 +196,7 @@ class Fail2BanDb(object):
lastLinePos = None lastLinePos = None
return lastLinePos return lastLinePos
@lockandcommit() @commitandrollback()
def getLogPaths(self, cur, jail=None): def getLogPaths(self, cur, jail=None):
query = "SELECT path FROM logs" query = "SELECT path FROM logs"
queryArgs = [] queryArgs = []
@ -209,7 +206,7 @@ class Fail2BanDb(object):
cur.execute(query, queryArgs) cur.execute(query, queryArgs)
return set(row[0] for row in cur.fetchmany()) return set(row[0] for row in cur.fetchmany())
@lockandcommit() @commitandrollback()
def updateLog(self, cur, *args, **kwargs): def updateLog(self, cur, *args, **kwargs):
self._updateLog(cur, *args, **kwargs) self._updateLog(cur, *args, **kwargs)
@ -220,7 +217,7 @@ class Fail2BanDb(object):
(container.getHash(), container.getPos(), (container.getHash(), container.getPos(),
jail.getName(), container.getFileName())) jail.getName(), container.getFileName()))
@lockandcommit() @commitandrollback()
def addBan(self, cur, jail, ticket): def addBan(self, cur, jail, ticket):
#TODO: Implement data parts once arbitrary match keys completed #TODO: Implement data parts once arbitrary match keys completed
cur.execute( cur.execute(
@ -229,7 +226,7 @@ class Fail2BanDb(object):
{"matches": ticket.getMatches(), {"matches": ticket.getMatches(),
"failures": ticket.getAttempt()})) "failures": ticket.getAttempt()}))
@lockandcommit() @commitandrollback()
def getBans(self, cur, jail=None, bantime=None): def getBans(self, cur, jail=None, bantime=None):
query = "SELECT ip, timeofban, data FROM bans" query = "SELECT ip, timeofban, data FROM bans"
queryArgs = [] queryArgs = []
@ -248,7 +245,7 @@ class Fail2BanDb(object):
tickets[-1].setAttempt(data['failures']) tickets[-1].setAttempt(data['failures'])
return tickets return tickets
@lockandcommit() @commitandrollback()
def purge(self, cur): def purge(self, cur):
cur.execute( cur.execute(
"DELETE FROM bans WHERE timeofban < ?", "DELETE FROM bans WHERE timeofban < ?",