diff --git a/fail2ban/client/configparserinc.py b/fail2ban/client/configparserinc.py index 80b99517..ad0a255c 100644 --- a/fail2ban/client/configparserinc.py +++ b/fail2ban/client/configparserinc.py @@ -65,7 +65,136 @@ logSys = getLogger(__name__) __all__ = ['SafeConfigParserWithIncludes'] -class SafeConfigParserWithIncludes(SafeConfigParser): +class SafeConfigParserWithIncludes(object): + + SECTION_NAME = "INCLUDES" + CFG_CACHE = {} + CFG_INC_CACHE = {} + CFG_EMPY_CFG = None + + def __init__(self): + self.__cr = None + + def __check_read(self, attr): + if self.__cr is None: + # raise RuntimeError("Access to wrapped attribute \"%s\" before read call" % attr) + if SafeConfigParserWithIncludes.CFG_EMPY_CFG is None: + SafeConfigParserWithIncludes.CFG_EMPY_CFG = _SafeConfigParserWithIncludes() + self.__cr = SafeConfigParserWithIncludes.CFG_EMPY_CFG + + def __getattr__(self,attr): + # check we access local implementation + try: + orig_attr = self.__getattribute__(attr) + except AttributeError: + self.__check_read(attr) + orig_attr = self.__cr.__getattribute__(attr) + return orig_attr + + @staticmethod + def _resource_mtime(resource): + mt = [] + dirnames = [] + for filename in resource: + if os.path.exists(filename): + s = os.stat(filename) + mt.append(s.st_mtime) + mt.append(s.st_mode) + mt.append(s.st_size) + dirname = os.path.dirname(filename) + if dirname not in dirnames: + dirnames.append(dirname) + for dirname in dirnames: + if os.path.exists(dirname): + s = os.stat(dirname) + mt.append(s.st_mtime) + mt.append(s.st_mode) + mt.append(s.st_size) + return mt + + def read(self, resource, get_includes=True, log_info=None): + SCPWI = SafeConfigParserWithIncludes + # check includes : + fileNamesFull = [] + if not isinstance(resource, list): + resource = [ resource ] + if get_includes: + for filename in resource: + fileNamesFull += SCPWI.getIncludes(filename) + else: + fileNamesFull = resource + # check cache + hashv = '///'.join(fileNamesFull) + cr, ret, mtime = SCPWI.CFG_CACHE.get(hashv, (None, False, 0)) + curmt = SCPWI._resource_mtime(fileNamesFull) + if cr is not None and mtime == curmt: + self.__cr = cr + logSys.debug("Cached config files: %s", resource) + #logSys.debug("Cached config files: %s", fileNamesFull) + return ret + # not yet in cache - create/read and add to cache: + if log_info is not None: + logSys.info(*log_info) + cr = _SafeConfigParserWithIncludes() + ret = cr.read(fileNamesFull) + SCPWI.CFG_CACHE[hashv] = (cr, ret, curmt) + self.__cr = cr + return ret + + def getOptions(self, *args, **kwargs): + self.__check_read('getOptions') + return self.__cr.getOptions(*args, **kwargs) + + @staticmethod + def getIncludes(resource, seen = []): + """ + Given 1 config resource returns list of included files + (recursively) with the original one as well + Simple loops are taken care about + """ + + # Use a short class name ;) + SCPWI = SafeConfigParserWithIncludes + + resources = seen + [resource] + # check cache + hashv = '///'.join(resources) + cinc, mtime = SCPWI.CFG_INC_CACHE.get(hashv, (None, 0)) + curmt = SCPWI._resource_mtime(resources) + if cinc is not None and mtime == curmt: + return cinc + + parser = SCPWI() + try: + # read without includes + parser.read(resource, get_includes = False) + except UnicodeDecodeError, e: + logSys.error("Error decoding config file '%s': %s" % (resource, e)) + return [] + + resourceDir = os.path.dirname(resource) + + newFiles = [ ('before', []), ('after', []) ] + if SCPWI.SECTION_NAME in parser.sections(): + for option_name, option_list in newFiles: + if option_name in parser.options(SCPWI.SECTION_NAME): + newResources = parser.get(SCPWI.SECTION_NAME, option_name) + for newResource in newResources.split('\n'): + if os.path.isabs(newResource): + r = newResource + else: + r = os.path.join(resourceDir, newResource) + if r in seen: + continue + option_list += SCPWI.getIncludes(r, resources) + # combine lists + cinc = newFiles[0][1] + [resource] + newFiles[1][1] + # cache and return : + SCPWI.CFG_INC_CACHE[hashv] = (cinc, curmt) + return cinc + #print "Includes list for " + resource + " is " + `resources` + +class _SafeConfigParserWithIncludes(SafeConfigParser, object): """ Class adds functionality to SafeConfigParser to handle included other configuration files (or may be urls, whatever in the future) @@ -94,69 +223,22 @@ after = 1.conf """ - SECTION_NAME = "INCLUDES" - if sys.version_info >= (3,2): # overload constructor only for fancy new Python3's def __init__(self, *args, **kwargs): kwargs = kwargs.copy() kwargs['interpolation'] = BasicInterpolationWithName() kwargs['inline_comment_prefixes'] = ";" - super(SafeConfigParserWithIncludes, self).__init__( + super(_SafeConfigParserWithIncludes, self).__init__( *args, **kwargs) - #@staticmethod - def getIncludes(resource, seen = []): - """ - Given 1 config resource returns list of included files - (recursively) with the original one as well - Simple loops are taken care about - """ - - # Use a short class name ;) - SCPWI = SafeConfigParserWithIncludes - - parser = SafeConfigParser() - try: - if sys.version_info >= (3,2): # pragma: no cover - parser.read(resource, encoding='utf-8') - else: - parser.read(resource) - except UnicodeDecodeError, e: - logSys.error("Error decoding config file '%s': %s" % (resource, e)) - return [] - - resourceDir = os.path.dirname(resource) - - newFiles = [ ('before', []), ('after', []) ] - if SCPWI.SECTION_NAME in parser.sections(): - for option_name, option_list in newFiles: - if option_name in parser.options(SCPWI.SECTION_NAME): - newResources = parser.get(SCPWI.SECTION_NAME, option_name) - for newResource in newResources.split('\n'): - if os.path.isabs(newResource): - r = newResource - else: - r = os.path.join(resourceDir, newResource) - if r in seen: - continue - s = seen + [resource] - option_list += SCPWI.getIncludes(r, s) - # combine lists - return newFiles[0][1] + [resource] + newFiles[1][1] - #print "Includes list for " + resource + " is " + `resources` - getIncludes = staticmethod(getIncludes) - def read(self, filenames): - fileNamesFull = [] if not isinstance(filenames, list): filenames = [ filenames ] - for filename in filenames: - fileNamesFull += SafeConfigParserWithIncludes.getIncludes(filename) - logSys.debug("Reading files: %s" % fileNamesFull) + logSys.debug("Reading files: %s", filenames) if sys.version_info >= (3,2): # pragma: no cover - return SafeConfigParser.read(self, fileNamesFull, encoding='utf-8') + return SafeConfigParser.read(self, filenames, encoding='utf-8') else: - return SafeConfigParser.read(self, fileNamesFull) + return SafeConfigParser.read(self, filenames) diff --git a/fail2ban/client/configreader.py b/fail2ban/client/configreader.py index 22115d3a..ada48803 100644 --- a/fail2ban/client/configreader.py +++ b/fail2ban/client/configreader.py @@ -55,7 +55,7 @@ class ConfigReader(SafeConfigParserWithIncludes): raise ValueError("Base configuration directory %s does not exist " % self._basedir) basename = os.path.join(self._basedir, filename) - logSys.info("Reading configs for %s under %s " % (basename, self._basedir)) + logSys.debug("Reading configs for %s under %s " , filename, self._basedir) config_files = [ basename + ".conf" ] # possible further customizations under a .conf.d directory @@ -71,14 +71,15 @@ class ConfigReader(SafeConfigParserWithIncludes): if len(config_files): # at least one config exists and accessible - logSys.debug("Reading config files: " + ', '.join(config_files)) - config_files_read = SafeConfigParserWithIncludes.read(self, config_files) + logSys.debug("Reading config files: %s", ', '.join(config_files)) + config_files_read = SafeConfigParserWithIncludes.read(self, config_files, + log_info=("Cache configs for %s under %s " , filename, self._basedir)) missed = [ cf for cf in config_files if cf not in config_files_read ] if missed: - logSys.error("Could not read config files: " + ', '.join(missed)) + logSys.error("Could not read config files: %s", ', '.join(missed)) if config_files_read: return True - logSys.error("Found no accessible config files for %r under %s" % + logSys.error("Found no accessible config files for %r under %s", ( filename, self.getBaseDir() )) return False else: @@ -133,12 +134,12 @@ class ConfigReader(SafeConfigParserWithIncludes): class DefinitionInitConfigReader(ConfigReader): """Config reader for files with options grouped in [Definition] and - [Init] sections. + [Init] sections. - Is a base class for readers of filters and actions, where definitions - in jails might provide custom values for options defined in [Init] - section. - """ + Is a base class for readers of filters and actions, where definitions + in jails might provide custom values for options defined in [Init] + section. + """ _configOpts = [] diff --git a/fail2ban/tests/clientreadertestcase.py b/fail2ban/tests/clientreadertestcase.py index ce19a50e..f505f297 100644 --- a/fail2ban/tests/clientreadertestcase.py +++ b/fail2ban/tests/clientreadertestcase.py @@ -21,7 +21,7 @@ __author__ = "Cyril Jaquier, Yaroslav Halchenko" __copyright__ = "Copyright (c) 2004 Cyril Jaquier, 2011-2013 Yaroslav Halchenko" __license__ = "GPL" -import os, glob, shutil, tempfile, unittest +import os, glob, shutil, tempfile, unittest, time from ..client.configreader import ConfigReader from ..client.jailreader import JailReader @@ -37,6 +37,8 @@ CONFIG_DIR='config' if STOCK else '/etc/fail2ban' IMPERFECT_CONFIG = os.path.join(os.path.dirname(__file__), 'config') +LAST_WRITE_TIME = 0 + class ConfigReaderTest(unittest.TestCase): def setUp(self): @@ -55,7 +57,8 @@ class ConfigReaderTest(unittest.TestCase): d_ = os.path.join(self.d, d) if not os.path.exists(d_): os.makedirs(d_) - f = open("%s/%s" % (self.d, fname), "w") + fname = "%s/%s" % (self.d, fname) + f = open(fname, "w") if value is not None: f.write(""" [section] @@ -64,6 +67,14 @@ option = %s if content is not None: f.write(content) f.close() + # set modification time to another second to revalidate cache (if milliseconds not supported) : + global LAST_WRITE_TIME + mtime = os.path.getmtime(fname) + if LAST_WRITE_TIME == mtime: + mtime += 1 + os.utime(fname, (mtime, mtime)) + LAST_WRITE_TIME = mtime + def _remove(self, fname): os.unlink("%s/%s" % (self.d, fname)) @@ -89,7 +100,6 @@ option = %s # raise unittest.SkipTest("Skipping on %s -- access rights are not enforced" % platform) pass - def testOptionalDotDDir(self): self.assertFalse(self.c.read('c')) # nothing is there yet self._write("c.conf", "1")