Merge pull request #483 from grooverdan/more-tests

More tests and slight RF of tests to provide base log capturing unittest clas
pull/489/merge
Yaroslav Halchenko 2013-12-10 18:28:28 -08:00
commit 60699a6585
6 changed files with 118 additions and 40 deletions

View File

@ -173,7 +173,9 @@ tests.addTest(unittest.makeSuite(misctestcase.CustomDateFormatsTest))
# Filter # Filter
if not opts.no_network: if not opts.no_network:
tests.addTest(unittest.makeSuite(filtertestcase.IgnoreIP)) tests.addTest(unittest.makeSuite(filtertestcase.IgnoreIPDNS))
tests.addTest(unittest.makeSuite(filtertestcase.IgnoreIP))
tests.addTest(unittest.makeSuite(filtertestcase.BasicFilter))
tests.addTest(unittest.makeSuite(filtertestcase.LogFile)) tests.addTest(unittest.makeSuite(filtertestcase.LogFile))
tests.addTest(unittest.makeSuite(filtertestcase.LogFileMonitor)) tests.addTest(unittest.makeSuite(filtertestcase.LogFileMonitor))
if not opts.no_network: if not opts.no_network:

View File

@ -446,7 +446,7 @@ class FileFilter(Filter):
self._delLogPath(path) self._delLogPath(path)
return return
def _delLogPath(self, path): def _delLogPath(self, path): # pragma: no cover - overwritten function
# nothing to do by default # nothing to do by default
# to be overridden by backends # to be overridden by backends
pass pass
@ -568,6 +568,9 @@ class FileContainer:
def getFileName(self): def getFileName(self):
return self.__filename return self.__filename
def getPos(self):
return self.__pos
def open(self): def open(self):
self.__handler = open(self.__filename) self.__handler = open(self.__filename)
# Set the file descriptor to be FD_CLOEXEC # Set the file descriptor to be FD_CLOEXEC

View File

@ -24,40 +24,23 @@ __author__ = "Cyril Jaquier"
__copyright__ = "Copyright (c) 2004 Cyril Jaquier" __copyright__ = "Copyright (c) 2004 Cyril Jaquier"
__license__ = "GPL" __license__ = "GPL"
import unittest, time import time
import logging, sys import logging, sys
from server.action import Action from server.action import Action
from StringIO import StringIO from utils import LogCaptureTestCase
class ExecuteAction(unittest.TestCase): class ExecuteAction(LogCaptureTestCase):
def setUp(self): def setUp(self):
"""Call before every test case.""" """Call before every test case."""
self.__action = Action("Test") self.__action = Action("Test")
LogCaptureTestCase.setUp(self)
# For extended testing of what gets output into logging
# system, we will redirect it to a string
logSys = logging.getLogger("fail2ban")
# Keep old settings
self._old_level = logSys.level
self._old_handlers = logSys.handlers
# Let's log everything into a string
self._log = StringIO()
logSys.handlers = [logging.StreamHandler(self._log)]
logSys.setLevel(getattr(logging, 'DEBUG'))
def tearDown(self): def tearDown(self):
"""Call after every test case.""" """Call after every test case."""
# print "O: >>%s<<" % self._log.getvalue() LogCaptureTestCase.tearDown(self)
logSys = logging.getLogger("fail2ban")
logSys.handlers = self._old_handlers
logSys.level = self._old_level
self.__action.execActionStop() self.__action.execActionStop()
def _is_logged(self, s):
return s in self._log.getvalue()
def testNameChange(self): def testNameChange(self):
self.assertEqual(self.__action.getName(), "Test") self.assertEqual(self.__action.getName(), "Test")
self.__action.setName("Tricky Test") self.__action.setName("Tricky Test")

View File

