webssh/main.py

401 lines
12 KiB
Python
Raw Normal View History

2017-11-09 02:58:28 +00:00
import io
2017-11-08 14:33:05 +00:00
import logging
import os.path
import socket
import threading
2017-11-09 03:23:19 +00:00
import traceback
2017-11-08 14:33:05 +00:00
import uuid
2017-11-09 02:58:28 +00:00
import weakref
2017-11-08 14:33:05 +00:00
import paramiko
import tornado.gen
2017-11-08 14:33:05 +00:00
import tornado.web
import tornado.websocket
from tornado.ioloop import IOLoop
2017-11-11 14:43:33 +00:00
from tornado.iostream import _ERRNO_CONNRESET
2017-11-08 14:33:05 +00:00
from tornado.options import define, options, parse_command_line
2017-11-11 14:43:33 +00:00
from tornado.util import errno_from_exception
2017-11-08 14:33:05 +00:00
2018-04-05 05:50:04 +00:00
try:
from concurrent.futures import Future
2018-04-08 13:52:54 +00:00
except ImportError:
2018-04-05 05:50:04 +00:00
from tornado.concurrent import Future
2017-11-08 14:33:05 +00:00
define('address', default='127.0.0.1', help='listen address')
define('port', default=8888, help='listen port', type=int)
2018-02-26 23:09:19 +00:00
define('debug', default=False, help='debug mode', type=bool)
define('policy', default='warning',
2018-03-14 14:15:36 +00:00
help='missing host key policy, reject|autoadd|warning')
2017-11-08 14:33:05 +00:00
BUF_SIZE = 1024
DELAY = 3
workers = {}
class AutoAddPolicy(paramiko.client.MissingHostKeyPolicy):
"""
thread-safe AutoAddPolicy
"""
lock = threading.Lock()
def missing_host_key(self, client, hostname, key):
with self.lock:
keytype = key.get_name()
logging.info(
'Adding {} host key for {}'.format(keytype, hostname)
)
client._host_keys.add(hostname, keytype, key)
with open(client._host_keys_filename, 'a') as f:
f.write('{} {} {}\n'.format(
hostname, keytype, key.get_base64()
))
paramiko.client.AutoAddPolicy = AutoAddPolicy
2017-11-08 14:33:05 +00:00
class Worker(object):
2018-04-05 05:50:04 +00:00
def __init__(self, loop, ssh, chan, dst_addr):
self.loop = loop
2017-11-08 14:33:05 +00:00
self.ssh = ssh
self.chan = chan
self.dst_addr = dst_addr
self.fd = chan.fileno()
self.id = str(id(self))
self.data_to_dst = []
self.handler = None
self.mode = IOLoop.READ
2017-11-08 14:33:05 +00:00
def __call__(self, fd, events):
if events & IOLoop.READ:
self.on_read()
if events & IOLoop.WRITE:
self.on_write()
if events & IOLoop.ERROR:
self.close()
def set_handler(self, handler):
if not self.handler:
self.handler = handler
def update_handler(self, mode):
if self.mode != mode:
self.loop.update_handler(self.fd, mode)
self.mode = mode
2017-11-08 14:33:05 +00:00
def on_read(self):
logging.debug('worker {} on read'.format(self.id))
try:
2017-11-11 14:43:33 +00:00
data = self.chan.recv(BUF_SIZE)
except (OSError, IOError) as e:
logging.error(e)
if errno_from_exception(e) in _ERRNO_CONNRESET:
self.close()
else:
2018-04-09 01:35:43 +00:00
logging.debug('{!r} from {}:{}'.format(data, *self.dst_addr))
2017-11-11 14:43:33 +00:00
if not data:
self.close()
return
2018-04-09 01:35:43 +00:00
logging.debug('{!r} to {}:{}'.format(data, *self.handler.src_addr))
2017-11-11 14:43:33 +00:00
try:
self.handler.write_message(data)
except tornado.websocket.WebSocketClosedError:
self.close()
2017-11-08 14:33:05 +00:00
def on_write(self):
logging.debug('worker {} on write'.format(self.id))
if not self.data_to_dst:
return
2017-11-11 14:43:33 +00:00
2017-11-08 14:33:05 +00:00
data = ''.join(self.data_to_dst)
2018-04-09 01:35:43 +00:00
logging.debug('{!r} to {}:{}'.format(data, *self.dst_addr))
2017-11-11 14:43:33 +00:00
2017-11-08 14:33:05 +00:00
try:
sent = self.chan.send(data)
2017-11-11 14:43:33 +00:00
except (OSError, IOError) as e:
2017-11-08 14:33:05 +00:00
logging.error(e)
2017-11-11 14:43:33 +00:00
if errno_from_exception(e) in _ERRNO_CONNRESET:
self.close()
else:
self.update_handler(IOLoop.WRITE)
2017-11-08 14:33:05 +00:00
else:
2017-11-11 14:43:33 +00:00
self.data_to_dst = []
2017-11-08 14:33:05 +00:00
data = data[sent:]
if data:
self.data_to_dst.append(data)
self.update_handler(IOLoop.WRITE)
2017-11-10 09:49:26 +00:00
else:
self.update_handler(IOLoop.READ)
2017-11-08 14:33:05 +00:00
def close(self):
logging.debug('Closing worker {}'.format(self.id))
if self.handler:
self.loop.remove_handler(self.fd)
self.handler.close()
self.chan.close()
self.ssh.close()
2018-03-06 01:34:55 +00:00
logging.info('Connection to {}:{} lost'.format(*self.dst_addr))
2017-11-08 14:33:05 +00:00
2018-03-06 01:34:55 +00:00
class MixinHandler(object):
2018-03-08 08:41:14 +00:00
def __init__(self, *args, **kwargs):
2018-04-05 05:50:04 +00:00
self.loop = args[0]._loop
2018-03-08 08:41:14 +00:00
super(MixinHandler, self).__init__(*args, **kwargs)
def get_client_addr(self):
2018-03-06 01:34:55 +00:00
ip = self.request.headers.get('X-Real-Ip')
port = self.request.headers.get('X-Real-Port')
2018-03-08 08:41:14 +00:00
addr = None
2018-03-06 01:34:55 +00:00
if ip and port:
addr = (ip, int(port))
2018-03-08 08:41:14 +00:00
elif ip or port:
logging.warn('Wrong nginx configuration.')
2018-03-06 01:34:55 +00:00
return addr
class IndexHandler(MixinHandler, tornado.web.RequestHandler):
2018-03-08 08:41:14 +00:00
2017-11-08 14:33:05 +00:00
def get_privatekey(self):
try:
data = self.request.files.get('privatekey')[0]['body']
2017-11-08 14:33:05 +00:00
except TypeError:
return
return data.decode('utf-8')
def get_specific_pkey(self, pkeycls, privatekey, password):
logging.info('Trying {}'.format(pkeycls.__name__))
try:
pkey = pkeycls.from_private_key(io.StringIO(privatekey),
password=password)
except paramiko.PasswordRequiredException:
raise ValueError('Need password to decrypt the private key.')
except paramiko.SSHException:
2017-11-08 14:33:05 +00:00
pass
else:
return pkey
2017-11-08 14:33:05 +00:00
def get_pkey(self, privatekey, password):
password = password.encode('utf-8') if password else None
2017-11-08 14:33:05 +00:00
pkey = self.get_specific_pkey(paramiko.RSAKey, privatekey, password)\
or self.get_specific_pkey(paramiko.DSSKey, privatekey, password)\
or self.get_specific_pkey(paramiko.ECDSAKey, privatekey, password)\
or self.get_specific_pkey(paramiko.Ed25519Key, privatekey,
password)
if not pkey:
2017-12-08 07:50:37 +00:00
raise ValueError('Not a valid private key file or '
'wrong password for decrypting the private key.')
2017-11-08 14:33:05 +00:00
return pkey
def get_port(self):
value = self.get_value('port')
try:
port = int(value)
except ValueError:
port = 0
if 0 < port < 65536:
return port
2017-11-20 10:01:40 +00:00
raise ValueError('Invalid port {}'.format(value))
2017-11-08 14:33:05 +00:00
def get_value(self, name):
value = self.get_argument(name)
if not value:
2017-11-20 10:01:40 +00:00
raise ValueError('Empty {}'.format(name))
2017-11-08 14:33:05 +00:00
return value
def get_args(self):
hostname = self.get_value('hostname')
port = self.get_port()
username = self.get_value('username')
password = self.get_argument('password')
privatekey = self.get_privatekey()
pkey = self.get_pkey(privatekey, password) if privatekey else None
args = (hostname, port, username, password, pkey)
logging.debug(args)
return args
2018-03-08 08:41:14 +00:00
def get_client_addr(self):
return super(IndexHandler, self).get_client_addr() or self.request.\
connection.stream.socket.getpeername()
2017-11-08 14:33:05 +00:00
def ssh_connect(self):
ssh = paramiko.SSHClient()
ssh._system_host_keys = self.settings['system_host_keys']
ssh._host_keys = self.settings['host_keys']
ssh._host_keys_filename = self.settings['host_keys_filename']
2018-03-14 14:09:17 +00:00
ssh.set_missing_host_key_policy(self.settings['policy'])
2018-03-20 23:38:48 +00:00
2017-11-08 14:33:05 +00:00
args = self.get_args()
2018-03-06 01:34:55 +00:00
dst_addr = (args[0], args[1])
logging.info('Connecting to {}:{}'.format(*dst_addr))
2018-03-20 23:38:48 +00:00
try:
ssh.connect(*args, timeout=6)
except socket.error:
2018-03-06 01:34:55 +00:00
raise ValueError('Unable to connect to {}:{}'.format(*dst_addr))
except paramiko.BadAuthenticationType:
raise ValueError('Authentication failed.')
2018-03-14 14:09:17 +00:00
except paramiko.BadHostKeyException:
raise ValueError('Bad host key.')
2018-03-20 23:38:48 +00:00
2017-11-08 14:33:05 +00:00
chan = ssh.invoke_shell(term='xterm')
chan.setblocking(0)
2018-04-05 05:50:04 +00:00
worker = Worker(self.loop, ssh, chan, dst_addr)
2018-04-03 11:17:59 +00:00
worker.src_addr = self.get_client_addr()
2017-11-08 14:33:05 +00:00
return worker
def ssh_connect_wrapped(self, future):
try:
worker = self.ssh_connect()
except Exception as exc:
2018-04-02 13:50:05 +00:00
logging.error(traceback.format_exc())
future.set_exception(exc)
else:
future.set_result(worker)
2017-11-08 14:33:05 +00:00
def get(self):
self.render('index.html')
@tornado.gen.coroutine
2017-11-08 14:33:05 +00:00
def post(self):
worker_id = None
status = None
future = Future()
2018-04-03 11:17:59 +00:00
t = threading.Thread(target=self.ssh_connect_wrapped, args=(future,))
t.setDaemon(True)
t.start()
2018-04-02 13:50:05 +00:00
2017-11-08 14:33:05 +00:00
try:
2018-04-02 13:50:05 +00:00
worker = yield future
except Exception as exc:
status = str(exc)
2017-11-08 14:33:05 +00:00
else:
worker_id = worker.id
workers[worker_id] = worker
2018-04-05 06:10:31 +00:00
self.loop.call_later(DELAY, recycle, worker)
2017-11-08 14:33:05 +00:00
self.write(dict(id=worker_id, status=status))
2018-03-06 01:34:55 +00:00
class WsockHandler(MixinHandler, tornado.websocket.WebSocketHandler):
2017-11-08 14:33:05 +00:00
def __init__(self, *args, **kwargs):
self.worker_ref = None
2018-03-08 08:41:14 +00:00
super(WsockHandler, self).__init__(*args, **kwargs)
def get_client_addr(self):
return super(WsockHandler, self).get_client_addr() or self.stream.\
socket.getpeername()
2017-11-08 14:33:05 +00:00
def open(self):
2018-03-08 08:41:14 +00:00
self.src_addr = self.get_client_addr()
2018-03-06 01:34:55 +00:00
logging.info('Connected from {}:{}'.format(*self.src_addr))
2018-03-14 22:50:46 +00:00
worker = workers.get(self.get_argument('id'))
if worker and worker.src_addr[0] == self.src_addr[0]:
workers.pop(worker.id)
self.set_nodelay(True)
worker.set_handler(self)
self.worker_ref = weakref.ref(worker)
self.loop.add_handler(worker.fd, worker, IOLoop.READ)
else:
self.close()
2017-11-08 14:33:05 +00:00
def on_message(self, message):
2018-04-09 01:35:43 +00:00
logging.debug('{!r} from {}:{}'.format(message, *self.src_addr))
2017-11-08 14:33:05 +00:00
worker = self.worker_ref()
worker.data_to_dst.append(message)
worker.on_write()
def on_close(self):
2018-03-06 01:34:55 +00:00
logging.info('Disconnected from {}:{}'.format(*self.src_addr))
2017-11-08 14:33:05 +00:00
worker = self.worker_ref() if self.worker_ref else None
if worker:
worker.close()
2018-03-14 14:09:17 +00:00
def recycle(worker):
if worker.handler:
return
logging.debug('Recycling worker {}'.format(worker.id))
workers.pop(worker.id, None)
worker.close()
def get_host_keys(path):
if os.path.exists(path) and os.path.isfile(path):
return paramiko.hostkeys.HostKeys(filename=path)
2018-03-14 17:09:51 +00:00
return paramiko.hostkeys.HostKeys()
2018-03-14 14:09:17 +00:00
def get_policy_class(policy):
origin_policy = policy
policy = policy.lower()
if not policy.endswith('policy'):
policy += 'policy'
2018-03-14 22:50:46 +00:00
dic = {k.lower(): v for k, v in vars(paramiko.client).items() if type(v)
is type and issubclass(v, paramiko.client.MissingHostKeyPolicy)}
2018-03-14 14:09:17 +00:00
try:
cls = dic[policy]
except KeyError:
raise ValueError('Unknown policy {!r}'.format(origin_policy))
return cls
2018-03-20 23:38:48 +00:00
def get_application_settings():
2018-03-14 14:09:17 +00:00
base_dir = os.path.dirname(__file__)
filename = os.path.join(base_dir, 'known_hosts')
host_keys = get_host_keys(filename)
system_host_keys = get_host_keys(os.path.expanduser('~/.ssh/known_hosts'))
policy_class = get_policy_class(options.policy)
2018-03-21 00:07:34 +00:00
logging.info(policy_class.__name__)
if policy_class is paramiko.client.AutoAddPolicy:
2018-03-16 05:11:32 +00:00
host_keys.save(filename) # for permission test
elif policy_class is paramiko.client.RejectPolicy:
if not host_keys and not system_host_keys:
raise ValueError('Empty known_hosts with reject policy?')
2018-03-14 14:09:17 +00:00
2018-03-20 23:38:48 +00:00
settings = dict(
template_path=os.path.join(base_dir, 'templates'),
static_path=os.path.join(base_dir, 'static'),
cookie_secret=uuid.uuid4().hex,
xsrf_cookies=True,
host_keys=host_keys,
host_keys_filename=filename,
2018-03-20 23:38:48 +00:00
system_host_keys=system_host_keys,
policy=policy_class(),
debug=options.debug
)
return settings
def main():
parse_command_line()
settings = get_application_settings()
2017-11-08 14:33:05 +00:00
handlers = [
(r'/', IndexHandler),
(r'/ws', WsockHandler)
]
2018-04-05 05:50:04 +00:00
loop = IOLoop.current()
2017-11-08 14:33:05 +00:00
app = tornado.web.Application(handlers, **settings)
2018-04-05 05:50:04 +00:00
app._loop = loop
2017-11-08 14:33:05 +00:00
app.listen(options.port, options.address)
logging.info('Listening on {}:{}'.format(options.address, options.port))
2018-04-05 05:50:04 +00:00
loop.start()
2017-11-08 14:33:05 +00:00
if __name__ == '__main__':
main()