mirror of https://github.com/huashengdun/webssh
Added PrivateKey class
parent
6d62642c7f
commit
2b8b978ca2
|
@ -358,7 +358,7 @@ class TestAppBasic(TestAppBase):
|
|||
|
||||
if swallow_http_errors:
|
||||
response = yield self.async_post(url, body, headers=headers)
|
||||
self.assertIn(b'Invalid private key', response.body)
|
||||
self.assertIn(b'Invalid key', response.body)
|
||||
else:
|
||||
with self.assertRaises(HTTPError) as ctx:
|
||||
yield self.async_post(url, body, headers=headers)
|
||||
|
@ -367,7 +367,7 @@ class TestAppBasic(TestAppBase):
|
|||
@tornado.testing.gen_test
|
||||
def test_app_auth_with_pubkey_exceeds_key_max_size(self):
|
||||
url = self.get_url('/')
|
||||
privatekey = 'h' * (handler.KEY_MAX_SIZE * 2)
|
||||
privatekey = 'h' * (handler.PrivateKey.max_length + 1)
|
||||
files = [('privatekey', 'user_rsa_key', privatekey)]
|
||||
content_type, body = encode_multipart_formdata(self.body_dict.items(),
|
||||
files)
|
||||
|
@ -376,7 +376,7 @@ class TestAppBasic(TestAppBase):
|
|||
}
|
||||
if swallow_http_errors:
|
||||
response = yield self.async_post(url, body, headers=headers)
|
||||
self.assertIn(b'Invalid private key', response.body)
|
||||
self.assertIn(b'Invalid key', response.body)
|
||||
else:
|
||||
with self.assertRaises(HTTPError) as ctx:
|
||||
yield self.async_post(url, body, headers=headers)
|
||||
|
|
|
@ -6,7 +6,7 @@ from tornado.options import options
|
|||
from tests.utils import read_file, make_tests_data_path
|
||||
from webssh import handler
|
||||
from webssh.handler import (
|
||||
MixinHandler, IndexHandler, WsockHandler, InvalidValueError
|
||||
MixinHandler, WsockHandler, PrivateKey, InvalidValueError
|
||||
)
|
||||
|
||||
try:
|
||||
|
@ -142,73 +142,59 @@ class TestMixinHandler(unittest.TestCase):
|
|||
(x_real_ip, x_real_port))
|
||||
|
||||
|
||||
class TestIndexHandler(unittest.TestCase):
|
||||
class TestPrivateKey(unittest.TestCase):
|
||||
|
||||
def test_get_specific_pkey_with_plain_key(self):
|
||||
fname = 'test_rsa.key'
|
||||
cls = paramiko.RSAKey
|
||||
def get_pk_obj(self, fname, password=None):
|
||||
key = read_file(make_tests_data_path(fname))
|
||||
return PrivateKey(key, password=password, filename=fname)
|
||||
|
||||
pkey = IndexHandler.get_specific_pkey(cls, key, None)
|
||||
self.assertIsInstance(pkey, cls)
|
||||
|
||||
pkey = IndexHandler.get_specific_pkey(cls, key, 'iginored')
|
||||
self.assertIsInstance(pkey, cls)
|
||||
|
||||
pkey = IndexHandler.get_specific_pkey(cls, 'x'+key, None)
|
||||
self.assertIsNone(pkey)
|
||||
|
||||
def test_get_specific_pkey_with_encrypted_key(self):
|
||||
fname = 'test_rsa_password.key'
|
||||
cls = paramiko.RSAKey
|
||||
password = 'television'
|
||||
|
||||
key = read_file(make_tests_data_path(fname))
|
||||
pkey = IndexHandler.get_specific_pkey(cls, key, password)
|
||||
self.assertIsInstance(pkey, cls)
|
||||
|
||||
pkey = IndexHandler.get_specific_pkey(cls, 'x'+key, None)
|
||||
self.assertIsNone(pkey)
|
||||
|
||||
def _test_with_encrypted_key(self, fname, password, klass):
|
||||
pk = self.get_pk_obj(fname, password='')
|
||||
with self.assertRaises(InvalidValueError) as ctx:
|
||||
pkey = IndexHandler.get_specific_pkey(cls, key, None)
|
||||
pk.get_pkey_obj()
|
||||
self.assertIn('Need a password', str(ctx.exception))
|
||||
|
||||
def test_get_pkey_obj_with_plain_key(self):
|
||||
fname = 'test_ed25519.key'
|
||||
cls = paramiko.Ed25519Key
|
||||
key = read_file(make_tests_data_path(fname))
|
||||
pk = self.get_pk_obj(fname, password='wrongpass')
|
||||
with self.assertRaises(InvalidValueError) as ctx:
|
||||
pk.get_pkey_obj()
|
||||
self.assertIn('wrong password', str(ctx.exception))
|
||||
|
||||
pkey = IndexHandler.get_pkey_obj(key, None, fname)
|
||||
self.assertIsInstance(pkey, cls)
|
||||
pk = self.get_pk_obj(fname, password=password)
|
||||
self.assertIsInstance(pk.get_pkey_obj(), klass)
|
||||
|
||||
pkey = IndexHandler.get_pkey_obj(key, 'iginored', fname)
|
||||
self.assertIsInstance(pkey, cls)
|
||||
def test_class_with_invalid_key_length(self):
|
||||
key = u'a' * (PrivateKey.max_length + 1)
|
||||
|
||||
with self.assertRaises(InvalidValueError) as ctx:
|
||||
pkey = IndexHandler.get_pkey_obj('x'+key, None, fname)
|
||||
self.assertIn('Invalid private key', str(ctx.exception))
|
||||
PrivateKey(key)
|
||||
self.assertIn('Invalid key length', str(ctx.exception))
|
||||
|
||||
def test_get_pkey_obj_with_encrypted_key(self):
|
||||
def test_get_pkey_obj_with_invalid_key(self):
|
||||
key = u'a b c'
|
||||
fname = 'abc'
|
||||
|
||||
pk = PrivateKey(key, filename=fname)
|
||||
with self.assertRaises(InvalidValueError) as ctx:
|
||||
pk.get_pkey_obj()
|
||||
self.assertIn('Invalid key {}'.format(fname), str(ctx.exception))
|
||||
|
||||
def test_get_pkey_obj_with_plain_rsa_key(self):
|
||||
pk = self.get_pk_obj('test_rsa.key')
|
||||
self.assertIsInstance(pk.get_pkey_obj(), paramiko.RSAKey)
|
||||
|
||||
def test_get_pkey_obj_with_plain_ed25519_key(self):
|
||||
pk = self.get_pk_obj('test_ed25519.key')
|
||||
self.assertIsInstance(pk.get_pkey_obj(), paramiko.Ed25519Key)
|
||||
|
||||
def test_get_pkey_obj_with_encrypted_rsa_key(self):
|
||||
fname = 'test_rsa_password.key'
|
||||
password = 'television'
|
||||
self._test_with_encrypted_key(fname, password, paramiko.RSAKey)
|
||||
|
||||
def test_get_pkey_obj_with_encrypted_ed25519_key(self):
|
||||
fname = 'test_ed25519_password.key'
|
||||
password = 'abc123'
|
||||
cls = paramiko.Ed25519Key
|
||||
key = read_file(make_tests_data_path(fname))
|
||||
|
||||
pkey = IndexHandler.get_pkey_obj(key, password, fname)
|
||||
self.assertIsInstance(pkey, cls)
|
||||
|
||||
with self.assertRaises(InvalidValueError) as ctx:
|
||||
pkey = IndexHandler.get_pkey_obj(key, 'wrongpass', fname)
|
||||
self.assertIn('Wrong password', str(ctx.exception))
|
||||
|
||||
with self.assertRaises(InvalidValueError) as ctx:
|
||||
pkey = IndexHandler.get_pkey_obj('x'+key, '', fname)
|
||||
self.assertIn('Invalid private key', str(ctx.exception))
|
||||
|
||||
with self.assertRaises(InvalidValueError) as ctx:
|
||||
pkey = IndexHandler.get_specific_pkey(cls, key, None)
|
||||
self.assertIn('Need a password', str(ctx.exception))
|
||||
self._test_with_encrypted_key(fname, password, paramiko.Ed25519Key)
|
||||
|
||||
|
||||
class TestWsockHandler(unittest.TestCase):
|
||||
|
|
|
@ -30,7 +30,6 @@ except ImportError:
|
|||
|
||||
|
||||
DELAY = 3
|
||||
KEY_MAX_SIZE = 16384
|
||||
DEFAULT_PORT = 22
|
||||
|
||||
swallow_http_errors = True
|
||||
|
@ -41,6 +40,53 @@ class InvalidValueError(Exception):
|
|||
pass
|
||||
|
||||
|
||||
class PrivateKey(object):
|
||||
|
||||
max_length = 16384 # rough number
|
||||
|
||||
tag_to_name = {
|
||||
'RSA': 'RSA',
|
||||
'DSA': 'DSS',
|
||||
'EC': 'ECDSA',
|
||||
'OPENSSH': 'Ed25519'
|
||||
}
|
||||
|
||||
def __init__(self, privatekey, password=None, filename=''):
|
||||
self.privatekey = privatekey.strip()
|
||||
self.filename = filename
|
||||
self.password = password
|
||||
self.check_length()
|
||||
|
||||
def check_length(self):
|
||||
if len(self.privatekey) > self.max_length:
|
||||
raise InvalidValueError('Invalid key length.')
|
||||
|
||||
def get_name(self):
|
||||
lst = self.privatekey.split(' ', 2)
|
||||
if len(lst) > 1:
|
||||
return self.tag_to_name.get(lst[1])
|
||||
|
||||
def get_pkey_obj(self):
|
||||
name = self.get_name()
|
||||
if not name:
|
||||
raise InvalidValueError('Invalid key {}.'.format(self.filename))
|
||||
|
||||
logging.info('Parsing {} key'.format(name))
|
||||
pkeycls = getattr(paramiko, name+'Key')
|
||||
password = to_bytes(self.password) if self.password else None
|
||||
try:
|
||||
return pkeycls.from_private_key(io.StringIO(self.privatekey),
|
||||
password=password)
|
||||
except paramiko.PasswordRequiredException:
|
||||
raise InvalidValueError('Need a password to decrypt the key.')
|
||||
except paramiko.SSHException as exc:
|
||||
logging.error(str(exc))
|
||||
raise InvalidValueError(
|
||||
'Invalid key or wrong password "{}" for decrypting it.'
|
||||
.format(self.password)
|
||||
)
|
||||
|
||||
|
||||
class MixinHandler(object):
|
||||
|
||||
custom_headers = {
|
||||
|
@ -176,7 +222,6 @@ class IndexHandler(MixinHandler, tornado.web.RequestHandler):
|
|||
self.policy = policy
|
||||
self.host_keys_settings = host_keys_settings
|
||||
self.ssh_client = self.get_ssh_client()
|
||||
self.privatekey_filename = None
|
||||
self.debug = self.settings.get('debug', False)
|
||||
self.result = dict(id=None, status=None, encoding=None)
|
||||
|
||||
|
@ -206,53 +251,15 @@ class IndexHandler(MixinHandler, tornado.web.RequestHandler):
|
|||
lst = self.request.files.get(name)
|
||||
if lst:
|
||||
# multipart form
|
||||
self.privatekey_filename = lst[0]['filename']
|
||||
filename = lst[0]['filename']
|
||||
data = lst[0]['body']
|
||||
value = self.decode_argument(data, name=name).strip()
|
||||
else:
|
||||
# urlencoded form
|
||||
value = self.get_argument(name, u'')
|
||||
filename = ''
|
||||
|
||||
if len(value) > KEY_MAX_SIZE:
|
||||
raise InvalidValueError(
|
||||
'Invalid private key: {}'.format(self.privatekey_filename)
|
||||
)
|
||||
return value
|
||||
|
||||
@classmethod
|
||||
def get_specific_pkey(cls, pkeycls, privatekey, password):
|
||||
logging.info('Trying {}'.format(pkeycls.__name__))
|
||||
try:
|
||||
pkey = pkeycls.from_private_key(io.StringIO(privatekey),
|
||||
password=password)
|
||||
except paramiko.PasswordRequiredException:
|
||||
raise InvalidValueError(
|
||||
'Need a password to decrypt the private key.'
|
||||
)
|
||||
except paramiko.SSHException:
|
||||
pass
|
||||
else:
|
||||
return pkey
|
||||
|
||||
@classmethod
|
||||
def get_pkey_obj(cls, privatekey, password, filename):
|
||||
bpass = to_bytes(password) if password else None
|
||||
|
||||
pkey = cls.get_specific_pkey(paramiko.RSAKey, privatekey, bpass)\
|
||||
or cls.get_specific_pkey(paramiko.DSSKey, privatekey, bpass)\
|
||||
or cls.get_specific_pkey(paramiko.ECDSAKey, privatekey, bpass)\
|
||||
or cls.get_specific_pkey(paramiko.Ed25519Key, privatekey, bpass)
|
||||
|
||||
if not pkey:
|
||||
if not password:
|
||||
error = 'Invalid private key: {}'.format(filename)
|
||||
else:
|
||||
error = (
|
||||
'Wrong password {!r} for decrypting the private key.'
|
||||
) .format(password)
|
||||
raise InvalidValueError(error)
|
||||
|
||||
return pkey
|
||||
return value, filename
|
||||
|
||||
def get_hostname(self):
|
||||
value = self.get_value('hostname')
|
||||
|
@ -287,11 +294,9 @@ class IndexHandler(MixinHandler, tornado.web.RequestHandler):
|
|||
self.lookup_hostname(hostname, port)
|
||||
username = self.get_value('username')
|
||||
password = self.get_argument('password', u'')
|
||||
privatekey = self.get_privatekey()
|
||||
privatekey, filename = self.get_privatekey()
|
||||
if privatekey:
|
||||
pkey = self.get_pkey_obj(
|
||||
privatekey, password, self.privatekey_filename
|
||||
)
|
||||
pkey = PrivateKey(privatekey, password, filename).get_pkey_obj()
|
||||
password = None
|
||||
else:
|
||||
pkey = None
|
||||
|
|
Loading…
Reference in New Issue