diff --git a/tests/test_app.py b/tests/test_app.py index 75357a3..6be47f3 100644 --- a/tests/test_app.py +++ b/tests/test_app.py @@ -11,7 +11,9 @@ from tornado.options import options from tests.sshserver import run_ssh_server, banner from tests.utils import encode_multipart_formdata, read_file, make_tests_data_path # noqa from webssh.main import make_app, make_handlers -from webssh.settings import get_app_settings, max_body_size +from webssh.settings import ( + get_app_settings, max_body_size, swallow_http_errors +) from webssh.utils import to_str try: @@ -65,59 +67,56 @@ class TestApp(AsyncHTTPTestCase): options.update(max_body_size=max_body_size) return options + def my_assertIn(self, part, whole): + if swallow_http_errors: + self.assertIn(part, whole) + else: + self.assertIn(b'Bad Request', whole) + def test_app_with_invalid_form_for_missing_argument(self): response = self.fetch('/') self.assertEqual(response.code, 200) body = 'port=7000&username=admin&password' response = self.fetch('/', method='POST', body=body) - self.assertEqual(response.code, 400) - self.assertIn(b'Missing argument hostname', response.body) + self.my_assertIn(b'Missing argument hostname', response.body) body = 'hostname=127.0.0.1&username=admin&password' - self.assertEqual(response.code, 400) response = self.fetch('/', method='POST', body=body) - self.assertIn(b'Missing argument port', response.body) + self.my_assertIn(b'Missing argument port', response.body) body = 'hostname=127.0.0.1&port=7000&password' - self.assertEqual(response.code, 400) response = self.fetch('/', method='POST', body=body) - self.assertIn(b'Missing argument username', response.body) + self.my_assertIn(b'Missing argument username', response.body) body = 'hostname=&port=&username=&password' response = self.fetch('/', method='POST', body=body) - self.assertEqual(response.code, 400) - self.assertIn(b'Missing value hostname', response.body) + self.my_assertIn(b'Missing value hostname', response.body) body = 'hostname=127.0.0.1&port=&username=&password' response = self.fetch('/', method='POST', body=body) - self.assertEqual(response.code, 400) - self.assertIn(b'Missing value port', response.body) + self.my_assertIn(b'Missing value port', response.body) body = 'hostname=127.0.0.1&port=7000&username=&password' response = self.fetch('/', method='POST', body=body) - self.assertEqual(response.code, 400) - self.assertIn(b'Missing value username', response.body) + self.my_assertIn(b'Missing value username', response.body) def test_app_with_invalid_form_for_invalid_value(self): body = 'hostname=127.0.0&port=22&username=&password' response = self.fetch('/', method='POST', body=body) - self.assertIn(b'Invalid hostname', response.body) + self.my_assertIn(b'Invalid hostname', response.body) body = 'hostname=http://www.googe.com&port=22&username=&password' response = self.fetch('/', method='POST', body=body) - self.assertEqual(response.code, 400) - self.assertIn(b'Invalid hostname', response.body) + self.my_assertIn(b'Invalid hostname', response.body) body = 'hostname=127.0.0.1&port=port&username=&password' response = self.fetch('/', method='POST', body=body) - self.assertEqual(response.code, 400) - self.assertIn(b'Invalid port', response.body) + self.my_assertIn(b'Invalid port', response.body) body = 'hostname=127.0.0.1&port=70000&username=&password' response = self.fetch('/', method='POST', body=body) - self.assertEqual(response.code, 400) - self.assertIn(b'Invalid port', response.body) + self.my_assertIn(b'Invalid port', response.body) def test_app_with_wrong_hostname_ip(self): body = 'hostname=127.0.0.1&port=7000&username=admin' @@ -370,10 +369,16 @@ class TestApp(AsyncHTTPTestCase): headers = { 'Content-Type': content_type, 'content-length': str(len(body)) } - with self.assertRaises(HTTPError) as ctx: - yield client.fetch(url, method='POST', headers=headers, body=body) - self.assertEqual(ctx.exception.code, 400) - self.assertIn('Invalid private key', ctx.exception.message) + + if swallow_http_errors: + response = yield client.fetch(url, method='POST', headers=headers, + body=body) + self.assertIn(b'Invalid private key', response.body) + else: + with self.assertRaises(HTTPError) as ctx: + yield client.fetch(url, method='POST', headers=headers, + body=body) + self.assertIn('Bad Request', ctx.exception.message) @tornado.testing.gen_test def test_app_auth_with_pubkey_exceeds_key_max_size(self): @@ -389,10 +394,15 @@ class TestApp(AsyncHTTPTestCase): headers = { 'Content-Type': content_type, 'content-length': str(len(body)) } - with self.assertRaises(HTTPError) as ctx: - yield client.fetch(url, method='POST', headers=headers, body=body) - self.assertEqual(ctx.exception.code, 400) - self.assertIn('Invalid private key', ctx.exception.message) + if swallow_http_errors: + response = yield client.fetch(url, method='POST', headers=headers, + body=body) + self.assertIn(b'Invalid private key', response.body) + else: + with self.assertRaises(HTTPError) as ctx: + yield client.fetch(url, method='POST', headers=headers, + body=body) + self.assertIn('Bad Request', ctx.exception.message) @tornado.testing.gen_test def test_app_auth_with_pubkey_cannot_be_decoded_by_multipart_form(self): @@ -411,10 +421,15 @@ class TestApp(AsyncHTTPTestCase): headers = { 'Content-Type': content_type, 'content-length': str(len(body)) } - with self.assertRaises(HTTPError) as ctx: - yield client.fetch(url, method='POST', headers=headers, body=body) - self.assertEqual(ctx.exception.code, 400) - self.assertIn('Invalid unicode', ctx.exception.message) + if swallow_http_errors: + response = yield client.fetch(url, method='POST', headers=headers, + body=body) + self.assertIn(b'Invalid unicode', response.body) + else: + with self.assertRaises(HTTPError) as ctx: + yield client.fetch(url, method='POST', headers=headers, + body=body) + self.assertIn('Bad Request', ctx.exception.message) @tornado.testing.gen_test def test_app_post_form_with_large_body_size_by_multipart_form(self): @@ -432,8 +447,8 @@ class TestApp(AsyncHTTPTestCase): } with self.assertRaises(HTTPError) as ctx: - yield client.fetch(url, method='POST', headers=headers, body=body) - self.assertEqual(ctx.exception.code, 400) + yield client.fetch(url, method='POST', headers=headers, + body=body) self.assertIn('Bad Request', ctx.exception.message) @tornado.testing.gen_test @@ -447,7 +462,6 @@ class TestApp(AsyncHTTPTestCase): body = self.body + '&privatekey=' + privatekey with self.assertRaises(HTTPError) as ctx: yield client.fetch(url, method='POST', body=body) - self.assertEqual(ctx.exception.code, 400) self.assertIn('Bad Request', ctx.exception.message) @tornado.testing.gen_test diff --git a/tests/test_handler.py b/tests/test_handler.py index e86b2a3..4d35607 100644 --- a/tests/test_handler.py +++ b/tests/test_handler.py @@ -4,7 +4,7 @@ import paramiko from tornado.httputil import HTTPServerRequest from tests.utils import read_file, make_tests_data_path from webssh.handler import ( - MixinHandler, IndexHandler, parse_encoding, InvalidException + MixinHandler, IndexHandler, parse_encoding, InvalidValueError ) @@ -83,7 +83,7 @@ class TestIndexHandler(unittest.TestCase): self.assertIsInstance(pkey, cls) pkey = IndexHandler.get_pkey_obj(key, 'iginored', fname) self.assertIsInstance(pkey, cls) - with self.assertRaises(InvalidException) as exc: + with self.assertRaises(InvalidValueError) as exc: pkey = IndexHandler.get_pkey_obj('x'+key, None, fname) self.assertIn('Invalid private key', str(exc)) @@ -94,9 +94,9 @@ class TestIndexHandler(unittest.TestCase): key = read_file(make_tests_data_path(fname)) pkey = IndexHandler.get_pkey_obj(key, password, fname) self.assertIsInstance(pkey, cls) - with self.assertRaises(InvalidException) as exc: + with self.assertRaises(InvalidValueError) as exc: pkey = IndexHandler.get_pkey_obj(key, 'wrongpass', fname) self.assertIn('Wrong password', str(exc)) - with self.assertRaises(InvalidException) as exc: + with self.assertRaises(InvalidValueError) as exc: pkey = IndexHandler.get_pkey_obj('x'+key, password, fname) self.assertIn('Invalid private key', str(exc)) diff --git a/webssh/handler.py b/webssh/handler.py index 41c0720..fbba347 100644 --- a/webssh/handler.py +++ b/webssh/handler.py @@ -10,11 +10,12 @@ import paramiko import tornado.web from tornado.ioloop import IOLoop -from webssh.worker import Worker, recycle_worker, workers +from webssh.settings import swallow_http_errors from webssh.utils import ( is_valid_ipv4_address, is_valid_ipv6_address, is_valid_port, is_valid_hostname, to_bytes, to_str, UnicodeType ) +from webssh.worker import Worker, recycle_worker, workers try: from concurrent.futures import Future @@ -38,34 +39,24 @@ def parse_encoding(data): return s.strip('"').split('.')[-1] -class InvalidException(Exception): +class InvalidValueError(Exception): pass class MixinHandler(object): - formater = 'Missing value {}' - - def write_error(self, status_code, **kwargs): - exc_info = kwargs.get('exc_info') - if exc_info and len(exc_info) > 1: - info = str(exc_info[1]) - if info: - self._reason = info.split(':', 1)[-1].strip() - super(MixinHandler, self).write_error(status_code, **kwargs) - def get_value(self, name): value = self.get_argument(name) if not value: - raise InvalidException(self.formater.format(name)) + raise InvalidValueError('Missing value {}'.format(name)) return value def get_real_client_addr(self): ip = self.request.headers.get('X-Real-Ip') port = self.request.headers.get('X-Real-Port') - if ip is None and port is None: # suppose the server doesn't use nginx - return + if ip is None and port is None: + return # suppose this app doesn't run after an nginx server if is_valid_ipv4_address(ip) or is_valid_ipv6_address(ip): try: @@ -87,19 +78,33 @@ class IndexHandler(MixinHandler, tornado.web.RequestHandler): self.policy = policy self.host_keys_settings = host_keys_settings self.filename = None + self.result = dict(id=None, status=None, encoding=None) + + def write_error(self, status_code, **kwargs): + if self.settings.get('serve_traceback') or status_code == 500 or \ + not swallow_http_errors: + super(MixinHandler, self).write_error(status_code, **kwargs) + else: + exc_info = kwargs.get('exc_info') + if exc_info: + self._reason = exc_info[1].log_message + self.result.update(status=self._reason) + self.set_status(200) + self.finish(self.result) def get_privatekey(self): - lst = self.request.files.get('privatekey') # multipart form + name = 'privatekey' + lst = self.request.files.get(name) # multipart form if not lst: - return self.get_argument('privatekey', u'') # urlencoded form + return self.get_argument(name, u'') # urlencoded form else: self.filename = lst[0]['filename'] data = lst[0]['body'] if len(data) > KEY_MAX_SIZE: - raise InvalidException( + raise InvalidValueError( 'Invalid private key: {}'.format(self.filename) ) - return self.decode_argument(data, name=self.filename) + return self.decode_argument(data, name=name) @classmethod def get_specific_pkey(cls, pkeycls, privatekey, password): @@ -130,7 +135,7 @@ class IndexHandler(MixinHandler, tornado.web.RequestHandler): error = ( 'Wrong password {!r} for decrypting the private key.' ) .format(password) - raise InvalidException(error) + raise InvalidValueError(error) return pkey @@ -138,7 +143,7 @@ class IndexHandler(MixinHandler, tornado.web.RequestHandler): value = self.get_value('hostname') if not (is_valid_hostname(value) | is_valid_ipv4_address(value) | is_valid_ipv6_address(value)): - raise InvalidException('Invalid hostname: {}'.format(value)) + raise InvalidValueError('Invalid hostname: {}'.format(value)) return value def get_port(self): @@ -151,7 +156,7 @@ class IndexHandler(MixinHandler, tornado.web.RequestHandler): if is_valid_port(port): return port - raise InvalidException('Invalid port: {}'.format(value)) + raise InvalidValueError('Invalid port: {}'.format(value)) def get_args(self): hostname = self.get_hostname() @@ -189,7 +194,7 @@ class IndexHandler(MixinHandler, tornado.web.RequestHandler): try: args = self.get_args() - except InvalidException as exc: + except InvalidValueError as exc: raise tornado.web.HTTPError(400, str(exc)) dst_addr = (args[0], args[1]) @@ -227,10 +232,6 @@ class IndexHandler(MixinHandler, tornado.web.RequestHandler): @tornado.gen.coroutine def post(self): - worker_id = None - status = None - encoding = None - future = Future() t = threading.Thread(target=self.ssh_connect_wrapped, args=(future,)) t.setDaemon(True) @@ -239,20 +240,17 @@ class IndexHandler(MixinHandler, tornado.web.RequestHandler): try: worker = yield future except (ValueError, paramiko.SSHException) as exc: - status = str(exc) + self.result.update(status=str(exc)) else: - worker_id = worker.id - workers[worker_id] = worker + workers[worker.id] = worker self.loop.call_later(DELAY, recycle_worker, worker) - encoding = worker.encoding + self.result.update(id=worker.id, encoding=worker.encoding) - self.write(dict(id=worker_id, status=status, encoding=encoding)) + self.write(self.result) class WsockHandler(MixinHandler, tornado.websocket.WebSocketHandler): - formater = 'Bad Request (Missing value {})' - def initialize(self, loop): self.loop = loop self.worker_ref = None @@ -265,8 +263,8 @@ class WsockHandler(MixinHandler, tornado.websocket.WebSocketHandler): logging.info('Connected from {}:{}'.format(*self.src_addr)) try: worker_id = self.get_value('id') - except (tornado.web.MissingArgumentError, InvalidException) as exc: - self.close(reason=str(exc).split(':', 1)[-1].strip()) + except (tornado.web.MissingArgumentError, InvalidValueError) as exc: + self.close(reason=str(exc)) else: worker = workers.get(worker_id) if worker and worker.src_addr[0] == self.src_addr[0]: diff --git a/webssh/settings.py b/webssh/settings.py index 0ea22f7..d5b91f2 100644 --- a/webssh/settings.py +++ b/webssh/settings.py @@ -29,6 +29,7 @@ define('version', type=bool, help='Show version information', base_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) max_body_size = 1 * 1024 * 1024 +swallow_http_errors = True def get_app_settings(options): diff --git a/webssh/worker.py b/webssh/worker.py index 569c74b..a53eb86 100644 --- a/webssh/worker.py +++ b/webssh/worker.py @@ -94,7 +94,7 @@ class Worker(object): def close(self, reason=None): logging.info( - 'Closing worker {} with reason: {}'.format(self.id, reason) + 'Closing worker {} with reason: {}'.format(self.id, reason) ) if self.handler: self.loop.remove_handler(self.fd)