Merge branch '0.9/gh-1492'

pull/1486/merge
sebres 2016-08-01 14:45:05 +02:00
commit 0083036b5f
9 changed files with 204 additions and 47 deletions

View File

@ -293,7 +293,11 @@ class Fail2BanDb(object):
Jail to be added to the database. Jail to be added to the database.
""" """
cur.execute( cur.execute(
"INSERT OR REPLACE INTO jails(name, enabled) VALUES(?, 1)", "INSERT OR IGNORE INTO jails(name, enabled) VALUES(?, 1)",
(jail.name,))
if cur.rowcount <= 0:
cur.execute(
"UPDATE jails SET enabled = 1 WHERE name = ? AND enabled != 1",
(jail.name,)) (jail.name,))
@commitandrollback @commitandrollback
@ -317,7 +321,7 @@ class Fail2BanDb(object):
cur.execute("UPDATE jails SET enabled=0") cur.execute("UPDATE jails SET enabled=0")
@commitandrollback @commitandrollback
def getJailNames(self, cur): def getJailNames(self, cur, enabled=None):
"""Get name of jails in database. """Get name of jails in database.
Currently only used for testing purposes. Currently only used for testing purposes.
@ -327,7 +331,11 @@ class Fail2BanDb(object):
set set
Set of jail names. Set of jail names.
""" """
if enabled is None:
cur.execute("SELECT name FROM jails") cur.execute("SELECT name FROM jails")
else:
cur.execute("SELECT name FROM jails WHERE enabled=%s" %
(int(enabled),))
return set(row[0] for row in cur.fetchmany()) return set(row[0] for row in cur.fetchmany())
@commitandrollback @commitandrollback

View File

@ -49,7 +49,7 @@ if sys.version_info >= (2,7):
def testCategory(self): def testCategory(self):
categories = self.action.getCategories() categories = self.action.getCategories()
self.assertTrue("ssh" in categories) self.assertIn("ssh", categories)
self.assertTrue(len(categories) >= 10) self.assertTrue(len(categories) >= 10)
self.assertRaises( self.assertRaises(

View File

@ -101,21 +101,21 @@ class SMTPActionTest(unittest.TestCase):
self.assertEqual(self.smtpd.rcpttos, ["root"]) self.assertEqual(self.smtpd.rcpttos, ["root"])
subject = "Subject: [Fail2Ban] %s: banned %s" % ( subject = "Subject: [Fail2Ban] %s: banned %s" % (
self.jail.name, aInfo['ip']) self.jail.name, aInfo['ip'])
self.assertTrue(subject in self.smtpd.data.replace("\n", "")) self.assertIn(subject, self.smtpd.data.replace("\n", ""))
self.assertTrue( self.assertTrue(
"%i attempts" % aInfo['failures'] in self.smtpd.data) "%i attempts" % aInfo['failures'] in self.smtpd.data)
self.action.matches = "matches" self.action.matches = "matches"
self.action.ban(aInfo) self.action.ban(aInfo)
self.assertTrue(aInfo['matches'] in self.smtpd.data) self.assertIn(aInfo['matches'], self.smtpd.data)
self.action.matches = "ipjailmatches" self.action.matches = "ipjailmatches"
self.action.ban(aInfo) self.action.ban(aInfo)
self.assertTrue(aInfo['ipjailmatches'] in self.smtpd.data) self.assertIn(aInfo['ipjailmatches'], self.smtpd.data)
self.action.matches = "ipmatches" self.action.matches = "ipmatches"
self.action.ban(aInfo) self.action.ban(aInfo)
self.assertTrue(aInfo['ipmatches'] in self.smtpd.data) self.assertIn(aInfo['ipmatches'], self.smtpd.data)
def testOptions(self): def testOptions(self):
self.action.start() self.action.start()

View File

@ -65,12 +65,12 @@ class ExecuteActions(LogCaptureTestCase):
def testActionsManipulation(self): def testActionsManipulation(self):
self.__actions.add('test') self.__actions.add('test')
self.assertTrue(self.__actions['test']) self.assertTrue(self.__actions['test'])
self.assertTrue('test' in self.__actions) self.assertIn('test', self.__actions)
self.assertFalse('nonexistant action' in self.__actions) self.assertNotIn('nonexistant action', self.__actions)
self.__actions.add('test1') self.__actions.add('test1')
del self.__actions['test'] del self.__actions['test']
del self.__actions['test1'] del self.__actions['test1']
self.assertFalse('test' in self.__actions) self.assertNotIn('test', self.__actions)
self.assertEqual(len(self.__actions), 0) self.assertEqual(len(self.__actions), 0)
self.__actions.setBanTime(127) self.__actions.setBanTime(127)

View File

@ -525,12 +525,12 @@ class JailsReaderTest(LogCaptureTestCase):
self.assertTrue(actionReader.read()) self.assertTrue(actionReader.read())
actionReader.getOptions({}) # populate _opts actionReader.getOptions({}) # populate _opts
if not actionName.endswith('-common'): if not actionName.endswith('-common'):
self.assertTrue('Definition' in actionReader.sections(), self.assertIn('Definition', actionReader.sections(),
msg="Action file %r is lacking [Definition] section" % actionConfig) msg="Action file %r is lacking [Definition] section" % actionConfig)
# all must have some actionban defined # all must have some actionban defined
self.assertTrue(actionReader._opts.get('actionban', '').strip(), self.assertTrue(actionReader._opts.get('actionban', '').strip(),
msg="Action file %r is lacking actionban" % actionConfig) msg="Action file %r is lacking actionban" % actionConfig)
self.assertTrue('Init' in actionReader.sections(), self.assertIn('Init', actionReader.sections(),
msg="Action file %r is lacking [Init] section" % actionConfig) msg="Action file %r is lacking [Init] section" % actionConfig)
def testReadStockJailConf(self): def testReadStockJailConf(self):
@ -582,7 +582,7 @@ class JailsReaderTest(LogCaptureTestCase):
self.assertTrue(len(actName)) self.assertTrue(len(actName))
self.assertTrue(isinstance(actOpt, dict)) self.assertTrue(isinstance(actOpt, dict))
if actName == 'iptables-multiport': if actName == 'iptables-multiport':
self.assertTrue('port' in actOpt) self.assertIn('port', actOpt)
actionReader = ActionReader( actionReader = ActionReader(
actName, jail, {}, basedir=CONFIG_DIR) actName, jail, {}, basedir=CONFIG_DIR)
@ -632,11 +632,13 @@ class JailsReaderTest(LogCaptureTestCase):
# and we know even some of them by heart # and we know even some of them by heart
for j in ['sshd', 'recidive']: for j in ['sshd', 'recidive']:
# by default we have 'auto' backend ATM # by default we have 'auto' backend ATM, but some distributions can overwrite it,
self.assertTrue(['add', j, 'auto'] in comm_commands) # (e.g. fedora default is 'systemd') therefore let check it without backend...
self.assertIn(['add', j],
(cmd[:2] for cmd in comm_commands if len(cmd) == 3 and cmd[0] == 'add'))
# and warn on useDNS # and warn on useDNS
self.assertTrue(['set', j, 'usedns', 'warn'] in comm_commands) self.assertIn(['set', j, 'usedns', 'warn'], comm_commands)
self.assertTrue(['start', j] in comm_commands) self.assertIn(['start', j], comm_commands)
# last commands should be the 'start' commands # last commands should be the 'start' commands
self.assertEqual(comm_commands[-1][0], 'start') self.assertEqual(comm_commands[-1][0], 'start')
@ -655,7 +657,7 @@ class JailsReaderTest(LogCaptureTestCase):
action_name = action.getName() action_name = action.getName()
if '<blocktype>' in str(commands): if '<blocktype>' in str(commands):
# Verify that it is among cInfo # Verify that it is among cInfo
self.assertTrue('blocktype' in action._initOpts) self.assertIn('blocktype', action._initOpts)
# Verify that we have a call to set it up # Verify that we have a call to set it up
blocktype_present = False blocktype_present = False
target_command = ['set', jail_name, 'action', action_name, 'blocktype'] target_command = ['set', jail_name, 'action', action_name, 'blocktype']

View File

@ -123,7 +123,7 @@ class DatabaseTest(LogCaptureTestCase):
self.db.addLog(self.jail, self.fileContainer) self.db.addLog(self.jail, self.fileContainer)
self.assertTrue(filename in self.db.getLogPaths(self.jail)) self.assertIn(filename, self.db.getLogPaths(self.jail))
os.remove(filename) os.remove(filename)
def testUpdateLog(self): def testUpdateLog(self):
@ -318,6 +318,25 @@ class DatabaseTest(LogCaptureTestCase):
actions._Actions__checkBan() actions._Actions__checkBan()
self.assertLogged("ban ainfo %s, %s, %s, %s" % (True, True, True, True)) self.assertLogged("ban ainfo %s, %s, %s, %s" % (True, True, True, True))
def testDelAndAddJail(self):
self.testAddJail() # Add jail
# Delete jail (just disabled it):
self.db.delJail(self.jail)
jails = self.db.getJailNames()
self.assertIn(len(jails) == 1 and self.jail.name, jails)
jails = self.db.getJailNames(enabled=False)
self.assertIn(len(jails) == 1 and self.jail.name, jails)
jails = self.db.getJailNames(enabled=True)
self.assertTrue(len(jails) == 0)
# Add it again - should just enable it:
self.db.addJail(self.jail)
jails = self.db.getJailNames()
self.assertIn(len(jails) == 1 and self.jail.name, jails)
jails = self.db.getJailNames(enabled=True)
self.assertIn(len(jails) == 1 and self.jail.name, jails)
jails = self.db.getJailNames(enabled=False)
self.assertTrue(len(jails) == 0)
def testPurge(self): def testPurge(self):
if Fail2BanDb is None: # pragma: no cover if Fail2BanDb is None: # pragma: no cover
return return

View File

@ -23,6 +23,7 @@ __license__ = "GPL"
import logging import logging
import os import os
import re
import sys import sys
import unittest import unittest
import tempfile import tempfile
@ -32,6 +33,8 @@ import datetime
from glob import glob from glob import glob
from StringIO import StringIO from StringIO import StringIO
from utils import LogCaptureTestCase, logSys as DefLogSys
from ..helpers import formatExceptionInfo, mbasename, TraceBack, FormatterWithTraceBack, getLogger from ..helpers import formatExceptionInfo, mbasename, TraceBack, FormatterWithTraceBack, getLogger
from ..helpers import splitwords from ..helpers import splitwords
from ..server.datetemplate import DatePatternRegex from ..server.datetemplate import DatePatternRegex
@ -130,7 +133,7 @@ class SetupTest(unittest.TestCase):
% (sys.executable, self.setup)) % (sys.executable, self.setup))
class TestsUtilsTest(unittest.TestCase): class TestsUtilsTest(LogCaptureTestCase):
def testmbasename(self): def testmbasename(self):
self.assertEqual(mbasename("sample.py"), 'sample') self.assertEqual(mbasename("sample.py"), 'sample')
@ -165,12 +168,88 @@ class TestsUtilsTest(unittest.TestCase):
if not ('fail2ban-testcases' in s): if not ('fail2ban-testcases' in s):
# we must be calling it from setup or nosetests but using at least # we must be calling it from setup or nosetests but using at least
# nose's core etc # nose's core etc
self.assertTrue('>' in s, msg="no '>' in %r" % s) self.assertIn('>', s)
elif not ('coverage' in s): elif not ('coverage' in s):
# There is only "fail2ban-testcases" in this case, no true traceback # There is only "fail2ban-testcases" in this case, no true traceback
self.assertFalse('>' in s, msg="'>' present in %r" % s) self.assertNotIn('>', s)
self.assertTrue(':' in s, msg="no ':' in %r" % s) self.assertIn(':', s)
def _testAssertionErrorRE(self, regexp, fun, *args, **kwargs):
self.assertRaisesRegexp(AssertionError, regexp, fun, *args, **kwargs)
def testExtendedAssertRaisesRE(self):
## test _testAssertionErrorRE several fail cases:
def _key_err(msg):
raise KeyError(msg)
self.assertRaises(KeyError,
self._testAssertionErrorRE, r"^failed$",
_key_err, 'failed')
self.assertRaises(AssertionError,
self._testAssertionErrorRE, r"^failed$",
self.fail, '__failed__')
self._testAssertionErrorRE(r'failed.* does not match .*__failed__',
lambda: self._testAssertionErrorRE(r"^failed$",
self.fail, '__failed__')
)
## no exception in callable:
self.assertRaises(AssertionError,
self._testAssertionErrorRE, r"", int, 1)
self._testAssertionErrorRE(r'0 AssertionError not raised X.* does not match .*AssertionError not raised',
lambda: self._testAssertionErrorRE(r"^0 AssertionError not raised X$",
lambda: self._testAssertionErrorRE(r"", int, 1))
)
def testExtendedAssertMethods(self):
## assertIn, assertNotIn positive case:
self.assertIn('a', ['a', 'b', 'c', 'd'])
self.assertIn('a', ('a', 'b', 'c', 'd',))
self.assertIn('a', 'cba')
self.assertIn('a', (c for c in 'cba' if c != 'b'))
self.assertNotIn('a', ['b', 'c', 'd'])
self.assertNotIn('a', ('b', 'c', 'd',))
self.assertNotIn('a', 'cbd')
self.assertNotIn('a', (c.upper() for c in 'cba' if c != 'b'))
## assertIn, assertNotIn negative case:
self._testAssertionErrorRE(r"'a' unexpectedly found in 'cba'",
self.assertNotIn, 'a', 'cba')
self._testAssertionErrorRE(r"1 unexpectedly found in \[0, 1, 2\]",
self.assertNotIn, 1, xrange(3))
self._testAssertionErrorRE(r"'A' unexpectedly found in \['C', 'A'\]",
self.assertNotIn, 'A', (c.upper() for c in 'cba' if c != 'b'))
self._testAssertionErrorRE(r"'a' was not found in 'xyz'",
self.assertIn, 'a', 'xyz')
self._testAssertionErrorRE(r"5 was not found in \[0, 1, 2\]",
self.assertIn, 5, xrange(3))
self._testAssertionErrorRE(r"'A' was not found in \['C', 'B'\]",
self.assertIn, 'A', (c.upper() for c in 'cba' if c != 'a'))
## assertLogged, assertNotLogged positive case:
logSys = DefLogSys
self.pruneLog()
logSys.debug('test "xyz"')
self.assertLogged('test "xyz"')
self.assertLogged('test', 'xyz', all=True)
self.assertNotLogged('test', 'zyx', all=False)
self.assertNotLogged('test_zyx', 'zyx', all=True)
self.assertLogged('test', 'zyx', all=False)
self.pruneLog()
logSys.debug('xxxx "xxx"')
self.assertNotLogged('test "xyz"')
self.assertNotLogged('test', 'xyz', all=False)
self.assertNotLogged('test', 'xyz', 'zyx', all=True)
## assertLogged, assertNotLogged negative case:
self.pruneLog()
logSys.debug('test "xyz"')
self._testAssertionErrorRE(r"All of the .* were found present in the log",
self.assertNotLogged, 'test "xyz"')
self._testAssertionErrorRE(r"was found in the log",
self.assertNotLogged, 'test', 'xyz', all=True)
self._testAssertionErrorRE(r"was not found in the log",
self.assertLogged, 'test', 'zyx', all=True)
self._testAssertionErrorRE(r"None among .* was found in the log",
self.assertLogged, 'test_zyx', 'zyx', all=False)
self._testAssertionErrorRE(r"All of the .* were found present in the log",
self.assertNotLogged, 'test', 'xyz', all=False)
def testFormatterWithTraceBack(self): def testFormatterWithTraceBack(self):
strout = StringIO() strout = StringIO()

