Added a MixinHandler

pull/12/head
Sheng 2018-03-06 09:34:55 +08:00
parent d07eb5b910
commit 311fcfedc9
1 changed files with 36 additions and 21 deletions

57
main.py
View File

@ -71,12 +71,12 @@ class Worker(object):
if errno_from_exception(e) in _ERRNO_CONNRESET: if errno_from_exception(e) in _ERRNO_CONNRESET:
self.close() self.close()
else: else:
logging.debug('"{}" from {}'.format(data, self.dst_addr)) logging.debug('"{}" from {}:{}'.format(data, *self.dst_addr))
if not data: if not data:
self.close() self.close()
return return
logging.debug('"{}" to {}'.format(data, self.handler.src_addr)) logging.debug('"{}" to {}:{}'.format(data, *self.handler.src_addr))
try: try:
self.handler.write_message(data) self.handler.write_message(data)
except tornado.websocket.WebSocketClosedError: except tornado.websocket.WebSocketClosedError:
@ -88,7 +88,7 @@ class Worker(object):
return return
data = ''.join(self.data_to_dst) data = ''.join(self.data_to_dst)
logging.debug('"{}" to {}'.format(data, self.dst_addr)) logging.debug('"{}" to {}:{}'.format(data, *self.dst_addr))
try: try:
sent = self.chan.send(data) sent = self.chan.send(data)
@ -114,10 +114,28 @@ class Worker(object):
self.handler.close() self.handler.close()
self.chan.close() self.chan.close()
self.ssh.close() self.ssh.close()
logging.info('Connection to {} lost'.format(self.dst_addr)) logging.info('Connection to {}:{} lost'.format(*self.dst_addr))
class IndexHandler(tornado.web.RequestHandler): class MixinHandler(object):
def get_addr(self):
ip = self.request.headers.get('X-Real-Ip')
port = self.request.headers.get('X-Real-Port')
if ip and port:
addr = (ip, int(port))
elif not ip and not port:
if not getattr(self, 'stream', None):
self.stream = self.request.connection.stream
addr = self.stream.socket.getpeername()
else:
raise ValueError('Wrong nginx configuration.')
return addr
class IndexHandler(MixinHandler, tornado.web.RequestHandler):
def get_privatekey(self): def get_privatekey(self):
try: try:
data = self.request.files.get('privatekey')[0]['body'] data = self.request.files.get('privatekey')[0]['body']
@ -184,12 +202,12 @@ class IndexHandler(tornado.web.RequestHandler):
ssh.load_system_host_keys() ssh.load_system_host_keys()
ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy()) ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
args = self.get_args() args = self.get_args()
dst_addr = '{}:{}'.format(*args[:2]) dst_addr = (args[0], args[1])
logging.info('Connecting to {}'.format(dst_addr)) logging.info('Connecting to {}:{}'.format(*dst_addr))
try: try:
ssh.connect(*args, timeout=6) ssh.connect(*args, timeout=6)
except socket.error: except socket.error:
raise ValueError('Unable to connect to {}'.format(dst_addr)) raise ValueError('Unable to connect to {}:{}'.format(*dst_addr))
except paramiko.BadAuthenticationType: except paramiko.BadAuthenticationType:
raise ValueError('Authentication failed.') raise ValueError('Authentication failed.')
chan = ssh.invoke_shell(term='xterm') chan = ssh.invoke_shell(term='xterm')
@ -211,47 +229,44 @@ class IndexHandler(tornado.web.RequestHandler):
logging.error(traceback.format_exc()) logging.error(traceback.format_exc())
status = str(e) status = str(e)
else: else:
worker.src_addr = self.get_addr()
worker_id = worker.id worker_id = worker.id
workers[worker_id] = worker workers[worker_id] = worker
self.write(dict(id=worker_id, status=status)) self.write(dict(id=worker_id, status=status))
class WsockHandler(tornado.websocket.WebSocketHandler): class WsockHandler(MixinHandler, tornado.websocket.WebSocketHandler):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
self.loop = IOLoop.current() self.loop = IOLoop.current()
self.worker_ref = None self.worker_ref = None
super(self.__class__, self).__init__(*args, **kwargs) super(self.__class__, self).__init__(*args, **kwargs)
def get_addr(self):
ip = self.request.headers.get_list('X-Real-Ip')
port = self.request.headers.get_list('X-Real-Port')
addr = ':'.join(ip + port)
if not addr:
addr = '{}:{}'.format(*self.stream.socket.getpeername())
return addr
def open(self): def open(self):
self.src_addr = self.get_addr() self.src_addr = self.get_addr()
logging.info('Connected from {}'.format(self.src_addr)) logging.info('Connected from {}:{}'.format(*self.src_addr))
worker = workers.pop(self.get_argument('id'), None) worker = workers.pop(self.get_argument('id'), None)
if not worker: if not worker:
self.close(reason='Invalid worker id') self.close(reason='Invalid worker id.')
return return
if self.src_addr[0] != worker.src_addr[0]:
self.close(reason='Invalid client addr.')
return
self.set_nodelay(True) self.set_nodelay(True)
worker.set_handler(self) worker.set_handler(self)
self.worker_ref = weakref.ref(worker) self.worker_ref = weakref.ref(worker)
self.loop.add_handler(worker.fd, worker, IOLoop.READ) self.loop.add_handler(worker.fd, worker, IOLoop.READ)
def on_message(self, message): def on_message(self, message):
logging.debug('"{}" from {}'.format(message, self.src_addr)) logging.debug('"{}" from {}:{}'.format(message, *self.src_addr))
worker = self.worker_ref() worker = self.worker_ref()
worker.data_to_dst.append(message) worker.data_to_dst.append(message)
worker.on_write() worker.on_write()
def on_close(self): def on_close(self):
logging.info('Disconnected from {}'.format(self.src_addr)) logging.info('Disconnected from {}:{}'.format(*self.src_addr))
worker = self.worker_ref() if self.worker_ref else None worker = self.worker_ref() if self.worker_ref else None
if worker: if worker:
worker.close() worker.close()