Refactored handler.py

pull/38/head
Sheng 2018-10-19 18:18:55 +08:00
parent fbb3e466b2
commit 1f835f5a70
2 changed files with 15 additions and 17 deletions

View File

@ -18,33 +18,33 @@ class TestMixinHandler(unittest.TestCase):
handler = MixinHandler() handler = MixinHandler()
options.fbidhttp = True options.fbidhttp = True
handler.context = Mock( context = Mock(
address=('8.8.8.8', 8888), address=('8.8.8.8', 8888),
trusted_downstream=['127.0.0.1'], trusted_downstream=['127.0.0.1'],
_orig_protocol='http' _orig_protocol='http'
) )
self.assertTrue(handler.is_forbidden()) self.assertTrue(handler.is_forbidden(context))
handler.context = Mock( context = Mock(
address=('8.8.8.8', 8888), address=('8.8.8.8', 8888),
trusted_downstream=[], trusted_downstream=[],
_orig_protocol='http' _orig_protocol='http'
) )
self.assertTrue(handler.is_forbidden()) self.assertTrue(handler.is_forbidden(context))
handler.context = Mock( context = Mock(
address=('192.168.1.1', 8888), address=('192.168.1.1', 8888),
trusted_downstream=[], trusted_downstream=[],
_orig_protocol='http' _orig_protocol='http'
) )
self.assertIsNone(handler.is_forbidden()) self.assertIsNone(handler.is_forbidden(context))
handler.context = Mock( context = Mock(
address=('8.8.8.8', 8888), address=('8.8.8.8', 8888),
trusted_downstream=[], trusted_downstream=[],
_orig_protocol='https' _orig_protocol='https'
) )
self.assertIsNone(handler.is_forbidden()) self.assertIsNone(handler.is_forbidden(context))
def test_get_client_addr(self): def test_get_client_addr(self):
handler = MixinHandler() handler = MixinHandler()

View File

@ -45,22 +45,22 @@ class MixinHandler(object):
'Server': 'TornadoServer' 'Server': 'TornadoServer'
} }
def initialize(self): def initialize(self, loop=None):
conn = self.request.connection conn = self.request.connection
self.context = conn.context if self.is_forbidden(conn.context):
if self.is_forbidden():
result = '{} 403 Forbidden\r\n\r\n'.format(self.request.version) result = '{} 403 Forbidden\r\n\r\n'.format(self.request.version)
conn.stream.write(to_bytes(result)) conn.stream.write(to_bytes(result))
conn.close() conn.close()
raise ValueError('Accesss denied') raise ValueError('Accesss denied')
self.loop = loop
self.context = conn.context
def is_forbidden(self): def is_forbidden(self, context):
""" """
Following requests are forbidden: Following requests are forbidden:
* requests not come from trusted_downstream (if set). * requests not come from trusted_downstream (if set).
* plain http requests from a public network. * plain http requests from a public network.
""" """
context = self.context
ip = context.address[0] ip = context.address[0]
lst = context.trusted_downstream lst = context.trusted_downstream
@ -123,14 +123,13 @@ class NotFoundHandler(MixinHandler, tornado.web.ErrorHandler):
class IndexHandler(MixinHandler, tornado.web.RequestHandler): class IndexHandler(MixinHandler, tornado.web.RequestHandler):
def initialize(self, loop, policy, host_keys_settings): def initialize(self, loop, policy, host_keys_settings):
self.loop = loop super(IndexHandler, self).initialize(loop)
self.policy = policy self.policy = policy
self.host_keys_settings = host_keys_settings self.host_keys_settings = host_keys_settings
self.ssh_client = self.get_ssh_client() self.ssh_client = self.get_ssh_client()
self.privatekey_filename = None self.privatekey_filename = None
self.debug = self.settings.get('debug', False) self.debug = self.settings.get('debug', False)
self.result = dict(id=None, status=None, encoding=None) self.result = dict(id=None, status=None, encoding=None)
super(IndexHandler, self).initialize()
def write_error(self, status_code, **kwargs): def write_error(self, status_code, **kwargs):
if self.request.method != 'POST' or not swallow_http_errors: if self.request.method != 'POST' or not swallow_http_errors:
@ -329,9 +328,8 @@ class IndexHandler(MixinHandler, tornado.web.RequestHandler):
class WsockHandler(MixinHandler, tornado.websocket.WebSocketHandler): class WsockHandler(MixinHandler, tornado.websocket.WebSocketHandler):
def initialize(self, loop): def initialize(self, loop):
self.loop = loop super(WsockHandler, self).initialize(loop)
self.worker_ref = None self.worker_ref = None
super(WsockHandler, self).initialize()
def open(self): def open(self):
self.src_addr = self.get_client_addr() self.src_addr = self.get_client_addr()