mirror of https://github.com/huashengdun/webssh
				
				
				
			Made PrivateKey more robust for parsing keys
							parent
							
								
									931f8bf261
								
							
						
					
					
						commit
						fb6617cc1f
					
				| 
						 | 
				
			
			@ -43,6 +43,7 @@ class InvalidValueError(Exception):
 | 
			
		|||
class PrivateKey(object):
 | 
			
		||||
 | 
			
		||||
    max_length = 16384  # rough number
 | 
			
		||||
    name = None
 | 
			
		||||
 | 
			
		||||
    tag_to_name = {
 | 
			
		||||
        'RSA': 'RSA',
 | 
			
		||||
| 
						 | 
				
			
			@ -52,31 +53,43 @@ class PrivateKey(object):
 | 
			
		|||
    }
 | 
			
		||||
 | 
			
		||||
    def __init__(self, privatekey, password=None, filename=''):
 | 
			
		||||
        self.privatekey = privatekey.strip()
 | 
			
		||||
        self.privatekey = privatekey
 | 
			
		||||
        self.filename = filename
 | 
			
		||||
        self.password = password
 | 
			
		||||
        self.check_length()
 | 
			
		||||
        self.iostr = io.StringIO(privatekey)
 | 
			
		||||
 | 
			
		||||
    def check_length(self):
 | 
			
		||||
        if len(self.privatekey) > self.max_length:
 | 
			
		||||
            raise InvalidValueError('Invalid key length.')
 | 
			
		||||
 | 
			
		||||
    def get_name(self):
 | 
			
		||||
        lst = self.privatekey.split(' ', 2)
 | 
			
		||||
        if len(lst) > 1:
 | 
			
		||||
            return self.tag_to_name.get(lst[1])
 | 
			
		||||
    def parse_name(self):
 | 
			
		||||
        for line_orig in self.iostr:
 | 
			
		||||
            line = line_orig.strip()
 | 
			
		||||
            if line and line.startswith('-----BEGIN ') and \
 | 
			
		||||
                    line.endswith(' PRIVATE KEY-----'):
 | 
			
		||||
                tag = line.split(' ', 2)[1]
 | 
			
		||||
                if tag:
 | 
			
		||||
                    name = self.tag_to_name.get(tag)
 | 
			
		||||
                    if name:
 | 
			
		||||
                        self.name = name
 | 
			
		||||
                        break
 | 
			
		||||
 | 
			
		||||
    def get_pkey_obj(self):
 | 
			
		||||
        name = self.get_name()
 | 
			
		||||
        if not name:
 | 
			
		||||
        if not self.name:
 | 
			
		||||
            raise InvalidValueError('Invalid key {}.'.format(self.filename))
 | 
			
		||||
 | 
			
		||||
        logging.info('Parsing {} key'.format(name))
 | 
			
		||||
        pkeycls = getattr(paramiko, name+'Key')
 | 
			
		||||
        offset = self.iostr.tell() - len(line_orig)
 | 
			
		||||
        self.iostr.seek(offset)
 | 
			
		||||
        logging.debug('Reset offset to {}.'.format(offset))
 | 
			
		||||
 | 
			
		||||
    def get_pkey_obj(self):
 | 
			
		||||
        self.parse_name()
 | 
			
		||||
        logging.info('Parsing {} key'.format(self.name))
 | 
			
		||||
        pkeycls = getattr(paramiko, self.name+'Key')
 | 
			
		||||
        password = to_bytes(self.password) if self.password else None
 | 
			
		||||
 | 
			
		||||
        try:
 | 
			
		||||
            return pkeycls.from_private_key(io.StringIO(self.privatekey),
 | 
			
		||||
                                            password=password)
 | 
			
		||||
            return pkeycls.from_private_key(self.iostr, password=password)
 | 
			
		||||
        except paramiko.PasswordRequiredException:
 | 
			
		||||
            raise InvalidValueError('Need a password to decrypt the key.')
 | 
			
		||||
        except paramiko.SSHException as exc:
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue