Try to detect the encoding set by the user

pull/104/head
Sheng 2019-10-06 15:18:23 +08:00
parent afcf8b52cc
commit 049baad909
3 changed files with 57 additions and 28 deletions

View File

@ -51,16 +51,24 @@ class Server(paramiko.ServerInterface):
b'UWT10hcuO4Ks8=') b'UWT10hcuO4Ks8=')
good_pub_key = paramiko.RSAKey(data=decodebytes(data)) good_pub_key = paramiko.RSAKey(data=decodebytes(data))
commands = [
b'$SHELL -ilc "locale charmap"',
b'$SHELL -ic "locale charmap"'
]
encodings = ['UTF-8', 'GBK', 'UTF-8\r\n', 'GBK\r\n'] encodings = ['UTF-8', 'GBK', 'UTF-8\r\n', 'GBK\r\n']
def __init__(self, encoding=None): def __init__(self, encodings=[]):
self.shell_event = threading.Event() self.shell_event = threading.Event()
self.exec_event = threading.Event() self.exec_event = threading.Event()
self.encoding = random.choice(self.encodings) self.cmd_to_enc = self.get_cmd2enc(encodings)
self.bad_encoding = encoding
self.password_verified = False self.password_verified = False
self.key_verified = False self.key_verified = False
def get_cmd2enc(self, encodings):
while len(encodings) < 2:
encodings.append(random.choice(self.encodings))
return dict(zip(self.commands, encodings[0:2]))
def check_channel_request(self, kind, chanid): def check_channel_request(self, kind, chanid):
if kind == 'session': if kind == 'session':
return paramiko.OPEN_SUCCEEDED return paramiko.OPEN_SUCCEEDED
@ -123,11 +131,12 @@ class Server(paramiko.ServerInterface):
return 'password,publickey' return 'password,publickey'
def check_channel_exec_request(self, channel, command): def check_channel_exec_request(self, channel, command):
if command != b'locale charmap': if command not in self.commands:
ret = False ret = False
else: else:
ret = True ret = True
channel.send(self.bad_encoding or self.encoding) self.encoding = self.cmd_to_enc[command]
channel.send(self.encoding)
channel.shutdown(1) channel.shutdown(1)
self.exec_event.set() self.exec_event.set()
return ret return ret
@ -146,7 +155,7 @@ class Server(paramiko.ServerInterface):
return True return True
def run_ssh_server(port=2200, running=True, encoding=None): def run_ssh_server(port=2200, running=True, encodings=[]):
# now connect # now connect
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
@ -160,7 +169,7 @@ def run_ssh_server(port=2200, running=True, encoding=None):
t = paramiko.Transport(client) t = paramiko.Transport(client)
t.load_server_moduli() t.load_server_moduli()
t.add_server_key(host_key) t.add_server_key(host_key)
server = Server(encoding) server = Server(encodings)
try: try:
t.start_server(server=server) t.start_server(server=server)
except Exception as e: except Exception as e:
@ -188,7 +197,12 @@ def run_ssh_server(port=2200, running=True, encoding=None):
# chan.send('\r\n\r\nWelcome!\r\n\r\n') # chan.send('\r\n\r\nWelcome!\r\n\r\n')
print(server.encoding) print(server.encoding)
chan.send(banner.encode(server.encoding.strip())) try:
banner_encoded = banner.encode(server.encoding)
except (ValueError, LookupError):
continue
chan.send(banner_encoded)
if username == 'bar': if username == 'bar':
msg = chan.recv(1024) msg = chan.recv(1024)
chan.send(msg) chan.send(msg)

View File

@ -515,7 +515,7 @@ class OtherTestBase(TestAppBase):
tdstream = '' tdstream = ''
maxconn = 20 maxconn = 20
origin = 'same' origin = 'same'
encoding = '' encodings = []
body = { body = {
'hostname': '127.0.0.1', 'hostname': '127.0.0.1',
'port': '', 'port': '',
@ -545,7 +545,7 @@ class OtherTestBase(TestAppBase):
t = threading.Thread( t = threading.Thread(
target=run_ssh_server, target=run_ssh_server,
args=(self.sshserver_port, self.running, self.encoding) args=(self.sshserver_port, self.running, self.encodings)
) )
t.setDaemon(True) t.setDaemon(True)
t.start() t.start()
@ -768,19 +768,24 @@ class TestAppWithCrossOriginOperation(OtherTestBase):
class TestAppWithBadEncoding(OtherTestBase): class TestAppWithBadEncoding(OtherTestBase):
encoding = b'\xe7\xbc\x96\xe7\xa0\x81' encodings = [u'\u7f16\u7801']
@tornado.testing.gen_test @tornado.testing.gen_test
def test_app_with_a_bad_encoding(self): def test_app_with_a_bad_encoding(self):
response = yield self.async_post('/', self.body) response = yield self.async_post('/', self.body)
self.assertIn(b'Bad encoding', response.body) dic = json.loads(to_str(response.body))
self.assert_status_none(dic)
self.assertIn(dic['encoding'], ['UTF-8', 'GBK'])
class TestAppWithUnknownEncoding(OtherTestBase): class TestAppWithUnknownEncoding(OtherTestBase):
encoding = u'UnknownEncoding' encodings = [u'\u7f16\u7801', u'UnknownEncoding']
@tornado.testing.gen_test @tornado.testing.gen_test
def test_app_with_a_bad_encoding(self): def test_app_with_a_unknown_encoding(self):
response = yield self.async_post('/', self.body) response = yield self.async_post('/', self.body)
self.assertIn(b'Unknown encoding', response.body) self.assert_status_none(json.loads(to_str(response.body)))
dic = json.loads(to_str(response.body))
self.assert_status_none(dic)
self.assertEqual(dic['encoding'], 'utf-8')

View File

@ -390,21 +390,31 @@ class IndexHandler(MixinHandler, tornado.web.RequestHandler):
return args return args
def get_default_encoding(self, ssh): def parse_encoding(self, data):
try: try:
_, stdout, _ = ssh.exec_command('locale charmap') encoding = to_str(data, 'ascii')
except paramiko.SSHException as exc:
logging.warn(str(exc))
return u'utf-8'
try:
enc = to_str(stdout.read().strip(), 'ascii')
except UnicodeDecodeError: except UnicodeDecodeError:
raise ValueError('Bad encoding') return
else:
if not is_valid_encoding(enc): if is_valid_encoding(encoding):
raise ValueError('Unknown encoding "{}"'.format(enc)) return encoding
return enc
def get_default_encoding(self, ssh):
commands = [
'$SHELL -ilc "locale charmap"',
'$SHELL -ic "locale charmap"'
]
for command in commands:
_, stdout, _ = ssh.exec_command(command, get_pty=True)
data = stdout.read().strip()
logging.debug('encoding: {}'.format(data))
result = self.parse_encoding(data)
if result:
return result
logging.warn('Could not detect the default ecnoding.')
return 'utf-8'
def ssh_connect(self, args): def ssh_connect(self, args):
ssh = self.ssh_client ssh = self.ssh_client