Use ThreadPoolExecutor instead of Thread

pull/58/head
Sheng 2019-02-21 17:45:38 +08:00
parent 9fbd5d325f
commit d22b0cdfd8
1 changed files with 5 additions and 21 deletions

View File

@ -3,12 +3,12 @@ import json
import logging import logging
import socket import socket
import struct import struct
import threading
import traceback import traceback
import weakref import weakref
import paramiko import paramiko
import tornado.web import tornado.web
from concurrent.futures import ThreadPoolExecutor
from tornado.ioloop import IOLoop from tornado.ioloop import IOLoop
from tornado.options import options from tornado.options import options
from webssh.utils import ( from webssh.utils import (
@ -17,11 +17,6 @@ from webssh.utils import (
) )
from webssh.worker import Worker, recycle_worker, clients from webssh.worker import Worker, recycle_worker, clients
try:
from concurrent.futures import Future
except ImportError:
from tornado.concurrent import Future
try: try:
from json.decoder import JSONDecodeError from json.decoder import JSONDecodeError
except ImportError: except ImportError:
@ -173,6 +168,8 @@ class NotFoundHandler(MixinHandler, tornado.web.ErrorHandler):
class IndexHandler(MixinHandler, tornado.web.RequestHandler): class IndexHandler(MixinHandler, tornado.web.RequestHandler):
executor = ThreadPoolExecutor()
def initialize(self, loop, policy, host_keys_settings): def initialize(self, loop, policy, host_keys_settings):
super(IndexHandler, self).initialize(loop) super(IndexHandler, self).initialize(loop)
self.policy = policy self.policy = policy
@ -331,15 +328,6 @@ class IndexHandler(MixinHandler, tornado.web.RequestHandler):
worker.encoding = self.get_default_encoding(ssh) worker.encoding = self.get_default_encoding(ssh)
return worker return worker
def ssh_connect_wrapped(self, future, args):
try:
worker = self.ssh_connect(args)
except Exception as exc:
logging.error(traceback.format_exc())
future.set_exception(exc)
else:
future.set_result(worker)
def check_origin(self): def check_origin(self):
event_origin = self.get_argument('_origin', u'') event_origin = self.get_argument('_origin', u'')
header_origin = self.request.headers.get('Origin') header_origin = self.request.headers.get('Origin')
@ -377,16 +365,12 @@ class IndexHandler(MixinHandler, tornado.web.RequestHandler):
except InvalidValueError as exc: except InvalidValueError as exc:
raise tornado.web.HTTPError(400, str(exc)) raise tornado.web.HTTPError(400, str(exc))
future = Future() future = self.executor.submit(self.ssh_connect, args)
t = threading.Thread(
target=self.ssh_connect_wrapped, args=(future, args)
)
t.daemon = True
t.start()
try: try:
worker = yield future worker = yield future
except (ValueError, paramiko.SSHException) as exc: except (ValueError, paramiko.SSHException) as exc:
logging.error(traceback.format_exc())
self.result.update(status=str(exc)) self.result.update(status=str(exc))
else: else:
workers = clients.setdefault(worker.src_addr[0], {}) workers = clients.setdefault(worker.src_addr[0], {})