Validate the result of locale charmap

pull/104/head
Sheng 2019-10-05 13:18:53 +08:00
parent 1fe361f601
commit afcf8b52cc
4 changed files with 49 additions and 11 deletions

View File

@ -53,10 +53,11 @@ class Server(paramiko.ServerInterface):
encodings = ['UTF-8', 'GBK', 'UTF-8\r\n', 'GBK\r\n'] encodings = ['UTF-8', 'GBK', 'UTF-8\r\n', 'GBK\r\n']
def __init__(self): def __init__(self, encoding=None):
self.shell_event = threading.Event() self.shell_event = threading.Event()
self.exec_event = threading.Event() self.exec_event = threading.Event()
self.encoding = random.choice(self.encodings) self.encoding = random.choice(self.encodings)
self.bad_encoding = encoding
self.password_verified = False self.password_verified = False
self.key_verified = False self.key_verified = False
@ -126,7 +127,7 @@ class Server(paramiko.ServerInterface):
ret = False ret = False
else: else:
ret = True ret = True
channel.send(self.encoding) channel.send(self.bad_encoding or self.encoding)
channel.shutdown(1) channel.shutdown(1)
self.exec_event.set() self.exec_event.set()
return ret return ret
@ -145,7 +146,7 @@ class Server(paramiko.ServerInterface):
return True return True
def run_ssh_server(port=2200, running=True): def run_ssh_server(port=2200, running=True, encoding=None):
# now connect # now connect
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
@ -159,7 +160,7 @@ def run_ssh_server(port=2200, running=True):
t = paramiko.Transport(client) t = paramiko.Transport(client)
t.load_server_moduli() t.load_server_moduli()
t.add_server_key(host_key) t.add_server_key(host_key)
server = Server() server = Server(encoding)
try: try:
t.start_server(server=server) t.start_server(server=server)
except Exception as e: except Exception as e:

View File

@ -515,6 +515,7 @@ class OtherTestBase(TestAppBase):
tdstream = '' tdstream = ''
maxconn = 20 maxconn = 20
origin = 'same' origin = 'same'
encoding = ''
body = { body = {
'hostname': '127.0.0.1', 'hostname': '127.0.0.1',
'port': '', 'port': '',
@ -543,7 +544,8 @@ class OtherTestBase(TestAppBase):
OtherTestBase.sshserver_port += 1 OtherTestBase.sshserver_port += 1
t = threading.Thread( t = threading.Thread(
target=run_ssh_server, args=(self.sshserver_port, self.running) target=run_ssh_server,
args=(self.sshserver_port, self.running, self.encoding)
) )
t.setDaemon(True) t.setDaemon(True)
t.start() t.start()
@ -762,3 +764,23 @@ class TestAppWithCrossOriginOperation(OtherTestBase):
self.assertEqual( self.assertEqual(
response.headers.get('Access-Control-Allow-Origin'), self.origin response.headers.get('Access-Control-Allow-Origin'), self.origin
) )
class TestAppWithBadEncoding(OtherTestBase):
encoding = b'\xe7\xbc\x96\xe7\xa0\x81'
@tornado.testing.gen_test
def test_app_with_a_bad_encoding(self):
response = yield self.async_post('/', self.body)
self.assertIn(b'Bad encoding', response.body)
class TestAppWithUnknownEncoding(OtherTestBase):
encoding = u'UnknownEncoding'
@tornado.testing.gen_test
def test_app_with_a_bad_encoding(self):
response = yield self.async_post('/', self.body)
self.assertIn(b'Unknown encoding', response.body)

View File

@ -14,7 +14,8 @@ from tornado.options import options
from tornado.process import cpu_count from tornado.process import cpu_count
from webssh.utils import ( from webssh.utils import (
is_valid_ip_address, is_valid_port, is_valid_hostname, to_bytes, to_str, is_valid_ip_address, is_valid_port, is_valid_hostname, to_bytes, to_str,
to_int, to_ip_address, UnicodeType, is_ip_hostname, is_same_primary_domain to_int, to_ip_address, UnicodeType, is_ip_hostname, is_same_primary_domain,
is_valid_encoding
) )
from webssh.worker import Worker, recycle_worker, clients from webssh.worker import Worker, recycle_worker, clients
@ -392,12 +393,18 @@ class IndexHandler(MixinHandler, tornado.web.RequestHandler):
def get_default_encoding(self, ssh): def get_default_encoding(self, ssh):
try: try:
_, stdout, _ = ssh.exec_command('locale charmap') _, stdout, _ = ssh.exec_command('locale charmap')
except paramiko.SSHException: except paramiko.SSHException as exc:
result = None logging.warn(str(exc))
else: return u'utf-8'
result = to_str(stdout.read().strip())
return result if result else 'utf-8' try:
enc = to_str(stdout.read().strip(), 'ascii')
except UnicodeDecodeError:
raise ValueError('Bad encoding')
else:
if not is_valid_encoding(enc):
raise ValueError('Unknown encoding "{}"'.format(enc))
return enc
def ssh_connect(self, args): def ssh_connect(self, args):
ssh = self.ssh_client ssh = self.ssh_client

View File

@ -54,6 +54,14 @@ def is_valid_port(port):
return 0 < port < 65536 return 0 < port < 65536
def is_valid_encoding(encoding, ustr=u'test'):
try:
ustr.encode(encoding)
except LookupError:
return False
return True
def is_ip_hostname(hostname): def is_ip_hostname(hostname):
it = iter(hostname) it = iter(hostname)
if next(it) == '[': if next(it) == '[':