mirror of https://github.com/huashengdun/webssh
Added to_bytes function to utils
parent
cb86682551
commit
e85ae1692e
|
@ -13,6 +13,7 @@ from tests.sshserver import run_ssh_server, banner
|
|||
from tests.utils import encode_multipart_formdata, read_file
|
||||
from webssh.main import make_app, make_handlers
|
||||
from webssh.settings import get_app_settings, max_body_size, base_dir
|
||||
from webssh.utils import to_str
|
||||
|
||||
|
||||
handler.DELAY = 0.1
|
||||
|
@ -22,7 +23,7 @@ class TestApp(AsyncHTTPTestCase):
|
|||
|
||||
running = [True]
|
||||
sshserver_port = 2200
|
||||
body = u'hostname=127.0.0.1&port={}&username=robey&password=foo'.format(sshserver_port) # noqa
|
||||
body = 'hostname=127.0.0.1&port={}&username=robey&password=foo'.format(sshserver_port) # noqa
|
||||
body_dict = {
|
||||
'hostname': '127.0.0.1',
|
||||
'port': str(sshserver_port),
|
||||
|
@ -61,37 +62,37 @@ class TestApp(AsyncHTTPTestCase):
|
|||
def test_app_with_invalid_form(self):
|
||||
response = self.fetch('/')
|
||||
self.assertEqual(response.code, 200)
|
||||
body = u'hostname=&port=&username=&password'
|
||||
body = 'hostname=&port=&username=&password'
|
||||
response = self.fetch('/', method="POST", body=body)
|
||||
self.assertIn(b'"status": "Empty hostname"', response.body)
|
||||
|
||||
body = u'hostname=127.0.0.1&port=&username=&password'
|
||||
body = 'hostname=127.0.0.1&port=&username=&password'
|
||||
response = self.fetch('/', method="POST", body=body)
|
||||
self.assertIn(b'"status": "Empty port"', response.body)
|
||||
|
||||
body = u'hostname=127.0.0.1&port=port&username=&password'
|
||||
body = 'hostname=127.0.0.1&port=port&username=&password'
|
||||
response = self.fetch('/', method="POST", body=body)
|
||||
self.assertIn(b'"status": "Invalid port', response.body)
|
||||
|
||||
body = u'hostname=127.0.0.1&port=70000&username=&password'
|
||||
body = 'hostname=127.0.0.1&port=70000&username=&password'
|
||||
response = self.fetch('/', method="POST", body=body)
|
||||
self.assertIn(b'"status": "Invalid port', response.body)
|
||||
|
||||
body = u'hostname=127.0.0.1&port=7000&username=&password'
|
||||
body = 'hostname=127.0.0.1&port=7000&username=&password'
|
||||
response = self.fetch('/', method="POST", body=body)
|
||||
self.assertIn(b'"status": "Empty username"', response.body)
|
||||
|
||||
def test_app_with_wrong_credentials(self):
|
||||
response = self.fetch('/')
|
||||
self.assertEqual(response.code, 200)
|
||||
response = self.fetch('/', method="POST", body=self.body + u's')
|
||||
response = self.fetch('/', method="POST", body=self.body + 's')
|
||||
self.assertIn(b'Authentication failed.', response.body)
|
||||
|
||||
def test_app_with_correct_credentials(self):
|
||||
response = self.fetch('/')
|
||||
self.assertEqual(response.code, 200)
|
||||
response = self.fetch('/', method="POST", body=self.body)
|
||||
data = json.loads(response.body.decode('utf-8'))
|
||||
data = json.loads(to_str(response.body))
|
||||
self.assertIsNone(data['status'])
|
||||
self.assertIsNotNone(data['id'])
|
||||
self.assertIsNotNone(data['encoding'])
|
||||
|
@ -104,7 +105,7 @@ class TestApp(AsyncHTTPTestCase):
|
|||
self.assertEqual(response.code, 200)
|
||||
|
||||
response = yield client.fetch(url, method="POST", body=self.body)
|
||||
data = json.loads(response.body.decode('utf-8'))
|
||||
data = json.loads(to_str(response.body))
|
||||
self.assertIsNone(data['status'])
|
||||
self.assertIsNotNone(data['id'])
|
||||
self.assertIsNotNone(data['encoding'])
|
||||
|
@ -133,7 +134,7 @@ class TestApp(AsyncHTTPTestCase):
|
|||
}
|
||||
response = yield client.fetch(url, method="POST", headers=headers,
|
||||
body=body)
|
||||
data = json.loads(response.body.decode('utf-8'))
|
||||
data = json.loads(to_str(response.body))
|
||||
self.assertIsNone(data['status'])
|
||||
self.assertIsNotNone(data['id'])
|
||||
self.assertIsNotNone(data['encoding'])
|
||||
|
@ -142,7 +143,7 @@ class TestApp(AsyncHTTPTestCase):
|
|||
ws_url = url + 'ws?id=' + data['id']
|
||||
ws = yield tornado.websocket.websocket_connect(ws_url)
|
||||
msg = yield ws.read_message()
|
||||
self.assertEqual(msg.decode(data['encoding']), banner)
|
||||
self.assertEqual(to_str(msg, data['encoding']), banner)
|
||||
ws.close()
|
||||
|
||||
@tornado.testing.gen_test
|
||||
|
@ -153,7 +154,7 @@ class TestApp(AsyncHTTPTestCase):
|
|||
self.assertEqual(response.code, 200)
|
||||
|
||||
privatekey = read_file(os.path.join(base_dir, 'tests', 'user_rsa_key'))
|
||||
privatekey = privatekey[:100] + u'bad' + privatekey[100:]
|
||||
privatekey = privatekey[:100] + 'bad' + privatekey[100:]
|
||||
files = [('privatekey', 'user_rsa_key', privatekey)]
|
||||
content_type, body = encode_multipart_formdata(self.body_dict.items(),
|
||||
files)
|
||||
|
@ -162,7 +163,7 @@ class TestApp(AsyncHTTPTestCase):
|
|||
}
|
||||
response = yield client.fetch(url, method="POST", headers=headers,
|
||||
body=body)
|
||||
data = json.loads(response.body.decode('utf-8'))
|
||||
data = json.loads(to_str(response.body))
|
||||
self.assertIsNotNone(data['status'])
|
||||
self.assertIsNone(data['id'])
|
||||
self.assertIsNone(data['encoding'])
|
||||
|
@ -174,7 +175,7 @@ class TestApp(AsyncHTTPTestCase):
|
|||
response = yield client.fetch(url)
|
||||
self.assertEqual(response.code, 200)
|
||||
|
||||
privatekey = u'h' * (2 * max_body_size)
|
||||
privatekey = 'h' * (2 * max_body_size)
|
||||
files = [('privatekey', 'user_rsa_key', privatekey)]
|
||||
content_type, body = encode_multipart_formdata(self.body_dict.items(),
|
||||
files)
|
||||
|
@ -193,7 +194,7 @@ class TestApp(AsyncHTTPTestCase):
|
|||
self.assertEqual(response.code, 200)
|
||||
|
||||
response = yield client.fetch(url, method="POST", body=self.body)
|
||||
data = json.loads(response.body.decode('utf-8'))
|
||||
data = json.loads(to_str(response.body))
|
||||
self.assertIsNone(data['status'])
|
||||
self.assertIsNotNone(data['id'])
|
||||
self.assertIsNotNone(data['encoding'])
|
||||
|
@ -202,7 +203,7 @@ class TestApp(AsyncHTTPTestCase):
|
|||
ws_url = url + 'ws?id=' + data['id']
|
||||
ws = yield tornado.websocket.websocket_connect(ws_url)
|
||||
msg = yield ws.read_message()
|
||||
self.assertEqual(msg.decode(data['encoding']), banner)
|
||||
self.assertEqual(to_str(msg, data['encoding']), banner)
|
||||
ws.close()
|
||||
|
||||
@tornado.testing.gen_test
|
||||
|
@ -214,7 +215,7 @@ class TestApp(AsyncHTTPTestCase):
|
|||
|
||||
body = self.body.replace('robey', 'bar')
|
||||
response = yield client.fetch(url, method="POST", body=body)
|
||||
data = json.loads(response.body.decode('utf-8'))
|
||||
data = json.loads(to_str(response.body))
|
||||
self.assertIsNone(data['status'])
|
||||
self.assertIsNotNone(data['id'])
|
||||
self.assertIsNotNone(data['encoding'])
|
||||
|
@ -223,7 +224,7 @@ class TestApp(AsyncHTTPTestCase):
|
|||
ws_url = url + 'ws?id=' + data['id']
|
||||
ws = yield tornado.websocket.websocket_connect(ws_url)
|
||||
msg = yield ws.read_message()
|
||||
self.assertEqual(msg.decode(data['encoding']), banner)
|
||||
self.assertEqual(to_str(msg, data['encoding']), banner)
|
||||
|
||||
# messages below will be ignored silently
|
||||
yield ws.write_message('hello')
|
||||
|
|
|
@ -56,7 +56,7 @@ class TestIndexHandler(unittest.TestCase):
|
|||
key = read_file(os.path.join(base_dir, 'tests', fname))
|
||||
pkey = IndexHandler.get_specific_pkey(cls, key, None)
|
||||
self.assertIsInstance(pkey, cls)
|
||||
pkey = IndexHandler.get_specific_pkey(cls, key, b'iginored')
|
||||
pkey = IndexHandler.get_specific_pkey(cls, key, 'iginored')
|
||||
self.assertIsInstance(pkey, cls)
|
||||
pkey = IndexHandler.get_specific_pkey(cls, 'x'+key, None)
|
||||
self.assertIsNone(pkey)
|
||||
|
@ -64,7 +64,7 @@ class TestIndexHandler(unittest.TestCase):
|
|||
def test_get_specific_pkey_with_encrypted_key(self):
|
||||
fname = 'test_rsa_password.key'
|
||||
cls = paramiko.RSAKey
|
||||
password = b'television'
|
||||
password = 'television'
|
||||
|
||||
key = read_file(os.path.join(base_dir, 'tests', fname))
|
||||
pkey = IndexHandler.get_specific_pkey(cls, key, password)
|
||||
|
@ -81,7 +81,7 @@ class TestIndexHandler(unittest.TestCase):
|
|||
key = read_file(os.path.join(base_dir, 'tests', fname))
|
||||
pkey = IndexHandler.get_pkey_obj(key, None)
|
||||
self.assertIsInstance(pkey, cls)
|
||||
pkey = IndexHandler.get_pkey_obj(key, u'iginored')
|
||||
pkey = IndexHandler.get_pkey_obj(key, 'iginored')
|
||||
self.assertIsInstance(pkey, cls)
|
||||
with self.assertRaises(ValueError):
|
||||
pkey = IndexHandler.get_pkey_obj('x'+key, None)
|
||||
|
@ -94,6 +94,6 @@ class TestIndexHandler(unittest.TestCase):
|
|||
pkey = IndexHandler.get_pkey_obj(key, password)
|
||||
self.assertIsInstance(pkey, cls)
|
||||
with self.assertRaises(ValueError):
|
||||
pkey = IndexHandler.get_pkey_obj(key, u'wrongpass')
|
||||
pkey = IndexHandler.get_pkey_obj(key, 'wrongpass')
|
||||
with self.assertRaises(ValueError):
|
||||
pkey = IndexHandler.get_pkey_obj('x'+key, password)
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
import unittest
|
||||
|
||||
from webssh.utils import (is_valid_ipv4_address, is_valid_ipv6_address,
|
||||
is_valid_port, to_str)
|
||||
is_valid_port, to_str, to_bytes)
|
||||
|
||||
|
||||
class TestUitls(unittest.TestCase):
|
||||
|
@ -12,6 +12,12 @@ class TestUitls(unittest.TestCase):
|
|||
self.assertEqual(to_str(b), u)
|
||||
self.assertEqual(to_str(u), u)
|
||||
|
||||
def test_to_bytes(self):
|
||||
b = b'hello'
|
||||
u = u'hello'
|
||||
self.assertEqual(to_bytes(b), b)
|
||||
self.assertEqual(to_bytes(u), b)
|
||||
|
||||
def test_is_valid_ipv4_address(self):
|
||||
self.assertFalse(is_valid_ipv4_address('127.0.0'))
|
||||
self.assertFalse(is_valid_ipv4_address(b'127.0.0'))
|
||||
|
|
|
@ -10,10 +10,9 @@ import paramiko
|
|||
import tornado.web
|
||||
|
||||
from tornado.ioloop import IOLoop
|
||||
from tornado.util import basestring_type
|
||||
from webssh.worker import Worker, recycle_worker, workers
|
||||
from webssh.utils import (is_valid_ipv4_address, is_valid_ipv6_address,
|
||||
is_valid_port)
|
||||
is_valid_port, to_bytes, to_str, UnicodeType)
|
||||
|
||||
try:
|
||||
from concurrent.futures import Future
|
||||
|
@ -70,7 +69,7 @@ class IndexHandler(MixinHandler, tornado.web.RequestHandler):
|
|||
data = self.request.files.get('privatekey')[0]['body']
|
||||
except TypeError:
|
||||
return
|
||||
return data.decode('utf-8')
|
||||
return to_str(data)
|
||||
|
||||
@classmethod
|
||||
def get_specific_pkey(cls, pkeycls, privatekey, password):
|
||||
|
@ -87,7 +86,7 @@ class IndexHandler(MixinHandler, tornado.web.RequestHandler):
|
|||
|
||||
@classmethod
|
||||
def get_pkey_obj(cls, privatekey, password):
|
||||
password = password.encode('utf-8') if password else None
|
||||
password = to_bytes(password)
|
||||
|
||||
pkey = cls.get_specific_pkey(paramiko.RSAKey, privatekey, password)\
|
||||
or cls.get_specific_pkey(paramiko.DSSKey, privatekey, password)\
|
||||
|
@ -138,8 +137,8 @@ class IndexHandler(MixinHandler, tornado.web.RequestHandler):
|
|||
except paramiko.SSHException:
|
||||
result = None
|
||||
else:
|
||||
data = stdout.read().decode('utf-8')
|
||||
result = parse_encoding(data)
|
||||
data = stdout.read()
|
||||
result = parse_encoding(to_str(data))
|
||||
|
||||
return result if result else 'utf-8'
|
||||
|
||||
|
@ -247,7 +246,7 @@ class WsockHandler(MixinHandler, tornado.websocket.WebSocketHandler):
|
|||
pass
|
||||
|
||||
data = msg.get('data')
|
||||
if data and isinstance(data, basestring_type):
|
||||
if data and isinstance(data, UnicodeType):
|
||||
worker.data_to_dst.append(data)
|
||||
worker.on_write()
|
||||
|
||||
|
|
|
@ -1,10 +1,21 @@
|
|||
import ipaddress
|
||||
|
||||
try:
|
||||
from types import UnicodeType
|
||||
except ImportError:
|
||||
UnicodeType = str
|
||||
|
||||
def to_str(s):
|
||||
if isinstance(s, bytes):
|
||||
return s.decode('utf-8')
|
||||
return s
|
||||
|
||||
def to_str(bstr, encoding='utf-8'):
|
||||
if isinstance(bstr, bytes):
|
||||
return bstr.decode(encoding)
|
||||
return bstr
|
||||
|
||||
|
||||
def to_bytes(ustr, encoding='utf-8'):
|
||||
if isinstance(ustr, UnicodeType):
|
||||
return ustr.encode(encoding)
|
||||
return ustr
|
||||
|
||||
|
||||
def is_valid_ipv4_address(ipstr):
|
||||
|
|
Loading…
Reference in New Issue