mirror of https://github.com/huashengdun/webssh
Lookup hostname before connection under reject policy
parent
90e7ea0327
commit
a576a41ea4
|
@ -428,6 +428,9 @@ class OtherTestBase(AsyncHTTPTestCase):
|
||||||
sshserver_port = 3300
|
sshserver_port = 3300
|
||||||
headers = {'Cookie': '_xsrf=yummy'}
|
headers = {'Cookie': '_xsrf=yummy'}
|
||||||
debug = False
|
debug = False
|
||||||
|
policy = None
|
||||||
|
hostFile = None
|
||||||
|
sysHostFile = None
|
||||||
body = {
|
body = {
|
||||||
'hostname': '127.0.0.1',
|
'hostname': '127.0.0.1',
|
||||||
'port': '',
|
'port': '',
|
||||||
|
@ -440,9 +443,9 @@ class OtherTestBase(AsyncHTTPTestCase):
|
||||||
self.body.update(port=str(self.sshserver_port))
|
self.body.update(port=str(self.sshserver_port))
|
||||||
loop = self.io_loop
|
loop = self.io_loop
|
||||||
options.debug = self.debug
|
options.debug = self.debug
|
||||||
options.policy = random.choice(['warning', 'autoadd'])
|
options.policy = self.policy if self.policy else random.choice(['warning', 'autoadd']) # noqa
|
||||||
options.hostFile = ''
|
options.hostFile = self.hostFile if self.hostFile else ''
|
||||||
options.sysHostFile = ''
|
options.sysHostFile = self.sysHostFile if self.sysHostFile else ''
|
||||||
app = make_app(make_handlers(loop, options), get_app_settings(options))
|
app = make_app(make_handlers(loop, options), get_app_settings(options))
|
||||||
return app
|
return app
|
||||||
|
|
||||||
|
@ -516,3 +519,21 @@ class TestAppMiscell(OtherTestBase):
|
||||||
recv = b''.join(lst).decode(data['encoding'])
|
recv = b''.join(lst).decode(data['encoding'])
|
||||||
self.assertEqual(send, recv)
|
self.assertEqual(send, recv)
|
||||||
ws.close()
|
ws.close()
|
||||||
|
|
||||||
|
|
||||||
|
class TestAppWithRejectPolicy(OtherTestBase):
|
||||||
|
|
||||||
|
policy = 'reject'
|
||||||
|
hostFile = make_tests_data_path('known_hosts_example')
|
||||||
|
|
||||||
|
@tornado.testing.gen_test
|
||||||
|
def test_app_with_hostname_not_in_hostkeys(self):
|
||||||
|
url = self.get_url('/')
|
||||||
|
client = self.get_http_client()
|
||||||
|
body = urlencode(dict(self.body, username='foo'))
|
||||||
|
response = yield client.fetch(url, method='POST', body=body,
|
||||||
|
headers=self.headers)
|
||||||
|
data = json.loads(to_str(response.body))
|
||||||
|
self.assertIsNone(data['id'])
|
||||||
|
self.assertIsNone(data['encoding'])
|
||||||
|
self.assertEqual('Connection to 127.0.0.1 is not allowed.', data['status']) # noqa
|
||||||
|
|
|
@ -70,6 +70,7 @@ class IndexHandler(MixinHandler, tornado.web.RequestHandler):
|
||||||
self.loop = loop
|
self.loop = loop
|
||||||
self.policy = policy
|
self.policy = policy
|
||||||
self.host_keys_settings = host_keys_settings
|
self.host_keys_settings = host_keys_settings
|
||||||
|
self.ssh_client = self.get_ssh_client()
|
||||||
self.filename = None
|
self.filename = None
|
||||||
self.result = dict(id=None, status=None, encoding=None)
|
self.result = dict(id=None, status=None, encoding=None)
|
||||||
|
|
||||||
|
@ -87,6 +88,14 @@ class IndexHandler(MixinHandler, tornado.web.RequestHandler):
|
||||||
self.set_status(200)
|
self.set_status(200)
|
||||||
self.finish(self.result)
|
self.finish(self.result)
|
||||||
|
|
||||||
|
def get_ssh_client(self):
|
||||||
|
ssh = paramiko.SSHClient()
|
||||||
|
ssh._system_host_keys = self.host_keys_settings['system_host_keys']
|
||||||
|
ssh._host_keys = self.host_keys_settings['host_keys']
|
||||||
|
ssh._host_keys_filename = self.host_keys_settings['host_keys_filename']
|
||||||
|
ssh.set_missing_host_key_policy(self.policy)
|
||||||
|
return ssh
|
||||||
|
|
||||||
def get_privatekey(self):
|
def get_privatekey(self):
|
||||||
name = 'privatekey'
|
name = 'privatekey'
|
||||||
lst = self.request.files.get(name) # multipart form
|
lst = self.request.files.get(name) # multipart form
|
||||||
|
@ -143,6 +152,14 @@ class IndexHandler(MixinHandler, tornado.web.RequestHandler):
|
||||||
raise InvalidValueError('Invalid hostname: {}'.format(value))
|
raise InvalidValueError('Invalid hostname: {}'.format(value))
|
||||||
return value
|
return value
|
||||||
|
|
||||||
|
def lookup_hostname(self, hostname):
|
||||||
|
if isinstance(self.policy, paramiko.RejectPolicy):
|
||||||
|
if self.ssh_client._system_host_keys.lookup(hostname) is None:
|
||||||
|
if self.ssh_client._host_keys.lookup(hostname) is None:
|
||||||
|
raise ValueError(
|
||||||
|
'Connection to {} is not allowed.'.format(hostname)
|
||||||
|
)
|
||||||
|
|
||||||
def get_port(self):
|
def get_port(self):
|
||||||
value = self.get_value('port')
|
value = self.get_value('port')
|
||||||
try:
|
try:
|
||||||
|
@ -157,6 +174,7 @@ class IndexHandler(MixinHandler, tornado.web.RequestHandler):
|
||||||
|
|
||||||
def get_args(self):
|
def get_args(self):
|
||||||
hostname = self.get_hostname()
|
hostname = self.get_hostname()
|
||||||
|
self.lookup_hostname(hostname)
|
||||||
port = self.get_port()
|
port = self.get_port()
|
||||||
username = self.get_value('username')
|
username = self.get_value('username')
|
||||||
password = self.get_argument('password', u'')
|
password = self.get_argument('password', u'')
|
||||||
|
@ -182,11 +200,7 @@ class IndexHandler(MixinHandler, tornado.web.RequestHandler):
|
||||||
return result if result else 'utf-8'
|
return result if result else 'utf-8'
|
||||||
|
|
||||||
def ssh_connect(self):
|
def ssh_connect(self):
|
||||||
ssh = paramiko.SSHClient()
|
ssh = self.ssh_client
|
||||||
ssh._system_host_keys = self.host_keys_settings['system_host_keys']
|
|
||||||
ssh._host_keys = self.host_keys_settings['host_keys']
|
|
||||||
ssh._host_keys_filename = self.host_keys_settings['host_keys_filename']
|
|
||||||
ssh.set_missing_host_key_policy(self.policy)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
args = self.get_args()
|
args = self.get_args()
|
||||||
|
|
Loading…
Reference in New Issue