diff --git a/fail2ban/server/database.py b/fail2ban/server/database.py index dd48ceb3..4e38b6ed 100644 --- a/fail2ban/server/database.py +++ b/fail2ban/server/database.py @@ -45,24 +45,38 @@ if sys.version_info >= (3,): logSys.error('json dumps failed: %s', e) x = '{}' return x + def _json_loads_safe(x): + try: + x = json.loads(x.decode(locale.getpreferredencoding(), 'replace')) + except Exception, e: # pragma: no cover + logSys.error('json loads failed: %s', e) + x = {} + return x else: + def _normalize(x): + if isinstance(x, dict): + return dict((_normalize(k), _normalize(v)) for k, v in x.iteritems()) + elif isinstance(x, list): + return [_normalize(element) for element in x] + elif isinstance(x, unicode): + return x.encode(locale.getpreferredencoding()) + else: + return x def _json_dumps_safe(x): try: - x = json.dumps(x, ensure_ascii=False).decode( + x = json.dumps(_normalize(x), ensure_ascii=False).decode( locale.getpreferredencoding(), 'replace') except Exception, e: # pragma: no cover logSys.error('json dumps failed: %s', e) x = '{}' return x - -def _json_loads_safe(x): - try: - x = json.loads(x.decode( - locale.getpreferredencoding(), 'replace')) - except Exception, e: # pragma: no cover - logSys.error('json loads failed: %s', e) - x = {} - return x + def _json_loads_safe(x): + try: + x = _normalize(json.loads(x.decode(locale.getpreferredencoding(), 'replace'))) + except Exception, e: # pragma: no cover + logSys.error('json loads failed: %s', e) + x = {} + return x sqlite3.register_adapter(dict, _json_dumps_safe) sqlite3.register_converter("JSON", _json_loads_safe) @@ -449,8 +463,8 @@ class Fail2BanDb(object): tickets = [] for ip, timeofban, data in self._getBans(**kwargs): #TODO: Implement data parts once arbitrary match keys completed - tickets.append(FailTicket(ip, timeofban, data['matches'])) - tickets[-1].setAttempt(data['failures']) + tickets.append(FailTicket(ip, timeofban, data.get('matches'))) + tickets[-1].setAttempt(data.get('failures', 1)) return tickets def getBansMerged(self, ip=None, jail=None, bantime=None): @@ -502,8 +516,8 @@ class Fail2BanDb(object): prev_banip = banip matches = [] failures = 0 - matches.extend(data['matches']) - failures += data['failures'] + matches.extend(data.get('matches', [])) + failures += data.get('failures', 1) prev_timeofban = timeofban ticket = FailTicket(banip, prev_timeofban, matches) ticket.setAttempt(failures) diff --git a/fail2ban/tests/databasetestcase.py b/fail2ban/tests/databasetestcase.py index b56414ae..9665e322 100644 --- a/fail2ban/tests/databasetestcase.py +++ b/fail2ban/tests/databasetestcase.py @@ -181,15 +181,31 @@ class DatabaseTest(LogCaptureTestCase): if Fail2BanDb is None: # pragma: no cover return self.testAddJail() - ticket = FailTicket("127.0.0.1", 0, ['... user "\xd1\xe2\xe5\xf2\xe0" ...']) - self.db.addBan(self.jail, ticket) + # invalid + valid, invalid + valid unicode, invalid + valid dual converted (like in filter:readline by fallback) ... + tickets = [ + FailTicket("127.0.0.1", 0, ['user "\xd1\xe2\xe5\xf2\xe0"', 'user "\xc3\xa4\xc3\xb6\xc3\xbc\xc3\x9f"']), + FailTicket("127.0.0.2", 0, ['user "\xd1\xe2\xe5\xf2\xe0"', u'user "\xc3\xa4\xc3\xb6\xc3\xbc\xc3\x9f"']), + FailTicket("127.0.0.3", 0, ['user "\xd1\xe2\xe5\xf2\xe0"', b'user "\xc3\xa4\xc3\xb6\xc3\xbc\xc3\x9f"'.decode('utf-8', 'replace')]) + ] + self.db.addBan(self.jail, tickets[0]) + self.db.addBan(self.jail, tickets[1]) + self.db.addBan(self.jail, tickets[2]) - self.assertEqual(len(self.db.getBans(jail=self.jail)), 1) - readticket = self.db.getBans(jail=self.jail)[0] + readtickets = self.db.getBans(jail=self.jail) + self.assertEqual(len(readtickets), 3) ## python 2 or 3 : + invstr = u'user "\ufffd\ufffd\ufffd\ufffd\ufffd"'.encode('utf-8', 'replace') self.assertTrue( - readticket == FailTicket("127.0.0.1", 0, [u'... user "\ufffd\ufffd\ufffd\ufffd\ufffd" ...']) - or readticket == ticket + readtickets[0] == FailTicket("127.0.0.1", 0, [invstr, 'user "\xc3\xa4\xc3\xb6\xc3\xbc\xc3\x9f"']) + or readtickets[0] == tickets[0] + ) + self.assertTrue( + readtickets[1] == FailTicket("127.0.0.2", 0, [invstr, u'user "\xc3\xa4\xc3\xb6\xc3\xbc\xc3\x9f"'.encode('utf-8', 'replace')]) + or readtickets[1] == tickets[1] + ) + self.assertTrue( + readtickets[2] == FailTicket("127.0.0.3", 0, [invstr, 'user "\xc3\xa4\xc3\xb6\xc3\xbc\xc3\x9f"']) + or readtickets[2] == tickets[2] ) def testDelBan(self):