testcases extended and observer optimized to run test cases faster;

code review
pull/716/head
sebres 2014-06-07 04:37:06 +02:00
parent e7bd8ed619
commit bb0a181056
6 changed files with 154 additions and 44 deletions

View File

@ -292,7 +292,7 @@ class Actions(JailThread, Mapping):
if self.__banManager.addBanTicket(bTicket): if self.__banManager.addBanTicket(bTicket):
# report ticket to observer, to check time should be increased and hereafter observer writes ban to database (asynchronous) # report ticket to observer, to check time should be increased and hereafter observer writes ban to database (asynchronous)
if not bTicket.getRestored(): if Observers.Main and not bTicket.getRestored():
Observers.Main.add('banFound', bTicket, self._jail, btime) Observers.Main.add('banFound', bTicket, self._jail, btime)
logSys.notice("[%s] %sBan %s (%d # %s -> %s)", self._jail.name, ('' if not bTicket.getRestored() else 'Restore '), logSys.notice("[%s] %sBan %s (%d # %s -> %s)", self._jail.name, ('' if not bTicket.getRestored() else 'Restore '),
aInfo["ip"], bTicket.getBanCount()+(1 if not bTicket.getRestored() else 0), *logtime) aInfo["ip"], bTicket.getBanCount()+(1 if not bTicket.getRestored() else 0), *logtime)

View File

@ -400,12 +400,12 @@ class Fail2BanDb(object):
#TODO: Implement data parts once arbitrary match keys completed #TODO: Implement data parts once arbitrary match keys completed
cur.execute( cur.execute(
"INSERT INTO bans(jail, ip, timeofban, bantime, bancount, data) VALUES(?, ?, ?, ?, ?, ?)", "INSERT INTO bans(jail, ip, timeofban, bantime, bancount, data) VALUES(?, ?, ?, ?, ?, ?)",
(jail.name, ticket.getIP(), ticket.getTime(), ticket.getBanTime(jail.actions.getBanTime()), ticket.getBanCount() + 1, (jail.name, ticket.getIP(), ticket.getTime(), ticket.getBanTime(jail.actions.getBanTime()), ticket.getBanCount(),
{"matches": ticket.getMatches(), {"matches": ticket.getMatches(),
"failures": ticket.getAttempt()})) "failures": ticket.getAttempt()}))
cur.execute( cur.execute(
"INSERT OR REPLACE INTO bips(ip, jail, timeofban, bantime, bancount, data) VALUES(?, ?, ?, ?, ?, ?)", "INSERT OR REPLACE INTO bips(ip, jail, timeofban, bantime, bancount, data) VALUES(?, ?, ?, ?, ?, ?)",
(ticket.getIP(), jail.name, ticket.getTime(), ticket.getBanTime(jail.actions.getBanTime()), ticket.getBanCount() + 1, (ticket.getIP(), jail.name, ticket.getTime(), ticket.getBanTime(jail.actions.getBanTime()), ticket.getBanCount(),
{"matches": ticket.getMatches(), {"matches": ticket.getMatches(),
"failures": ticket.getAttempt()})) "failures": ticket.getAttempt()}))
@ -558,11 +558,6 @@ class Fail2BanDb(object):
return cur.execute(query, queryArgs) return cur.execute(query, queryArgs)
def getCurrentBans(self, jail = None, ip = None, forbantime=None, fromtime=None): def getCurrentBans(self, jail = None, ip = None, forbantime=None, fromtime=None):
if forbantime is None and jail is not None:
cacheKey = (ip, jail)
if cacheKey in self._bansMergedCache:
return self._bansMergedCache[cacheKey]
tickets = [] tickets = []
ticket = None ticket = None
@ -584,8 +579,6 @@ class Fail2BanDb(object):
ticket.setAttempt(failures) ticket.setAttempt(failures)
tickets.append(ticket) tickets.append(ticket)
if forbantime is None and jail is not None:
self._bansMergedCache[cacheKey] = tickets if ip is None else ticket
return tickets if ip is None else ticket return tickets if ip is None else ticket
def _cleanjails(self, cur): def _cleanjails(self, cur):

