From ebb30424fa88fc3b9a7c4be5790322ed279e7471 Mon Sep 17 00:00:00 2001 From: ibuler Date: Sun, 25 Sep 2016 19:53:55 +0800 Subject: [PATCH] Use process except thread --- terminal/ssh_config.py | 105 ++++++++++++++ ...onfig_example.py => ssh_config_example.py} | 55 +++---- terminal/ssh_server.py | 134 +++++++++++------- terminal/utils.py | 9 ++ terminal/{web_server.py => web_ssh_server.py} | 0 5 files changed, 224 insertions(+), 79 deletions(-) create mode 100644 terminal/ssh_config.py rename terminal/{config_example.py => ssh_config_example.py} (66%) rename terminal/{web_server.py => web_ssh_server.py} (100%) diff --git a/terminal/ssh_config.py b/terminal/ssh_config.py new file mode 100644 index 000000000..52a0424c1 --- /dev/null +++ b/terminal/ssh_config.py @@ -0,0 +1,105 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# + +import logging +import os + + +BASE_DIR = os.path.dirname(os.path.abspath(__name__)) + + +class Config: + SSH_HOST = '' + SSH_PORT = 2200 + LOG_LEVEL = 'INFO' + LOG_DIR = os.path.join(BASE_DIR, 'logs') + LOG_FILENAME = 'ssh_server.log' + LOGGING = { + 'version': 1, + 'disable_existing_loggers': False, + 'formatters': { + 'verbose': { + 'format': '%(levelname)s %(asctime)s %(module)s %(process)d %(thread)d %(message)s' + }, + 'main': { + 'datefmt': '%Y-%m-%d %H:%M:%S', + 'format': '%(asctime)s [%(module)s %(levelname)s] %(message)s', + }, + 'simple': { + 'format': '%(levelname)s %(message)s' + }, + }, + 'handlers': { + 'null': { + 'level': 'DEBUG', + 'class': 'logging.NullHandler', + }, + 'console': { + 'level': 'DEBUG', + 'class': 'logging.StreamHandler', + 'formatter': 'main', + 'stream': 'ext://sys.stdout', + }, + 'file': { + 'level': 'DEBUG', + 'class': 'logging.FileHandler', + 'formatter': 'main', + 'mode': 'a', + 'filename': os.path.join(LOG_DIR, LOG_FILENAME), + }, + }, + 'loggers': { + 'jumpserver': { + 'handlers': ['console', 'file'], + # 'level': LOG_LEVEL_CHOICES.get(LOG_LEVEL, None) or LOG_LEVEL_CHOICES.get('info') + 'level': LOG_LEVEL, + 'propagate': True, + }, + 'jumpserver.web_ssh_server': { + 'handlers': ['console', 'file'], + # 'level': LOG_LEVEL_CHOICES.get(LOG_LEVEL, None) or LOG_LEVEL_CHOICES.get('info') + 'level': LOG_LEVEL, + 'propagate': True, + }, + 'jumpserver.ssh_server': { + 'handlers': ['console', 'file'], + # 'level': LOG_LEVEL_CHOICES.get(LOG_LEVEL, None) or LOG_LEVEL_CHOICES.get('info') + 'level': LOG_LEVEL, + 'propagate': True, + } + } + } + + def __init__(self): + pass + + def __getattr__(self, item): + return None + + +class DevelopmentConfig(Config): + pass + + +class ProductionConfig(Config): + pass + + +class TestingConfig(Config): + pass + + +config = { + 'development': DevelopmentConfig, + 'production': ProductionConfig, + 'testing': TestingConfig, + 'default': DevelopmentConfig, +} + +env = 'default' + + +if __name__ == '__main__': + pass + diff --git a/terminal/config_example.py b/terminal/ssh_config_example.py similarity index 66% rename from terminal/config_example.py rename to terminal/ssh_config_example.py index 9cde69b99..238304fcd 100644 --- a/terminal/config_example.py +++ b/terminal/ssh_config_example.py @@ -7,20 +7,14 @@ import os BASE_DIR = os.path.dirname(os.path.abspath(__name__)) -LOG_LEVEL_CHOICES = { - 'debug': logging.DEBUG, - 'info': logging.INFO, - 'warning': logging.WARNING, - 'error': logging.ERROR, - 'critical': logging.CRITICAL -} class Config: - LOG_LEVEL = '' + LOG_LEVEL = 'INFO' LOG_DIR = os.path.join(BASE_DIR, 'logs') LOGGING = { 'version': 1, + 'disable_existing_loggers': False, 'formatters': { 'verbose': { 'format': '%(levelname)s %(asctime)s %(module)s %(process)d %(thread)d %(message)s' @@ -47,35 +41,23 @@ class Config: 'level': 'DEBUG', 'class': 'logging.FileHandler', 'formatter': 'main', - 'filename': os.path.join(PROJECT_DIR, 'logs', 'jumpserver.log') + 'filename': LOG_DIR, }, }, 'loggers': { - 'django': { - 'handlers': ['null'], - 'propagate': False, - 'level': LOG_LEVEL, - }, - 'django.request': { - 'handlers': ['console', 'file'], - 'level': LOG_LEVEL, - 'propagate': False, - }, - 'django.server': { - 'handlers': ['console', 'file'], - 'level': LOG_LEVEL, - 'propagate': False, - }, 'jumpserver': { 'handlers': ['console', 'file'], + # 'level': LOG_LEVEL_CHOICES.get(LOG_LEVEL, None) or LOG_LEVEL_CHOICES.get('info') 'level': LOG_LEVEL, }, - 'jumpserver.users.api': { + 'jumpserver.web_ssh_server': { 'handlers': ['console', 'file'], + # 'level': LOG_LEVEL_CHOICES.get(LOG_LEVEL, None) or LOG_LEVEL_CHOICES.get('info') 'level': LOG_LEVEL, }, - 'jumpserver.users.view': { + 'jumpserver.ssh_server': { 'handlers': ['console', 'file'], + # 'level': LOG_LEVEL_CHOICES.get(LOG_LEVEL, None) or LOG_LEVEL_CHOICES.get('info') 'level': LOG_LEVEL, } } @@ -88,6 +70,27 @@ class Config: return None +class DevelopmentConfig(Config): + pass + + +class ProductionConfig(Config): + pass + + +class TestingConfig(Config): + pass + + +config = { + 'development': DevelopmentConfig, + 'production': ProductionConfig, + 'testing': TestingConfig, + 'default': DevelopmentConfig, +} + +env = 'default' + if __name__ == '__main__': pass diff --git a/terminal/ssh_server.py b/terminal/ssh_server.py index 7d6a6dee3..94a0b70e8 100644 --- a/terminal/ssh_server.py +++ b/terminal/ssh_server.py @@ -7,6 +7,7 @@ import base64 from binascii import hexlify import sys import threading +from multiprocessing import process import traceback import tty import termios @@ -31,17 +32,21 @@ except IndexError: pass from django.conf import settings -from common.utils import get_logger from users.utils import ssh_key_gen, check_user_is_valid +from utils import get_logger + logger = get_logger(__name__) class SSHServerInterface(paramiko.ServerInterface): host_key_path = os.path.join(BASE_DIR, 'host_rsa_key') + channel_pools = [] - def __init__(self): + def __init__(self, client, addr): self.event = threading.Event() + self.client = client + self.addr = addr self.user = None @classmethod @@ -70,19 +75,35 @@ class SSHServerInterface(paramiko.ServerInterface): def check_auth_password(self, username, password): self.user = check_user_is_valid(username=username, password=password) if self.user: - logger.info('User: %s password auth passed' % username) + logger.info('Accepted password for %(user)s from %(host)s port %(port)s ' % { + 'user': username, + 'host': self.addr[0], + 'port': self.addr[1], + }) return paramiko.AUTH_SUCCESSFUL else: - logger.warning('User: %s password auth failed' % username) + logger.info('Authentication password failed for %(user)s from %(host)s port %(port)s ' % { + 'user': username, + 'host': self.addr[0], + 'port': self.addr[1], + }) return paramiko.AUTH_FAILED def check_auth_publickey(self, username, public_key): self.user = check_user_is_valid(username=username, public_key=public_key) if self.user: - logger.info('User: %s public key auth passed' % username) + logger.info('Accepted public key for %(user)s from %(host)s port %(port)s ' % { + 'user': username, + 'host': self.addr[0], + 'port': self.addr[1], + }) return paramiko.AUTH_SUCCESSFUL else: - logger.warning('User: %s public key auth failed' % username) + logger.info('Authentication public key failed for %(user)s from %(host)s port %(port)s ' % { + 'user': username, + 'host': self.addr[0], + 'port': self.addr[1], + }) return paramiko.AUTH_FAILED def get_allowed_auths(self, username): @@ -95,12 +116,20 @@ class SSHServerInterface(paramiko.ServerInterface): def check_channel_shell_request(self, channel): self.event.set() + self.__class__.channel_pools.append(channel) return True def check_channel_pty_request(self, channel, term, width, height, pixelwidth, pixelheight, modes): return True + def check_channel_window_change_request(self, channel, width, height, pixelwidth, pixelheight): + logger.info('Change window size %s * %s' % (width, height)) + logger.info('Change length %s ' % len(self.__class__.channel_pools)) + # for channel in self.__class__.channel_pools: + # channel.send("Hello world") + return True + class SSHServer: def __init__(self, host='127.0.0.1', port=2200): @@ -110,18 +139,22 @@ class SSHServer: self.sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) self.sock.bind((self.host, self.port)) self.server_ssh = None - self.server_chan = None + self.server_channel = None + self.client_channel = None def connect(self): ssh = paramiko.SSHClient() ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy()) ssh.connect(hostname='127.0.0.1', port=22, username='root', password='redhat') self.server_ssh = ssh - self.server_chan = channel = ssh.invoke_shell(term='xterm') + self.server_channel = channel = ssh.invoke_shell(term='xterm') return channel def handle_ssh_request(self, client, addr): - logger.info("Get connection from " + str(addr)) + logger.info("Get connection from %(host)s:%(port)s" % { + 'host': addr[0], + 'port': addr[1], + }) try: transport = paramiko.Transport(client, gss_kex=False) transport.set_gss_host(socket.getfqdn("")) @@ -132,70 +165,63 @@ class SSHServer: raise transport.add_server_key(SSHServerInterface.get_host_key()) - ssh_interface = SSHServerInterface() + ssh_interface = SSHServerInterface(client, addr) try: transport.start_server(server=ssh_interface) except paramiko.SSHException: print('*** SSH negotiation failed.') return - channel = transport.accept(20) - if channel is None: + self.client_channel = client_channel = transport.accept(20) + if client_channel is None: print('*** No channel.') return print('Authenticated!') - channel.settimeout(100) + client_channel.settimeout(100) - channel.send('\r\n\r\nWelcome to my dorky little BBS!\r\n\r\n') - channel.send('We are on fire all the time! Hooray! Candy corn for everyone!\r\n') - channel.send('Happy birthday to Robot Dave!\r\n\r\n') + client_channel.send('\r\n\r\nWelcome to my dorky little BBS!\r\n\r\n') + client_channel.send('We are on fire all the time! Hooray! Candy corn for everyone!\r\n') + client_channel.send('Happy birthday to Robot Dave!\r\n\r\n') server_channel = self.connect() if not ssh_interface.event.is_set(): print('*** Client never asked for a shell.') return - server_data = [] - input_mode = True + while True: - r, w, e = select.select([server_channel, channel], [], []) + r, w, x = select.select([client_channel, server_channel], [], []) - if channel in r: - recv_data = channel.recv(1024).decode('utf8') - # print("From client: " + repr(recv_data)) - if len(recv_data) == 0: + if client_channel in r: + data_client = client_channel.recv(1024) + logger.info(data_client) + if len(data_client) == 0: break - server_channel.send(recv_data) + # client_channel.send(data_client) + server_channel.send(data_client) if server_channel in r: - recv_data = server_channel.recv(1024).decode('utf8') - # print("From server: " + repr(recv_data)) - if len(recv_data) == 0: + data_server = server_channel.recv(1024) + if len(data_server) == 0: break - channel.send(recv_data) - if len(recv_data) > 20: - server_data.append('...') - else: - server_data.append(recv_data) - try: - if repr(server_data[-2]) == u'\r\n': - result = server_data.pop() - server_data.pop() - command = ''.join(server_data) - server_data = [] - print(">>> Command: %s" % command) - print(result) - except IndexError: - pass - print(server_data) - - except Exception as e: - print('*** Caught exception: ' + str(e.__class__) + ': ' + str(e)) - traceback.print_exc() - try: - transport.close() - except: - pass - sys.exit(1) + client_channel.send(data_server) + + # if len(recv_data) > 20: + # server_data.append('...') + # else: + # server_data.append(recv_data) + # try: + # if repr(server_data[-2]) == u'\r\n': + # result = server_data.pop() + # server_data.pop() + # command = ''.join(server_data) + # server_data = [] + # except IndexError: + # pass + + except Exception: + client_channel.close() + server_channel.close() + logger.info('Close with server %s from %s' % ('127.0.0.1', '127.0.0.1')) def listen(self): self.sock.listen(5) @@ -204,7 +230,9 @@ class SSHServer: try: client, addr = self.sock.accept() print('Listening for connection ...') - t = threading.Thread(target=self.handle_ssh_request, args=(client, addr)) + # t = threading.Thread(target=self.handle_ssh_request, args=(client, addr)) + t = process.Process(target=self.handle_ssh_request, args=(client, addr)) + t.daemon = True t.start() except Exception as e: diff --git a/terminal/utils.py b/terminal/utils.py index f5c5d234f..1aba6ba02 100644 --- a/terminal/utils.py +++ b/terminal/utils.py @@ -3,6 +3,15 @@ # import logging +from logging.config import dictConfig +from ssh_config import config, env +CONFIG_SSH_SERVER = config.get(env) + + +def get_logger(name): + dictConfig(CONFIG_SSH_SERVER.LOGGING) + return logging.getLogger('jumpserver.%s' % name) + diff --git a/terminal/web_server.py b/terminal/web_ssh_server.py similarity index 100% rename from terminal/web_server.py rename to terminal/web_ssh_server.py