mirror of https://github.com/fail2ban/fail2ban
testcases extended and observer optimized to run test cases faster;
code reviewpull/716/head
parent
e7bd8ed619
commit
bb0a181056
|
@ -292,7 +292,7 @@ class Actions(JailThread, Mapping):
|
|||
|
||||
if self.__banManager.addBanTicket(bTicket):
|
||||
# report ticket to observer, to check time should be increased and hereafter observer writes ban to database (asynchronous)
|
||||
if not bTicket.getRestored():
|
||||
if Observers.Main and not bTicket.getRestored():
|
||||
Observers.Main.add('banFound', bTicket, self._jail, btime)
|
||||
logSys.notice("[%s] %sBan %s (%d # %s -> %s)", self._jail.name, ('' if not bTicket.getRestored() else 'Restore '),
|
||||
aInfo["ip"], bTicket.getBanCount()+(1 if not bTicket.getRestored() else 0), *logtime)
|
||||
|
|
|
@ -400,12 +400,12 @@ class Fail2BanDb(object):
|
|||
#TODO: Implement data parts once arbitrary match keys completed
|
||||
cur.execute(
|
||||
"INSERT INTO bans(jail, ip, timeofban, bantime, bancount, data) VALUES(?, ?, ?, ?, ?, ?)",
|
||||
(jail.name, ticket.getIP(), ticket.getTime(), ticket.getBanTime(jail.actions.getBanTime()), ticket.getBanCount() + 1,
|
||||
(jail.name, ticket.getIP(), ticket.getTime(), ticket.getBanTime(jail.actions.getBanTime()), ticket.getBanCount(),
|
||||
{"matches": ticket.getMatches(),
|
||||
"failures": ticket.getAttempt()}))
|
||||
cur.execute(
|
||||
"INSERT OR REPLACE INTO bips(ip, jail, timeofban, bantime, bancount, data) VALUES(?, ?, ?, ?, ?, ?)",
|
||||
(ticket.getIP(), jail.name, ticket.getTime(), ticket.getBanTime(jail.actions.getBanTime()), ticket.getBanCount() + 1,
|
||||
(ticket.getIP(), jail.name, ticket.getTime(), ticket.getBanTime(jail.actions.getBanTime()), ticket.getBanCount(),
|
||||
{"matches": ticket.getMatches(),
|
||||
"failures": ticket.getAttempt()}))
|
||||
|
||||
|
@ -558,11 +558,6 @@ class Fail2BanDb(object):
|
|||
return cur.execute(query, queryArgs)
|
||||
|
||||
def getCurrentBans(self, jail = None, ip = None, forbantime=None, fromtime=None):
|
||||
if forbantime is None and jail is not None:
|
||||
cacheKey = (ip, jail)
|
||||
if cacheKey in self._bansMergedCache:
|
||||
return self._bansMergedCache[cacheKey]
|
||||
|
||||
tickets = []
|
||||
ticket = None
|
||||
|
||||
|
@ -584,8 +579,6 @@ class Fail2BanDb(object):
|
|||
ticket.setAttempt(failures)
|
||||
tickets.append(ticket)
|
||||
|
||||
if forbantime is None and jail is not None:
|
||||
self._bansMergedCache[cacheKey] = tickets if ip is None else ticket
|
||||
return tickets if ip is None else ticket
|
||||
|
||||
def _cleanjails(self, cur):
|
||||
|
|
|
@ -427,7 +427,8 @@ class Filter(JailThread):
|
|||
tick = FailTicket(ip, unixTime, lines)
|
||||
self.failManager.addFailure(tick)
|
||||
# report to observer - failure was found, for possibly increasing of it retry counter (asynchronous)
|
||||
Observers.Main.add('failureFound', self.failManager, self.jail, tick)
|
||||
if Observers.Main:
|
||||
Observers.Main.add('failureFound', self.failManager, self.jail, tick)
|
||||
|
||||
##
|
||||
# Returns true if the line should be ignored.
|
||||
|
|
|
@ -70,7 +70,6 @@ class ObserverThread(threading.Thread):
|
|||
## but so we can later do some service "events" occurred infrequently directly in main loop of observer (not using queue)
|
||||
self.sleeptime = 60
|
||||
#
|
||||
self._started = False
|
||||
self._timers = {}
|
||||
self._paused = False
|
||||
self.__db = None
|
||||
|
@ -124,11 +123,11 @@ class ObserverThread(threading.Thread):
|
|||
t.start()
|
||||
|
||||
def pulse_notify(self):
|
||||
"""Notify wakeup (sets and resets notify event)
|
||||
"""Notify wakeup (sets /and resets/ notify event)
|
||||
"""
|
||||
if not self._paused and self._notify:
|
||||
self._notify.set()
|
||||
self._notify.clear()
|
||||
#self._notify.clear()
|
||||
|
||||
def add(self, *event):
|
||||
"""Add a event to queue and notify thread to wake up.
|
||||
|
@ -138,6 +137,13 @@ class ObserverThread(threading.Thread):
|
|||
self._queue.append(event)
|
||||
self.pulse_notify()
|
||||
|
||||
def add_wn(self, *event):
|
||||
"""Add a event to queue withouth notifying thread to wake up.
|
||||
"""
|
||||
## lock and add new event to queue:
|
||||
with self._queue_lock:
|
||||
self._queue.append(event)
|
||||
|
||||
def call_lambda(self, l, *args):
|
||||
l(*args)
|
||||
|
||||
|
@ -168,6 +174,7 @@ class ObserverThread(threading.Thread):
|
|||
'is_active': self.is_active,
|
||||
'start': self.start,
|
||||
'stop': self.stop,
|
||||
'nop': lambda:(),
|
||||
'shutdown': lambda:()
|
||||
}
|
||||
try:
|
||||
|
@ -177,13 +184,13 @@ class ObserverThread(threading.Thread):
|
|||
while self.active:
|
||||
## going sleep, wait for events (in queue)
|
||||
self.idle = True
|
||||
self._notify.wait(self.sleeptime)
|
||||
# does not clear notify event here - we use pulse (and clear it inside) ...
|
||||
# ## wake up - reset signal now (we don't need it so long as we reed from queue)
|
||||
# if self._notify:
|
||||
# self._notify.clear()
|
||||
if self._paused:
|
||||
continue
|
||||
n = self._notify
|
||||
if n:
|
||||
n.wait(self.sleeptime)
|
||||
## wake up - reset signal now (we don't need it so long as we reed from queue)
|
||||
n.clear()
|
||||
if self._paused:
|
||||
continue
|
||||
self.idle = False
|
||||
## check events available and execute all events from queue
|
||||
while not self._paused:
|
||||
|
@ -203,13 +210,14 @@ class ObserverThread(threading.Thread):
|
|||
#logSys.error('%s', e, exc_info=logSys.getEffectiveLevel()<=logging.DEBUG)
|
||||
logSys.error('%s', e, exc_info=True)
|
||||
## end of main loop - exit
|
||||
logSys.info("Observer stopped, %s events remaining.", len(self._queue))
|
||||
#print("Observer stopped, %s events remaining." % len(self._queue))
|
||||
except Exception as e:
|
||||
logSys.error('Observer stopped after error: %s', e, exc_info=True)
|
||||
#print("Observer stopped with error: %s" % str(e))
|
||||
self.idle = True
|
||||
return True
|
||||
logSys.info("Observer stopped, %s events remaining.", len(self._queue))
|
||||
#print("Observer stopped, %s events remaining." % len(self._queue))
|
||||
# clear all events - exit, for possible calls of wait_empty:
|
||||
with self._queue_lock:
|
||||
self._queue = []
|
||||
self.idle = True
|
||||
return True
|
||||
|
||||
|
@ -230,16 +238,22 @@ class ObserverThread(threading.Thread):
|
|||
super(ObserverThread, self).start()
|
||||
|
||||
def stop(self):
|
||||
logSys.info("Observer stop ...")
|
||||
#print("Observer stop ....")
|
||||
self.active = False
|
||||
if self._notify:
|
||||
if self.active and self._notify:
|
||||
wtime = 5
|
||||
logSys.info("Observer stop ... try to end queue %s seconds", wtime)
|
||||
#print("Observer stop ....")
|
||||
# just add shutdown job to make possible wait later until full (events remaining)
|
||||
self.add('shutdown')
|
||||
self.pulse_notify()
|
||||
self.add_wn('shutdown')
|
||||
#don't pulse - just set, because we will delete it hereafter (sometimes not wakeup)
|
||||
n = self._notify
|
||||
self._notify.set()
|
||||
#self.pulse_notify()
|
||||
self._notify = None
|
||||
# wait max 5 seconds until full (events remaining)
|
||||
self.wait_empty(5)
|
||||
self.active = False
|
||||
# wait max wtime seconds until full (events remaining)
|
||||
self.wait_empty(wtime)
|
||||
n.clear()
|
||||
self.wait_idle(0.5)
|
||||
|
||||
@property
|
||||
def is_full(self):
|
||||
|
@ -249,14 +263,16 @@ class ObserverThread(threading.Thread):
|
|||
def wait_empty(self, sleeptime=None):
|
||||
"""Wait observer is running and returns if observer has no more events (queue is empty)
|
||||
"""
|
||||
if not self.is_full:
|
||||
return True
|
||||
# block queue with not operation to be sure all really jobs are executed if nop goes from queue :
|
||||
self._queue.append(('nop',))
|
||||
if sleeptime is not None:
|
||||
e = MyTime.time() + sleeptime
|
||||
while self.is_full:
|
||||
if sleeptime is not None and MyTime.time() > e:
|
||||
break
|
||||
time.sleep(0.1)
|
||||
time.sleep(0.01)
|
||||
# wait idle to be sure the last queue element is processed (because pop event before processing it) :
|
||||
self.wait_idle(0.01)
|
||||
return not self.is_full
|
||||
|
||||
|
||||
|
@ -271,7 +287,7 @@ class ObserverThread(threading.Thread):
|
|||
while not self.idle:
|
||||
if sleeptime is not None and MyTime.time() > e:
|
||||
break
|
||||
time.sleep(0.1)
|
||||
time.sleep(0.01)
|
||||
return self.idle
|
||||
|
||||
@property
|
||||
|
@ -443,6 +459,8 @@ class ObserverThread(threading.Thread):
|
|||
return False
|
||||
else:
|
||||
logtime = ('permanent', 'infinite')
|
||||
# increment count:
|
||||
ticket.incrBanCount()
|
||||
# if ban time was prolonged - log again with new ban time:
|
||||
if btime != oldbtime:
|
||||
logSys.notice("[%s] Increase Ban %s (%d # %s -> %s)", jail.name,
|
||||
|
|
|
@ -88,6 +88,9 @@ class Ticket:
|
|||
def setBanCount(self, value):
|
||||
self.__banCount = value;
|
||||
|
||||
def incrBanCount(self, value = 1):
|
||||
self.__banCount += value;
|
||||
|
||||
def getBanCount(self):
|
||||
return self.__banCount;
|
||||
|
||||
|
|
|
@ -32,6 +32,7 @@ import time
|
|||
|
||||
from ..server.mytime import MyTime
|
||||
from ..server.ticket import FailTicket
|
||||
from ..server.failmanager import FailManager
|
||||
from ..server.observer import Observers, ObserverThread
|
||||
from .utils import LogCaptureTestCase
|
||||
from .dummyjail import DummyJail
|
||||
|
@ -174,7 +175,8 @@ class BanTimeIncr(LogCaptureTestCase):
|
|||
a.setBanTimeExtra('rndtime', None)
|
||||
|
||||
|
||||
class BanTimeIncrDB(LogCaptureTestCase):
|
||||
class BanTimeIncrDB(unittest.TestCase):
|
||||
#class BanTimeIncrDB(LogCaptureTestCase):
|
||||
|
||||
def setUp(self):
|
||||
"""Call before every test case."""
|
||||
|
@ -187,8 +189,10 @@ class BanTimeIncrDB(LogCaptureTestCase):
|
|||
return
|
||||
_, self.dbFilename = tempfile.mkstemp(".db", "fail2ban_")
|
||||
self.db = Fail2BanDb(self.dbFilename)
|
||||
self.jail = None
|
||||
self.jail = DummyJail()
|
||||
self.jail.database = self.db
|
||||
self.Observer = ObserverThread()
|
||||
Observers.Main = self.Observer
|
||||
|
||||
def tearDown(self):
|
||||
"""Call after every test case."""
|
||||
|
@ -196,6 +200,8 @@ class BanTimeIncrDB(LogCaptureTestCase):
|
|||
if Fail2BanDb is None: # pragma: no cover
|
||||
return
|
||||
# Cleanup
|
||||
self.Observer.stop()
|
||||
Observers.Main = None
|
||||
os.remove(self.dbFilename)
|
||||
|
||||
def incrBanTime(self, ticket, banTime=None):
|
||||
|
@ -211,9 +217,7 @@ class BanTimeIncrDB(LogCaptureTestCase):
|
|||
def testBanTimeIncr(self):
|
||||
if Fail2BanDb is None: # pragma: no cover
|
||||
return
|
||||
jail = DummyJail()
|
||||
self.jail = jail
|
||||
jail.database = self.db
|
||||
jail = self.jail
|
||||
self.db.addJail(jail)
|
||||
# we tests with initial ban time = 10 seconds:
|
||||
jail.actions.setBanTime(10)
|
||||
|
@ -229,6 +233,7 @@ class BanTimeIncrDB(LogCaptureTestCase):
|
|||
[10, 10, 10]
|
||||
)
|
||||
# add a ticket banned
|
||||
ticket.incrBanCount()
|
||||
self.db.addBan(jail, ticket)
|
||||
# get a ticket already banned in this jail:
|
||||
self.assertEqual(
|
||||
|
@ -238,6 +243,7 @@ class BanTimeIncrDB(LogCaptureTestCase):
|
|||
# incr time and ban a ticket again :
|
||||
ticket.setTime(stime + 15)
|
||||
self.assertEqual(self.incrBanTime(ticket, 10), 20)
|
||||
ticket.incrBanCount()
|
||||
self.db.addBan(jail, ticket)
|
||||
# get a ticket already banned in this jail:
|
||||
self.assertEqual(
|
||||
|
@ -274,6 +280,7 @@ class BanTimeIncrDB(LogCaptureTestCase):
|
|||
ticket.setTime(stime + lastBanTime + 5)
|
||||
banTime = self.incrBanTime(ticket, 10)
|
||||
self.assertEqual(banTime, lastBanTime * 2)
|
||||
ticket.incrBanCount()
|
||||
self.db.addBan(jail, ticket)
|
||||
lastBanTime = banTime
|
||||
# increase again, but the last multiplier reached (time not increased):
|
||||
|
@ -281,15 +288,18 @@ class BanTimeIncrDB(LogCaptureTestCase):
|
|||
banTime = self.incrBanTime(ticket, 10)
|
||||
self.assertNotEqual(banTime, lastBanTime * 2)
|
||||
self.assertEqual(banTime, lastBanTime)
|
||||
ticket.incrBanCount()
|
||||
self.db.addBan(jail, ticket)
|
||||
lastBanTime = banTime
|
||||
# add two tickets from yesterday: one unbanned (bantime already out-dated):
|
||||
ticket2 = FailTicket(ip+'2', stime-24*60*60, [])
|
||||
ticket2.setBanTime(12*60*60)
|
||||
ticket2.incrBanCount()
|
||||
self.db.addBan(jail, ticket2)
|
||||
# and one from yesterday also, but still currently banned :
|
||||
ticket2 = FailTicket(ip+'1', stime-24*60*60, [])
|
||||
ticket2.setBanTime(36*60*60)
|
||||
ticket2.incrBanCount()
|
||||
self.db.addBan(jail, ticket2)
|
||||
# search currently banned:
|
||||
restored_tickets = self.db.getCurrentBans(fromtime=stime)
|
||||
|
@ -331,6 +341,7 @@ class BanTimeIncrDB(LogCaptureTestCase):
|
|||
|
||||
# get currently banned pis with permanent one:
|
||||
ticket.setBanTime(-1)
|
||||
ticket.incrBanCount()
|
||||
self.db.addBan(jail, ticket)
|
||||
restored_tickets = self.db.getCurrentBans(fromtime=stime)
|
||||
self.assertEqual(len(restored_tickets), 3)
|
||||
|
@ -344,6 +355,7 @@ class BanTimeIncrDB(LogCaptureTestCase):
|
|||
self.assertEqual(len(restored_tickets), 3)
|
||||
# set short time and purge again:
|
||||
ticket.setBanTime(600)
|
||||
ticket.incrBanCount()
|
||||
self.db.addBan(jail, ticket)
|
||||
self.db.purge()
|
||||
# this old ticket should be removed now:
|
||||
|
@ -373,10 +385,12 @@ class BanTimeIncrDB(LogCaptureTestCase):
|
|||
self.db.addJail(jail2)
|
||||
ticket1 = FailTicket(ip, stime, [])
|
||||
ticket1.setBanTime(6000)
|
||||
ticket1.incrBanCount()
|
||||
self.db.addBan(jail1, ticket1)
|
||||
ticket2 = FailTicket(ip, stime-6000, [])
|
||||
ticket2.setBanTime(12000)
|
||||
ticket2.setBanCount(1)
|
||||
ticket2.incrBanCount()
|
||||
self.db.addBan(jail2, ticket2)
|
||||
restored_tickets = self.db.getCurrentBans(jail=jail1, fromtime=stime)
|
||||
self.assertEqual(len(restored_tickets), 1)
|
||||
|
@ -402,6 +416,86 @@ class BanTimeIncrDB(LogCaptureTestCase):
|
|||
self.assertEqual(row, (3, stime, 18000))
|
||||
break
|
||||
|
||||
def testObserver(self):
|
||||
if Fail2BanDb is None: # pragma: no cover
|
||||
return
|
||||
jail = self.jail
|
||||
self.db.addJail(jail)
|
||||
# we tests with initial ban time = 10 seconds:
|
||||
jail.actions.setBanTime(10)
|
||||
jail.setBanTimeExtra('increment', 'true')
|
||||
# observer / database features:
|
||||
obs = Observers.Main
|
||||
obs.start()
|
||||
obs.db_set(self.db)
|
||||
# wait for start ready
|
||||
obs.add('nop')
|
||||
obs.wait_empty(5)
|
||||
# purge database right now, but using timer, to test it also:
|
||||
self.db._purgeAge = -240*60*60
|
||||
obs.add_named_timer('DB_PURGE', 0.001, 'db_purge')
|
||||
# wait for timer ready
|
||||
time.sleep(0.025)
|
||||
# wait for ready
|
||||
obs.add('nop')
|
||||
obs.wait_empty(5)
|
||||
|
||||
stime = int(MyTime.time())
|
||||
# completelly empty ?
|
||||
tickets = self.db.getBans()
|
||||
self.assertEqual(tickets, [])
|
||||
|
||||
# add failure:
|
||||
ip = "127.0.0.2"
|
||||
ticket = FailTicket(ip, stime-120, [])
|
||||
failManager = FailManager()
|
||||
failManager.setMaxRetry(3)
|
||||
for i in xrange(3):
|
||||
failManager.addFailure(ticket)
|
||||
obs.add('failureFound', failManager, jail, ticket)
|
||||
obs.wait_empty(5)
|
||||
self.assertEqual(ticket.getBanCount(), 0)
|
||||
# check still not ban :
|
||||
self.assertTrue(not jail.getFailTicket())
|
||||
# add manually 4th times banned (added to bips - make ip bad):
|
||||
ticket.setBanCount(4)
|
||||
self.db.addBan(self.jail, ticket)
|
||||
restored_tickets = self.db.getCurrentBans(jail=jail, fromtime=stime-120)
|
||||
self.assertEqual(len(restored_tickets), 1)
|
||||
# check again, new ticket, new failmanager:
|
||||
ticket = FailTicket(ip, stime, [])
|
||||
failManager = FailManager()
|
||||
failManager.setMaxRetry(3)
|
||||
# add once only - but bad - should be banned:
|
||||
failManager.addFailure(ticket)
|
||||
obs.add('failureFound', failManager, self.jail, ticket)
|
||||
obs.wait_empty(5)
|
||||
# wait until ticket transfered from failmanager into jail:
|
||||
i = 50
|
||||
while True:
|
||||
ticket2 = jail.getFailTicket()
|
||||
if ticket2:
|
||||
break
|
||||
time.sleep(0.1)
|
||||
# check ticket and failure count:
|
||||
self.assertFalse(not ticket2)
|
||||
self.assertEqual(ticket2.getAttempt(), failManager.getMaxRetry())
|
||||
|
||||
# add this ticket to ban (use observer only without ban manager):
|
||||
obs.add('banFound', ticket2, jail, 10)
|
||||
obs.wait_empty(5)
|
||||
# increased?
|
||||
self.assertEqual(ticket2.getBanTime(), 160)
|
||||
self.assertEqual(ticket2.getBanCount(), 5)
|
||||
|
||||
# check prolonged in database also :
|
||||
restored_tickets = self.db.getCurrentBans(jail=jail, fromtime=stime)
|
||||
self.assertEqual(len(restored_tickets), 1)
|
||||
self.assertEqual(restored_tickets[0].getBanTime(), 160)
|
||||
self.assertEqual(restored_tickets[0].getBanCount(), 5)
|
||||
|
||||
# stop observer
|
||||
obs.stop()
|
||||
|
||||
class ObserverTest(unittest.TestCase):
|
||||
|
||||
|
@ -419,16 +513,17 @@ class ObserverTest(unittest.TestCase):
|
|||
obs.start()
|
||||
# wait for idle
|
||||
obs.wait_idle(0.1)
|
||||
# observer will sleep 0.5 second (in busy state):
|
||||
# observer will replace test set:
|
||||
o = set(['test'])
|
||||
obs.add('call', o.clear)
|
||||
obs.add('call', o.add, 'test2')
|
||||
# wait for observer ready:
|
||||
obs.wait_empty(1)
|
||||
self.assertFalse(obs.is_full)
|
||||
self.assertEqual(o, set(['test2']))
|
||||
# observer makes pause
|
||||
obs.paused = True
|
||||
# observer will sleep 0.5 second after pause ends:
|
||||
# observer will replace test set, but first after pause ends:
|
||||
obs.add('call', o.clear)
|
||||
obs.add('call', o.add, 'test3')
|
||||
obs.wait_empty(0.25)
|
||||
|
|
Loading…
Reference in New Issue