- Added more locking

git-svn-id: https://fail2ban.svn.sourceforge.net/svnroot/fail2ban/trunk@361 a942ae1a-1317-0410-a47c-b1dcaea8d605
0.x
Cyril Jaquier 2006-09-17 22:02:22 +00:00
parent 94d167e620
commit c80164b329
4 changed files with 194 additions and 78 deletions

View File

@ -61,7 +61,9 @@ class BanManager:
# @param value the time # @param value the time
def setBanTime(self, value): def setBanTime(self, value):
self.lock.acquire()
self.banTime = int(value) self.banTime = int(value)
self.lock.release()
## ##
# Get the ban time. # Get the ban time.
@ -70,7 +72,11 @@ class BanManager:
# @return the time # @return the time
def getBanTime(self): def getBanTime(self):
return self.banTime try:
self.lock.acquire()
return self.banTime
finally:
self.lock.release()
## ##
# Set the total number of banned address. # Set the total number of banned address.
@ -78,7 +84,9 @@ class BanManager:
# @param value total number # @param value total number
def setBanTotal(self, value): def setBanTotal(self, value):
self.lock.acquire()
self.banTotal = value self.banTotal = value
self.lock.release()
## ##
# Get the total number of banned address. # Get the total number of banned address.
@ -86,7 +94,11 @@ class BanManager:
# @return the total number # @return the total number
def getBanTotal(self): def getBanTotal(self):
return self.banTotal try:
self.lock.acquire()
return self.banTotal
finally:
self.lock.release()
## ##
# Create a ban ticket. # Create a ban ticket.
@ -98,12 +110,16 @@ class BanManager:
@staticmethod @staticmethod
def createBanTicket(ticket): def createBanTicket(ticket):
ip = ticket.getIP() try:
#lastTime = ticket.getTime() self.lock.acquire()
lastTime = time.time() ip = ticket.getIP()
banTicket = BanTicket(ip, lastTime) #lastTime = ticket.getTime()
banTicket.setAttempt(ticket.getAttempt()) lastTime = time.time()
return banTicket banTicket = BanTicket(ip, lastTime)
banTicket.setAttempt(ticket.getAttempt())
return banTicket
finally:
self.lock.release()
## ##
# Add a ban ticket. # Add a ban ticket.
@ -113,14 +129,15 @@ class BanManager:
# @return True if the IP address is not in the ban list # @return True if the IP address is not in the ban list
def addBanTicket(self, ticket): def addBanTicket(self, ticket):
self.lock.acquire() try:
if not self.inBanList(ticket): self.lock.acquire()
self.banList.append(ticket) if not self.inBanList(ticket):
self.banTotal += 1 self.banList.append(ticket)
self.banTotal += 1
return True
return False
finally:
self.lock.release() self.lock.release()
return True
self.lock.release()
return False
## ##
# Delete a ban ticket. # Delete a ban ticket.
@ -128,7 +145,7 @@ class BanManager:
# Remove a BanTicket from the ban list. # Remove a BanTicket from the ban list.
# @param ticket the ticket # @param ticket the ticket
def delBanTicket(self, ticket): def __delBanTicket(self, ticket):
self.banList.remove(ticket) self.banList.remove(ticket)
## ##
@ -137,7 +154,11 @@ class BanManager:
# @return the size # @return the size
def size(self): def size(self):
return len(self.banList) try:
self.lock.acquire()
return len(self.banList)
finally:
self.lock.release()
## ##
# Check if a ticket is in the list. # Check if a ticket is in the list.
@ -148,10 +169,14 @@ class BanManager:
# @return True if a ticket already exists # @return True if a ticket already exists
def inBanList(self, ticket): def inBanList(self, ticket):
for i in self.banList: try:
if ticket.getIP() == i.getIP(): self.lock.acquire()
return True for i in self.banList:
return False if ticket.getIP() == i.getIP():
return True
return False
finally:
self.lock.release()
## ##
# Get the list of IP address to unban. # Get the list of IP address to unban.
@ -162,14 +187,16 @@ class BanManager:
# @todo Check the delete operation # @todo Check the delete operation
def unBanList(self, time): def unBanList(self, time):
uBList = list() try:
self.lock.acquire() self.lock.acquire()
for ticket in self.banList: uBList = list()
if ticket.getTime() < time - self.banTime: for ticket in self.banList:
uBList.append(ticket) if ticket.getTime() < time - self.banTime:
self.delBanTicket(ticket) uBList.append(ticket)
self.lock.release() self.__delBanTicket(ticket)
return uBList return uBList
finally:
self.lock.release()
## ##
# Flush the ban list. # Flush the ban list.
@ -178,9 +205,10 @@ class BanManager:
# @return the complete ban list # @return the complete ban list
def flushBanList(self): def flushBanList(self):
self.lock.acquire() try:
uBList = self.banList self.lock.acquire()
self.banList = list() uBList = self.banList
self.lock.release() self.banList = list()
return uBList return uBList
finally:
self.lock.release()

