Added PrivateKey class

pull/75/head
Sheng 2019-06-27 12:52:19 +08:00
parent 6d62642c7f
commit 2b8b978ca2
3 changed files with 96 additions and 105 deletions

View File

@ -358,7 +358,7 @@ class TestAppBasic(TestAppBase):
if swallow_http_errors:
response = yield self.async_post(url, body, headers=headers)
self.assertIn(b'Invalid private key', response.body)
self.assertIn(b'Invalid key', response.body)
else:
with self.assertRaises(HTTPError) as ctx:
yield self.async_post(url, body, headers=headers)
@ -367,7 +367,7 @@ class TestAppBasic(TestAppBase):
@tornado.testing.gen_test
def test_app_auth_with_pubkey_exceeds_key_max_size(self):
url = self.get_url('/')
privatekey = 'h' * (handler.KEY_MAX_SIZE * 2)
privatekey = 'h' * (handler.PrivateKey.max_length + 1)
files = [('privatekey', 'user_rsa_key', privatekey)]
content_type, body = encode_multipart_formdata(self.body_dict.items(),
files)
@ -376,7 +376,7 @@ class TestAppBasic(TestAppBase):
}
if swallow_http_errors:
response = yield self.async_post(url, body, headers=headers)
self.assertIn(b'Invalid private key', response.body)
self.assertIn(b'Invalid key', response.body)
else:
with self.assertRaises(HTTPError) as ctx:
yield self.async_post(url, body, headers=headers)

View File

@ -6,7 +6,7 @@ from tornado.options import options
from tests.utils import read_file, make_tests_data_path
from webssh import handler
from webssh.handler import (
MixinHandler, IndexHandler, WsockHandler, InvalidValueError
MixinHandler, WsockHandler, PrivateKey, InvalidValueError
)
try:
@ -142,73 +142,59 @@ class TestMixinHandler(unittest.TestCase):
(x_real_ip, x_real_port))
class TestIndexHandler(unittest.TestCase):
class TestPrivateKey(unittest.TestCase):
def test_get_specific_pkey_with_plain_key(self):
fname = 'test_rsa.key'
cls = paramiko.RSAKey
def get_pk_obj(self, fname, password=None):
key = read_file(make_tests_data_path(fname))
return PrivateKey(key, password=password, filename=fname)
pkey = IndexHandler.get_specific_pkey(cls, key, None)
self.assertIsInstance(pkey, cls)
pkey = IndexHandler.get_specific_pkey(cls, key, 'iginored')
self.assertIsInstance(pkey, cls)
pkey = IndexHandler.get_specific_pkey(cls, 'x'+key, None)
self.assertIsNone(pkey)
def test_get_specific_pkey_with_encrypted_key(self):
fname = 'test_rsa_password.key'
cls = paramiko.RSAKey
password = 'television'
key = read_file(make_tests_data_path(fname))
pkey = IndexHandler.get_specific_pkey(cls, key, password)
self.assertIsInstance(pkey, cls)
pkey = IndexHandler.get_specific_pkey(cls, 'x'+key, None)
self.assertIsNone(pkey)
def _test_with_encrypted_key(self, fname, password, klass):
pk = self.get_pk_obj(fname, password='')
with self.assertRaises(InvalidValueError) as ctx:
pkey = IndexHandler.get_specific_pkey(cls, key, None)
pk.get_pkey_obj()
self.assertIn('Need a password', str(ctx.exception))
def test_get_pkey_obj_with_plain_key(self):
fname = 'test_ed25519.key'
cls = paramiko.Ed25519Key
key = read_file(make_tests_data_path(fname))
pk = self.get_pk_obj(fname, password='wrongpass')
with self.assertRaises(InvalidValueError) as ctx:
pk.get_pkey_obj()
self.assertIn('wrong password', str(ctx.exception))
pkey = IndexHandler.get_pkey_obj(key, None, fname)
self.assertIsInstance(pkey, cls)
pk = self.get_pk_obj(fname, password=password)
self.assertIsInstance(pk.get_pkey_obj(), klass)
pkey = IndexHandler.get_pkey_obj(key, 'iginored', fname)
self.assertIsInstance(pkey, cls)
def test_class_with_invalid_key_length(self):
key = u'a' * (PrivateKey.max_length + 1)
with self.assertRaises(InvalidValueError) as ctx:
pkey = IndexHandler.get_pkey_obj('x'+key, None, fname)
self.assertIn('Invalid private key', str(ctx.exception))
PrivateKey(key)
self.assertIn('Invalid key length', str(ctx.exception))
def test_get_pkey_obj_with_encrypted_key(self):
def test_get_pkey_obj_with_invalid_key(self):
key = u'a b c'
fname = 'abc'
pk = PrivateKey(key, filename=fname)
with self.assertRaises(InvalidValueError) as ctx:
pk.get_pkey_obj()
self.assertIn('Invalid key {}'.format(fname), str(ctx.exception))
def test_get_pkey_obj_with_plain_rsa_key(self):
pk = self.get_pk_obj('test_rsa.key')
self.assertIsInstance(pk.get_pkey_obj(), paramiko.RSAKey)
def test_get_pkey_obj_with_plain_ed25519_key(self):
pk = self.get_pk_obj('test_ed25519.key')
self.assertIsInstance(pk.get_pkey_obj(), paramiko.Ed25519Key)
def test_get_pkey_obj_with_encrypted_rsa_key(self):
fname = 'test_rsa_password.key'
password = 'television'
self._test_with_encrypted_key(fname, password, paramiko.RSAKey)
def test_get_pkey_obj_with_encrypted_ed25519_key(self):
fname = 'test_ed25519_password.key'
password = 'abc123'
cls = paramiko.Ed25519Key
key = read_file(make_tests_data_path(fname))
pkey = IndexHandler.get_pkey_obj(key, password, fname)
self.assertIsInstance(pkey, cls)
with self.assertRaises(InvalidValueError) as ctx:
pkey = IndexHandler.get_pkey_obj(key, 'wrongpass', fname)
self.assertIn('Wrong password', str(ctx.exception))
with self.assertRaises(InvalidValueError) as ctx:
pkey = IndexHandler.get_pkey_obj('x'+key, '', fname)
self.assertIn('Invalid private key', str(ctx.exception))
with self.assertRaises(InvalidValueError) as ctx:
pkey = IndexHandler.get_specific_pkey(cls, key, None)
self.assertIn('Need a password', str(ctx.exception))
self._test_with_encrypted_key(fname, password, paramiko.Ed25519Key)
class TestWsockHandler(unittest.TestCase):

