mirror of https://github.com/huashengdun/webssh
Validate the result of locale charmap
parent
1fe361f601
commit
afcf8b52cc
|
@ -53,10 +53,11 @@ class Server(paramiko.ServerInterface):
|
||||||
|
|
||||||
encodings = ['UTF-8', 'GBK', 'UTF-8\r\n', 'GBK\r\n']
|
encodings = ['UTF-8', 'GBK', 'UTF-8\r\n', 'GBK\r\n']
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self, encoding=None):
|
||||||
self.shell_event = threading.Event()
|
self.shell_event = threading.Event()
|
||||||
self.exec_event = threading.Event()
|
self.exec_event = threading.Event()
|
||||||
self.encoding = random.choice(self.encodings)
|
self.encoding = random.choice(self.encodings)
|
||||||
|
self.bad_encoding = encoding
|
||||||
self.password_verified = False
|
self.password_verified = False
|
||||||
self.key_verified = False
|
self.key_verified = False
|
||||||
|
|
||||||
|
@ -126,7 +127,7 @@ class Server(paramiko.ServerInterface):
|
||||||
ret = False
|
ret = False
|
||||||
else:
|
else:
|
||||||
ret = True
|
ret = True
|
||||||
channel.send(self.encoding)
|
channel.send(self.bad_encoding or self.encoding)
|
||||||
channel.shutdown(1)
|
channel.shutdown(1)
|
||||||
self.exec_event.set()
|
self.exec_event.set()
|
||||||
return ret
|
return ret
|
||||||
|
@ -145,7 +146,7 @@ class Server(paramiko.ServerInterface):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
def run_ssh_server(port=2200, running=True):
|
def run_ssh_server(port=2200, running=True, encoding=None):
|
||||||
# now connect
|
# now connect
|
||||||
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||||
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
||||||
|
@ -159,7 +160,7 @@ def run_ssh_server(port=2200, running=True):
|
||||||
t = paramiko.Transport(client)
|
t = paramiko.Transport(client)
|
||||||
t.load_server_moduli()
|
t.load_server_moduli()
|
||||||
t.add_server_key(host_key)
|
t.add_server_key(host_key)
|
||||||
server = Server()
|
server = Server(encoding)
|
||||||
try:
|
try:
|
||||||
t.start_server(server=server)
|
t.start_server(server=server)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
|
@ -515,6 +515,7 @@ class OtherTestBase(TestAppBase):
|
||||||
tdstream = ''
|
tdstream = ''
|
||||||
maxconn = 20
|
maxconn = 20
|
||||||
origin = 'same'
|
origin = 'same'
|
||||||
|
encoding = ''
|
||||||
body = {
|
body = {
|
||||||
'hostname': '127.0.0.1',
|
'hostname': '127.0.0.1',
|
||||||
'port': '',
|
'port': '',
|
||||||
|
@ -543,7 +544,8 @@ class OtherTestBase(TestAppBase):
|
||||||
OtherTestBase.sshserver_port += 1
|
OtherTestBase.sshserver_port += 1
|
||||||
|
|
||||||
t = threading.Thread(
|
t = threading.Thread(
|
||||||
target=run_ssh_server, args=(self.sshserver_port, self.running)
|
target=run_ssh_server,
|
||||||
|
args=(self.sshserver_port, self.running, self.encoding)
|
||||||
)
|
)
|
||||||
t.setDaemon(True)
|
t.setDaemon(True)
|
||||||
t.start()
|
t.start()
|
||||||
|
@ -762,3 +764,23 @@ class TestAppWithCrossOriginOperation(OtherTestBase):
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
response.headers.get('Access-Control-Allow-Origin'), self.origin
|
response.headers.get('Access-Control-Allow-Origin'), self.origin
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestAppWithBadEncoding(OtherTestBase):
|
||||||
|
|
||||||
|
encoding = b'\xe7\xbc\x96\xe7\xa0\x81'
|
||||||
|
|
||||||
|
@tornado.testing.gen_test
|
||||||
|
def test_app_with_a_bad_encoding(self):
|
||||||
|
response = yield self.async_post('/', self.body)
|
||||||
|
self.assertIn(b'Bad encoding', response.body)
|
||||||
|
|
||||||
|
|
||||||
|
class TestAppWithUnknownEncoding(OtherTestBase):
|
||||||
|
|
||||||
|
encoding = u'UnknownEncoding'
|
||||||
|
|
||||||
|
@tornado.testing.gen_test
|
||||||
|
def test_app_with_a_bad_encoding(self):
|
||||||
|
response = yield self.async_post('/', self.body)
|
||||||
|
self.assertIn(b'Unknown encoding', response.body)
|
||||||
|
|
|
@ -14,7 +14,8 @@ from tornado.options import options
|
||||||
from tornado.process import cpu_count
|
from tornado.process import cpu_count
|
||||||
from webssh.utils import (
|
from webssh.utils import (
|
||||||
is_valid_ip_address, is_valid_port, is_valid_hostname, to_bytes, to_str,
|
is_valid_ip_address, is_valid_port, is_valid_hostname, to_bytes, to_str,
|
||||||
to_int, to_ip_address, UnicodeType, is_ip_hostname, is_same_primary_domain
|
to_int, to_ip_address, UnicodeType, is_ip_hostname, is_same_primary_domain,
|
||||||
|
is_valid_encoding
|
||||||
)
|
)
|
||||||
from webssh.worker import Worker, recycle_worker, clients
|
from webssh.worker import Worker, recycle_worker, clients
|
||||||
|
|
||||||
|
@ -392,12 +393,18 @@ class IndexHandler(MixinHandler, tornado.web.RequestHandler):
|
||||||
def get_default_encoding(self, ssh):
|
def get_default_encoding(self, ssh):
|
||||||
try:
|
try:
|
||||||
_, stdout, _ = ssh.exec_command('locale charmap')
|
_, stdout, _ = ssh.exec_command('locale charmap')
|
||||||
except paramiko.SSHException:
|
except paramiko.SSHException as exc:
|
||||||
result = None
|
logging.warn(str(exc))
|
||||||
else:
|
return u'utf-8'
|
||||||
result = to_str(stdout.read().strip())
|
|
||||||
|
|
||||||
return result if result else 'utf-8'
|
try:
|
||||||
|
enc = to_str(stdout.read().strip(), 'ascii')
|
||||||
|
except UnicodeDecodeError:
|
||||||
|
raise ValueError('Bad encoding')
|
||||||
|
else:
|
||||||
|
if not is_valid_encoding(enc):
|
||||||
|
raise ValueError('Unknown encoding "{}"'.format(enc))
|
||||||
|
return enc
|
||||||
|
|
||||||
def ssh_connect(self, args):
|
def ssh_connect(self, args):
|
||||||
ssh = self.ssh_client
|
ssh = self.ssh_client
|
||||||
|
|
|
@ -54,6 +54,14 @@ def is_valid_port(port):
|
||||||
return 0 < port < 65536
|
return 0 < port < 65536
|
||||||
|
|
||||||
|
|
||||||
|
def is_valid_encoding(encoding, ustr=u'test'):
|
||||||
|
try:
|
||||||
|
ustr.encode(encoding)
|
||||||
|
except LookupError:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
def is_ip_hostname(hostname):
|
def is_ip_hostname(hostname):
|
||||||
it = iter(hostname)
|
it = iter(hostname)
|
||||||
if next(it) == '[':
|
if next(it) == '[':
|
||||||
|
|
Loading…
Reference in New Issue