View File

@ -42,22 +42,40 @@ class FailManager:
self.failTotal = 0 self.failTotal = 0
def setFailTotal(self, value): def setFailTotal(self, value):
self.lock.acquire()
self.failTotal = value self.failTotal = value
self.lock.release()
def getFailTotal(self): def getFailTotal(self):
return self.failTotal try:
self.lock.acquire()
return self.failTotal
finally:
self.lock.release()
def setMaxRetry(self, value): def setMaxRetry(self, value):
self.lock.acquire()
self.maxRetry = value self.maxRetry = value
self.lock.release()
def getMaxRetry(self): def getMaxRetry(self):
return self.maxRetry try:
self.lock.acquire()
return self.maxRetry
finally:
self.lock.release()
def setMaxTime(self, value): def setMaxTime(self, value):
self.lock.acquire()
self.maxTime = value self.maxTime = value
self.lock.release()
def getMaxTime(self): def getMaxTime(self):
return self.maxTime try:
self.lock.acquire()
return self.maxTime
finally:
self.lock.release()
def addFailure(self, ticket): def addFailure(self, ticket):
self.lock.acquire() self.lock.acquire()
@ -76,33 +94,38 @@ class FailManager:
self.lock.release() self.lock.release()
def size(self): def size(self):
return len(self.failList) try:
self.lock.acquire()
return len(self.failList)
finally:
self.lock.release()
def cleanup(self, time): def cleanup(self, time):
self.lock.acquire() self.lock.acquire()
tmp = self.failList.copy() tmp = self.failList.copy()
for item in tmp: for item in tmp:
if tmp[item].getLastTime() < time - self.maxTime: if tmp[item].getLastTime() < time - self.maxTime:
self.delFailure(item) self.__delFailure(item)
self.lock.release() self.lock.release()
def delFailure(self, ip): def __delFailure(self, ip):
if self.failList.has_key(ip): if self.failList.has_key(ip):
del self.failList[ip] del self.failList[ip]
def toBan(self): def toBan(self):
self.lock.acquire() try:
for ip in self.failList: self.lock.acquire()
data = self.failList[ip] for ip in self.failList:
if data.getRetry() >= self.maxRetry: data = self.failList[ip]
self.delFailure(ip) if data.getRetry() >= self.maxRetry:
self.lock.release() self.delFailure(ip)
# Create a FailTicket from BanData # Create a FailTicket from BanData
failTicket = FailTicket(ip, data.getLastTime()) failTicket = FailTicket(ip, data.getLastTime())
failTicket.setAttempt(data.getRetry()) failTicket.setAttempt(data.getRetry())
return failTicket return failTicket
self.lock.release() raise FailManagerEmpty
raise FailManagerEmpty finally:
self.lock.release()
class FailManagerEmpty(Exception): class FailManagerEmpty(Exception):
pass pass

View File

@ -27,6 +27,7 @@ __license__ = "GPL"
import Queue, logging import Queue, logging
from actions import Actions from actions import Actions
from threading import Lock
# Gets the instance of the logger. # Gets the instance of the logger.
logSys = logging.getLogger("fail2ban.jail") logSys = logging.getLogger("fail2ban.jail")
@ -34,6 +35,7 @@ logSys = logging.getLogger("fail2ban.jail")
class Jail: class Jail:
def __init__(self, name): def __init__(self, name):
self.lock = Lock()
self.name = name self.name = name
self.queue = Queue.Queue() self.queue = Queue.Queue()
try: try:
@ -48,58 +50,99 @@ class Jail:
self.action = Actions(self) self.action = Actions(self)
def setName(self, name): def setName(self, name):
self.lock.acquire()
self.name = name self.name = name
self.lock.release()
def getName(self): def getName(self):
return self.name try:
self.lock.acquire()
return self.name
finally:
self.lock.release()
def setFilter(self, filter): def setFilter(self, filter):
self.lock.acquire()
self.filter = filter self.filter = filter
self.lock.release()
def getFilter(self): def getFilter(self):
return self.filter try:
self.lock.acquire()
return self.filter
finally:
self.lock.release()
def setAction(self, action): def setAction(self, action):
self.lock.acquire()
self.action = action self.action = action
self.lock.release()
def getAction(self): def getAction(self):
return self.action try:
self.lock.acquire()
return self.action
finally:
self.lock.release()
def putFailTicket(self, ticket): def putFailTicket(self, ticket):
self.lock.acquire()
self.queue.put(ticket) self.queue.put(ticket)
self.lock.release()
def getFailTicket(self): def getFailTicket(self):
try: try:
return self.queue.get(False) try:
except Queue.Empty: self.lock.acquire()
return False return self.queue.get(False)
except Queue.Empty:
return False
finally:
self.lock.release()
def start(self): def start(self):
self.lock.acquire()
self.filter.start() self.filter.start()
self.action.start() self.action.start()
self.lock.release()
def stop(self): def stop(self):
self.lock.acquire()
self.filter.stop() self.filter.stop()
self.action.stop() self.action.stop()
self.lock.release()
self.filter.join() self.filter.join()
self.action.join() self.action.join()
def isActive(self): def isActive(self):
isActive0 = self.filter.isActive() try:
isActive1 = self.action.isActive() self.lock.acquire()
return isActive0 or isActive1 isActive0 = self.filter.isActive()
isActive1 = self.action.isActive()
return isActive0 or isActive1
finally:
self.lock.release()
def setIdle(self, value): def setIdle(self, value):
self.lock.acquire()
self.filter.setIdle(value) self.filter.setIdle(value)
self.action.setIdle(value) self.action.setIdle(value)
self.lock.release()
def getIdle(self): def getIdle(self):
return self.filter.getIdle() or self.action.getIdle() try:
self.lock.acquire()
return self.filter.getIdle() or self.action.getIdle()
finally:
self.lock.release()
def getStatus(self): def getStatus(self):
fStatus = self.filter.status() try:
aStatus = self.action.status() self.lock.acquire()
ret = [("filter", fStatus), fStatus = self.filter.status()
("action", aStatus)] aStatus = self.action.status()
return ret ret = [("filter", fStatus),
("action", aStatus)]
return ret
finally:
self.lock.release()

View File

@ -26,6 +26,7 @@ __license__ = "GPL"
from ssocket import SSocket from ssocket import SSocket
from ssocket import SSocketErrorException from ssocket import SSocketErrorException
from threading import Lock
import re, pickle, logging import re, pickle, logging
# Gets the instance of the logger. # Gets the instance of the logger.
@ -34,6 +35,7 @@ logSys = logging.getLogger("fail2ban.comm")
class Transmitter: class Transmitter:
def __init__(self, server): def __init__(self, server):
self.lock = Lock()
self.server = server self.server = server
self.socket = SSocket(self) self.socket = SSocket(self)
@ -44,11 +46,14 @@ class Transmitter:
def start(self, force): def start(self, force):
try: try:
self.lock.acquire()
self.socket.initialize(force) self.socket.initialize(force)
self.socket.start() self.socket.start()
self.lock.release()
self.socket.join() self.socket.join()
except SSocketErrorException: except SSocketErrorException:
logSys.error("Could not start server") logSys.error("Could not start server")
self.lock.release()
## ##
# Stop the transmitter. # Stop the transmitter.
@ -58,18 +63,22 @@ class Transmitter:
def stop(self): def stop(self):
self.socket.stop() self.socket.stop()
#self.socket.join() self.socket.join()
def proceed(self, action): def proceed(self, action):
# Deserialize object # Deserialize object
logSys.debug("Action: " + `action`)
try: try:
ret = self.actionHandler(action) self.lock.acquire()
ack = 0, ret logSys.debug("Action: " + `action`)
except Exception, e: try:
logSys.warn("Invalid command: " + `action`) ret = self.actionHandler(action)
ack = 1, e ack = 0, ret
return ack except Exception, e:
logSys.warn("Invalid command: " + `action`)
ack = 1, e
return ack
finally:
self.lock.release()
## ##
# Handle an action. # Handle an action.
@ -129,6 +138,14 @@ class Transmitter:
self.server.setIdleJail(name, False) self.server.setIdleJail(name, False)
return self.server.getIdleJail(name) return self.server.getIdleJail(name)
# Filter # Filter
elif action[1] == "addignoreip":
value = action[2]
self.server.addIgnoreIP(name, value)
return self.server.getIgnoreIP(name)
elif action[1] == "delignoreip":
value = action[2]
self.server.delIgnoreIP(name, value)
return self.server.getIgnoreIP(name)
elif action[1] == "addlogpath": elif action[1] == "addlogpath":
value = action[2:] value = action[2:]
for path in value: for path in value:
@ -217,8 +234,13 @@ class Transmitter:
# Logging # Logging
if name == "loglevel": if name == "loglevel":
return self.server.getLogLevel() return self.server.getLogLevel()
elif name == "logtarget":
return self.server.getLogTarget()
# Filter
elif action[1] == "logpath": elif action[1] == "logpath":
return self.server.getLogPath(name) return self.server.getLogPath(name)
elif action[1] == "ignoreip":
return self.server.getIgnoreIP(name)
elif action[1] == "timeregex": elif action[1] == "timeregex":
return self.server.getTimeRegex(name) return self.server.getTimeRegex(name)
elif action[1] == "timepattern": elif action[1] == "timepattern":
@ -231,7 +253,7 @@ class Transmitter:
return self.server.getFindTime(name) return self.server.getFindTime(name)
elif action[1] == "maxretry": elif action[1] == "maxretry":
return self.server.getMaxRetry(name) return self.server.getMaxRetry(name)
# Filter # Action
elif action[1] == "bantime": elif action[1] == "bantime":
return self.server.getBanTime(name) return self.server.getBanTime(name)
elif action[1] == "addaction": elif action[1] == "addaction":