@ -31,15 +31,16 @@ import tempfile
from server.jail import Jail from server.jail import Jail
from server.filterpoll import FilterPoll from server.filterpoll import FilterPoll
from server.filter import FileFilter, DNSUtils from server.filter import Filter, FileFilter, DNSUtils
from server.failmanager import FailManager from server.failmanager import FailManager
from server.failmanager import FailManagerEmpty from server.failmanager import FailManagerEmpty
from dummyjail import DummyJail
# #
# Useful helpers # Useful helpers
# #
from utils import mtimesleep from utils import mtimesleep, LogCaptureTestCase
# yoh: per Steven Hiscocks's insight while troubleshooting # yoh: per Steven Hiscocks's insight while troubleshooting
# https://github.com/fail2ban/fail2ban/issues/103#issuecomment-15542836 # https://github.com/fail2ban/fail2ban/issues/103#issuecomment-15542836
@ -144,14 +145,27 @@ def _copy_lines_between_files(fin, fout, n=None, skip=0, mode='a', terminal_line
# Actual tests # Actual tests
# #
class IgnoreIP(unittest.TestCase): class BasicFilter(unittest.TestCase):
def setUp(self):
self.filter = Filter('name')
def testGetSetUseDNS(self):
# default is warn
self.assertEqual(self.filter.getUseDns(), 'warn')
self.filter.setUseDns(True)
self.assertEqual(self.filter.getUseDns(), 'yes')
self.filter.setUseDns(False)
self.assertEqual(self.filter.getUseDns(), 'no')
class IgnoreIP(LogCaptureTestCase):
def setUp(self): def setUp(self):
"""Call before every test case.""" """Call before every test case."""
self.filter = FileFilter(None) LogCaptureTestCase.setUp(self)
self.jail = DummyJail()
def tearDown(self): self.filter = FileFilter(self.jail)
"""Call after every test case."""
def testIgnoreIPOK(self): def testIgnoreIPOK(self):
ipList = "127.0.0.1", "192.168.0.1", "255.255.255.255", "99.99.99.99" ipList = "127.0.0.1", "192.168.0.1", "255.255.255.255", "99.99.99.99"
@ -159,19 +173,47 @@ class IgnoreIP(unittest.TestCase):
self.filter.addIgnoreIP(ip) self.filter.addIgnoreIP(ip)
self.assertTrue(self.filter.inIgnoreIPList(ip)) self.assertTrue(self.filter.inIgnoreIPList(ip))
# Test DNS
self.filter.addIgnoreIP("www.epfl.ch")
self.assertTrue(self.filter.inIgnoreIPList("128.178.50.12"))
def testIgnoreIPNOK(self): def testIgnoreIPNOK(self):
ipList = "", "999.999.999.999", "abcdef", "192.168.0." ipList = "", "999.999.999.999", "abcdef", "192.168.0."
for ip in ipList: for ip in ipList:
self.filter.addIgnoreIP(ip) self.filter.addIgnoreIP(ip)
self.assertFalse(self.filter.inIgnoreIPList(ip)) self.assertFalse(self.filter.inIgnoreIPList(ip))
def testIgnoreIPCIDR(self):
self.filter.addIgnoreIP('192.168.1.0/25')
self.assertTrue(self.filter.inIgnoreIPList('192.168.1.0'))
self.assertTrue(self.filter.inIgnoreIPList('192.168.1.1'))
self.assertTrue(self.filter.inIgnoreIPList('192.168.1.127'))
self.assertFalse(self.filter.inIgnoreIPList('192.168.1.128'))
self.assertFalse(self.filter.inIgnoreIPList('192.168.1.255'))
self.assertFalse(self.filter.inIgnoreIPList('192.168.0.255'))
def testIgnoreInProcessLine(self):
self.filter.addIgnoreIP('192.168.1.0/25')
self.filter.addFailRegex('<HOST>')
self.filter.processLineAndAdd('Thu Jul 11 01:21:43 2013 192.168.1.32')
self.assertTrue(self._is_logged('Ignore 192.168.1.32'))
def testIgnoreAddBannedIP(self):
self.filter.addIgnoreIP('192.168.1.0/25')
self.filter.addBannedIP('192.168.1.32')
self.assertFalse(self._is_logged('Ignore 192.168.1.32'))
self.assertTrue(self._is_logged('Requested to manually ban an ignored IP 192.168.1.32. User knows best. Proceeding to ban it.'))
class IgnoreIPDNS(IgnoreIP):
def testIgnoreIPDNSOK(self):
self.filter.addIgnoreIP("www.epfl.ch")
self.assertTrue(self.filter.inIgnoreIPList("128.178.50.12"))
def testIgnoreIPDNSNOK(self):
# Test DNS # Test DNS
self.filter.addIgnoreIP("www.epfl.ch") self.filter.addIgnoreIP("www.epfl.ch")
self.assertFalse(self.filter.inIgnoreIPList("127.177.50.10")) self.assertFalse(self.filter.inIgnoreIPList("127.177.50.10"))
self.assertFalse(self.filter.inIgnoreIPList("128.178.50.11"))
self.assertFalse(self.filter.inIgnoreIPList("128.178.50.13"))
class LogFile(unittest.TestCase): class LogFile(unittest.TestCase):
@ -194,11 +236,12 @@ class LogFile(unittest.TestCase):
self.assertTrue(self.filter.isModified(LogFile.FILENAME)) self.assertTrue(self.filter.isModified(LogFile.FILENAME))
class LogFileMonitor(unittest.TestCase): class LogFileMonitor(LogCaptureTestCase):
"""Few more tests for FilterPoll API """Few more tests for FilterPoll API
""" """
def setUp(self): def setUp(self):
"""Call before every test case.""" """Call before every test case."""
LogCaptureTestCase.setUp(self)
self.filter = self.name = 'NA' self.filter = self.name = 'NA'
_, self.name = tempfile.mkstemp('fail2ban', 'monitorfailures') _, self.name = tempfile.mkstemp('fail2ban', 'monitorfailures')
self.file = open(self.name, 'a') self.file = open(self.name, 'a')
@ -208,6 +251,7 @@ class LogFileMonitor(unittest.TestCase):
self.filter.addFailRegex("(?:(?:Authentication failure|Failed [-/\w+]+) for(?: [iI](?:llegal|nvalid) user)?|[Ii](?:llegal|nvalid) user|ROOT LOGIN REFUSED) .*(?: from|FROM) <HOST>") self.filter.addFailRegex("(?:(?:Authentication failure|Failed [-/\w+]+) for(?: [iI](?:llegal|nvalid) user)?|[Ii](?:llegal|nvalid) user|ROOT LOGIN REFUSED) .*(?: from|FROM) <HOST>")
def tearDown(self): def tearDown(self):
LogCaptureTestCase.tearDown(self)
_killfile(self.file, self.name) _killfile(self.file, self.name)
pass pass
@ -225,6 +269,21 @@ class LogFileMonitor(unittest.TestCase):
# shorter wait time for not modified status # shorter wait time for not modified status
return not self.isModified(0.4) return not self.isModified(0.4)
def testNoLogFile(self):
os.chmod(self.name, 0)
self.filter.getFailures(self.name)
self.assertTrue(self._is_logged('Unable to open %s' % self.name))
def testRemovingFailRegex(self):
self.filter.delFailRegex(0)
self.assertFalse(self._is_logged('Cannot remove regular expression. Index 0 is not valid'))
self.filter.delFailRegex(0)
self.assertTrue(self._is_logged('Cannot remove regular expression. Index 0 is not valid'))
def testRemovingIgnoreRegex(self):
self.filter.delIgnoreRegex(0)
self.assertTrue(self._is_logged('Cannot remove regular expression. Index 0 is not valid'))
def testNewChangeViaIsModified(self): def testNewChangeViaIsModified(self):
# it is a brand new one -- so first we think it is modified # it is a brand new one -- so first we think it is modified
self.assertTrue(self.isModified()) self.assertTrue(self.isModified())
@ -306,7 +365,6 @@ class LogFileMonitor(unittest.TestCase):
from threading import Lock from threading import Lock
from dummyjail import DummyJail
def get_monitor_failures_testcase(Filter_): def get_monitor_failures_testcase(Filter_):
"""Generator of TestCase's for different filters/backends """Generator of TestCase's for different filters/backends
@ -545,7 +603,13 @@ class GetFailures(unittest.TestCase):
def tearDown(self): def tearDown(self):
"""Call after every test case.""" """Call after every test case."""
def testTail(self):
self.filter.addLogPath(LogFile.FILENAME, tail=True)
self.assertEqual(self.filter.getLogPath()[-1].getPos(), 1653)
self.filter.getLogPath()[-1].close()
self.assertEqual(self.filter.getLogPath()[-1].readline(), "")
self.filter.delLogPath(LogFile.FILENAME)
self.assertEqual(self.filter.getLogPath(),[])
def testGetFailures01(self, filename=None, failures=None): def testGetFailures01(self, filename=None, failures=None):
filename = filename or GetFailures.FILENAME_01 filename = filename or GetFailures.FILENAME_01

View File

@ -123,7 +123,6 @@ def testSampleRegexsFactory(name):
regexsUsed.add(failregex) regexsUsed.add(failregex)
# TODO: Remove exception handling once all regexs have samples
for failRegexIndex, failRegex in enumerate(self.filter.getFailRegex()): for failRegexIndex, failRegex in enumerate(self.filter.getFailRegex()):
self.assertTrue( self.assertTrue(
failRegexIndex in regexsUsed, failRegexIndex in regexsUsed,

View File

@ -22,8 +22,9 @@ __author__ = "Yaroslav Halchenko"
__copyright__ = "Copyright (c) 2013 Yaroslav Halchenko" __copyright__ = "Copyright (c) 2013 Yaroslav Halchenko"
__license__ = "GPL" __license__ = "GPL"
import logging, os, re, tempfile, sys, time, traceback import unittest, logging, os, re, tempfile, sys, time, traceback
from os.path import basename, dirname from os.path import basename, dirname
from StringIO import StringIO
# #
# Following "traceback" functions are adopted from PyMVPA distributed # Following "traceback" functions are adopted from PyMVPA distributed
@ -105,3 +106,29 @@ def mtimesleep():
# no sleep now should be necessary since polling tracks now not only # no sleep now should be necessary since polling tracks now not only
# mtime but also ino and size # mtime but also ino and size
pass pass
class LogCaptureTestCase(unittest.TestCase):
def setUp(self):
# For extended testing of what gets output into logging
# system, we will redirect it to a string
logSys = logging.getLogger("fail2ban")
# Keep old settings
self._old_level = logSys.level
self._old_handlers = logSys.handlers
# Let's log everything into a string
self._log = StringIO()
logSys.handlers = [logging.StreamHandler(self._log)]
logSys.setLevel(getattr(logging, 'DEBUG'))
def tearDown(self):
"""Call after every test case."""
# print "O: >>%s<<" % self._log.getvalue()
logSys = logging.getLogger("fail2ban")
logSys.handlers = self._old_handlers
logSys.level = self._old_level
def _is_logged(self, s):
return s in self._log.getvalue()