mirror of https://github.com/jumpserver/jumpserver
Use process except thread
parent
216163f436
commit
ebb30424fa
|
@ -0,0 +1,105 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
|
||||
import logging
|
||||
import os
|
||||
|
||||
|
||||
BASE_DIR = os.path.dirname(os.path.abspath(__name__))
|
||||
|
||||
|
||||
class Config:
|
||||
SSH_HOST = ''
|
||||
SSH_PORT = 2200
|
||||
LOG_LEVEL = 'INFO'
|
||||
LOG_DIR = os.path.join(BASE_DIR, 'logs')
|
||||
LOG_FILENAME = 'ssh_server.log'
|
||||
LOGGING = {
|
||||
'version': 1,
|
||||
'disable_existing_loggers': False,
|
||||
'formatters': {
|
||||
'verbose': {
|
||||
'format': '%(levelname)s %(asctime)s %(module)s %(process)d %(thread)d %(message)s'
|
||||
},
|
||||
'main': {
|
||||
'datefmt': '%Y-%m-%d %H:%M:%S',
|
||||
'format': '%(asctime)s [%(module)s %(levelname)s] %(message)s',
|
||||
},
|
||||
'simple': {
|
||||
'format': '%(levelname)s %(message)s'
|
||||
},
|
||||
},
|
||||
'handlers': {
|
||||
'null': {
|
||||
'level': 'DEBUG',
|
||||
'class': 'logging.NullHandler',
|
||||
},
|
||||
'console': {
|
||||
'level': 'DEBUG',
|
||||
'class': 'logging.StreamHandler',
|
||||
'formatter': 'main',
|
||||
'stream': 'ext://sys.stdout',
|
||||
},
|
||||
'file': {
|
||||
'level': 'DEBUG',
|
||||
'class': 'logging.FileHandler',
|
||||
'formatter': 'main',
|
||||
'mode': 'a',
|
||||
'filename': os.path.join(LOG_DIR, LOG_FILENAME),
|
||||
},
|
||||
},
|
||||
'loggers': {
|
||||
'jumpserver': {
|
||||
'handlers': ['console', 'file'],
|
||||
# 'level': LOG_LEVEL_CHOICES.get(LOG_LEVEL, None) or LOG_LEVEL_CHOICES.get('info')
|
||||
'level': LOG_LEVEL,
|
||||
'propagate': True,
|
||||
},
|
||||
'jumpserver.web_ssh_server': {
|
||||
'handlers': ['console', 'file'],
|
||||
# 'level': LOG_LEVEL_CHOICES.get(LOG_LEVEL, None) or LOG_LEVEL_CHOICES.get('info')
|
||||
'level': LOG_LEVEL,
|
||||
'propagate': True,
|
||||
},
|
||||
'jumpserver.ssh_server': {
|
||||
'handlers': ['console', 'file'],
|
||||
# 'level': LOG_LEVEL_CHOICES.get(LOG_LEVEL, None) or LOG_LEVEL_CHOICES.get('info')
|
||||
'level': LOG_LEVEL,
|
||||
'propagate': True,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def __getattr__(self, item):
|
||||
return None
|
||||
|
||||
|
||||
class DevelopmentConfig(Config):
|
||||
pass
|
||||
|
||||
|
||||
class ProductionConfig(Config):
|
||||
pass
|
||||
|
||||
|
||||
class TestingConfig(Config):
|
||||
pass
|
||||
|
||||
|
||||
config = {
|
||||
'development': DevelopmentConfig,
|
||||
'production': ProductionConfig,
|
||||
'testing': TestingConfig,
|
||||
'default': DevelopmentConfig,
|
||||
}
|
||||
|
||||
env = 'default'
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pass
|
||||
|
|
@ -7,20 +7,14 @@ import os
|
|||
|
||||
|
||||
BASE_DIR = os.path.dirname(os.path.abspath(__name__))
|
||||
LOG_LEVEL_CHOICES = {
|
||||
'debug': logging.DEBUG,
|
||||
'info': logging.INFO,
|
||||
'warning': logging.WARNING,
|
||||
'error': logging.ERROR,
|
||||
'critical': logging.CRITICAL
|
||||
}
|
||||
|
||||
|
||||
class Config:
|
||||
LOG_LEVEL = ''
|
||||
LOG_LEVEL = 'INFO'
|
||||
LOG_DIR = os.path.join(BASE_DIR, 'logs')
|
||||
LOGGING = {
|
||||
'version': 1,
|
||||
'disable_existing_loggers': False,
|
||||
'formatters': {
|
||||
'verbose': {
|
||||
'format': '%(levelname)s %(asctime)s %(module)s %(process)d %(thread)d %(message)s'
|
||||
|
@ -47,35 +41,23 @@ class Config:
|
|||
'level': 'DEBUG',
|
||||
'class': 'logging.FileHandler',
|
||||
'formatter': 'main',
|
||||
'filename': os.path.join(PROJECT_DIR, 'logs', 'jumpserver.log')
|
||||
'filename': LOG_DIR,
|
||||
},
|
||||
},
|
||||
'loggers': {
|
||||
'django': {
|
||||
'handlers': ['null'],
|
||||
'propagate': False,
|
||||
'level': LOG_LEVEL,
|
||||
},
|
||||
'django.request': {
|
||||
'handlers': ['console', 'file'],
|
||||
'level': LOG_LEVEL,
|
||||
'propagate': False,
|
||||
},
|
||||
'django.server': {
|
||||
'handlers': ['console', 'file'],
|
||||
'level': LOG_LEVEL,
|
||||
'propagate': False,
|
||||
},
|
||||
'jumpserver': {
|
||||
'handlers': ['console', 'file'],
|
||||
# 'level': LOG_LEVEL_CHOICES.get(LOG_LEVEL, None) or LOG_LEVEL_CHOICES.get('info')
|
||||
'level': LOG_LEVEL,
|
||||
},
|
||||
'jumpserver.users.api': {
|
||||
'jumpserver.web_ssh_server': {
|
||||
'handlers': ['console', 'file'],
|
||||
# 'level': LOG_LEVEL_CHOICES.get(LOG_LEVEL, None) or LOG_LEVEL_CHOICES.get('info')
|
||||
'level': LOG_LEVEL,
|
||||
},
|
||||
'jumpserver.users.view': {
|
||||
'jumpserver.ssh_server': {
|
||||
'handlers': ['console', 'file'],
|
||||
# 'level': LOG_LEVEL_CHOICES.get(LOG_LEVEL, None) or LOG_LEVEL_CHOICES.get('info')
|
||||
'level': LOG_LEVEL,
|
||||
}
|
||||
}
|
||||
|
@ -88,6 +70,27 @@ class Config:
|
|||
return None
|
||||
|
||||
|
||||
class DevelopmentConfig(Config):
|
||||
pass
|
||||
|
||||
|
||||
class ProductionConfig(Config):
|
||||
pass
|
||||
|
||||
|
||||
class TestingConfig(Config):
|
||||
pass
|
||||
|
||||
|
||||
config = {
|
||||
'development': DevelopmentConfig,
|
||||
'production': ProductionConfig,
|
||||
'testing': TestingConfig,
|
||||
'default': DevelopmentConfig,
|
||||
}
|
||||
|
||||
env = 'default'
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pass
|
|
@ -7,6 +7,7 @@ import base64
|
|||
from binascii import hexlify
|
||||
import sys
|
||||
import threading
|
||||
from multiprocessing import process
|
||||
import traceback
|
||||
import tty
|
||||
import termios
|
||||
|
@ -31,17 +32,21 @@ except IndexError:
|
|||
pass
|
||||
|
||||
from django.conf import settings
|
||||
from common.utils import get_logger
|
||||
from users.utils import ssh_key_gen, check_user_is_valid
|
||||
from utils import get_logger
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class SSHServerInterface(paramiko.ServerInterface):
|
||||
host_key_path = os.path.join(BASE_DIR, 'host_rsa_key')
|
||||
channel_pools = []
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self, client, addr):
|
||||
self.event = threading.Event()
|
||||
self.client = client
|
||||
self.addr = addr
|
||||
self.user = None
|
||||
|
||||
@classmethod
|
||||
|
@ -70,19 +75,35 @@ class SSHServerInterface(paramiko.ServerInterface):
|
|||
def check_auth_password(self, username, password):
|
||||
self.user = check_user_is_valid(username=username, password=password)
|
||||
if self.user:
|
||||
logger.info('User: %s password auth passed' % username)
|
||||
logger.info('Accepted password for %(user)s from %(host)s port %(port)s ' % {
|
||||
'user': username,
|
||||
'host': self.addr[0],
|
||||
'port': self.addr[1],
|
||||
})
|
||||
return paramiko.AUTH_SUCCESSFUL
|
||||
else:
|
||||
logger.warning('User: %s password auth failed' % username)
|
||||
logger.info('Authentication password failed for %(user)s from %(host)s port %(port)s ' % {
|
||||
'user': username,
|
||||
'host': self.addr[0],
|
||||
'port': self.addr[1],
|
||||
})
|
||||
return paramiko.AUTH_FAILED
|
||||
|
||||
def check_auth_publickey(self, username, public_key):
|
||||
self.user = check_user_is_valid(username=username, public_key=public_key)
|
||||
if self.user:
|
||||
logger.info('User: %s public key auth passed' % username)
|
||||
logger.info('Accepted public key for %(user)s from %(host)s port %(port)s ' % {
|
||||
'user': username,
|
||||
'host': self.addr[0],
|
||||
'port': self.addr[1],
|
||||
})
|
||||
return paramiko.AUTH_SUCCESSFUL
|
||||
else:
|
||||
logger.warning('User: %s public key auth failed' % username)
|
||||
logger.info('Authentication public key failed for %(user)s from %(host)s port %(port)s ' % {
|
||||
'user': username,
|
||||
'host': self.addr[0],
|
||||
'port': self.addr[1],
|
||||
})
|
||||
return paramiko.AUTH_FAILED
|
||||
|
||||
def get_allowed_auths(self, username):
|
||||
|
@ -95,12 +116,20 @@ class SSHServerInterface(paramiko.ServerInterface):
|
|||
|
||||
def check_channel_shell_request(self, channel):
|
||||
self.event.set()
|
||||
self.__class__.channel_pools.append(channel)
|
||||
return True
|
||||
|
||||
def check_channel_pty_request(self, channel, term, width, height, pixelwidth,
|
||||
pixelheight, modes):
|
||||
return True
|
||||
|
||||
def check_channel_window_change_request(self, channel, width, height, pixelwidth, pixelheight):
|
||||
logger.info('Change window size %s * %s' % (width, height))
|
||||
logger.info('Change length %s ' % len(self.__class__.channel_pools))
|
||||
# for channel in self.__class__.channel_pools:
|
||||
# channel.send("Hello world")
|
||||
return True
|
||||
|
||||
|
||||
class SSHServer:
|
||||
def __init__(self, host='127.0.0.1', port=2200):
|
||||
|
@ -110,18 +139,22 @@ class SSHServer:
|
|||
self.sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
||||
self.sock.bind((self.host, self.port))
|
||||
self.server_ssh = None
|
||||
self.server_chan = None
|
||||
self.server_channel = None
|
||||
self.client_channel = None
|
||||
|
||||
def connect(self):
|
||||
ssh = paramiko.SSHClient()
|
||||
ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
|
||||
ssh.connect(hostname='127.0.0.1', port=22, username='root', password='redhat')
|
||||
self.server_ssh = ssh
|
||||
self.server_chan = channel = ssh.invoke_shell(term='xterm')
|
||||
self.server_channel = channel = ssh.invoke_shell(term='xterm')
|
||||
return channel
|
||||
|
||||
def handle_ssh_request(self, client, addr):
|
||||
logger.info("Get connection from " + str(addr))
|
||||
logger.info("Get connection from %(host)s:%(port)s" % {
|
||||
'host': addr[0],
|
||||
'port': addr[1],
|
||||
})
|
||||
try:
|
||||
transport = paramiko.Transport(client, gss_kex=False)
|
||||
transport.set_gss_host(socket.getfqdn(""))
|
||||
|
@ -132,70 +165,63 @@ class SSHServer:
|
|||
raise
|
||||
|
||||
transport.add_server_key(SSHServerInterface.get_host_key())
|
||||
ssh_interface = SSHServerInterface()
|
||||
ssh_interface = SSHServerInterface(client, addr)
|
||||
try:
|
||||
transport.start_server(server=ssh_interface)
|
||||
except paramiko.SSHException:
|
||||
print('*** SSH negotiation failed.')
|
||||
return
|
||||
|
||||
channel = transport.accept(20)
|
||||
if channel is None:
|
||||
self.client_channel = client_channel = transport.accept(20)
|
||||
if client_channel is None:
|
||||
print('*** No channel.')
|
||||
return
|
||||
print('Authenticated!')
|
||||
|
||||
channel.settimeout(100)
|
||||
client_channel.settimeout(100)
|
||||
|
||||
channel.send('\r\n\r\nWelcome to my dorky little BBS!\r\n\r\n')
|
||||
channel.send('We are on fire all the time! Hooray! Candy corn for everyone!\r\n')
|
||||
channel.send('Happy birthday to Robot Dave!\r\n\r\n')
|
||||
client_channel.send('\r\n\r\nWelcome to my dorky little BBS!\r\n\r\n')
|
||||
client_channel.send('We are on fire all the time! Hooray! Candy corn for everyone!\r\n')
|
||||
client_channel.send('Happy birthday to Robot Dave!\r\n\r\n')
|
||||
server_channel = self.connect()
|
||||
if not ssh_interface.event.is_set():
|
||||
print('*** Client never asked for a shell.')
|
||||
return
|
||||
server_data = []
|
||||
input_mode = True
|
||||
while True:
|
||||
r, w, e = select.select([server_channel, channel], [], [])
|
||||
|
||||
if channel in r:
|
||||
recv_data = channel.recv(1024).decode('utf8')
|
||||
# print("From client: " + repr(recv_data))
|
||||
if len(recv_data) == 0:
|
||||
while True:
|
||||
r, w, x = select.select([client_channel, server_channel], [], [])
|
||||
|
||||
if client_channel in r:
|
||||
data_client = client_channel.recv(1024)
|
||||
logger.info(data_client)
|
||||
if len(data_client) == 0:
|
||||
break
|
||||
server_channel.send(recv_data)
|
||||
# client_channel.send(data_client)
|
||||
server_channel.send(data_client)
|
||||
|
||||
if server_channel in r:
|
||||
recv_data = server_channel.recv(1024).decode('utf8')
|
||||
# print("From server: " + repr(recv_data))
|
||||
if len(recv_data) == 0:
|
||||
data_server = server_channel.recv(1024)
|
||||
if len(data_server) == 0:
|
||||
break
|
||||
channel.send(recv_data)
|
||||
if len(recv_data) > 20:
|
||||
server_data.append('...')
|
||||
else:
|
||||
server_data.append(recv_data)
|
||||
try:
|
||||
if repr(server_data[-2]) == u'\r\n':
|
||||
result = server_data.pop()
|
||||
server_data.pop()
|
||||
command = ''.join(server_data)
|
||||
server_data = []
|
||||
print(">>> Command: %s" % command)
|
||||
print(result)
|
||||
except IndexError:
|
||||
pass
|
||||
print(server_data)
|
||||
client_channel.send(data_server)
|
||||
|
||||
except Exception as e:
|
||||
print('*** Caught exception: ' + str(e.__class__) + ': ' + str(e))
|
||||
traceback.print_exc()
|
||||
try:
|
||||
transport.close()
|
||||
except:
|
||||
pass
|
||||
sys.exit(1)
|
||||
# if len(recv_data) > 20:
|
||||
# server_data.append('...')
|
||||
# else:
|
||||
# server_data.append(recv_data)
|
||||
# try:
|
||||
# if repr(server_data[-2]) == u'\r\n':
|
||||
# result = server_data.pop()
|
||||
# server_data.pop()
|
||||
# command = ''.join(server_data)
|
||||
# server_data = []
|
||||
# except IndexError:
|
||||
# pass
|
||||
|
||||
except Exception:
|
||||
client_channel.close()
|
||||
server_channel.close()
|
||||
logger.info('Close with server %s from %s' % ('127.0.0.1', '127.0.0.1'))
|
||||
|
||||
def listen(self):
|
||||
self.sock.listen(5)
|
||||
|
@ -204,7 +230,9 @@ class SSHServer:
|
|||
try:
|
||||
client, addr = self.sock.accept()
|
||||
print('Listening for connection ...')
|
||||
t = threading.Thread(target=self.handle_ssh_request, args=(client, addr))
|
||||
# t = threading.Thread(target=self.handle_ssh_request, args=(client, addr))
|
||||
t = process.Process(target=self.handle_ssh_request, args=(client, addr))
|
||||
|
||||
t.daemon = True
|
||||
t.start()
|
||||
except Exception as e:
|
||||
|
|
|
@ -3,6 +3,15 @@
|
|||
#
|
||||
|
||||
import logging
|
||||
from logging.config import dictConfig
|
||||
from ssh_config import config, env
|
||||
|
||||
|
||||
CONFIG_SSH_SERVER = config.get(env)
|
||||
|
||||
|
||||
def get_logger(name):
|
||||
dictConfig(CONFIG_SSH_SERVER.LOGGING)
|
||||
return logging.getLogger('jumpserver.%s' % name)
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue