mirror of https://github.com/huashengdun/webssh
				
				
				
			Lookup hostname before connection under reject policy
							parent
							
								
									90e7ea0327
								
							
						
					
					
						commit
						a576a41ea4
					
				| 
						 | 
				
			
			@ -428,6 +428,9 @@ class OtherTestBase(AsyncHTTPTestCase):
 | 
			
		|||
    sshserver_port = 3300
 | 
			
		||||
    headers = {'Cookie': '_xsrf=yummy'}
 | 
			
		||||
    debug = False
 | 
			
		||||
    policy = None
 | 
			
		||||
    hostFile = None
 | 
			
		||||
    sysHostFile = None
 | 
			
		||||
    body = {
 | 
			
		||||
        'hostname': '127.0.0.1',
 | 
			
		||||
        'port': '',
 | 
			
		||||
| 
						 | 
				
			
			@ -440,9 +443,9 @@ class OtherTestBase(AsyncHTTPTestCase):
 | 
			
		|||
        self.body.update(port=str(self.sshserver_port))
 | 
			
		||||
        loop = self.io_loop
 | 
			
		||||
        options.debug = self.debug
 | 
			
		||||
        options.policy = random.choice(['warning', 'autoadd'])
 | 
			
		||||
        options.hostFile = ''
 | 
			
		||||
        options.sysHostFile = ''
 | 
			
		||||
        options.policy = self.policy if self.policy else random.choice(['warning', 'autoadd'])  # noqa
 | 
			
		||||
        options.hostFile = self.hostFile if self.hostFile else ''
 | 
			
		||||
        options.sysHostFile = self.sysHostFile if self.sysHostFile else ''
 | 
			
		||||
        app = make_app(make_handlers(loop, options), get_app_settings(options))
 | 
			
		||||
        return app
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -516,3 +519,21 @@ class TestAppMiscell(OtherTestBase):
 | 
			
		|||
        recv = b''.join(lst).decode(data['encoding'])
 | 
			
		||||
        self.assertEqual(send, recv)
 | 
			
		||||
        ws.close()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TestAppWithRejectPolicy(OtherTestBase):
 | 
			
		||||
 | 
			
		||||
    policy = 'reject'
 | 
			
		||||
    hostFile = make_tests_data_path('known_hosts_example')
 | 
			
		||||
 | 
			
		||||
    @tornado.testing.gen_test
 | 
			
		||||
    def test_app_with_hostname_not_in_hostkeys(self):
 | 
			
		||||
        url = self.get_url('/')
 | 
			
		||||
        client = self.get_http_client()
 | 
			
		||||
        body = urlencode(dict(self.body, username='foo'))
 | 
			
		||||
        response = yield client.fetch(url, method='POST', body=body,
 | 
			
		||||
                                      headers=self.headers)
 | 
			
		||||
        data = json.loads(to_str(response.body))
 | 
			
		||||
        self.assertIsNone(data['id'])
 | 
			
		||||
        self.assertIsNone(data['encoding'])
 | 
			
		||||
        self.assertEqual('Connection to 127.0.0.1 is not allowed.', data['status'])  # noqa
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -70,6 +70,7 @@ class IndexHandler(MixinHandler, tornado.web.RequestHandler):
 | 
			
		|||
        self.loop = loop
 | 
			
		||||
        self.policy = policy
 | 
			
		||||
        self.host_keys_settings = host_keys_settings
 | 
			
		||||
        self.ssh_client = self.get_ssh_client()
 | 
			
		||||
        self.filename = None
 | 
			
		||||
        self.result = dict(id=None, status=None, encoding=None)
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -87,6 +88,14 @@ class IndexHandler(MixinHandler, tornado.web.RequestHandler):
 | 
			
		|||
            self.set_status(200)
 | 
			
		||||
            self.finish(self.result)
 | 
			
		||||
 | 
			
		||||
    def get_ssh_client(self):
 | 
			
		||||
        ssh = paramiko.SSHClient()
 | 
			
		||||
        ssh._system_host_keys = self.host_keys_settings['system_host_keys']
 | 
			
		||||
        ssh._host_keys = self.host_keys_settings['host_keys']
 | 
			
		||||
        ssh._host_keys_filename = self.host_keys_settings['host_keys_filename']
 | 
			
		||||
        ssh.set_missing_host_key_policy(self.policy)
 | 
			
		||||
        return ssh
 | 
			
		||||
 | 
			
		||||
    def get_privatekey(self):
 | 
			
		||||
        name = 'privatekey'
 | 
			
		||||
        lst = self.request.files.get(name)  # multipart form
 | 
			
		||||
| 
						 | 
				
			
			@ -143,6 +152,14 @@ class IndexHandler(MixinHandler, tornado.web.RequestHandler):
 | 
			
		|||
            raise InvalidValueError('Invalid hostname: {}'.format(value))
 | 
			
		||||
        return value
 | 
			
		||||
 | 
			
		||||
    def lookup_hostname(self, hostname):
 | 
			
		||||
        if isinstance(self.policy, paramiko.RejectPolicy):
 | 
			
		||||
            if self.ssh_client._system_host_keys.lookup(hostname) is None:
 | 
			
		||||
                if self.ssh_client._host_keys.lookup(hostname) is None:
 | 
			
		||||
                    raise ValueError(
 | 
			
		||||
                        'Connection to {} is not allowed.'.format(hostname)
 | 
			
		||||
                    )
 | 
			
		||||
 | 
			
		||||
    def get_port(self):
 | 
			
		||||
        value = self.get_value('port')
 | 
			
		||||
        try:
 | 
			
		||||
| 
						 | 
				
			
			@ -157,6 +174,7 @@ class IndexHandler(MixinHandler, tornado.web.RequestHandler):
 | 
			
		|||
 | 
			
		||||
    def get_args(self):
 | 
			
		||||
        hostname = self.get_hostname()
 | 
			
		||||
        self.lookup_hostname(hostname)
 | 
			
		||||
        port = self.get_port()
 | 
			
		||||
        username = self.get_value('username')
 | 
			
		||||
        password = self.get_argument('password', u'')
 | 
			
		||||
| 
						 | 
				
			
			@ -182,11 +200,7 @@ class IndexHandler(MixinHandler, tornado.web.RequestHandler):
 | 
			
		|||
        return result if result else 'utf-8'
 | 
			
		||||
 | 
			
		||||
    def ssh_connect(self):
 | 
			
		||||
        ssh = paramiko.SSHClient()
 | 
			
		||||
        ssh._system_host_keys = self.host_keys_settings['system_host_keys']
 | 
			
		||||
        ssh._host_keys = self.host_keys_settings['host_keys']
 | 
			
		||||
        ssh._host_keys_filename = self.host_keys_settings['host_keys_filename']
 | 
			
		||||
        ssh.set_missing_host_key_policy(self.policy)
 | 
			
		||||
        ssh = self.ssh_client
 | 
			
		||||
 | 
			
		||||
        try:
 | 
			
		||||
            args = self.get_args()
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue