|
|
|
@ -23,16 +23,20 @@ __copyright__ = "Copyright (c) 2013 Steven Hiscocks"
|
|
|
|
|
__license__ = "GPL" |
|
|
|
|
|
|
|
|
|
import os |
|
|
|
|
import sys |
|
|
|
|
import unittest |
|
|
|
|
import tempfile |
|
|
|
|
import sqlite3 |
|
|
|
|
import shutil |
|
|
|
|
|
|
|
|
|
from ..server.database import Fail2BanDb |
|
|
|
|
from ..server.filter import FileContainer |
|
|
|
|
from ..server.mytime import MyTime |
|
|
|
|
from ..server.ticket import FailTicket |
|
|
|
|
from .dummyjail import DummyJail |
|
|
|
|
try: |
|
|
|
|
from ..server.database import Fail2BanDb |
|
|
|
|
except ImportError: |
|
|
|
|
Fail2BanDb = None |
|
|
|
|
|
|
|
|
|
TEST_FILES_DIR = os.path.join(os.path.dirname(__file__), "files") |
|
|
|
|
|
|
|
|
@ -40,24 +44,38 @@ class DatabaseTest(unittest.TestCase):
|
|
|
|
|
|
|
|
|
|
def setUp(self): |
|
|
|
|
"""Call before every test case.""" |
|
|
|
|
if Fail2BanDb is None and sys.version_info >= (2,7): # pragma: no cover |
|
|
|
|
raise unittest.SkipTest( |
|
|
|
|
"Unable to import fail2ban database module as sqlite is not " |
|
|
|
|
"available.") |
|
|
|
|
elif Fail2BanDb is None: |
|
|
|
|
return |
|
|
|
|
_, self.dbFilename = tempfile.mkstemp(".db", "fail2ban_") |
|
|
|
|
self.db = Fail2BanDb(self.dbFilename) |
|
|
|
|
|
|
|
|
|
def tearDown(self): |
|
|
|
|
"""Call after every test case.""" |
|
|
|
|
if Fail2BanDb is None: # pragma: no cover |
|
|
|
|
return |
|
|
|
|
# Cleanup |
|
|
|
|
os.remove(self.dbFilename) |
|
|
|
|
|
|
|
|
|
def testGetFilename(self): |
|
|
|
|
if Fail2BanDb is None: # pragma: no cover |
|
|
|
|
return |
|
|
|
|
self.assertEqual(self.dbFilename, self.db.filename) |
|
|
|
|
|
|
|
|
|
def testCreateInvalidPath(self): |
|
|
|
|
if Fail2BanDb is None: # pragma: no cover |
|
|
|
|
return |
|
|
|
|
self.assertRaises( |
|
|
|
|
sqlite3.OperationalError, |
|
|
|
|
Fail2BanDb, |
|
|
|
|
"/this/path/should/not/exist") |
|
|
|
|
|
|
|
|
|
def testCreateAndReconnect(self): |
|
|
|
|
if Fail2BanDb is None: # pragma: no cover |
|
|
|
|
return |
|
|
|
|
self.testAddJail() |
|
|
|
|
# Reconnect... |
|
|
|
|
self.db = Fail2BanDb(self.dbFilename) |
|
|
|
@ -67,6 +85,8 @@ class DatabaseTest(unittest.TestCase):
|
|
|
|
|
"Jail not retained in Db after disconnect reconnect.") |
|
|
|
|
|
|
|
|
|
def testUpdateDb(self): |
|
|
|
|
if Fail2BanDb is None: # pragma: no cover |
|
|
|
|
return |
|
|
|
|
shutil.copyfile( |
|
|
|
|
os.path.join(TEST_FILES_DIR, 'database_v1.db'), self.dbFilename) |
|
|
|
|
self.db = Fail2BanDb(self.dbFilename) |
|
|
|
@ -80,6 +100,8 @@ class DatabaseTest(unittest.TestCase):
|
|
|
|
|
os.remove(self.db._dbBackupFilename) |
|
|
|
|
|
|
|
|
|
def testAddJail(self): |
|
|
|
|
if Fail2BanDb is None: # pragma: no cover |
|
|
|
|
return |
|
|
|
|
self.jail = DummyJail() |
|
|
|
|
self.db.addJail(self.jail) |
|
|
|
|
self.assertTrue( |
|
|
|
@ -87,6 +109,8 @@ class DatabaseTest(unittest.TestCase):
|
|
|
|
|
"Jail not added to database") |
|
|
|
|
|
|
|
|
|
def testAddLog(self): |
|
|
|
|
if Fail2BanDb is None: # pragma: no cover |
|
|
|
|
return |
|
|
|
|
self.testAddJail() # Jail required |
|
|
|
|
|
|
|
|
|
_, filename = tempfile.mkstemp(".log", "Fail2BanDb_") |
|
|
|
@ -98,6 +122,8 @@ class DatabaseTest(unittest.TestCase):
|
|
|
|
|
os.remove(filename) |
|
|
|
|
|
|
|
|
|
def testUpdateLog(self): |
|
|
|
|
if Fail2BanDb is None: # pragma: no cover |
|
|
|
|
return |
|
|
|
|
self.testAddLog() # Add log file |
|
|
|
|
|
|
|
|
|
# Write some text |
|
|
|
@ -137,6 +163,8 @@ class DatabaseTest(unittest.TestCase):
|
|
|
|
|
os.remove(filename) |
|
|
|
|
|
|
|
|
|
def testAddBan(self): |
|
|
|
|
if Fail2BanDb is None: # pragma: no cover |
|
|
|
|
return |
|
|
|
|
self.testAddJail() |
|
|
|
|
ticket = FailTicket("127.0.0.1", 0, ["abc\n"]) |
|
|
|
|
self.db.addBan(self.jail, ticket) |
|
|
|
@ -146,6 +174,8 @@ class DatabaseTest(unittest.TestCase):
|
|
|
|
|
isinstance(self.db.getBans(jail=self.jail)[0], FailTicket)) |
|
|
|
|
|
|
|
|
|
def testGetBansWithTime(self): |
|
|
|
|
if Fail2BanDb is None: # pragma: no cover |
|
|
|
|
return |
|
|
|
|
self.testAddJail() |
|
|
|
|
ticket = FailTicket("127.0.0.1", MyTime.time() - 40, ["abc\n"]) |
|
|
|
|
self.db.addBan(self.jail, ticket) |
|
|
|
@ -153,6 +183,8 @@ class DatabaseTest(unittest.TestCase):
|
|
|
|
|
self.assertEqual(len(self.db.getBans(jail=self.jail,bantime=20)), 0) |
|
|
|
|
|
|
|
|
|
def testGetBansMerged(self): |
|
|
|
|
if Fail2BanDb is None: # pragma: no cover |
|
|
|
|
return |
|
|
|
|
self.testAddJail() |
|
|
|
|
|
|
|
|
|
jail2 = DummyJail() |
|
|
|
@ -197,6 +229,8 @@ class DatabaseTest(unittest.TestCase):
|
|
|
|
|
id(self.db.getBansMerged("127.0.0.1", jail=self.jail))) |
|
|
|
|
|
|
|
|
|
def testPurge(self): |
|
|
|
|
if Fail2BanDb is None: # pragma: no cover |
|
|
|
|
return |
|
|
|
|
self.testAddJail() # Add jail |
|
|
|
|
|
|
|
|
|
self.db.purge() # Jail enabled by default so shouldn't be purged |
|
|
|
|