From afcf8b52cc9c14f0c93671e379a8736f8df0c885 Mon Sep 17 00:00:00 2001 From: Sheng Date: Sat, 5 Oct 2019 13:18:53 +0800 Subject: [PATCH] Validate the result of locale charmap --- tests/sshserver.py | 9 +++++---- tests/test_app.py | 24 +++++++++++++++++++++++- webssh/handler.py | 19 +++++++++++++------ webssh/utils.py | 8 ++++++++ 4 files changed, 49 insertions(+), 11 deletions(-) diff --git a/tests/sshserver.py b/tests/sshserver.py index 03f4973..807c78c 100644 --- a/tests/sshserver.py +++ b/tests/sshserver.py @@ -53,10 +53,11 @@ class Server(paramiko.ServerInterface): encodings = ['UTF-8', 'GBK', 'UTF-8\r\n', 'GBK\r\n'] - def __init__(self): + def __init__(self, encoding=None): self.shell_event = threading.Event() self.exec_event = threading.Event() self.encoding = random.choice(self.encodings) + self.bad_encoding = encoding self.password_verified = False self.key_verified = False @@ -126,7 +127,7 @@ class Server(paramiko.ServerInterface): ret = False else: ret = True - channel.send(self.encoding) + channel.send(self.bad_encoding or self.encoding) channel.shutdown(1) self.exec_event.set() return ret @@ -145,7 +146,7 @@ class Server(paramiko.ServerInterface): return True -def run_ssh_server(port=2200, running=True): +def run_ssh_server(port=2200, running=True, encoding=None): # now connect sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) @@ -159,7 +160,7 @@ def run_ssh_server(port=2200, running=True): t = paramiko.Transport(client) t.load_server_moduli() t.add_server_key(host_key) - server = Server() + server = Server(encoding) try: t.start_server(server=server) except Exception as e: diff --git a/tests/test_app.py b/tests/test_app.py index 2daffbe..1cb682d 100644 --- a/tests/test_app.py +++ b/tests/test_app.py @@ -515,6 +515,7 @@ class OtherTestBase(TestAppBase): tdstream = '' maxconn = 20 origin = 'same' + encoding = '' body = { 'hostname': '127.0.0.1', 'port': '', @@ -543,7 +544,8 @@ class OtherTestBase(TestAppBase): OtherTestBase.sshserver_port += 1 t = threading.Thread( - target=run_ssh_server, args=(self.sshserver_port, self.running) + target=run_ssh_server, + args=(self.sshserver_port, self.running, self.encoding) ) t.setDaemon(True) t.start() @@ -762,3 +764,23 @@ class TestAppWithCrossOriginOperation(OtherTestBase): self.assertEqual( response.headers.get('Access-Control-Allow-Origin'), self.origin ) + + +class TestAppWithBadEncoding(OtherTestBase): + + encoding = b'\xe7\xbc\x96\xe7\xa0\x81' + + @tornado.testing.gen_test + def test_app_with_a_bad_encoding(self): + response = yield self.async_post('/', self.body) + self.assertIn(b'Bad encoding', response.body) + + +class TestAppWithUnknownEncoding(OtherTestBase): + + encoding = u'UnknownEncoding' + + @tornado.testing.gen_test + def test_app_with_a_bad_encoding(self): + response = yield self.async_post('/', self.body) + self.assertIn(b'Unknown encoding', response.body) diff --git a/webssh/handler.py b/webssh/handler.py index 7eade72..8c39f29 100644 --- a/webssh/handler.py +++ b/webssh/handler.py @@ -14,7 +14,8 @@ from tornado.options import options from tornado.process import cpu_count from webssh.utils import ( is_valid_ip_address, is_valid_port, is_valid_hostname, to_bytes, to_str, - to_int, to_ip_address, UnicodeType, is_ip_hostname, is_same_primary_domain + to_int, to_ip_address, UnicodeType, is_ip_hostname, is_same_primary_domain, + is_valid_encoding ) from webssh.worker import Worker, recycle_worker, clients @@ -392,12 +393,18 @@ class IndexHandler(MixinHandler, tornado.web.RequestHandler): def get_default_encoding(self, ssh): try: _, stdout, _ = ssh.exec_command('locale charmap') - except paramiko.SSHException: - result = None - else: - result = to_str(stdout.read().strip()) + except paramiko.SSHException as exc: + logging.warn(str(exc)) + return u'utf-8' - return result if result else 'utf-8' + try: + enc = to_str(stdout.read().strip(), 'ascii') + except UnicodeDecodeError: + raise ValueError('Bad encoding') + else: + if not is_valid_encoding(enc): + raise ValueError('Unknown encoding "{}"'.format(enc)) + return enc def ssh_connect(self, args): ssh = self.ssh_client diff --git a/webssh/utils.py b/webssh/utils.py index 5bb21e3..3bdfda4 100644 --- a/webssh/utils.py +++ b/webssh/utils.py @@ -54,6 +54,14 @@ def is_valid_port(port): return 0 < port < 65536 +def is_valid_encoding(encoding, ustr=u'test'): + try: + ustr.encode(encoding) + except LookupError: + return False + return True + + def is_ip_hostname(hostname): it = iter(hostname) if next(it) == '[':