diff --git a/tests/test_app.py b/tests/test_app.py index 4334070..85c99bf 100644 --- a/tests/test_app.py +++ b/tests/test_app.py @@ -536,4 +536,5 @@ class TestAppWithRejectPolicy(OtherTestBase): data = json.loads(to_str(response.body)) self.assertIsNone(data['id']) self.assertIsNone(data['encoding']) - self.assertEqual('Connection to 127.0.0.1 is not allowed.', data['status']) # noqa + message = 'Connection to {}:{} is not allowed.'.format(self.body['hostname'], self.sshserver_port) # noqa + self.assertEqual(message, data['status']) diff --git a/webssh/handler.py b/webssh/handler.py index 4516638..7e227aa 100644 --- a/webssh/handler.py +++ b/webssh/handler.py @@ -152,12 +152,14 @@ class IndexHandler(MixinHandler, tornado.web.RequestHandler): raise InvalidValueError('Invalid hostname: {}'.format(value)) return value - def lookup_hostname(self, hostname): + def lookup_hostname(self, hostname, port): if isinstance(self.policy, paramiko.RejectPolicy): - if self.ssh_client._system_host_keys.lookup(hostname) is None: - if self.ssh_client._host_keys.lookup(hostname) is None: + key = hostname if port == 22 else '[{}]:{}'.format(hostname, port) + if self.ssh_client._system_host_keys.lookup(key) is None: + if self.ssh_client._host_keys.lookup(key) is None: raise ValueError( - 'Connection to {} is not allowed.'.format(hostname) + 'Connection to {}:{} is not allowed.'.format( + hostname, port) ) def get_port(self): @@ -174,8 +176,8 @@ class IndexHandler(MixinHandler, tornado.web.RequestHandler): def get_args(self): hostname = self.get_hostname() - self.lookup_hostname(hostname) port = self.get_port() + self.lookup_hostname(hostname, port) username = self.get_value('username') password = self.get_argument('password', u'') privatekey = self.get_privatekey()