more stable handling of json dump/load different encoded strings for older python versions;

extended test cases (more precise, python version insensitive, etc.)
pull/975/head
sebres 2015-02-25 18:16:06 +01:00 committed by sebres
parent 2bfe22aa66
commit 5ab30c88c2
2 changed files with 50 additions and 20 deletions

View File

@ -45,24 +45,38 @@ if sys.version_info >= (3,):
logSys.error('json dumps failed: %s', e) logSys.error('json dumps failed: %s', e)
x = '{}' x = '{}'
return x return x
def _json_loads_safe(x):
try:
x = json.loads(x.decode(locale.getpreferredencoding(), 'replace'))
except Exception, e: # pragma: no cover
logSys.error('json loads failed: %s', e)
x = {}
return x
else: else:
def _normalize(x):
if isinstance(x, dict):
return dict((_normalize(k), _normalize(v)) for k, v in x.iteritems())
elif isinstance(x, list):
return [_normalize(element) for element in x]
elif isinstance(x, unicode):
return x.encode(locale.getpreferredencoding())
else:
return x
def _json_dumps_safe(x): def _json_dumps_safe(x):
try: try:
x = json.dumps(x, ensure_ascii=False).decode( x = json.dumps(_normalize(x), ensure_ascii=False).decode(
locale.getpreferredencoding(), 'replace') locale.getpreferredencoding(), 'replace')
except Exception, e: # pragma: no cover except Exception, e: # pragma: no cover
logSys.error('json dumps failed: %s', e) logSys.error('json dumps failed: %s', e)
x = '{}' x = '{}'
return x return x
def _json_loads_safe(x):
def _json_loads_safe(x): try:
try: x = _normalize(json.loads(x.decode(locale.getpreferredencoding(), 'replace')))
x = json.loads(x.decode( except Exception, e: # pragma: no cover
locale.getpreferredencoding(), 'replace')) logSys.error('json loads failed: %s', e)
except Exception, e: # pragma: no cover x = {}
logSys.error('json loads failed: %s', e) return x
x = {}
return x
sqlite3.register_adapter(dict, _json_dumps_safe) sqlite3.register_adapter(dict, _json_dumps_safe)
sqlite3.register_converter("JSON", _json_loads_safe) sqlite3.register_converter("JSON", _json_loads_safe)
@ -449,8 +463,8 @@ class Fail2BanDb(object):
tickets = [] tickets = []
for ip, timeofban, data in self._getBans(**kwargs): for ip, timeofban, data in self._getBans(**kwargs):
#TODO: Implement data parts once arbitrary match keys completed #TODO: Implement data parts once arbitrary match keys completed
tickets.append(FailTicket(ip, timeofban, data['matches'])) tickets.append(FailTicket(ip, timeofban, data.get('matches')))
tickets[-1].setAttempt(data['failures']) tickets[-1].setAttempt(data.get('failures', 1))
return tickets return tickets
def getBansMerged(self, ip=None, jail=None, bantime=None): def getBansMerged(self, ip=None, jail=None, bantime=None):
@ -502,8 +516,8 @@ class Fail2BanDb(object):
prev_banip = banip prev_banip = banip
matches = [] matches = []
failures = 0 failures = 0
matches.extend(data['matches']) matches.extend(data.get('matches', []))
failures += data['failures'] failures += data.get('failures', 1)
prev_timeofban = timeofban prev_timeofban = timeofban
ticket = FailTicket(banip, prev_timeofban, matches) ticket = FailTicket(banip, prev_timeofban, matches)
ticket.setAttempt(failures) ticket.setAttempt(failures)

View File

@ -181,15 +181,31 @@ class DatabaseTest(LogCaptureTestCase):
if Fail2BanDb is None: # pragma: no cover if Fail2BanDb is None: # pragma: no cover
return return
self.testAddJail() self.testAddJail()
ticket = FailTicket("127.0.0.1", 0, ['... user "\xd1\xe2\xe5\xf2\xe0" ...']) # invalid + valid, invalid + valid unicode, invalid + valid dual converted (like in filter:readline by fallback) ...
self.db.addBan(self.jail, ticket) tickets = [
FailTicket("127.0.0.1", 0, ['user "\xd1\xe2\xe5\xf2\xe0"', 'user "\xc3\xa4\xc3\xb6\xc3\xbc\xc3\x9f"']),
FailTicket("127.0.0.2", 0, ['user "\xd1\xe2\xe5\xf2\xe0"', u'user "\xc3\xa4\xc3\xb6\xc3\xbc\xc3\x9f"']),
FailTicket("127.0.0.3", 0, ['user "\xd1\xe2\xe5\xf2\xe0"', b'user "\xc3\xa4\xc3\xb6\xc3\xbc\xc3\x9f"'.decode('utf-8', 'replace')])
]
self.db.addBan(self.jail, tickets[0])
self.db.addBan(self.jail, tickets[1])
self.db.addBan(self.jail, tickets[2])
self.assertEqual(len(self.db.getBans(jail=self.jail)), 1) readtickets = self.db.getBans(jail=self.jail)
readticket = self.db.getBans(jail=self.jail)[0] self.assertEqual(len(readtickets), 3)
## python 2 or 3 : ## python 2 or 3 :
invstr = u'user "\ufffd\ufffd\ufffd\ufffd\ufffd"'.encode('utf-8', 'replace')
self.assertTrue( self.assertTrue(
readticket == FailTicket("127.0.0.1", 0, [u'... user "\ufffd\ufffd\ufffd\ufffd\ufffd" ...']) readtickets[0] == FailTicket("127.0.0.1", 0, [invstr, 'user "\xc3\xa4\xc3\xb6\xc3\xbc\xc3\x9f"'])
or readticket == ticket or readtickets[0] == tickets[0]
)
self.assertTrue(
readtickets[1] == FailTicket("127.0.0.2", 0, [invstr, u'user "\xc3\xa4\xc3\xb6\xc3\xbc\xc3\x9f"'.encode('utf-8', 'replace')])
or readtickets[1] == tickets[1]
)
self.assertTrue(
readtickets[2] == FailTicket("127.0.0.3", 0, [invstr, 'user "\xc3\xa4\xc3\xb6\xc3\xbc\xc3\x9f"'])
or readtickets[2] == tickets[2]
) )
def testDelBan(self): def testDelBan(self):