View File

@ -228,7 +228,7 @@ class Transmitter(TransmitterBase):
time.sleep(1) time.sleep(1)
self.assertEqual( self.assertEqual(
self.transm.proceed(["stop", self.jailName]), (0, None)) self.transm.proceed(["stop", self.jailName]), (0, None))
self.assertTrue(self.jailName not in self.server._Server__jails) self.assertNotIn(self.jailName, self.server._Server__jails)
def testStartStopAllJail(self): def testStartStopAllJail(self):
self.server.addJail("TestJail2", "auto") self.server.addJail("TestJail2", "auto")
@ -242,8 +242,8 @@ class Transmitter(TransmitterBase):
time.sleep(0.1) time.sleep(0.1)
self.assertEqual(self.transm.proceed(["stop", "all"]), (0, None)) self.assertEqual(self.transm.proceed(["stop", "all"]), (0, None))
time.sleep(1) time.sleep(1)
self.assertTrue(self.jailName not in self.server._Server__jails) self.assertNotIn(self.jailName, self.server._Server__jails)
self.assertTrue("TestJail2" not in self.server._Server__jails) self.assertNotIn("TestJail2", self.server._Server__jails)
def testJailIdle(self): def testJailIdle(self):
self.assertEqual( self.assertEqual(

View File

@ -22,6 +22,7 @@ __author__ = "Yaroslav Halchenko"
__copyright__ = "Copyright (c) 2013 Yaroslav Halchenko" __copyright__ = "Copyright (c) 2013 Yaroslav Halchenko"
__license__ = "GPL" __license__ = "GPL"
import itertools
import logging import logging
import os import os
import re import re
@ -208,16 +209,45 @@ def gatherTests(regexps=None, no_network=False):
return tests return tests
# forwards compatibility of unittest.TestCase for some early python versions #
if not hasattr(unittest.TestCase, 'assertIn'): # Forwards compatibility of unittest.TestCase for some early python versions
def __assertIn(self, a, b, msg=None): #
if a not in b: # pragma: no cover
self.fail(msg or "%r was not found in %r" % (a, b)) if not hasattr(unittest.TestCase, 'assertRaisesRegexp'):
unittest.TestCase.assertIn = __assertIn def assertRaisesRegexp(self, exccls, regexp, fun, *args, **kwargs):
def __assertNotIn(self, a, b, msg=None): try:
if a in b: # pragma: no cover fun(*args, **kwargs)
self.fail(msg or "%r was found in %r" % (a, b)) except exccls as e:
unittest.TestCase.assertNotIn = __assertNotIn if re.search(regexp, e.message) is None:
self.fail('\"%s\" does not match \"%s\"' % (regexp, e.message))
else:
self.fail('%s not raised' % getattr(exccls, '__name__'))
unittest.TestCase.assertRaisesRegexp = assertRaisesRegexp
# always custom following methods, because we use atm better version of both (support generators)
if True: ## if not hasattr(unittest.TestCase, 'assertIn'):
def assertIn(self, a, b, msg=None):
bb = b
wrap = False
if msg is None and hasattr(b, '__iter__') and not isinstance(b, basestring):
b, bb = itertools.tee(b)
wrap = True
if a not in b:
if wrap: bb = list(bb)
msg = msg or "%r was not found in %r" % (a, bb)
self.fail(msg)
unittest.TestCase.assertIn = assertIn
def assertNotIn(self, a, b, msg=None):
bb = b
wrap = False
if msg is None and hasattr(b, '__iter__') and not isinstance(b, basestring):
b, bb = itertools.tee(b)
wrap = True
if a in b:
if wrap: bb = list(bb)
msg = msg or "%r unexpectedly found in %r" % (a, bb)
self.fail(msg)
unittest.TestCase.assertNotIn = assertNotIn
class LogCaptureTestCase(unittest.TestCase): class LogCaptureTestCase(unittest.TestCase):
@ -241,6 +271,7 @@ class LogCaptureTestCase(unittest.TestCase):
def tearDown(self): def tearDown(self):
"""Call after every test case.""" """Call after every test case."""
# print "O: >>%s<<" % self._log.getvalue() # print "O: >>%s<<" % self._log.getvalue()
self.pruneLog()
logSys = getLogger("fail2ban") logSys = getLogger("fail2ban")
logSys.handlers = self._old_handlers logSys.handlers = self._old_handlers
logSys.level = self._old_level logSys.level = self._old_level
@ -248,7 +279,7 @@ class LogCaptureTestCase(unittest.TestCase):
def _is_logged(self, s): def _is_logged(self, s):
return s in self._log.getvalue() return s in self._log.getvalue()
def assertLogged(self, *s): def assertLogged(self, *s, **kwargs):
"""Assert that one of the strings was logged """Assert that one of the strings was logged
Preferable to assertTrue(self._is_logged(..))) Preferable to assertTrue(self._is_logged(..)))
@ -258,14 +289,23 @@ class LogCaptureTestCase(unittest.TestCase):
---------- ----------
s : string or list/set/tuple of strings s : string or list/set/tuple of strings
Test should succeed if string (or any of the listed) is present in the log Test should succeed if string (or any of the listed) is present in the log
all : boolean (default False) if True should fail if any of s not logged
""" """
logged = self._log.getvalue() logged = self._log.getvalue()
if not kwargs.get('all', False):
# at least one entry should be found:
for s_ in s: for s_ in s:
if s_ in logged: if s_ in logged:
return return
raise AssertionError("None among %r was found in the log: %r" % (s, logged)) if True: # pragma: no cover
self.fail("None among %r was found in the log: ===\n%s===" % (s, logged))
else:
# each entry should be found:
for s_ in s:
if s_ not in logged: # pragma: no cover
self.fail("%r was not found in the log: ===\n%s===" % (s_, logged))
def assertNotLogged(self, *s): def assertNotLogged(self, *s, **kwargs):
"""Assert that strings were not logged """Assert that strings were not logged
Parameters Parameters
@ -273,13 +313,22 @@ class LogCaptureTestCase(unittest.TestCase):
s : string or list/set/tuple of strings s : string or list/set/tuple of strings
Test should succeed if the string (or at least one of the listed) is not Test should succeed if the string (or at least one of the listed) is not
present in the log present in the log
all : boolean (default False) if True should fail if any of s logged
""" """
logged = self._log.getvalue() logged = self._log.getvalue()
if not kwargs.get('all', False):
for s_ in s: for s_ in s:
if s_ not in logged: if s_ not in logged:
return return
raise AssertionError("All of the %r were found present in the log: %r" % (s, logged)) if True: # pragma: no cover
self.fail("All of the %r were found present in the log: ===\n%s===" % (s, logged))
else:
for s_ in s:
if s_ in logged: # pragma: no cover
self.fail("%r was found in the log: ===\n%s===" % (s_, logged))
def pruneLog(self):
self._log.truncate(0)
def getLog(self): def getLog(self):
return self._log.getvalue() return self._log.getvalue()