diff --git a/client/actionreader.py b/client/actionreader.py index 581a1b3c..9ad1ef28 100644 --- a/client/actionreader.py +++ b/client/actionreader.py @@ -35,8 +35,8 @@ logSys = logging.getLogger("fail2ban.client.config") class ActionReader(ConfigReader): - def __init__(self, action, name): - ConfigReader.__init__(self) + def __init__(self, action, name, **kwargs): + ConfigReader.__init__(self, **kwargs) self.__file = action[0] self.__cInfo = action[1] self.__name = name diff --git a/client/configreader.py b/client/configreader.py index 063484e8..79c5b47b 100644 --- a/client/configreader.py +++ b/client/configreader.py @@ -27,7 +27,7 @@ __date__ = "$Date$" __copyright__ = "Copyright (c) 2004 Cyril Jaquier" __license__ = "GPL" -import logging, os +import glob, logging, os from configparserinc import SafeConfigParserWithIncludes from ConfigParser import NoOptionError, NoSectionError @@ -35,36 +35,64 @@ from ConfigParser import NoOptionError, NoSectionError logSys = logging.getLogger("fail2ban.client.config") class ConfigReader(SafeConfigParserWithIncludes): + + DEFAULT_BASEDIR = '/etc/fail2ban' - BASE_DIRECTORY = "/etc/fail2ban/" - - def __init__(self): + def __init__(self, basedir=None): SafeConfigParserWithIncludes.__init__(self) + self.setBaseDir(basedir) self.__opts = None - #@staticmethod - def setBaseDir(folderName): - path = folderName.rstrip('/') - ConfigReader.BASE_DIRECTORY = path + '/' - setBaseDir = staticmethod(setBaseDir) - - #@staticmethod - def getBaseDir(): - return ConfigReader.BASE_DIRECTORY - getBaseDir = staticmethod(getBaseDir) + def setBaseDir(self, basedir): + if basedir is None: + basedir = ConfigReader.DEFAULT_BASEDIR # stock system location + if not (os.path.exists(basedir) and os.access(basedir, os.R_OK | os.X_OK)): + raise ValueError("Base configuration directory %s either does not exist " + "or is not accessible" % basedir) + self._basedir = basedir.rstrip('/') + + def getBaseDir(self): + return self._basedir def read(self, filename): - basename = ConfigReader.BASE_DIRECTORY + filename + basename = os.path.join(self._basedir, filename) logSys.debug("Reading " + basename) - bConf = basename + ".conf" - bLocal = basename + ".local" - if os.path.exists(bConf) or os.path.exists(bLocal): - SafeConfigParserWithIncludes.read(self, [bConf, bLocal]) + config_files = [ basename + ".conf", + basename + ".local" ] + + # choose only existing ones + config_files = filter(os.path.exists, config_files) + + # possible further customizations under a .conf.d directory + config_dir = basename + '.d' + if os.path.exists(config_dir): + if os.path.isdir(config_dir) and os.access(config_dir, os.X_OK | os.R_OK): + # files must carry .conf suffix as well + config_files += sorted(glob.glob('%s/*.conf' % config_dir)) + else: + logSys.warn("%s exists but not a directory or not accessible" + % config_dir) + + # check if files are accessible, warn if any is not accessible + # and remove it from the list + config_files_accessible = [] + for f in config_files: + if os.access(f, os.R_OK): + config_files_accessible.append(f) + else: + logSys.warn("%s exists but not accessible - skipping" % f) + + if len(config_files_accessible): + # at least one config exists and accessible + SafeConfigParserWithIncludes.read(self, config_files_accessible) return True else: - logSys.error(bConf + " and " + bLocal + " do not exist") + logSys.error("Found no accessible config files for %r " % filename + + (["", + "among existing ones: " + ', '.join(config_files)][bool(len(config_files))])) + return False - + ## # Read the options. # diff --git a/client/fail2banreader.py b/client/fail2banreader.py index ee097bd6..7115ec79 100644 --- a/client/fail2banreader.py +++ b/client/fail2banreader.py @@ -35,8 +35,8 @@ logSys = logging.getLogger("fail2ban.client.config") class Fail2banReader(ConfigReader): - def __init__(self): - ConfigReader.__init__(self) + def __init__(self, **kwargs): + ConfigReader.__init__(self, **kwargs) def read(self): ConfigReader.read(self, "fail2ban") diff --git a/client/filterreader.py b/client/filterreader.py index b7a72f9c..7dba3579 100644 --- a/client/filterreader.py +++ b/client/filterreader.py @@ -35,8 +35,8 @@ logSys = logging.getLogger("fail2ban.client.config") class FilterReader(ConfigReader): - def __init__(self, fileName, name): - ConfigReader.__init__(self) + def __init__(self, fileName, name, **kwargs): + ConfigReader.__init__(self, **kwargs) self.__file = fileName self.__name = name diff --git a/client/jailreader.py b/client/jailreader.py index f66dc010..ec73ce46 100644 --- a/client/jailreader.py +++ b/client/jailreader.py @@ -40,8 +40,8 @@ class JailReader(ConfigReader): actionCRE = re.compile("^((?:\w|-|_|\.)+)(?:\[(.*)\])?$") - def __init__(self, name): - ConfigReader.__init__(self) + def __init__(self, name, **kwargs): + ConfigReader.__init__(self, **kwargs) self.__name = name self.__filter = None self.__actions = list() @@ -53,7 +53,7 @@ class JailReader(ConfigReader): return self.__name def read(self): - ConfigReader.read(self, "jail") + return ConfigReader.read(self, "jail") def isEnabled(self): return self.__opts["enabled"] @@ -75,7 +75,8 @@ class JailReader(ConfigReader): if self.isEnabled(): # Read filter - self.__filter = FilterReader(self.__opts["filter"], self.__name) + self.__filter = FilterReader(self.__opts["filter"], self.__name, + basedir=self.getBaseDir()) ret = self.__filter.read() if ret: self.__filter.getOptions(self.__opts) @@ -87,7 +88,7 @@ class JailReader(ConfigReader): for act in self.__opts["action"].split('\n'): try: splitAct = JailReader.splitAction(act) - action = ActionReader(splitAct, self.__name) + action = ActionReader(splitAct, self.__name, basedir=self.getBaseDir()) ret = action.read() if ret: action.getOptions(self.__opts) diff --git a/client/jailsreader.py b/client/jailsreader.py index bedc5a3c..e1b8efa3 100644 --- a/client/jailsreader.py +++ b/client/jailsreader.py @@ -36,12 +36,12 @@ logSys = logging.getLogger("fail2ban.client.config") class JailsReader(ConfigReader): - def __init__(self): - ConfigReader.__init__(self) + def __init__(self, **kwargs): + ConfigReader.__init__(self, **kwargs) self.__jails = list() def read(self): - ConfigReader.read(self, "jail") + return ConfigReader.read(self, "jail") def getOptions(self, section = None): opts = [] @@ -49,7 +49,7 @@ class JailsReader(ConfigReader): if section: # Get the options of a specific jail. - jail = JailReader(section) + jail = JailReader(section, basedir=self.getBaseDir()) jail.read() ret = jail.getOptions() if ret: @@ -62,7 +62,7 @@ class JailsReader(ConfigReader): else: # Get the options of all jails. for sec in self.sections(): - jail = JailReader(sec) + jail = JailReader(sec, basedir=self.getBaseDir()) jail.read() ret = jail.getOptions() if ret: diff --git a/fail2ban-testcases b/fail2ban-testcases index 99fefd57..ff94cfad 100755 --- a/fail2ban-testcases +++ b/fail2ban-testcases @@ -130,8 +130,10 @@ tests.addTest(unittest.makeSuite(actiontestcase.ExecuteAction)) tests.addTest(unittest.makeSuite(failmanagertestcase.AddFailure)) # BanManager tests.addTest(unittest.makeSuite(banmanagertestcase.AddFailure)) -# ClientReader +# ClientReaders +tests.addTest(unittest.makeSuite(clientreadertestcase.ConfigReaderTest)) tests.addTest(unittest.makeSuite(clientreadertestcase.JailReaderTest)) +tests.addTest(unittest.makeSuite(clientreadertestcase.JailsReaderTest)) # Filter if not opts.no_network: diff --git a/testcases/clientreadertestcase.py b/testcases/clientreadertestcase.py index 83121345..55fb010d 100644 --- a/testcases/clientreadertestcase.py +++ b/testcases/clientreadertestcase.py @@ -27,20 +27,93 @@ __date__ = "$Date$" __copyright__ = "Copyright (c) 2004 Cyril Jaquier" __license__ = "GPL" -import unittest +import os, shutil, tempfile, unittest +from client.configreader import ConfigReader from client.jailreader import JailReader +from client.jailsreader import JailsReader -class JailReaderTest(unittest.TestCase): +class ConfigReaderTest(unittest.TestCase): def setUp(self): """Call before every test case.""" + self.d = tempfile.mkdtemp(prefix="f2b-temp") + self.c = ConfigReader(basedir=self.d) + def tearDown(self): """Call after every test case.""" + shutil.rmtree(self.d) + + def _write(self, fname, value): + # verify if we don't need to create .d directory + if os.path.sep in fname: + d = os.path.dirname(fname) + d_ = os.path.join(self.d, d) + if not os.path.exists(d_): + os.makedirs(d_) + open("%s/%s" % (self.d, fname), "w").write(""" +[section] +option = %s +""" % value) + + def _remove(self, fname): + os.unlink("%s/%s" % (self.d, fname)) + self.assertTrue(self.c.read('c')) # we still should have some + + + def _getoption(self): + self.assertTrue(self.c.read('c')) # we got some now + return self.c.getOptions('section', [("int", 'option')])['option'] + + def testOptionalDotDDir(self): + self.assertFalse(self.c.read('c')) # nothing is there yet + self._write("c.conf", "1") + self.assertEqual(self._getoption(), 1) + self._write("c.conf", "2") # overwrite + self.assertEqual(self._getoption(), 2) + self._write("c.local", "3") # add override in .local + self.assertEqual(self._getoption(), 3) + self._write("c.d/98.conf", "998") # add 1st override in .d/ + self.assertEqual(self._getoption(), 998) + self._write("c.d/90.conf", "990") # add previously sorted override in .d/ + self.assertEqual(self._getoption(), 998) # should stay the same + self._write("c.d/99.conf", "999") # now override in a way without sorting we possibly get a failure + self.assertEqual(self._getoption(), 999) + self._remove("c.d/99.conf") + self.assertEqual(self._getoption(), 998) + self._remove("c.d/98.conf") + self.assertEqual(self._getoption(), 990) + self._remove("c.d/90.conf") + self.assertEqual(self._getoption(), 3) + self._remove("c.conf") # we allow to stay without .conf + self.assertEqual(self._getoption(), 3) + self._write("c.conf", "1") + self._remove("c.local") + self.assertEqual(self._getoption(), 1) + + +class JailReaderTest(unittest.TestCase): + + def testStockSSHJail(self): + jail = JailReader('ssh-iptables', basedir='config') # we are running tests from root project dir atm + self.assertTrue(jail.read()) + self.assertTrue(jail.getOptions()) + self.assertFalse(jail.isEnabled()) + self.assertEqual(jail.getName(), 'ssh-iptables') + + +class JailsReaderTest(unittest.TestCase): + + def testProvidingBadBasedir(self): + if not os.path.exists('/XXX'): + self.assertRaises(ValueError, JailsReader, basedir='/XXX') + + def testReadStockJailConf(self): + jails = JailsReader(basedir='config') # we are running tests from root project dir atm + self.assertTrue(jails.read()) # opens fine + self.assertTrue(jails.getOptions()) # reads fine + comm_commands = jails.convert() + # by default None of the jails is enabled and we get no + # commands to communicate to the server + self.assertEqual(comm_commands, []) - def testSplitAction(self): - action = "mail-whois[name=SSH]" - expected = ['mail-whois', {'name': 'SSH'}] - result = JailReader.splitAction(action) - self.assertEquals(expected, result) -