From 4b7419559c7b387bddf79ee540ec32dc0dd9feb5 Mon Sep 17 00:00:00 2001 From: ibuler Date: Sun, 25 Sep 2016 23:11:09 +0800 Subject: [PATCH] Update ssh_server to some class --- terminal/ssh_server.py | 196 ++++++++++++++++++++++++----------------- terminal/utils.py | 2 + 2 files changed, 119 insertions(+), 79 deletions(-) diff --git a/terminal/ssh_server.py b/terminal/ssh_server.py index c70152173..0b8261963 100644 --- a/terminal/ssh_server.py +++ b/terminal/ssh_server.py @@ -1,9 +1,13 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- -# +# + +__version__ = '0.3.3' + import sys import os import base64 +import time from binascii import hexlify import sys import threading @@ -19,7 +23,6 @@ import select import errno import paramiko import django -from paramiko.py3compat import b, u, decodebytes BASE_DIR = os.path.abspath(os.path.dirname(__file__)) APP_DIR = os.path.join(os.path.dirname(BASE_DIR), 'apps') @@ -33,13 +36,13 @@ except IndexError: from django.conf import settings from users.utils import ssh_key_gen, check_user_is_valid -from utils import get_logger +from utils import get_logger, SSHServerException logger = get_logger(__name__) -class SSHServerInterface(paramiko.ServerInterface): +class SSHServer(paramiko.ServerInterface): host_key_path = os.path.join(BASE_DIR, 'host_rsa_key') channel_pools = [] @@ -47,7 +50,10 @@ class SSHServerInterface(paramiko.ServerInterface): self.event = threading.Event() self.client = client self.addr = addr + self.username = None self.user = None + self.channel_width = None + self.channel_height = None @classmethod def host_key(cls): @@ -73,34 +79,37 @@ class SSHServerInterface(paramiko.ServerInterface): return paramiko.OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED def check_auth_password(self, username, password): - self.user = check_user_is_valid(username=username, password=password) + self.user = user = check_user_is_valid(username=username, password=password) + self.username = username = user.username if self.user: - logger.info('Accepted password for %(user)s from %(host)s port %(port)s ' % { - 'user': username, + logger.info('Accepted password for %(username)s from %(host)s port %(port)s ' % { + 'username': username, 'host': self.addr[0], 'port': self.addr[1], }) return paramiko.AUTH_SUCCESSFUL else: - logger.info('Authentication password failed for %(user)s from %(host)s port %(port)s ' % { - 'user': username, + logger.info('Authentication password failed for %(username)s from %(host)s port %(port)s ' % { + 'username': username, 'host': self.addr[0], 'port': self.addr[1], }) return paramiko.AUTH_FAILED def check_auth_publickey(self, username, public_key): - self.user = check_user_is_valid(username=username, public_key=public_key) + self.user = user = check_user_is_valid(username=username, public_key=public_key) + self.username = username = user.username + if self.user: - logger.info('Accepted public key for %(user)s from %(host)s port %(port)s ' % { - 'user': username, + logger.info('Accepted public key for %(username)s from %(host)s port %(port)s ' % { + 'username': username, 'host': self.addr[0], 'port': self.addr[1], }) return paramiko.AUTH_SUCCESSFUL else: - logger.info('Authentication public key failed for %(user)s from %(host)s port %(port)s ' % { - 'user': username, + logger.info('Authentication public key failed for %(username)s from %(host)s port %(port)s ' % { + 'username': username, 'host': self.addr[0], 'port': self.addr[1], }) @@ -135,12 +144,18 @@ class BackendServer: self.ssh = None self.channel = None - def connect(self, term='xterm', width=80, height=24): + def connect(self, term='xterm', width=80, height=24, timeout=10): self.ssh = ssh = paramiko.SSHClient() ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy()) ssh.connect(hostname=self.host, port=self.port, username=self.username, password=self.host_password, - pkey=self.host_private_key, look_for_keys=False, allow_agent=True, compress=True) + pkey=self.host_private_key, look_for_keys=False, allow_agent=True, compress=True, timeout=timeout) self.channel = channel = ssh.invoke_shell(term=term, width=width, height=height) + logger.info('Connect %(username)s@%(host)s:%(port)s successfully' % { + 'username': self.username, + 'host': self.host, + 'port': self.port, + }) + channel.settimeout(100) return channel @property @@ -149,90 +164,108 @@ class BackendServer: @property def host_private_key(self): - return 'redhat' + return None class Navigation: - def __init__(self, username): + def __init__(self, username, client_channel): self.username = username + self.client_channel = client_channel + + def display_banner(self): + client_channel = self.client_channel + client_channel.send('\r\n\r\n\t\tWelcome to use Jumpserver open source system !\r\n\r\n') + client_channel.send('If use find some bug please contact us \r\n') + # client_channel.send(self.username) def display(self): + self.display_banner() + + def return_to_connect(self): pass -class SSHServer: - def __init__(self, host='127.0.0.1', port=2200): - self.host = host - self.port = port - self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - self.sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - self.sock.bind((self.host, self.port)) - self.server_ssh = None - self.server_channel = None +class JumpServer: + def __init__(self): + self.listen_host = '0.0.0.0' + self.listen_port = 2222 + self.username = None + self.backend_host = None + self.backend_port = None + self.backend_username = None + self.backend_channel = None self.client_channel = None + self.sock = None - def invoke_with_backend(self): - pass + def display_navigation(self, username, client_channel): + nav = Navigation(username, client_channel) + nav.display() + return '127.0.0.1', 22, 'root' - def display_navigation(self): - pass + def get_client_channel(self, client, addr): + transport = paramiko.Transport(client, gss_kex=False) + transport.set_gss_host(socket.getfqdn("")) + try: + transport.load_server_moduli() + except: + logger.warning('Failed to load moduli -- gex will be unsupported.') + raise - def make_client_channel(self): - pass + transport.add_server_key(SSHServer.get_host_key()) + ssh_server = SSHServer(client, addr) + self.username = ssh_server.username + + try: + transport.start_server(server=ssh_server) + except paramiko.SSHException: + logger.warning('SSH negotiation failed.') + + self.client_channel = client_channel = transport.accept(20) + if client_channel is None: + logger.warning('No channel get.') + raise SSHServerException('No channel get.') + + if not ssh_server.event.is_set(): + logger.warning('Client never asked for a shell.') + raise SSHServerException('Client never asked for a shell.') + return client_channel + + def get_backend_channel(self, host, port, username): + backend_server = BackendServer(host, port, username) + self.backend_channel = backend_channel = backend_server.connect() + if not backend_channel: + logger.warning('Connect %(username)s@%(host)s:%(port)s failed' % { + 'username': username, + 'host': host, + 'port': port, + }) + + return backend_channel def handle_ssh_request(self, client, addr): - logger.info("Get connection from %(host)s:%(port)s" % { + logger.info("Get ssh request from %(host)s:%(port)s" % { 'host': addr[0], 'port': addr[1], }) try: - transport = paramiko.Transport(client, gss_kex=False) - transport.set_gss_host(socket.getfqdn("")) - try: - transport.load_server_moduli() - except: - logger.warning('(Failed to load moduli -- gex will be unsupported.)') - raise + client_channel = self.get_client_channel(client, addr) + host, port, username = self.display_navigation(self.username, client_channel) + backend_channel = self.get_backend_channel(host, port, username) - transport.add_server_key(SSHServerInterface.get_host_key()) - ssh_interface = SSHServerInterface(client, addr) - try: - transport.start_server(server=ssh_interface) - except paramiko.SSHException: - print('*** SSH negotiation failed.') - return - - self.client_channel = client_channel = transport.accept(20) - # self.client_channel = client_channel = transport.open_session() - # client_channel.get_pty(term='xterm') - if client_channel is None: - print('*** No channel.') - return - print('Authenticated!') - - client_channel.settimeout(100) - - client_channel.send('\r\n\r\nWelcome to my dorky little BBS!\r\n\r\n') - client_channel.send('We are on fire all the time! Hooray! Candy corn for everyone!\r\n') - client_channel.send('Happy birthday to Robot Dave!\r\n\r\n') - server_channel = self.connect() - if not ssh_interface.event.is_set(): - print('*** Client never asked for a shell.') - return + print(client_channel.get_id(), backend_channel.get_id()) while True: - r, w, x = select.select([client_channel, server_channel], [], []) + r, w, x = select.select([client_channel, backend_channel], [], []) if client_channel in r: data_client = client_channel.recv(1024) logger.info(data_client) if len(data_client) == 0: break - # client_channel.send(data_client) - server_channel.send(data_client) + backend_channel.send(data_client) - if server_channel in r: - data_server = server_channel.recv(1024) + if backend_channel in r: + data_server = backend_channel.recv(1024) if len(data_server) == 0: break client_channel.send(data_server) @@ -250,30 +283,35 @@ class SSHServer: # except IndexError: # pass - except Exception: + except IndexError: logger.info('Close with server %s from %s' % ('127.0.0.1', '127.0.0.1')) sys.exit(100) def listen(self): - self.sock.listen(5) - print('Start ssh server %(host)s:%(port)s' % {'host': self.host, 'port': self.port}) + self.sock = sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + sock.bind((self.listen_host, self.listen_port)) + sock.listen(5) + + print(time.ctime()) + print('Jumpserver version %s, more see https://www.jumpserver.org' % __version__) + print('Starting ssh server at %(host)s:%(port)s' % {'host': self.listen_host, 'port': self.listen_port}) + print('Quit the server with CONTROL-C.') + while True: try: client, addr = self.sock.accept() - print('Listening for connection ...') - # t = threading.Thread(target=self.handle_ssh_request, args=(client, addr)) t = process.Process(target=self.handle_ssh_request, args=(client, addr)) - t.daemon = True t.start() except Exception as e: - print('*** Bind failed: ' + str(e)) + logger.error('Bind failed: ' + str(e)) traceback.print_exc() sys.exit(1) if __name__ == '__main__': - server = SSHServer(host='', port=2200) + server = JumpServer() try: server.listen() except KeyboardInterrupt: diff --git a/terminal/utils.py b/terminal/utils.py index 1aba6ba02..da648e05f 100644 --- a/terminal/utils.py +++ b/terminal/utils.py @@ -15,3 +15,5 @@ def get_logger(name): return logging.getLogger('jumpserver.%s' % name) +class SSHServerException(Exception): + pass