mirror of https://github.com/huashengdun/webssh
Made PrivateKey more robust for parsing keys
parent
931f8bf261
commit
fb6617cc1f
|
@ -43,6 +43,7 @@ class InvalidValueError(Exception):
|
||||||
class PrivateKey(object):
|
class PrivateKey(object):
|
||||||
|
|
||||||
max_length = 16384 # rough number
|
max_length = 16384 # rough number
|
||||||
|
name = None
|
||||||
|
|
||||||
tag_to_name = {
|
tag_to_name = {
|
||||||
'RSA': 'RSA',
|
'RSA': 'RSA',
|
||||||
|
@ -52,31 +53,43 @@ class PrivateKey(object):
|
||||||
}
|
}
|
||||||
|
|
||||||
def __init__(self, privatekey, password=None, filename=''):
|
def __init__(self, privatekey, password=None, filename=''):
|
||||||
self.privatekey = privatekey.strip()
|
self.privatekey = privatekey
|
||||||
self.filename = filename
|
self.filename = filename
|
||||||
self.password = password
|
self.password = password
|
||||||
self.check_length()
|
self.check_length()
|
||||||
|
self.iostr = io.StringIO(privatekey)
|
||||||
|
|
||||||
def check_length(self):
|
def check_length(self):
|
||||||
if len(self.privatekey) > self.max_length:
|
if len(self.privatekey) > self.max_length:
|
||||||
raise InvalidValueError('Invalid key length.')
|
raise InvalidValueError('Invalid key length.')
|
||||||
|
|
||||||
def get_name(self):
|
def parse_name(self):
|
||||||
lst = self.privatekey.split(' ', 2)
|
for line_orig in self.iostr:
|
||||||
if len(lst) > 1:
|
line = line_orig.strip()
|
||||||
return self.tag_to_name.get(lst[1])
|
if line and line.startswith('-----BEGIN ') and \
|
||||||
|
line.endswith(' PRIVATE KEY-----'):
|
||||||
|
tag = line.split(' ', 2)[1]
|
||||||
|
if tag:
|
||||||
|
name = self.tag_to_name.get(tag)
|
||||||
|
if name:
|
||||||
|
self.name = name
|
||||||
|
break
|
||||||
|
|
||||||
def get_pkey_obj(self):
|
if not self.name:
|
||||||
name = self.get_name()
|
|
||||||
if not name:
|
|
||||||
raise InvalidValueError('Invalid key {}.'.format(self.filename))
|
raise InvalidValueError('Invalid key {}.'.format(self.filename))
|
||||||
|
|
||||||
logging.info('Parsing {} key'.format(name))
|
offset = self.iostr.tell() - len(line_orig)
|
||||||
pkeycls = getattr(paramiko, name+'Key')
|
self.iostr.seek(offset)
|
||||||
|
logging.debug('Reset offset to {}.'.format(offset))
|
||||||
|
|
||||||
|
def get_pkey_obj(self):
|
||||||
|
self.parse_name()
|
||||||
|
logging.info('Parsing {} key'.format(self.name))
|
||||||
|
pkeycls = getattr(paramiko, self.name+'Key')
|
||||||
password = to_bytes(self.password) if self.password else None
|
password = to_bytes(self.password) if self.password else None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
return pkeycls.from_private_key(io.StringIO(self.privatekey),
|
return pkeycls.from_private_key(self.iostr, password=password)
|
||||||
password=password)
|
|
||||||
except paramiko.PasswordRequiredException:
|
except paramiko.PasswordRequiredException:
|
||||||
raise InvalidValueError('Need a password to decrypt the key.')
|
raise InvalidValueError('Need a password to decrypt the key.')
|
||||||
except paramiko.SSHException as exc:
|
except paramiko.SSHException as exc:
|
||||||
|
|
Loading…
Reference in New Issue