View File

@ -30,7 +30,6 @@ except ImportError:
DELAY = 3
KEY_MAX_SIZE = 16384
DEFAULT_PORT = 22
swallow_http_errors = True
@ -41,6 +40,53 @@ class InvalidValueError(Exception):
pass
class PrivateKey(object):
max_length = 16384 # rough number
tag_to_name = {
'RSA': 'RSA',
'DSA': 'DSS',
'EC': 'ECDSA',
'OPENSSH': 'Ed25519'
}
def __init__(self, privatekey, password=None, filename=''):
self.privatekey = privatekey.strip()
self.filename = filename
self.password = password
self.check_length()
def check_length(self):
if len(self.privatekey) > self.max_length:
raise InvalidValueError('Invalid key length.')
def get_name(self):
lst = self.privatekey.split(' ', 2)
if len(lst) > 1:
return self.tag_to_name.get(lst[1])
def get_pkey_obj(self):
name = self.get_name()
if not name:
raise InvalidValueError('Invalid key {}.'.format(self.filename))
logging.info('Parsing {} key'.format(name))
pkeycls = getattr(paramiko, name+'Key')
password = to_bytes(self.password) if self.password else None
try:
return pkeycls.from_private_key(io.StringIO(self.privatekey),
password=password)
except paramiko.PasswordRequiredException:
raise InvalidValueError('Need a password to decrypt the key.')
except paramiko.SSHException as exc:
logging.error(str(exc))
raise InvalidValueError(
'Invalid key or wrong password "{}" for decrypting it.'
.format(self.password)
)
class MixinHandler(object):
custom_headers = {
@ -176,7 +222,6 @@ class IndexHandler(MixinHandler, tornado.web.RequestHandler):
self.policy = policy
self.host_keys_settings = host_keys_settings
self.ssh_client = self.get_ssh_client()
self.privatekey_filename = None
self.debug = self.settings.get('debug', False)
self.result = dict(id=None, status=None, encoding=None)
@ -206,53 +251,15 @@ class IndexHandler(MixinHandler, tornado.web.RequestHandler):
lst = self.request.files.get(name)
if lst:
# multipart form
self.privatekey_filename = lst[0]['filename']
filename = lst[0]['filename']
data = lst[0]['body']
value = self.decode_argument(data, name=name).strip()
else:
# urlencoded form
value = self.get_argument(name, u'')
filename = ''
if len(value) > KEY_MAX_SIZE:
raise InvalidValueError(
'Invalid private key: {}'.format(self.privatekey_filename)
)
return value
@classmethod
def get_specific_pkey(cls, pkeycls, privatekey, password):
logging.info('Trying {}'.format(pkeycls.__name__))
try:
pkey = pkeycls.from_private_key(io.StringIO(privatekey),
password=password)
except paramiko.PasswordRequiredException:
raise InvalidValueError(
'Need a password to decrypt the private key.'
)
except paramiko.SSHException:
pass
else:
return pkey
@classmethod
def get_pkey_obj(cls, privatekey, password, filename):
bpass = to_bytes(password) if password else None
pkey = cls.get_specific_pkey(paramiko.RSAKey, privatekey, bpass)\
or cls.get_specific_pkey(paramiko.DSSKey, privatekey, bpass)\
or cls.get_specific_pkey(paramiko.ECDSAKey, privatekey, bpass)\
or cls.get_specific_pkey(paramiko.Ed25519Key, privatekey, bpass)
if not pkey:
if not password:
error = 'Invalid private key: {}'.format(filename)
else:
error = (
'Wrong password {!r} for decrypting the private key.'
) .format(password)
raise InvalidValueError(error)
return pkey
return value, filename
def get_hostname(self):
value = self.get_value('hostname')
@ -287,11 +294,9 @@ class IndexHandler(MixinHandler, tornado.web.RequestHandler):
self.lookup_hostname(hostname, port)
username = self.get_value('username')
password = self.get_argument('password', u'')
privatekey = self.get_privatekey()
privatekey, filename = self.get_privatekey()
if privatekey:
pkey = self.get_pkey_obj(
privatekey, password, self.privatekey_filename
)
pkey = PrivateKey(privatekey, password, filename).get_pkey_obj()
password = None
else:
pkey = None