mirror of https://github.com/huashengdun/webssh
Refactored handler.py
parent
27c587745c
commit
0775c0c3ae
|
@ -40,12 +40,12 @@ class TestAppBase(AsyncHTTPTestCase):
|
|||
self.assertEqual(response.code, 400)
|
||||
self.assertIn(b'Bad Request', response.body)
|
||||
|
||||
def assert_status_in(self, data, status):
|
||||
def assert_status_in(self, status, data):
|
||||
self.assertIsNone(data['encoding'])
|
||||
self.assertIsNone(data['id'])
|
||||
self.assertIn(status, data['status'])
|
||||
|
||||
def assert_status_equal(self, data, status):
|
||||
def assert_status_equal(self, status, data):
|
||||
self.assertIsNone(data['encoding'])
|
||||
self.assertIsNone(data['id'])
|
||||
self.assertEqual(status, data['status'])
|
||||
|
@ -172,7 +172,7 @@ class TestAppBasic(TestAppBase):
|
|||
|
||||
def test_app_with_wrong_credentials(self):
|
||||
response = self.sync_post('/', self.body + 's')
|
||||
self.assert_status_in(json.loads(to_str(response.body)), 'Authentication failed.') # noqa
|
||||
self.assert_status_in('Authentication failed.', json.loads(to_str(response.body))) # noqa
|
||||
|
||||
def test_app_with_correct_credentials(self):
|
||||
response = self.sync_post('/', self.body)
|
||||
|
@ -442,10 +442,10 @@ class TestAppBasic(TestAppBase):
|
|||
self.body_dict.update(username='keyonly', password='foo')
|
||||
response = yield self.async_post('/', self.body_dict)
|
||||
self.assertEqual(response.code, 200)
|
||||
self.assert_status_in(json.loads(to_str(response.body)), 'Bad authentication type') # noqa
|
||||
self.assert_status_in('Bad authentication type', json.loads(to_str(response.body))) # noqa
|
||||
|
||||
@tornado.testing.gen_test
|
||||
def test_app_with_user_pass2fa_with_correct_password_and_passcode(self):
|
||||
def test_app_with_user_pass2fa_with_correct_passwords(self):
|
||||
self.body_dict.update(username='pass2fa', password='password',
|
||||
totp='passcode')
|
||||
response = yield self.async_post('/', self.body_dict)
|
||||
|
@ -454,25 +454,7 @@ class TestAppBasic(TestAppBase):
|
|||
self.assert_status_none(data)
|
||||
|
||||
@tornado.testing.gen_test
|
||||
def test_app_with_user_pass2fa_with_wrong_password(self):
|
||||
self.body_dict.update(username='pass2fa', password='wrongpassword',
|
||||
totp='passcode')
|
||||
response = yield self.async_post('/', self.body_dict)
|
||||
self.assertEqual(response.code, 200)
|
||||
data = json.loads(to_str(response.body))
|
||||
self.assertIn('Authentication failed', data['status'])
|
||||
|
||||
@tornado.testing.gen_test
|
||||
def test_app_with_user_pass2fa_with_wrong_passcode(self):
|
||||
self.body_dict.update(username='pass2fa', password='password',
|
||||
totp='wrongpasscode')
|
||||
response = yield self.async_post('/', self.body_dict)
|
||||
self.assertEqual(response.code, 200)
|
||||
data = json.loads(to_str(response.body))
|
||||
self.assertIn('Authentication failed', data['status'])
|
||||
|
||||
@tornado.testing.gen_test
|
||||
def test_app_with_user_pass2fa_with_wrong_pkey_correct_passwords(self): # noqa
|
||||
def test_app_with_user_pass2fa_with_wrong_pkey_correct_passwords(self):
|
||||
url = self.get_url('/')
|
||||
privatekey = read_file(make_tests_data_path('user_rsa_key'))
|
||||
self.body_dict.update(username='pass2fa', password='password',
|
||||
|
@ -482,7 +464,7 @@ class TestAppBasic(TestAppBase):
|
|||
self.assert_status_none(data)
|
||||
|
||||
@tornado.testing.gen_test
|
||||
def test_app_with_user_pkey2fa_with_correct_password_and_passcode(self):
|
||||
def test_app_with_user_pkey2fa_with_correct_passwords(self):
|
||||
url = self.get_url('/')
|
||||
privatekey = read_file(make_tests_data_path('user_rsa_key'))
|
||||
self.body_dict.update(username='pkey2fa', password='password',
|
||||
|
@ -499,7 +481,7 @@ class TestAppBasic(TestAppBase):
|
|||
privatekey=privatekey, totp='passcode')
|
||||
response = yield self.async_post(url, self.body_dict)
|
||||
data = json.loads(to_str(response.body))
|
||||
self.assertIn('Authentication failed', data['status'])
|
||||
self.assert_status_in('Authentication failed', data)
|
||||
|
||||
@tornado.testing.gen_test
|
||||
def test_app_with_user_pkey2fa_with_wrong_passcode(self):
|
||||
|
@ -509,7 +491,17 @@ class TestAppBasic(TestAppBase):
|
|||
privatekey=privatekey, totp='wrongpasscode')
|
||||
response = yield self.async_post(url, self.body_dict)
|
||||
data = json.loads(to_str(response.body))
|
||||
self.assertIn('Authentication failed', data['status'])
|
||||
self.assert_status_in('Authentication failed', data)
|
||||
|
||||
@tornado.testing.gen_test
|
||||
def test_app_with_user_pkey2fa_with_empty_passcode(self):
|
||||
url = self.get_url('/')
|
||||
privatekey = read_file(make_tests_data_path('user_rsa_key'))
|
||||
self.body_dict.update(username='pkey2fa', password='password',
|
||||
privatekey=privatekey, totp='')
|
||||
response = yield self.async_post(url, self.body_dict)
|
||||
data = json.loads(to_str(response.body))
|
||||
self.assert_status_in('Need a verification code', data)
|
||||
|
||||
|
||||
class OtherTestBase(TestAppBase):
|
||||
|
@ -747,13 +739,13 @@ class TestAppWithCrossOriginOperation(OtherTestBase):
|
|||
def test_app_with_wrong_event_origin(self):
|
||||
body = dict(self.body, _origin='localhost')
|
||||
response = yield self.async_post('/', body)
|
||||
self.assert_status_equal(json.loads(to_str(response.body)), 'Cross origin operation is not allowed.') # noqa
|
||||
self.assert_status_equal('Cross origin operation is not allowed.', json.loads(to_str(response.body))) # noqa
|
||||
|
||||
@tornado.testing.gen_test
|
||||
def test_app_with_wrong_header_origin(self):
|
||||
headers = dict(Origin='localhost')
|
||||
response = yield self.async_post('/', self.body, headers=headers)
|
||||
self.assert_status_equal(json.loads(to_str(response.body)), 'Cross origin operation is not allowed.') # noqa
|
||||
self.assert_status_equal('Cross origin operation is not allowed.', json.loads(to_str(response.body)), ) # noqa
|
||||
|
||||
@tornado.testing.gen_test
|
||||
def test_app_with_correct_event_origin(self):
|
||||
|
|
|
@ -36,80 +36,68 @@ swallow_http_errors = True
|
|||
redirecting = None
|
||||
|
||||
|
||||
def make_handler(password, totp):
|
||||
class InvalidValueError(Exception):
|
||||
pass
|
||||
|
||||
def handler(title, instructions, prompt_list):
|
||||
|
||||
class SSHClient(paramiko.SSHClient):
|
||||
|
||||
def handler(self, title, instructions, prompt_list):
|
||||
answers = []
|
||||
for prompt_, _ in prompt_list:
|
||||
prompt = prompt_.strip().lower()
|
||||
if prompt.startswith('password'):
|
||||
answers.append(password)
|
||||
answers.append(self.password)
|
||||
elif prompt.startswith('verification'):
|
||||
answers.append(totp)
|
||||
answers.append(self.totp)
|
||||
else:
|
||||
raise ValueError('Unknown prompt: {}'.format(prompt_))
|
||||
return answers
|
||||
|
||||
return handler
|
||||
def auth_interactive(self, username, handler):
|
||||
if not self.totp:
|
||||
raise ValueError('Need a verification code for 2fa.')
|
||||
self._transport.auth_interactive(username, handler)
|
||||
|
||||
def _auth(self, username, password, pkey, *args):
|
||||
self.password = password
|
||||
saved_exception = None
|
||||
two_factor = False
|
||||
allowed_types = set()
|
||||
two_factor_types = {'keyboard-interactive', 'password'}
|
||||
|
||||
def auth_interactive(transport, username, handler):
|
||||
if not handler:
|
||||
raise ValueError('Need a verification code for 2fa.')
|
||||
transport.auth_interactive(username, handler)
|
||||
if pkey is not None:
|
||||
logging.info('Trying publickey authentication')
|
||||
try:
|
||||
allowed_types = set(
|
||||
self._transport.auth_publickey(username, pkey)
|
||||
)
|
||||
two_factor = allowed_types & two_factor_types
|
||||
if not two_factor:
|
||||
return
|
||||
except paramiko.SSHException as e:
|
||||
saved_exception = e
|
||||
|
||||
if two_factor:
|
||||
logging.info('Trying publickey 2fa')
|
||||
return self.auth_interactive(username, self.handler)
|
||||
|
||||
def auth(self, username, password, pkey, *args):
|
||||
handler = None
|
||||
saved_exception = None
|
||||
two_factor = False
|
||||
allowed_types = set()
|
||||
two_factor_types = {"keyboard-interactive", "password"}
|
||||
|
||||
if self._totp:
|
||||
handler = make_handler(password, self._totp)
|
||||
|
||||
if pkey is not None:
|
||||
logging.info('Trying public key authentication')
|
||||
try:
|
||||
allowed_types = set(
|
||||
self._transport.auth_publickey(username, pkey)
|
||||
)
|
||||
two_factor = allowed_types & two_factor_types
|
||||
if not two_factor:
|
||||
if password is not None:
|
||||
logging.info('Trying password authentication')
|
||||
try:
|
||||
self._transport.auth_password(username, password)
|
||||
return
|
||||
except paramiko.SSHException as e:
|
||||
saved_exception = e
|
||||
except paramiko.SSHException as e:
|
||||
saved_exception = e
|
||||
allowed_types = set(getattr(e, 'allowed_types', []))
|
||||
two_factor = allowed_types & two_factor_types
|
||||
|
||||
if two_factor:
|
||||
logging.info('Trying publickey 2fa')
|
||||
return auth_interactive(self._transport, username, handler)
|
||||
if two_factor:
|
||||
logging.info('Trying password 2fa')
|
||||
return self.auth_interactive(username, self.handler)
|
||||
|
||||
if password is not None:
|
||||
logging.info('Trying password authentication')
|
||||
try:
|
||||
self._transport.auth_password(username, password)
|
||||
return
|
||||
except paramiko.SSHException as e:
|
||||
saved_exception = e
|
||||
allowed_types = set(getattr(e, 'allowed_types', []))
|
||||
two_factor = allowed_types & two_factor_types
|
||||
|
||||
if two_factor:
|
||||
logging.info('Trying password 2fa')
|
||||
return auth_interactive(self._transport, username, handler)
|
||||
|
||||
# if we got an auth-failed exception earlier, re-raise it
|
||||
if saved_exception is not None:
|
||||
assert saved_exception is not None
|
||||
raise saved_exception
|
||||
raise paramiko.SSHException("No authentication methods available")
|
||||
|
||||
|
||||
paramiko.client.SSHClient._auth = auth
|
||||
|
||||
|
||||
class InvalidValueError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class PrivateKey(object):
|
||||
|
@ -327,7 +315,7 @@ class IndexHandler(MixinHandler, tornado.web.RequestHandler):
|
|||
super(IndexHandler, self).write_error(status_code, **kwargs)
|
||||
|
||||
def get_ssh_client(self):
|
||||
ssh = paramiko.SSHClient()
|
||||
ssh = SSHClient()
|
||||
ssh._system_host_keys = self.host_keys_settings['system_host_keys']
|
||||
ssh._host_keys = self.host_keys_settings['host_keys']
|
||||
ssh._host_keys_filename = self.host_keys_settings['host_keys_filename']
|
||||
|
@ -392,7 +380,7 @@ class IndexHandler(MixinHandler, tornado.web.RequestHandler):
|
|||
else:
|
||||
pkey = None
|
||||
|
||||
self.ssh_client._totp = totp
|
||||
self.ssh_client.totp = totp
|
||||
args = (hostname, port, username, password, pkey)
|
||||
logging.debug(args)
|
||||
|
||||
|
|
Loading…
Reference in New Issue