View File

@ -427,7 +427,8 @@ class Filter(JailThread):
tick = FailTicket(ip, unixTime, lines) tick = FailTicket(ip, unixTime, lines)
self.failManager.addFailure(tick) self.failManager.addFailure(tick)
# report to observer - failure was found, for possibly increasing of it retry counter (asynchronous) # report to observer - failure was found, for possibly increasing of it retry counter (asynchronous)
Observers.Main.add('failureFound', self.failManager, self.jail, tick) if Observers.Main:
Observers.Main.add('failureFound', self.failManager, self.jail, tick)
## ##
# Returns true if the line should be ignored. # Returns true if the line should be ignored.

View File

@ -70,7 +70,6 @@ class ObserverThread(threading.Thread):
## but so we can later do some service "events" occurred infrequently directly in main loop of observer (not using queue) ## but so we can later do some service "events" occurred infrequently directly in main loop of observer (not using queue)
self.sleeptime = 60 self.sleeptime = 60
# #
self._started = False
self._timers = {} self._timers = {}
self._paused = False self._paused = False
self.__db = None self.__db = None
@ -124,11 +123,11 @@ class ObserverThread(threading.Thread):
t.start() t.start()
def pulse_notify(self): def pulse_notify(self):
"""Notify wakeup (sets and resets notify event) """Notify wakeup (sets /and resets/ notify event)
""" """
if not self._paused and self._notify: if not self._paused and self._notify:
self._notify.set() self._notify.set()
self._notify.clear() #self._notify.clear()
def add(self, *event): def add(self, *event):
"""Add a event to queue and notify thread to wake up. """Add a event to queue and notify thread to wake up.
@ -138,6 +137,13 @@ class ObserverThread(threading.Thread):
self._queue.append(event) self._queue.append(event)
self.pulse_notify() self.pulse_notify()
def add_wn(self, *event):
"""Add a event to queue withouth notifying thread to wake up.
"""
## lock and add new event to queue:
with self._queue_lock:
self._queue.append(event)
def call_lambda(self, l, *args): def call_lambda(self, l, *args):
l(*args) l(*args)
@ -168,6 +174,7 @@ class ObserverThread(threading.Thread):
'is_active': self.is_active, 'is_active': self.is_active,
'start': self.start, 'start': self.start,
'stop': self.stop, 'stop': self.stop,
'nop': lambda:(),
'shutdown': lambda:() 'shutdown': lambda:()
} }
try: try:
@ -177,13 +184,13 @@ class ObserverThread(threading.Thread):
while self.active: while self.active:
## going sleep, wait for events (in queue) ## going sleep, wait for events (in queue)
self.idle = True self.idle = True
self._notify.wait(self.sleeptime) n = self._notify
# does not clear notify event here - we use pulse (and clear it inside) ... if n:
# ## wake up - reset signal now (we don't need it so long as we reed from queue) n.wait(self.sleeptime)
# if self._notify: ## wake up - reset signal now (we don't need it so long as we reed from queue)
# self._notify.clear() n.clear()
if self._paused: if self._paused:
continue continue
self.idle = False self.idle = False
## check events available and execute all events from queue ## check events available and execute all events from queue
while not self._paused: while not self._paused:
@ -203,13 +210,14 @@ class ObserverThread(threading.Thread):
#logSys.error('%s', e, exc_info=logSys.getEffectiveLevel()<=logging.DEBUG) #logSys.error('%s', e, exc_info=logSys.getEffectiveLevel()<=logging.DEBUG)
logSys.error('%s', e, exc_info=True) logSys.error('%s', e, exc_info=True)
## end of main loop - exit ## end of main loop - exit
logSys.info("Observer stopped, %s events remaining.", len(self._queue))
#print("Observer stopped, %s events remaining." % len(self._queue))
except Exception as e: except Exception as e:
logSys.error('Observer stopped after error: %s', e, exc_info=True) logSys.error('Observer stopped after error: %s', e, exc_info=True)
#print("Observer stopped with error: %s" % str(e)) #print("Observer stopped with error: %s" % str(e))
self.idle = True # clear all events - exit, for possible calls of wait_empty:
return True with self._queue_lock:
logSys.info("Observer stopped, %s events remaining.", len(self._queue)) self._queue = []
#print("Observer stopped, %s events remaining." % len(self._queue))
self.idle = True self.idle = True
return True return True
@ -230,16 +238,22 @@ class ObserverThread(threading.Thread):
super(ObserverThread, self).start() super(ObserverThread, self).start()
def stop(self): def stop(self):
logSys.info("Observer stop ...") if self.active and self._notify:
#print("Observer stop ....") wtime = 5
self.active = False logSys.info("Observer stop ... try to end queue %s seconds", wtime)
if self._notify: #print("Observer stop ....")
# just add shutdown job to make possible wait later until full (events remaining) # just add shutdown job to make possible wait later until full (events remaining)
self.add('shutdown') self.add_wn('shutdown')
self.pulse_notify() #don't pulse - just set, because we will delete it hereafter (sometimes not wakeup)
n = self._notify
self._notify.set()
#self.pulse_notify()
self._notify = None self._notify = None
# wait max 5 seconds until full (events remaining) self.active = False
self.wait_empty(5) # wait max wtime seconds until full (events remaining)
self.wait_empty(wtime)
n.clear()
self.wait_idle(0.5)
@property @property
def is_full(self): def is_full(self):
@ -249,14 +263,16 @@ class ObserverThread(threading.Thread):
def wait_empty(self, sleeptime=None): def wait_empty(self, sleeptime=None):
"""Wait observer is running and returns if observer has no more events (queue is empty) """Wait observer is running and returns if observer has no more events (queue is empty)
""" """
if not self.is_full: # block queue with not operation to be sure all really jobs are executed if nop goes from queue :
return True self._queue.append(('nop',))
if sleeptime is not None: if sleeptime is not None:
e = MyTime.time() + sleeptime e = MyTime.time() + sleeptime
while self.is_full: while self.is_full:
if sleeptime is not None and MyTime.time() > e: if sleeptime is not None and MyTime.time() > e:
break break
time.sleep(0.1) time.sleep(0.01)
# wait idle to be sure the last queue element is processed (because pop event before processing it) :
self.wait_idle(0.01)
return not self.is_full return not self.is_full
@ -271,7 +287,7 @@ class ObserverThread(threading.Thread):
while not self.idle: while not self.idle:
if sleeptime is not None and MyTime.time() > e: if sleeptime is not None and MyTime.time() > e:
break break
time.sleep(0.1) time.sleep(0.01)
return self.idle return self.idle
@property @property
@ -443,6 +459,8 @@ class ObserverThread(threading.Thread):
return False return False
else: else:
logtime = ('permanent', 'infinite') logtime = ('permanent', 'infinite')
# increment count:
ticket.incrBanCount()
# if ban time was prolonged - log again with new ban time: # if ban time was prolonged - log again with new ban time:
if btime != oldbtime: if btime != oldbtime:
logSys.notice("[%s] Increase Ban %s (%d # %s -> %s)", jail.name, logSys.notice("[%s] Increase Ban %s (%d # %s -> %s)", jail.name,

View File

@ -88,6 +88,9 @@ class Ticket:
def setBanCount(self, value): def setBanCount(self, value):
self.__banCount = value; self.__banCount = value;
def incrBanCount(self, value = 1):
self.__banCount += value;
def getBanCount(self): def getBanCount(self):
return self.__banCount; return self.__banCount;

View File

@ -32,6 +32,7 @@ import time
from ..server.mytime import MyTime from ..server.mytime import MyTime
from ..server.ticket import FailTicket from ..server.ticket import FailTicket
from ..server.failmanager import FailManager
from ..server.observer import Observers, ObserverThread from ..server.observer import Observers, ObserverThread
from .utils import LogCaptureTestCase from .utils import LogCaptureTestCase
from .dummyjail import DummyJail from .dummyjail import DummyJail
@ -174,7 +175,8 @@ class BanTimeIncr(LogCaptureTestCase):
a.setBanTimeExtra('rndtime', None) a.setBanTimeExtra('rndtime', None)
class BanTimeIncrDB(LogCaptureTestCase): class BanTimeIncrDB(unittest.TestCase):
#class BanTimeIncrDB(LogCaptureTestCase):
def setUp(self): def setUp(self):
"""Call before every test case.""" """Call before every test case."""
@ -187,8 +189,10 @@ class BanTimeIncrDB(LogCaptureTestCase):
return return
_, self.dbFilename = tempfile.mkstemp(".db", "fail2ban_") _, self.dbFilename = tempfile.mkstemp(".db", "fail2ban_")
self.db = Fail2BanDb(self.dbFilename) self.db = Fail2BanDb(self.dbFilename)
self.jail = None self.jail = DummyJail()
self.jail.database = self.db
self.Observer = ObserverThread() self.Observer = ObserverThread()
Observers.Main = self.Observer
def tearDown(self): def tearDown(self):
"""Call after every test case.""" """Call after every test case."""
@ -196,6 +200,8 @@ class BanTimeIncrDB(LogCaptureTestCase):
if Fail2BanDb is None: # pragma: no cover if Fail2BanDb is None: # pragma: no cover
return return
# Cleanup # Cleanup
self.Observer.stop()
Observers.Main = None
os.remove(self.dbFilename) os.remove(self.dbFilename)
def incrBanTime(self, ticket, banTime=None): def incrBanTime(self, ticket, banTime=None):
@ -211,9 +217,7 @@ class BanTimeIncrDB(LogCaptureTestCase):
def testBanTimeIncr(self): def testBanTimeIncr(self):
if Fail2BanDb is None: # pragma: no cover if Fail2BanDb is None: # pragma: no cover
return return
jail = DummyJail() jail = self.jail
self.jail = jail
jail.database = self.db
self.db.addJail(jail) self.db.addJail(jail)
# we tests with initial ban time = 10 seconds: # we tests with initial ban time = 10 seconds:
jail.actions.setBanTime(10) jail.actions.setBanTime(10)
@ -229,6 +233,7 @@ class BanTimeIncrDB(LogCaptureTestCase):
[10, 10, 10] [10, 10, 10]
) )
# add a ticket banned # add a ticket banned
ticket.incrBanCount()
self.db.addBan(jail, ticket) self.db.addBan(jail, ticket)
# get a ticket already banned in this jail: # get a ticket already banned in this jail:
self.assertEqual( self.assertEqual(
@ -238,6 +243,7 @@ class BanTimeIncrDB(LogCaptureTestCase):
# incr time and ban a ticket again : # incr time and ban a ticket again :
ticket.setTime(stime + 15) ticket.setTime(stime + 15)
self.assertEqual(self.incrBanTime(ticket, 10), 20) self.assertEqual(self.incrBanTime(ticket, 10), 20)
ticket.incrBanCount()
self.db.addBan(jail, ticket) self.db.addBan(jail, ticket)
# get a ticket already banned in this jail: # get a ticket already banned in this jail:
self.assertEqual( self.assertEqual(
@ -274,6 +280,7 @@ class BanTimeIncrDB(LogCaptureTestCase):
ticket.setTime(stime + lastBanTime + 5) ticket.setTime(stime + lastBanTime + 5)
banTime = self.incrBanTime(ticket, 10) banTime = self.incrBanTime(ticket, 10)
self.assertEqual(banTime, lastBanTime * 2) self.assertEqual(banTime, lastBanTime * 2)
ticket.incrBanCount()
self.db.addBan(jail, ticket) self.db.addBan(jail, ticket)
lastBanTime = banTime lastBanTime = banTime
# increase again, but the last multiplier reached (time not increased): # increase again, but the last multiplier reached (time not increased):
@ -281,15 +288,18 @@ class BanTimeIncrDB(LogCaptureTestCase):
banTime = self.incrBanTime(ticket, 10) banTime = self.incrBanTime(ticket, 10)
self.assertNotEqual(banTime, lastBanTime * 2) self.assertNotEqual(banTime, lastBanTime * 2)
self.assertEqual(banTime, lastBanTime) self.assertEqual(banTime, lastBanTime)
ticket.incrBanCount()
self.db.addBan(jail, ticket) self.db.addBan(jail, ticket)
lastBanTime = banTime lastBanTime = banTime
# add two tickets from yesterday: one unbanned (bantime already out-dated): # add two tickets from yesterday: one unbanned (bantime already out-dated):
ticket2 = FailTicket(ip+'2', stime-24*60*60, []) ticket2 = FailTicket(ip+'2', stime-24*60*60, [])
ticket2.setBanTime(12*60*60) ticket2.setBanTime(12*60*60)
ticket2.incrBanCount()
self.db.addBan(jail, ticket2) self.db.addBan(jail, ticket2)
# and one from yesterday also, but still currently banned : # and one from yesterday also, but still currently banned :
ticket2 = FailTicket(ip+'1', stime-24*60*60, []) ticket2 = FailTicket(ip+'1', stime-24*60*60, [])
ticket2.setBanTime(36*60*60) ticket2.setBanTime(36*60*60)
ticket2.incrBanCount()
self.db.addBan(jail, ticket2) self.db.addBan(jail, ticket2)
# search currently banned: # search currently banned:
restored_tickets = self.db.getCurrentBans(fromtime=stime) restored_tickets = self.db.getCurrentBans(fromtime=stime)
@ -331,6 +341,7 @@ class BanTimeIncrDB(LogCaptureTestCase):
# get currently banned pis with permanent one: # get currently banned pis with permanent one:
ticket.setBanTime(-1) ticket.setBanTime(-1)
ticket.incrBanCount()
self.db.addBan(jail, ticket) self.db.addBan(jail, ticket)
restored_tickets = self.db.getCurrentBans(fromtime=stime) restored_tickets = self.db.getCurrentBans(fromtime=stime)
self.assertEqual(len(restored_tickets), 3) self.assertEqual(len(restored_tickets), 3)
@ -344,6 +355,7 @@ class BanTimeIncrDB(LogCaptureTestCase):
self.assertEqual(len(restored_tickets), 3) self.assertEqual(len(restored_tickets), 3)
# set short time and purge again: # set short time and purge again:
ticket.setBanTime(600) ticket.setBanTime(600)
ticket.incrBanCount()
self.db.addBan(jail, ticket) self.db.addBan(jail, ticket)
self.db.purge() self.db.purge()
# this old ticket should be removed now: # this old ticket should be removed now:
@ -373,10 +385,12 @@ class BanTimeIncrDB(LogCaptureTestCase):
self.db.addJail(jail2) self.db.addJail(jail2)
ticket1 = FailTicket(ip, stime, []) ticket1 = FailTicket(ip, stime, [])
ticket1.setBanTime(6000) ticket1.setBanTime(6000)
ticket1.incrBanCount()
self.db.addBan(jail1, ticket1) self.db.addBan(jail1, ticket1)
ticket2 = FailTicket(ip, stime-6000, []) ticket2 = FailTicket(ip, stime-6000, [])
ticket2.setBanTime(12000) ticket2.setBanTime(12000)
ticket2.setBanCount(1) ticket2.setBanCount(1)
ticket2.incrBanCount()
self.db.addBan(jail2, ticket2) self.db.addBan(jail2, ticket2)
restored_tickets = self.db.getCurrentBans(jail=jail1, fromtime=stime) restored_tickets = self.db.getCurrentBans(jail=jail1, fromtime=stime)
self.assertEqual(len(restored_tickets), 1) self.assertEqual(len(restored_tickets), 1)
@ -402,6 +416,86 @@ class BanTimeIncrDB(LogCaptureTestCase):
self.assertEqual(row, (3, stime, 18000)) self.assertEqual(row, (3, stime, 18000))
break break
def testObserver(self):
if Fail2BanDb is None: # pragma: no cover
return
jail = self.jail
self.db.addJail(jail)
# we tests with initial ban time = 10 seconds:
jail.actions.setBanTime(10)
jail.setBanTimeExtra('increment', 'true')
# observer / database features:
obs = Observers.Main
obs.start()
obs.db_set(self.db)
# wait for start ready
obs.add('nop')
obs.wait_empty(5)
# purge database right now, but using timer, to test it also:
self.db._purgeAge = -240*60*60
obs.add_named_timer('DB_PURGE', 0.001, 'db_purge')
# wait for timer ready
time.sleep(0.025)
# wait for ready
obs.add('nop')
obs.wait_empty(5)
stime = int(MyTime.time())
# completelly empty ?
tickets = self.db.getBans()
self.assertEqual(tickets, [])
# add failure:
ip = "127.0.0.2"
ticket = FailTicket(ip, stime-120, [])
failManager = FailManager()
failManager.setMaxRetry(3)
for i in xrange(3):
failManager.addFailure(ticket)
obs.add('failureFound', failManager, jail, ticket)
obs.wait_empty(5)
self.assertEqual(ticket.getBanCount(), 0)
# check still not ban :
self.assertTrue(not jail.getFailTicket())
# add manually 4th times banned (added to bips - make ip bad):
ticket.setBanCount(4)
self.db.addBan(self.jail, ticket)
restored_tickets = self.db.getCurrentBans(jail=jail, fromtime=stime-120)
self.assertEqual(len(restored_tickets), 1)
# check again, new ticket, new failmanager:
ticket = FailTicket(ip, stime, [])
failManager = FailManager()
failManager.setMaxRetry(3)
# add once only - but bad - should be banned:
failManager.addFailure(ticket)
obs.add('failureFound', failManager, self.jail, ticket)
obs.wait_empty(5)
# wait until ticket transfered from failmanager into jail:
i = 50
while True:
ticket2 = jail.getFailTicket()
if ticket2:
break
time.sleep(0.1)
# check ticket and failure count:
self.assertFalse(not ticket2)
self.assertEqual(ticket2.getAttempt(), failManager.getMaxRetry())
# add this ticket to ban (use observer only without ban manager):
obs.add('banFound', ticket2, jail, 10)
obs.wait_empty(5)
# increased?
self.assertEqual(ticket2.getBanTime(), 160)
self.assertEqual(ticket2.getBanCount(), 5)
# check prolonged in database also :
restored_tickets = self.db.getCurrentBans(jail=jail, fromtime=stime)
self.assertEqual(len(restored_tickets), 1)
self.assertEqual(restored_tickets[0].getBanTime(), 160)
self.assertEqual(restored_tickets[0].getBanCount(), 5)
# stop observer
obs.stop()
class ObserverTest(unittest.TestCase): class ObserverTest(unittest.TestCase):
@ -419,16 +513,17 @@ class ObserverTest(unittest.TestCase):
obs.start() obs.start()
# wait for idle # wait for idle
obs.wait_idle(0.1) obs.wait_idle(0.1)
# observer will sleep 0.5 second (in busy state): # observer will replace test set:
o = set(['test']) o = set(['test'])
obs.add('call', o.clear) obs.add('call', o.clear)
obs.add('call', o.add, 'test2') obs.add('call', o.add, 'test2')
# wait for observer ready:
obs.wait_empty(1) obs.wait_empty(1)
self.assertFalse(obs.is_full) self.assertFalse(obs.is_full)
self.assertEqual(o, set(['test2'])) self.assertEqual(o, set(['test2']))
# observer makes pause # observer makes pause
obs.paused = True obs.paused = True
# observer will sleep 0.5 second after pause ends: # observer will replace test set, but first after pause ends:
obs.add('call', o.clear) obs.add('call', o.clear)
obs.add('call', o.add, 'test3') obs.add('call', o.add, 'test3')
obs.wait_empty(0.25) obs.wait_empty(0.25)