Refactored handler.py

pull/38/head
Sheng 2018-10-19 18:18:55 +08:00
parent fbb3e466b2
commit 1f835f5a70
2 changed files with 15 additions and 17 deletions

View File

@ -18,33 +18,33 @@ class TestMixinHandler(unittest.TestCase):
handler = MixinHandler()
options.fbidhttp = True
handler.context = Mock(
context = Mock(
address=('8.8.8.8', 8888),
trusted_downstream=['127.0.0.1'],
_orig_protocol='http'
)
self.assertTrue(handler.is_forbidden())
self.assertTrue(handler.is_forbidden(context))
handler.context = Mock(
context = Mock(
address=('8.8.8.8', 8888),
trusted_downstream=[],
_orig_protocol='http'
)
self.assertTrue(handler.is_forbidden())
self.assertTrue(handler.is_forbidden(context))
handler.context = Mock(
context = Mock(
address=('192.168.1.1', 8888),
trusted_downstream=[],
_orig_protocol='http'
)
self.assertIsNone(handler.is_forbidden())
self.assertIsNone(handler.is_forbidden(context))
handler.context = Mock(
context = Mock(
address=('8.8.8.8', 8888),
trusted_downstream=[],
_orig_protocol='https'
)
self.assertIsNone(handler.is_forbidden())
self.assertIsNone(handler.is_forbidden(context))
def test_get_client_addr(self):
handler = MixinHandler()

View File

@ -45,22 +45,22 @@ class MixinHandler(object):
'Server': 'TornadoServer'
}
def initialize(self):
def initialize(self, loop=None):
conn = self.request.connection
self.context = conn.context
if self.is_forbidden():
if self.is_forbidden(conn.context):
result = '{} 403 Forbidden\r\n\r\n'.format(self.request.version)
conn.stream.write(to_bytes(result))
conn.close()
raise ValueError('Accesss denied')
self.loop = loop
self.context = conn.context
def is_forbidden(self):
def is_forbidden(self, context):
"""
Following requests are forbidden:
* requests not come from trusted_downstream (if set).
* plain http requests from a public network.
"""
context = self.context
ip = context.address[0]
lst = context.trusted_downstream
@ -123,14 +123,13 @@ class NotFoundHandler(MixinHandler, tornado.web.ErrorHandler):
class IndexHandler(MixinHandler, tornado.web.RequestHandler):
def initialize(self, loop, policy, host_keys_settings):
self.loop = loop
super(IndexHandler, self).initialize(loop)
self.policy = policy
self.host_keys_settings = host_keys_settings
self.ssh_client = self.get_ssh_client()
self.privatekey_filename = None
self.debug = self.settings.get('debug', False)
self.result = dict(id=None, status=None, encoding=None)
super(IndexHandler, self).initialize()
def write_error(self, status_code, **kwargs):
if self.request.method != 'POST' or not swallow_http_errors:
@ -329,9 +328,8 @@ class IndexHandler(MixinHandler, tornado.web.RequestHandler):
class WsockHandler(MixinHandler, tornado.websocket.WebSocketHandler):
def initialize(self, loop):
self.loop = loop
super(WsockHandler, self).initialize(loop)
self.worker_ref = None
super(WsockHandler, self).initialize()
def open(self):
self.src_addr = self.get_client_addr()