filtertestcase.py: byte related copy of lines in tests (locale independent); closes gh-2936

pull/3435/merge
sebres 2023-04-04 12:46:35 +02:00
parent a9b30eb86e
commit 56485c8548
2 changed files with 37 additions and 27 deletions

View File

@ -98,6 +98,8 @@ if sys.version_info >= (3,): # pragma: 2.x no cover
if not isinstance(x, bytes): if not isinstance(x, bytes):
return str(x) return str(x)
return x.decode(PREFER_ENC, 'replace') return x.decode(PREFER_ENC, 'replace')
def uni_bytes(x):
return bytes(x, 'UTF-8')
else: # pragma: 3.x no cover else: # pragma: 3.x no cover
def uni_decode(x, enc=PREFER_ENC, errors='strict'): def uni_decode(x, enc=PREFER_ENC, errors='strict'):
try: try:
@ -115,6 +117,7 @@ else: # pragma: 3.x no cover
return x.encode(PREFER_ENC, 'replace') return x.encode(PREFER_ENC, 'replace')
else: else:
uni_string = str uni_string = str
uni_bytes = bytes
def _as_bool(val): def _as_bool(val):

View File

@ -36,6 +36,7 @@ try:
except ImportError: except ImportError:
journal = None journal = None
from ..helpers import uni_bytes
from ..server.jail import Jail from ..server.jail import Jail
from ..server.filterpoll import FilterPoll from ..server.filterpoll import FilterPoll
from ..server.filter import FailTicket, Filter, FileFilter, FileContainer from ..server.filter import FailTicket, Filter, FileFilter, FileContainer
@ -109,6 +110,7 @@ class _tmSerial():
return "%s%02u" % (c._str_s, sec) return "%s%02u" % (c._str_s, sec)
_tm = _tmSerial._tm _tm = _tmSerial._tm
_tmb = lambda t: uni_bytes(_tm(t))
def _assert_equal_entries(utest, found, output, count=None): def _assert_equal_entries(utest, found, output, count=None):
@ -204,6 +206,8 @@ def _copy_lines_between_files(in_, fout, n=None, skip=0, mode='a', terminal_line
# on old Python st_mtime is int, so we should give at least 1 sec so # on old Python st_mtime is int, so we should give at least 1 sec so
# polling filter could detect the change # polling filter could detect the change
mtimesleep() mtimesleep()
if terminal_line is not None:
terminal_line = uni_bytes(terminal_line)
if isinstance(in_, str): # pragma: no branch - only used with str in test cases if isinstance(in_, str): # pragma: no branch - only used with str in test cases
fin = open(in_, 'rb') fin = open(in_, 'rb')
else: else:
@ -213,18 +217,21 @@ def _copy_lines_between_files(in_, fout, n=None, skip=0, mode='a', terminal_line
fin.readline() fin.readline()
# Read # Read
i = 0 i = 0
if not lines: lines = [] if lines:
lines = map(uni_bytes, lines)
else:
lines = []
while n is None or i < n: while n is None or i < n:
l = fin.readline().decode('UTF-8', 'replace').rstrip('\r\n') l = fin.readline().rstrip(b'\r\n')
if terminal_line is not None and l == terminal_line: if terminal_line is not None and l == terminal_line:
break break
lines.append(l) lines.append(l)
i += 1 i += 1
# Write: all at once and flush # Write: all at once and flush
if isinstance(fout, str): if isinstance(fout, str):
fout = open(fout, mode) fout = open(fout, mode+'b')
DefLogSys.debug(' ++ write %d test lines', len(lines)) DefLogSys.debug(' ++ write %d test lines', len(lines))
fout.write('\n'.join(lines)+'\n') fout.write(b'\n'.join(lines)+b'\n')
fout.flush() fout.flush()
if isinstance(in_, str): # pragma: no branch - only used with str in test cases if isinstance(in_, str): # pragma: no branch - only used with str in test cases
# Opened earlier, therefore must close it # Opened earlier, therefore must close it
@ -711,7 +718,7 @@ class LogFileFilterPoll(unittest.TestCase):
self.filter.setDatePattern(r'^%ExY-%Exm-%Exd %ExH:%ExM:%ExS') self.filter.setDatePattern(r'^%ExY-%Exm-%Exd %ExH:%ExM:%ExS')
fname = tempfile.mktemp(prefix='tmp_fail2ban', suffix='.log') fname = tempfile.mktemp(prefix='tmp_fail2ban', suffix='.log')
time = 1417512352 time = 1417512352
f = open(fname, 'w') f = open(fname, 'wb')
fc = None fc = None
try: try:
fc = FileContainer(fname, self.filter.getLogEncoding()) fc = FileContainer(fname, self.filter.getLogEncoding())
@ -722,7 +729,7 @@ class LogFileFilterPoll(unittest.TestCase):
fc.setPos(0); self.filter.seekToTime(fc, time) fc.setPos(0); self.filter.seekToTime(fc, time)
self.assertEqual(fc.getPos(), 0) self.assertEqual(fc.getPos(), 0)
# one entry with exact time: # one entry with exact time:
f.write("%s [sshd] error: PAM: failure len 1\n" % _tm(time)) f.write(b"%s [sshd] error: PAM: failure len 1\n" % _tmb(time))
f.flush() f.flush()
fc.setPos(0); self.filter.seekToTime(fc, time) fc.setPos(0); self.filter.seekToTime(fc, time)
@ -734,7 +741,7 @@ class LogFileFilterPoll(unittest.TestCase):
fc.open() fc.open()
# no time - nothing should be found : # no time - nothing should be found :
for i in xrange(10): for i in xrange(10):
f.write("[sshd] error: PAM: failure len 1\n") f.write(b"[sshd] error: PAM: failure len 1\n")
f.flush() f.flush()
fc.setPos(0); self.filter.seekToTime(fc, time) fc.setPos(0); self.filter.seekToTime(fc, time)
@ -745,38 +752,38 @@ class LogFileFilterPoll(unittest.TestCase):
fc = FileContainer(fname, self.filter.getLogEncoding()) fc = FileContainer(fname, self.filter.getLogEncoding())
fc.open() fc.open()
# one entry with smaller time: # one entry with smaller time:
f.write("%s [sshd] error: PAM: failure len 2\n" % _tm(time - 10)) f.write(b"%s [sshd] error: PAM: failure len 2\n" % _tmb(time - 10))
f.flush() f.flush()
fc.setPos(0); self.filter.seekToTime(fc, time) fc.setPos(0); self.filter.seekToTime(fc, time)
self.assertEqual(fc.getPos(), 53) self.assertEqual(fc.getPos(), 53)
# two entries with smaller time: # two entries with smaller time:
f.write("%s [sshd] error: PAM: failure len 3 2 1\n" % _tm(time - 9)) f.write(b"%s [sshd] error: PAM: failure len 3 2 1\n" % _tmb(time - 9))
f.flush() f.flush()
fc.setPos(0); self.filter.seekToTime(fc, time) fc.setPos(0); self.filter.seekToTime(fc, time)
self.assertEqual(fc.getPos(), 110) self.assertEqual(fc.getPos(), 110)
# check move after end (all of time smaller): # check move after end (all of time smaller):
f.write("%s [sshd] error: PAM: failure\n" % _tm(time - 1)) f.write(b"%s [sshd] error: PAM: failure\n" % _tmb(time - 1))
f.flush() f.flush()
self.assertEqual(fc.getFileSize(), 157) self.assertEqual(fc.getFileSize(), 157)
fc.setPos(0); self.filter.seekToTime(fc, time) fc.setPos(0); self.filter.seekToTime(fc, time)
self.assertEqual(fc.getPos(), 157) self.assertEqual(fc.getPos(), 157)
# stil one exact line: # stil one exact line:
f.write("%s [sshd] error: PAM: Authentication failure\n" % _tm(time)) f.write(b"%s [sshd] error: PAM: Authentication failure\n" % _tmb(time))
f.write("%s [sshd] error: PAM: failure len 1\n" % _tm(time)) f.write(b"%s [sshd] error: PAM: failure len 1\n" % _tmb(time))
f.flush() f.flush()
fc.setPos(0); self.filter.seekToTime(fc, time) fc.setPos(0); self.filter.seekToTime(fc, time)
self.assertEqual(fc.getPos(), 157) self.assertEqual(fc.getPos(), 157)
# add something hereafter: # add something hereafter:
f.write("%s [sshd] error: PAM: failure len 3 2 1\n" % _tm(time + 2)) f.write(b"%s [sshd] error: PAM: failure len 3 2 1\n" % _tmb(time + 2))
f.write("%s [sshd] error: PAM: Authentication failure\n" % _tm(time + 3)) f.write(b"%s [sshd] error: PAM: Authentication failure\n" % _tmb(time + 3))
f.flush() f.flush()
fc.setPos(0); self.filter.seekToTime(fc, time) fc.setPos(0); self.filter.seekToTime(fc, time)
self.assertEqual(fc.getPos(), 157) self.assertEqual(fc.getPos(), 157)
# add something hereafter: # add something hereafter:
f.write("%s [sshd] error: PAM: failure\n" % _tm(time + 9)) f.write(b"%s [sshd] error: PAM: failure\n" % _tmb(time + 9))
f.write("%s [sshd] error: PAM: failure len 4 3 2\n" % _tm(time + 9)) f.write(b"%s [sshd] error: PAM: failure len 4 3 2\n" % _tmb(time + 9))
f.flush() f.flush()
fc.setPos(0); self.filter.seekToTime(fc, time) fc.setPos(0); self.filter.seekToTime(fc, time)
self.assertEqual(fc.getPos(), 157) self.assertEqual(fc.getPos(), 157)
@ -797,7 +804,7 @@ class LogFileFilterPoll(unittest.TestCase):
self.filter.setDatePattern(r'^%ExY-%Exm-%Exd %ExH:%ExM:%ExS') self.filter.setDatePattern(r'^%ExY-%Exm-%Exd %ExH:%ExM:%ExS')
fname = tempfile.mktemp(prefix='tmp_fail2ban', suffix='.log') fname = tempfile.mktemp(prefix='tmp_fail2ban', suffix='.log')
time = 1417512352 time = 1417512352
f = open(fname, 'w') f = open(fname, 'wb')
fc = None fc = None
count = 1000 if unittest.F2B.fast else 10000 count = 1000 if unittest.F2B.fast else 10000
try: try:
@ -808,14 +815,14 @@ class LogFileFilterPoll(unittest.TestCase):
# write lines with smaller as search time: # write lines with smaller as search time:
t = time - count - 1 t = time - count - 1
for i in xrange(count): for i in xrange(count):
f.write("%s [sshd] error: PAM: failure\n" % _tm(t)) f.write(b"%s [sshd] error: PAM: failure\n" % _tmb(t))
t += 1 t += 1
f.flush() f.flush()
fc.setPos(0); self.filter.seekToTime(fc, time) fc.setPos(0); self.filter.seekToTime(fc, time)
self.assertEqual(fc.getPos(), 47*count) self.assertEqual(fc.getPos(), 47*count)
# write lines with exact search time: # write lines with exact search time:
for i in xrange(10): for i in xrange(10):
f.write("%s [sshd] error: PAM: failure\n" % _tm(time)) f.write(b"%s [sshd] error: PAM: failure\n" % _tmb(time))
f.flush() f.flush()
fc.setPos(0); self.filter.seekToTime(fc, time) fc.setPos(0); self.filter.seekToTime(fc, time)
self.assertEqual(fc.getPos(), 47*count) self.assertEqual(fc.getPos(), 47*count)
@ -825,7 +832,7 @@ class LogFileFilterPoll(unittest.TestCase):
t = time+1 t = time+1
for i in xrange(count//500): for i in xrange(count//500):
for j in xrange(500): for j in xrange(500):
f.write("%s [sshd] error: PAM: failure\n" % _tm(t)) f.write(b"%s [sshd] error: PAM: failure\n" % _tmb(t))
t += 1 t += 1
f.flush() f.flush()
fc.setPos(0); self.filter.seekToTime(fc, time) fc.setPos(0); self.filter.seekToTime(fc, time)
@ -847,7 +854,7 @@ class LogFileMonitor(LogCaptureTestCase):
LogCaptureTestCase.setUp(self) LogCaptureTestCase.setUp(self)
self.filter = self.name = 'NA' self.filter = self.name = 'NA'
_, self.name = tempfile.mkstemp('fail2ban', 'monitorfailures') _, self.name = tempfile.mkstemp('fail2ban', 'monitorfailures')
self.file = open(self.name, 'a') self.file = open(self.name, 'ab')
self.filter = FilterPoll(DummyJail()) self.filter = FilterPoll(DummyJail())
self.filter.addLogPath(self.name, autoSeek=False) self.filter.addLogPath(self.name, autoSeek=False)
self.filter.active = True self.filter.active = True
@ -899,7 +906,7 @@ class LogFileMonitor(LogCaptureTestCase):
_org_processLine = self.filter.processLine _org_processLine = self.filter.processLine
self.filter.processLine = None self.filter.processLine = None
for i in range(100): for i in range(100):
self.file.write("line%d\n" % 1) self.file.write(b"line%d\n" % 1)
self.file.flush() self.file.flush()
for i in range(100): for i in range(100):
self.filter.getFailures(self.name) self.filter.getFailures(self.name)
@ -910,7 +917,7 @@ class LogFileMonitor(LogCaptureTestCase):
self.filter.idle = False self.filter.idle = False
self.filter.getFailures(self.name) self.filter.getFailures(self.name)
self.filter.processLine = _org_processLine self.filter.processLine = _org_processLine
self.file.write("line%d\n" % 1) self.file.write(b"line%d\n" % 1)
self.file.flush() self.file.flush()
self.filter.getFailures(self.name) self.filter.getFailures(self.name)
self.assertNotLogged('Failed to process line:') self.assertNotLogged('Failed to process line:')
@ -934,7 +941,7 @@ class LogFileMonitor(LogCaptureTestCase):
mtimesleep() # to guarantee freshier mtime mtimesleep() # to guarantee freshier mtime
for i in range(4): # few changes for i in range(4): # few changes
# unless we write into it # unless we write into it
self.file.write("line%d\n" % i) self.file.write(b"line%d\n" % i)
self.file.flush() self.file.flush()
self.assertTrue(self.isModified()) self.assertTrue(self.isModified())
self.assertTrue(self.notModified()) self.assertTrue(self.notModified())
@ -943,11 +950,11 @@ class LogFileMonitor(LogCaptureTestCase):
# we are not signaling as modified whenever # we are not signaling as modified whenever
# it gets away # it gets away
self.assertTrue(self.notModified(1)) self.assertTrue(self.notModified(1))
f = open(self.name, 'a') f = open(self.name, 'ab')
self.assertTrue(self.isModified()) self.assertTrue(self.isModified())
self.assertTrue(self.notModified()) self.assertTrue(self.notModified())
mtimesleep() mtimesleep()
f.write("line%d\n" % i) f.write(b"line%d\n" % i)
f.flush() f.flush()
self.assertTrue(self.isModified()) self.assertTrue(self.isModified())
self.assertTrue(self.notModified()) self.assertTrue(self.notModified())
@ -1077,7 +1084,7 @@ def get_monitor_failures_testcase(Filter_):
self.filter = self.name = 'NA' self.filter = self.name = 'NA'
self.name = '%s-%d' % (testclass_name, self.count) self.name = '%s-%d' % (testclass_name, self.count)
MonitorFailures.count += 1 # so we have unique filenames across tests MonitorFailures.count += 1 # so we have unique filenames across tests
self.file = open(self.name, 'a') self.file = open(self.name, 'ab')
self.jail = DummyJail() self.jail = DummyJail()
self.filter = Filter_(self.jail) self.filter = Filter_(self.jail)
# mock-up common error to find catched unhandled exceptions: # mock-up common error to find catched unhandled exceptions: