diff --git a/config/action.d/smtp.py b/config/action.d/smtp.py index 4210a920..26754d41 100644 --- a/config/action.d/smtp.py +++ b/config/action.d/smtp.py @@ -53,23 +53,23 @@ Matches for %(ip)s for jail %(jailname)s: class SMTPAction(ActionBase): - def __init__(self, jail, name, initOpts): - super(SMTPAction, self).__init__(jail, name, initOpts) - if initOpts is None: - initOpts = dict() # We have defaults for everything - self.host = initOpts.get('host', "localhost:25") - #TODO: self.ssl = initOpts.get('ssl', "no") == 'yes' + def __init__( + self, jail, name, host="localhost", user=None, password=None, + sendername="Fail2Ban", sender="fail2ban", dest="root", matches=None): - self.user = initOpts.get('user', '') - self.password = initOpts.get('password') + super(SMTPAction, self).__init__(jail, name) - self.fromname = initOpts.get('sendername', "Fail2Ban") - self.fromaddr = initOpts.get('sender', "fail2ban") - self.toaddr = initOpts.get('dest', "root") + self.host = host + #TODO: self.ssl = ssl - self.smtp = smtplib.SMTP() + self.user = user + self.password =password - self.matches = initOpts.get('matches') + self.fromname = sendername + self.fromaddr = sender + self.toaddr = dest + + self.matches = matches self.message_values = CallingMap( jailname = self.jail.getName(), # Doesn't change @@ -84,12 +84,13 @@ class SMTPAction(ActionBase): msg['To'] = self.toaddr msg['Date'] = formatdate() + smtp = smtplib.SMTP() try: self.logSys.debug("Connected to SMTP '%s', response: %i: %s", - *self.smtp.connect(self.host)) + *smtp.connect(self.host)) if self.user and self.password: smtp.login(self.user, self.password) - failed_recipients = self.smtp.sendmail( + failed_recipients = smtp.sendmail( self.fromaddr, self.toaddr, msg.as_string()) except smtplib.SMTPConnectError: self.logSys.error("Error connecting to host '%s'", self.host) @@ -112,7 +113,7 @@ class SMTPAction(ActionBase): self.logSys.debug("Email '%s' successfully sent", subject) finally: try: - self.smtp.quit() + smtp.quit() except smtplib.SMTPServerDisconnected: pass # Not connected diff --git a/fail2ban/server/action.py b/fail2ban/server/action.py index c4465503..01010bb0 100644 --- a/fail2ban/server/action.py +++ b/fail2ban/server/action.py @@ -90,7 +90,7 @@ class ActionBase(object): return False return True - def __init__(self, jail, name, initOpts=None): + def __init__(self, jail, name): self._jail = jail self._name = name self._logSys = logging.getLogger( diff --git a/fail2ban/server/actions.py b/fail2ban/server/actions.py index a8d0946a..c63a08e0 100644 --- a/fail2ban/server/actions.py +++ b/fail2ban/server/actions.py @@ -81,7 +81,7 @@ class Actions(JailThread): raise RuntimeError( "%s module %s does not implment required methods" % ( pythonModule, customActionModule.Action.__name__)) - action = customActionModule.Action(self.jail, name, initOpts) + action = customActionModule.Action(self.jail, name, **initOpts) self.__actions.append(action) ## diff --git a/fail2ban/tests/actionstestcase.py b/fail2ban/tests/actionstestcase.py index 87631dc2..7c0609de 100644 --- a/fail2ban/tests/actionstestcase.py +++ b/fail2ban/tests/actionstestcase.py @@ -86,7 +86,8 @@ class ExecuteActions(LogCaptureTestCase): def testAddActionPython(self): self.__actions.addAction( - "Action", os.path.join(TEST_FILES_DIR, "action.d/action.py"), {}) + "Action", os.path.join(TEST_FILES_DIR, "action.d/action.py"), + {'opt1': 'value'}) self.assertTrue(self._is_logged("TestAction initialised")) @@ -100,3 +101,17 @@ class ExecuteActions(LogCaptureTestCase): self.assertRaises(IOError, self.__actions.addAction, "Action3", "/does/not/exist.py", {}) + + # With optional argument + self.__actions.addAction( + "Action4", os.path.join(TEST_FILES_DIR, "action.d/action.py"), + {'opt1': 'value', 'opt2': 'value2'}) + # With too many arguments + self.assertRaises( + TypeError, self.__actions.addAction, "Action5", + os.path.join(TEST_FILES_DIR, "action.d/action.py"), + {'opt1': 'value', 'opt2': 'value2', 'opt3': 'value3'}) + # Missing required argument + self.assertRaises( + TypeError, self.__actions.addAction, "Action5", + os.path.join(TEST_FILES_DIR, "action.d/action.py"), {}) diff --git a/fail2ban/tests/files/action.d/action.py b/fail2ban/tests/files/action.d/action.py index 5dcafe31..f0535b76 100644 --- a/fail2ban/tests/files/action.d/action.py +++ b/fail2ban/tests/files/action.d/action.py @@ -3,20 +3,20 @@ from fail2ban.server.action import ActionBase class TestAction(ActionBase): - def __init__(self, *args, **kwargs): - super(TestAction, self).__init__(*args, **kwargs) + def __init__(self, jail, name, opt1, opt2=None): + super(TestAction, self).__init__(jail, name) self.logSys.debug("%s initialised" % self.__class__.__name__) - def execActionStart(self, *args, **kwargs): + def execActionStart(self): self.logSys.debug("%s action start" % self.__class__.__name__) - def execActionStop(self, *args, **kwargs): + def execActionStop(self): self.logSys.debug("%s action stop" % self.__class__.__name__) - def execActionBan(self, *args, **kwargs): + def execActionBan(self, aInfo): self.logSys.debug("%s action ban" % self.__class__.__name__) - def execActionUnban(self, *args, **kwargs): + def execActionUnban(self, aInfo): self.logSys.debug("%s action unban" % self.__class__.__name__) Action = TestAction diff --git a/fail2ban/tests/servertestcase.py b/fail2ban/tests/servertestcase.py index a2d90db1..21c6c371 100644 --- a/fail2ban/tests/servertestcase.py +++ b/fail2ban/tests/servertestcase.py @@ -566,7 +566,8 @@ class Transmitter(TransmitterBase): ["set", self.jailName, "delaction", "Doesn't exist"])[0],1) self.assertEqual( self.transm.proceed(["set", self.jailName, "addaction", action, - os.path.join(TEST_FILES_DIR, "action.d", "action.py"), "{}"]), + os.path.join(TEST_FILES_DIR, "action.d", "action.py"), + '{"opt1": "value"}']), (0, action)) for cmd, value in zip(cmdList, cmdValueList): self.assertTrue(