Refactored PrivateKey

pull/75/head
Sheng 2019-07-01 22:10:49 +08:00
parent e956c44c1e
commit ec545ec463
2 changed files with 51 additions and 16 deletions

View File

@ -196,6 +196,39 @@ class TestPrivateKey(unittest.TestCase):
password = 'abc123' password = 'abc123'
self._test_with_encrypted_key(fname, password, paramiko.Ed25519Key) self._test_with_encrypted_key(fname, password, paramiko.Ed25519Key)
def test_parse_name(self):
key = u'-----BEGIN PRIVATE KEY-----'
pk = PrivateKey(key)
name, _ = pk.parse_name(pk.iostr, pk.tag_to_name)
self.assertIsNone(name)
key = u'-----BEGIN xxx PRIVATE KEY-----'
pk = PrivateKey(key)
name, _ = pk.parse_name(pk.iostr, pk.tag_to_name)
self.assertIsNone(name)
key = u'-----BEGIN RSA PRIVATE KEY-----'
pk = PrivateKey(key)
name, _ = pk.parse_name(pk.iostr, pk.tag_to_name)
self.assertIsNone(name)
key = u'-----BEGIN RSA PRIVATE KEY-----'
pk = PrivateKey(key)
name, _ = pk.parse_name(pk.iostr, pk.tag_to_name)
self.assertIsNone(name)
key = u'-----BEGIN RSA PRIVATE KEY-----'
pk = PrivateKey(key)
name, _ = pk.parse_name(pk.iostr, pk.tag_to_name)
self.assertIsNone(name)
for tag, to_name in PrivateKey.tag_to_name.items():
key = u'-----BEGIN {} PRIVATE KEY----- \r\n'.format(tag)
pk = PrivateKey(key)
name, length = pk.parse_name(pk.iostr, pk.tag_to_name)
self.assertEqual(name, to_name)
self.assertEqual(length, len(key))
class TestWsockHandler(unittest.TestCase): class TestWsockHandler(unittest.TestCase):

View File

@ -43,7 +43,6 @@ class InvalidValueError(Exception):
class PrivateKey(object): class PrivateKey(object):
max_length = 16384 # rough number max_length = 16384 # rough number
name = None
tag_to_name = { tag_to_name = {
'RSA': 'RSA', 'RSA': 'RSA',
@ -63,29 +62,32 @@ class PrivateKey(object):
if len(self.privatekey) > self.max_length: if len(self.privatekey) > self.max_length:
raise InvalidValueError('Invalid key length.') raise InvalidValueError('Invalid key length.')
def parse_name(self): def parse_name(self, iostr, tag_to_name):
for line_orig in self.iostr: name = None
line = line_orig.strip() for line_ in iostr:
line = line_.strip()
if line and line.startswith('-----BEGIN ') and \ if line and line.startswith('-----BEGIN ') and \
line.endswith(' PRIVATE KEY-----'): line.endswith(' PRIVATE KEY-----'):
tag = line.split(' ', 2)[1] lst = line.split(' ')
if tag: if len(lst) == 4:
name = self.tag_to_name.get(tag) tag = lst[1]
if name: if tag:
self.name = name name = tag_to_name.get(tag)
break if name:
break
return name, len(line_)
if not self.name: def get_pkey_obj(self):
name, length = self.parse_name(self.iostr, self.tag_to_name)
if not name:
raise InvalidValueError('Invalid key {}.'.format(self.filename)) raise InvalidValueError('Invalid key {}.'.format(self.filename))
offset = self.iostr.tell() - len(line_orig) offset = self.iostr.tell() - length
self.iostr.seek(offset) self.iostr.seek(offset)
logging.debug('Reset offset to {}.'.format(offset)) logging.debug('Reset offset to {}.'.format(offset))
def get_pkey_obj(self): logging.info('Parsing {} key'.format(name))
self.parse_name() pkeycls = getattr(paramiko, name+'Key')
logging.info('Parsing {} key'.format(self.name))
pkeycls = getattr(paramiko, self.name+'Key')
password = to_bytes(self.password) if self.password else None password = to_bytes(self.password) if self.password else None
try: try: