mirror of https://github.com/huashengdun/webssh
Move method get_value to MixinHandler
parent
f3d9d297bb
commit
48acf09f21
|
@ -66,6 +66,23 @@ class TestApp(AsyncHTTPTestCase):
|
||||||
def test_app_with_invalid_form(self):
|
def test_app_with_invalid_form(self):
|
||||||
response = self.fetch('/')
|
response = self.fetch('/')
|
||||||
self.assertEqual(response.code, 200)
|
self.assertEqual(response.code, 200)
|
||||||
|
|
||||||
|
body = 'port=7000&username=admin&password'
|
||||||
|
response = self.fetch('/', method='POST', body=body)
|
||||||
|
self.assertIn(b'Missing argument hostname', response.body)
|
||||||
|
|
||||||
|
body = 'hostname=127.0.0.1&username=admin&password'
|
||||||
|
response = self.fetch('/', method='POST', body=body)
|
||||||
|
self.assertIn(b'Missing argument port', response.body)
|
||||||
|
|
||||||
|
body = 'hostname=127.0.0.1&port=7000&password'
|
||||||
|
response = self.fetch('/', method='POST', body=body)
|
||||||
|
self.assertIn(b'Missing argument username', response.body)
|
||||||
|
|
||||||
|
body = 'hostname=127.0.0.1&port=7000&username=admin'
|
||||||
|
response = self.fetch('/', method='POST', body=body)
|
||||||
|
self.assertIn(b'Missing argument password', response.body)
|
||||||
|
|
||||||
body = 'hostname=&port=&username=&password'
|
body = 'hostname=&port=&username=&password'
|
||||||
response = self.fetch('/', method='POST', body=body)
|
response = self.fetch('/', method='POST', body=body)
|
||||||
self.assertIn(b'The hostname field is required', response.body)
|
self.assertIn(b'The hostname field is required', response.body)
|
||||||
|
@ -74,6 +91,10 @@ class TestApp(AsyncHTTPTestCase):
|
||||||
response = self.fetch('/', method='POST', body=body)
|
response = self.fetch('/', method='POST', body=body)
|
||||||
self.assertIn(b'The port field is required', response.body)
|
self.assertIn(b'The port field is required', response.body)
|
||||||
|
|
||||||
|
body = 'hostname=127.0.0.1&port=7000&username=&password'
|
||||||
|
response = self.fetch('/', method='POST', body=body)
|
||||||
|
self.assertIn(b'The username field is required', response.body)
|
||||||
|
|
||||||
body = 'hostname=127.0.0&port=22&username=&password'
|
body = 'hostname=127.0.0&port=22&username=&password'
|
||||||
response = self.fetch('/', method='POST', body=body)
|
response = self.fetch('/', method='POST', body=body)
|
||||||
self.assertIn(b'Invalid hostname', response.body)
|
self.assertIn(b'Invalid hostname', response.body)
|
||||||
|
@ -90,10 +111,6 @@ class TestApp(AsyncHTTPTestCase):
|
||||||
response = self.fetch('/', method='POST', body=body)
|
response = self.fetch('/', method='POST', body=body)
|
||||||
self.assertIn(b'Invalid port', response.body)
|
self.assertIn(b'Invalid port', response.body)
|
||||||
|
|
||||||
body = 'hostname=127.0.0.1&port=7000&username=&password'
|
|
||||||
response = self.fetch('/', method='POST', body=body)
|
|
||||||
self.assertIn(b'The username field is required', response.body) # noqa
|
|
||||||
|
|
||||||
def test_app_with_wrong_credentials(self):
|
def test_app_with_wrong_credentials(self):
|
||||||
response = self.fetch('/')
|
response = self.fetch('/')
|
||||||
self.assertEqual(response.code, 200)
|
self.assertEqual(response.code, 200)
|
||||||
|
@ -150,6 +167,66 @@ class TestApp(AsyncHTTPTestCase):
|
||||||
self.assertEqual(to_str(msg, data['encoding']), banner)
|
self.assertEqual(to_str(msg, data['encoding']), banner)
|
||||||
ws.close()
|
ws.close()
|
||||||
|
|
||||||
|
@tornado.testing.gen_test
|
||||||
|
def test_app_with_correct_credentials_but_without_id_argument(self):
|
||||||
|
url = self.get_url('/')
|
||||||
|
client = self.get_http_client()
|
||||||
|
response = yield client.fetch(url)
|
||||||
|
self.assertEqual(response.code, 200)
|
||||||
|
|
||||||
|
response = yield client.fetch(url, method='POST', body=self.body)
|
||||||
|
data = json.loads(to_str(response.body))
|
||||||
|
self.assertIsNone(data['status'])
|
||||||
|
self.assertIsNotNone(data['id'])
|
||||||
|
self.assertIsNotNone(data['encoding'])
|
||||||
|
|
||||||
|
url = url.replace('http', 'ws')
|
||||||
|
ws_url = url + 'ws'
|
||||||
|
ws = yield tornado.websocket.websocket_connect(ws_url)
|
||||||
|
msg = yield ws.read_message()
|
||||||
|
self.assertIsNone(msg)
|
||||||
|
self.assertIn('Missing argument id', ws.close_reason)
|
||||||
|
|
||||||
|
@tornado.testing.gen_test
|
||||||
|
def test_app_with_correct_credentials_but_epmpty_id(self):
|
||||||
|
url = self.get_url('/')
|
||||||
|
client = self.get_http_client()
|
||||||
|
response = yield client.fetch(url)
|
||||||
|
self.assertEqual(response.code, 200)
|
||||||
|
|
||||||
|
response = yield client.fetch(url, method='POST', body=self.body)
|
||||||
|
data = json.loads(to_str(response.body))
|
||||||
|
self.assertIsNone(data['status'])
|
||||||
|
self.assertIsNotNone(data['id'])
|
||||||
|
self.assertIsNotNone(data['encoding'])
|
||||||
|
|
||||||
|
url = url.replace('http', 'ws')
|
||||||
|
ws_url = url + 'ws?id='
|
||||||
|
ws = yield tornado.websocket.websocket_connect(ws_url)
|
||||||
|
msg = yield ws.read_message()
|
||||||
|
self.assertIsNone(msg)
|
||||||
|
self.assertIn('field is required', ws.close_reason)
|
||||||
|
|
||||||
|
@tornado.testing.gen_test
|
||||||
|
def test_app_with_correct_credentials_but_wrong_id(self):
|
||||||
|
url = self.get_url('/')
|
||||||
|
client = self.get_http_client()
|
||||||
|
response = yield client.fetch(url)
|
||||||
|
self.assertEqual(response.code, 200)
|
||||||
|
|
||||||
|
response = yield client.fetch(url, method='POST', body=self.body)
|
||||||
|
data = json.loads(to_str(response.body))
|
||||||
|
self.assertIsNone(data['status'])
|
||||||
|
self.assertIsNotNone(data['id'])
|
||||||
|
self.assertIsNotNone(data['encoding'])
|
||||||
|
|
||||||
|
url = url.replace('http', 'ws')
|
||||||
|
ws_url = url + 'ws?id=1' + data['id']
|
||||||
|
ws = yield tornado.websocket.websocket_connect(ws_url)
|
||||||
|
msg = yield ws.read_message()
|
||||||
|
self.assertIsNone(msg)
|
||||||
|
self.assertIn('Websocket authentication failed', ws.close_reason)
|
||||||
|
|
||||||
@tornado.testing.gen_test
|
@tornado.testing.gen_test
|
||||||
def test_app_with_correct_credentials_user_bar(self):
|
def test_app_with_correct_credentials_user_bar(self):
|
||||||
url = self.get_url('/')
|
url = self.get_url('/')
|
||||||
|
|
|
@ -40,6 +40,22 @@ def parse_encoding(data):
|
||||||
|
|
||||||
class MixinHandler(object):
|
class MixinHandler(object):
|
||||||
|
|
||||||
|
arguments_required = {} # agruments must be deliverd
|
||||||
|
empty_allowed = {} # emtpy value alllowed
|
||||||
|
|
||||||
|
def get_value(self, name):
|
||||||
|
is_required = name in self.arguments_required
|
||||||
|
|
||||||
|
try:
|
||||||
|
value = self.get_argument(name)
|
||||||
|
except tornado.web.MissingArgumentError:
|
||||||
|
if is_required:
|
||||||
|
raise
|
||||||
|
else:
|
||||||
|
if not value and is_required and name not in self.empty_allowed:
|
||||||
|
raise ValueError('The {} field is required.'.format(name))
|
||||||
|
return value
|
||||||
|
|
||||||
def get_real_client_addr(self):
|
def get_real_client_addr(self):
|
||||||
ip = self.request.headers.get('X-Real-Ip')
|
ip = self.request.headers.get('X-Real-Ip')
|
||||||
port = self.request.headers.get('X-Real-Port')
|
port = self.request.headers.get('X-Real-Port')
|
||||||
|
@ -62,6 +78,9 @@ class MixinHandler(object):
|
||||||
|
|
||||||
class IndexHandler(MixinHandler, tornado.web.RequestHandler):
|
class IndexHandler(MixinHandler, tornado.web.RequestHandler):
|
||||||
|
|
||||||
|
arguments_required = {'hostname', 'port', 'username', 'password'}
|
||||||
|
empty_allowed = {'password'}
|
||||||
|
|
||||||
def initialize(self, loop, policy, host_keys_settings):
|
def initialize(self, loop, policy, host_keys_settings):
|
||||||
self.loop = loop
|
self.loop = loop
|
||||||
self.policy = policy
|
self.policy = policy
|
||||||
|
@ -71,10 +90,7 @@ class IndexHandler(MixinHandler, tornado.web.RequestHandler):
|
||||||
def get_privatekey(self):
|
def get_privatekey(self):
|
||||||
lst = self.request.files.get('privatekey') # multipart form
|
lst = self.request.files.get('privatekey') # multipart form
|
||||||
if not lst:
|
if not lst:
|
||||||
try:
|
return self.get_value('privatekey') # urlencoded form
|
||||||
return self.get_argument('privatekey') # urlencoded form
|
|
||||||
except tornado.web.MissingArgumentError:
|
|
||||||
pass
|
|
||||||
else:
|
else:
|
||||||
self.filename = lst[0]['filename']
|
self.filename = lst[0]['filename']
|
||||||
data = lst[0]['body']
|
data = lst[0]['body']
|
||||||
|
@ -136,17 +152,11 @@ class IndexHandler(MixinHandler, tornado.web.RequestHandler):
|
||||||
|
|
||||||
raise ValueError('Invalid port: {}'.format(value))
|
raise ValueError('Invalid port: {}'.format(value))
|
||||||
|
|
||||||
def get_value(self, name):
|
|
||||||
value = self.get_argument(name)
|
|
||||||
if not value:
|
|
||||||
raise ValueError('The {} field is required.'.format(name))
|
|
||||||
return value
|
|
||||||
|
|
||||||
def get_args(self):
|
def get_args(self):
|
||||||
hostname = self.get_hostname()
|
hostname = self.get_hostname()
|
||||||
port = self.get_port()
|
port = self.get_port()
|
||||||
username = self.get_value('username')
|
username = self.get_value('username')
|
||||||
password = self.get_argument('password')
|
password = self.get_value('password')
|
||||||
privatekey = self.get_privatekey()
|
privatekey = self.get_privatekey()
|
||||||
pkey = self.get_pkey_obj(privatekey, password, self.filename) \
|
pkey = self.get_pkey_obj(privatekey, password, self.filename) \
|
||||||
if privatekey else None
|
if privatekey else None
|
||||||
|
@ -234,6 +244,8 @@ class IndexHandler(MixinHandler, tornado.web.RequestHandler):
|
||||||
|
|
||||||
class WsockHandler(MixinHandler, tornado.websocket.WebSocketHandler):
|
class WsockHandler(MixinHandler, tornado.websocket.WebSocketHandler):
|
||||||
|
|
||||||
|
arguments_required = {'id'}
|
||||||
|
|
||||||
def initialize(self, loop):
|
def initialize(self, loop):
|
||||||
self.loop = loop
|
self.loop = loop
|
||||||
self.worker_ref = None
|
self.worker_ref = None
|
||||||
|
@ -244,7 +256,12 @@ class WsockHandler(MixinHandler, tornado.websocket.WebSocketHandler):
|
||||||
def open(self):
|
def open(self):
|
||||||
self.src_addr = self.get_client_addr()
|
self.src_addr = self.get_client_addr()
|
||||||
logging.info('Connected from {}:{}'.format(*self.src_addr))
|
logging.info('Connected from {}:{}'.format(*self.src_addr))
|
||||||
worker = workers.get(self.get_argument('id'))
|
try:
|
||||||
|
worker_id = self.get_value('id')
|
||||||
|
except (tornado.web.MissingArgumentError, ValueError) as exc:
|
||||||
|
self.close(reason=str(exc))
|
||||||
|
else:
|
||||||
|
worker = workers.get(worker_id)
|
||||||
if worker and worker.src_addr[0] == self.src_addr[0]:
|
if worker and worker.src_addr[0] == self.src_addr[0]:
|
||||||
workers.pop(worker.id)
|
workers.pop(worker.id)
|
||||||
self.set_nodelay(True)
|
self.set_nodelay(True)
|
||||||
|
|
Loading…
Reference in New Issue