mirror of https://github.com/huashengdun/webssh
Refactored handler.py
parent
fbb3e466b2
commit
1f835f5a70
|
@ -18,33 +18,33 @@ class TestMixinHandler(unittest.TestCase):
|
||||||
handler = MixinHandler()
|
handler = MixinHandler()
|
||||||
options.fbidhttp = True
|
options.fbidhttp = True
|
||||||
|
|
||||||
handler.context = Mock(
|
context = Mock(
|
||||||
address=('8.8.8.8', 8888),
|
address=('8.8.8.8', 8888),
|
||||||
trusted_downstream=['127.0.0.1'],
|
trusted_downstream=['127.0.0.1'],
|
||||||
_orig_protocol='http'
|
_orig_protocol='http'
|
||||||
)
|
)
|
||||||
self.assertTrue(handler.is_forbidden())
|
self.assertTrue(handler.is_forbidden(context))
|
||||||
|
|
||||||
handler.context = Mock(
|
context = Mock(
|
||||||
address=('8.8.8.8', 8888),
|
address=('8.8.8.8', 8888),
|
||||||
trusted_downstream=[],
|
trusted_downstream=[],
|
||||||
_orig_protocol='http'
|
_orig_protocol='http'
|
||||||
)
|
)
|
||||||
self.assertTrue(handler.is_forbidden())
|
self.assertTrue(handler.is_forbidden(context))
|
||||||
|
|
||||||
handler.context = Mock(
|
context = Mock(
|
||||||
address=('192.168.1.1', 8888),
|
address=('192.168.1.1', 8888),
|
||||||
trusted_downstream=[],
|
trusted_downstream=[],
|
||||||
_orig_protocol='http'
|
_orig_protocol='http'
|
||||||
)
|
)
|
||||||
self.assertIsNone(handler.is_forbidden())
|
self.assertIsNone(handler.is_forbidden(context))
|
||||||
|
|
||||||
handler.context = Mock(
|
context = Mock(
|
||||||
address=('8.8.8.8', 8888),
|
address=('8.8.8.8', 8888),
|
||||||
trusted_downstream=[],
|
trusted_downstream=[],
|
||||||
_orig_protocol='https'
|
_orig_protocol='https'
|
||||||
)
|
)
|
||||||
self.assertIsNone(handler.is_forbidden())
|
self.assertIsNone(handler.is_forbidden(context))
|
||||||
|
|
||||||
def test_get_client_addr(self):
|
def test_get_client_addr(self):
|
||||||
handler = MixinHandler()
|
handler = MixinHandler()
|
||||||
|
|
|
@ -45,22 +45,22 @@ class MixinHandler(object):
|
||||||
'Server': 'TornadoServer'
|
'Server': 'TornadoServer'
|
||||||
}
|
}
|
||||||
|
|
||||||
def initialize(self):
|
def initialize(self, loop=None):
|
||||||
conn = self.request.connection
|
conn = self.request.connection
|
||||||
self.context = conn.context
|
if self.is_forbidden(conn.context):
|
||||||
if self.is_forbidden():
|
|
||||||
result = '{} 403 Forbidden\r\n\r\n'.format(self.request.version)
|
result = '{} 403 Forbidden\r\n\r\n'.format(self.request.version)
|
||||||
conn.stream.write(to_bytes(result))
|
conn.stream.write(to_bytes(result))
|
||||||
conn.close()
|
conn.close()
|
||||||
raise ValueError('Accesss denied')
|
raise ValueError('Accesss denied')
|
||||||
|
self.loop = loop
|
||||||
|
self.context = conn.context
|
||||||
|
|
||||||
def is_forbidden(self):
|
def is_forbidden(self, context):
|
||||||
"""
|
"""
|
||||||
Following requests are forbidden:
|
Following requests are forbidden:
|
||||||
* requests not come from trusted_downstream (if set).
|
* requests not come from trusted_downstream (if set).
|
||||||
* plain http requests from a public network.
|
* plain http requests from a public network.
|
||||||
"""
|
"""
|
||||||
context = self.context
|
|
||||||
ip = context.address[0]
|
ip = context.address[0]
|
||||||
lst = context.trusted_downstream
|
lst = context.trusted_downstream
|
||||||
|
|
||||||
|
@ -123,14 +123,13 @@ class NotFoundHandler(MixinHandler, tornado.web.ErrorHandler):
|
||||||
class IndexHandler(MixinHandler, tornado.web.RequestHandler):
|
class IndexHandler(MixinHandler, tornado.web.RequestHandler):
|
||||||
|
|
||||||
def initialize(self, loop, policy, host_keys_settings):
|
def initialize(self, loop, policy, host_keys_settings):
|
||||||
self.loop = loop
|
super(IndexHandler, self).initialize(loop)
|
||||||
self.policy = policy
|
self.policy = policy
|
||||||
self.host_keys_settings = host_keys_settings
|
self.host_keys_settings = host_keys_settings
|
||||||
self.ssh_client = self.get_ssh_client()
|
self.ssh_client = self.get_ssh_client()
|
||||||
self.privatekey_filename = None
|
self.privatekey_filename = None
|
||||||
self.debug = self.settings.get('debug', False)
|
self.debug = self.settings.get('debug', False)
|
||||||
self.result = dict(id=None, status=None, encoding=None)
|
self.result = dict(id=None, status=None, encoding=None)
|
||||||
super(IndexHandler, self).initialize()
|
|
||||||
|
|
||||||
def write_error(self, status_code, **kwargs):
|
def write_error(self, status_code, **kwargs):
|
||||||
if self.request.method != 'POST' or not swallow_http_errors:
|
if self.request.method != 'POST' or not swallow_http_errors:
|
||||||
|
@ -329,9 +328,8 @@ class IndexHandler(MixinHandler, tornado.web.RequestHandler):
|
||||||
class WsockHandler(MixinHandler, tornado.websocket.WebSocketHandler):
|
class WsockHandler(MixinHandler, tornado.websocket.WebSocketHandler):
|
||||||
|
|
||||||
def initialize(self, loop):
|
def initialize(self, loop):
|
||||||
self.loop = loop
|
super(WsockHandler, self).initialize(loop)
|
||||||
self.worker_ref = None
|
self.worker_ref = None
|
||||||
super(WsockHandler, self).initialize()
|
|
||||||
|
|
||||||
def open(self):
|
def open(self):
|
||||||
self.src_addr = self.get_client_addr()
|
self.src_addr = self.get_client_addr()
|
||||||
|
|
Loading…
Reference in New Issue