Browse Source

BF: Handle case when no sqlite library is available for the database

pull/659/head
Steven Hiscocks 11 years ago
parent
commit
1c65b94617
  1. 16
      fail2ban/server/server.py
  2. 36
      fail2ban/tests/databasetestcase.py

16
fail2ban/server/server.py

@ -31,12 +31,17 @@ from .jails import Jails
from .filter import FileFilter, JournalFilter
from .transmitter import Transmitter
from .asyncserver import AsyncServer, AsyncServerException
from .database import Fail2BanDb
from .. import version
# Gets the instance of the logger.
logSys = logging.getLogger(__name__)
try:
from .database import Fail2BanDb
except ImportError:
# Dont print error here, as database may not even be used
Fail2BanDb = None
class Server:
def __init__(self, daemon = False):
@ -439,8 +444,13 @@ class Server:
if filename.lower() == "none":
self.__db = None
else:
self.__db = Fail2BanDb(filename)
self.__db.delAllJails()
if Fail2BanDb is not None:
self.__db = Fail2BanDb(filename)
self.__db.delAllJails()
else:
logSys.error(
"Unable to import fail2ban database module as sqlite "
"is not available.")
else:
raise RuntimeError(
"Cannot change database when there are jails present")

36
fail2ban/tests/databasetestcase.py

@ -23,16 +23,20 @@ __copyright__ = "Copyright (c) 2013 Steven Hiscocks"
__license__ = "GPL"
import os
import sys
import unittest
import tempfile
import sqlite3
import shutil
from ..server.database import Fail2BanDb
from ..server.filter import FileContainer
from ..server.mytime import MyTime
from ..server.ticket import FailTicket
from .dummyjail import DummyJail
try:
from ..server.database import Fail2BanDb
except ImportError:
Fail2BanDb = None
TEST_FILES_DIR = os.path.join(os.path.dirname(__file__), "files")
@ -40,24 +44,38 @@ class DatabaseTest(unittest.TestCase):
def setUp(self):
"""Call before every test case."""
if Fail2BanDb is None and sys.version_info >= (2,7): # pragma: no cover
raise unittest.SkipTest(
"Unable to import fail2ban database module as sqlite is not "
"available.")
elif Fail2BanDb is None:
return
_, self.dbFilename = tempfile.mkstemp(".db", "fail2ban_")
self.db = Fail2BanDb(self.dbFilename)
def tearDown(self):
"""Call after every test case."""
if Fail2BanDb is None: # pragma: no cover
return
# Cleanup
os.remove(self.dbFilename)
def testGetFilename(self):
if Fail2BanDb is None: # pragma: no cover
return
self.assertEqual(self.dbFilename, self.db.filename)
def testCreateInvalidPath(self):
if Fail2BanDb is None: # pragma: no cover
return
self.assertRaises(
sqlite3.OperationalError,
Fail2BanDb,
"/this/path/should/not/exist")
def testCreateAndReconnect(self):
if Fail2BanDb is None: # pragma: no cover
return
self.testAddJail()
# Reconnect...
self.db = Fail2BanDb(self.dbFilename)
@ -67,6 +85,8 @@ class DatabaseTest(unittest.TestCase):
"Jail not retained in Db after disconnect reconnect.")
def testUpdateDb(self):
if Fail2BanDb is None: # pragma: no cover
return
shutil.copyfile(
os.path.join(TEST_FILES_DIR, 'database_v1.db'), self.dbFilename)
self.db = Fail2BanDb(self.dbFilename)
@ -80,6 +100,8 @@ class DatabaseTest(unittest.TestCase):
os.remove(self.db._dbBackupFilename)
def testAddJail(self):
if Fail2BanDb is None: # pragma: no cover
return
self.jail = DummyJail()
self.db.addJail(self.jail)
self.assertTrue(
@ -87,6 +109,8 @@ class DatabaseTest(unittest.TestCase):
"Jail not added to database")
def testAddLog(self):
if Fail2BanDb is None: # pragma: no cover
return
self.testAddJail() # Jail required
_, filename = tempfile.mkstemp(".log", "Fail2BanDb_")
@ -98,6 +122,8 @@ class DatabaseTest(unittest.TestCase):
os.remove(filename)
def testUpdateLog(self):
if Fail2BanDb is None: # pragma: no cover
return
self.testAddLog() # Add log file
# Write some text
@ -137,6 +163,8 @@ class DatabaseTest(unittest.TestCase):
os.remove(filename)
def testAddBan(self):
if Fail2BanDb is None: # pragma: no cover
return
self.testAddJail()
ticket = FailTicket("127.0.0.1", 0, ["abc\n"])
self.db.addBan(self.jail, ticket)
@ -146,6 +174,8 @@ class DatabaseTest(unittest.TestCase):
isinstance(self.db.getBans(jail=self.jail)[0], FailTicket))
def testGetBansWithTime(self):
if Fail2BanDb is None: # pragma: no cover
return
self.testAddJail()
ticket = FailTicket("127.0.0.1", MyTime.time() - 40, ["abc\n"])
self.db.addBan(self.jail, ticket)
@ -153,6 +183,8 @@ class DatabaseTest(unittest.TestCase):
self.assertEqual(len(self.db.getBans(jail=self.jail,bantime=20)), 0)
def testGetBansMerged(self):
if Fail2BanDb is None: # pragma: no cover
return
self.testAddJail()
jail2 = DummyJail()
@ -197,6 +229,8 @@ class DatabaseTest(unittest.TestCase):
id(self.db.getBansMerged("127.0.0.1", jail=self.jail)))
def testPurge(self):
if Fail2BanDb is None: # pragma: no cover
return
self.testAddJail() # Add jail
self.db.purge() # Jail enabled by default so shouldn't be purged

Loading…
Cancel
Save