You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
jumpserver/jms

479 lines
12 KiB

#!/usr/bin/env python3
# coding: utf-8
import os
import subprocess
import threading
import logging
import logging.handlers
import time
import argparse
import sys
import signal
from collections import defaultdict
import daemon
from daemon import pidfile
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
sys.path.insert(0, BASE_DIR)
try:
from apps.jumpserver import const
__version__ = const.VERSION
except ImportError as e:
print("Not found __version__: {}".format(e))
print("Sys path: {}".format(sys.path))
print("Python is: ")
print(subprocess.call('which python', shell=True))
__version__ = 'Unknown'
try:
import apps
print("List apps: {}".format(os.listdir('apps')))
print('apps is: {}'.format(apps))
except:
pass
try:
from apps.jumpserver.conf import load_user_config
CONFIG = load_user_config()
except ImportError as e:
print("Import error: {}".format(e))
print("Could not find config file, `cp config_example.yml config.yml`")
sys.exit(1)
os.environ["PYTHONIOENCODING"] = "UTF-8"
APPS_DIR = os.path.join(BASE_DIR, 'apps')
LOG_DIR = os.path.join(BASE_DIR, 'logs')
TMP_DIR = os.path.join(BASE_DIR, 'tmp')
HTTP_HOST = CONFIG.HTTP_BIND_HOST or '127.0.0.1'
HTTP_PORT = CONFIG.HTTP_LISTEN_PORT or 8080
DEBUG = CONFIG.DEBUG or False
LOG_LEVEL = CONFIG.LOG_LEVEL or 'INFO'
START_TIMEOUT = 40
WORKERS = 4
DAEMON = False
EXIT_EVENT = threading.Event()
LOCK = threading.Lock()
daemon_pid_file = ''
try:
os.makedirs(os.path.join(BASE_DIR, "data", "static"))
os.makedirs(os.path.join(BASE_DIR, "data", "media"))
except:
pass
class LogPipe(threading.Thread):
def __init__(self, name, file_path, to_stdout=False):
"""Setup the object with a logger and a loglevel
and start the thread
"""
threading.Thread.__init__(self)
self.daemon = False
self.name = name
self.file_path = file_path
self.to_stdout = to_stdout
self.fdRead, self.fdWrite = os.pipe()
self.pipeReader = os.fdopen(self.fdRead)
self.logger = self.init_logger()
self.start()
def init_logger(self):
_logger = logging.getLogger(self.name)
_logger.setLevel(logging.INFO)
_formatter = logging.Formatter('%(message)s')
_handler = logging.handlers.RotatingFileHandler(
self.file_path, mode='a', maxBytes=5*1024*1024, backupCount=5
)
_handler.setFormatter(_formatter)
_handler.setLevel(logging.INFO)
_logger.addHandler(_handler)
if self.to_stdout:
_console = logging.StreamHandler()
_console.setLevel(logging.INFO)
_console.setFormatter(_formatter)
_logger.addHandler(_console)
return _logger
def fileno(self):
"""Return the write file descriptor of the pipe
"""
return self.fdWrite
def run(self):
"""Run the thread, logging everything.
"""
for line in iter(self.pipeReader.readline, ''):
self.logger.info(line.strip('\n'))
self.pipeReader.close()
def close(self):
"""Close the write end of the pipe.
"""
os.close(self.fdWrite)
def check_database_connection():
os.chdir(os.path.join(BASE_DIR, 'apps'))
for i in range(60):
print("Check database connection ...")
code = subprocess.call("python manage.py showmigrations users ", shell=True)
if code == 0:
print("Database connect success")
return
time.sleep(1)
print("Connection database failed, exist")
sys.exit(10)
def make_migrations():
print("Check database structure change ...")
os.chdir(os.path.join(BASE_DIR, 'apps'))
print("Migrate model change to database ...")
subprocess.call('python3 manage.py migrate', shell=True)
def collect_static():
print("Collect static files")
os.chdir(os.path.join(BASE_DIR, 'apps'))
command = 'python3 manage.py collectstatic --no-input -c &> /dev/null ' \
'&& echo "Collect static file done"'
subprocess.call(command, shell=True)
def prepare():
check_database_connection()
make_migrations()
collect_static()
def check_pid(pid):
""" Check For the existence of a unix pid. """
try:
os.kill(pid, 0)
except OSError:
return False
else:
return True
def get_pid_file_path(s):
return os.path.join('/tmp', '{}.pid'.format(s))
def get_log_file_path(s):
return os.path.join(LOG_DIR, '{}.log'.format(s))
def get_pid_from_file(path):
if os.path.isfile(path):
with open(path) as f:
try:
return int(f.read().strip())
except ValueError:
return 0
return 0
def get_pid(s):
pid_file = get_pid_file_path(s)
return get_pid_from_file(pid_file)
def is_running(s, unlink=True):
pid_file = get_pid_file_path(s)
if os.path.isfile(pid_file):
pid = get_pid(s)
if pid == 0:
return False
elif check_pid(pid):
return True
if unlink:
os.unlink(pid_file)
return False
def parse_service(s):
all_services = ['gunicorn', 'celery_ansible', 'celery_default', 'beat']
if s == 'all':
return all_services
elif s == "celery":
return ["celery_ansible", "celery_default"]
elif "," in s:
services = set()
for i in s.split(','):
services.update(parse_service(i))
return services
else:
return [s]
def get_start_gunicorn_kwargs():
print("\n- Start Gunicorn WSGI HTTP Server")
prepare()
service = 'gunicorn'
bind = '{}:{}'.format(HTTP_HOST, HTTP_PORT)
log_format = '%(h)s %(t)s "%(r)s" %(s)s %(b)s '
pid_file = get_pid_file_path(service)
cmd = [
'gunicorn', 'jumpserver.wsgi',
'-b', bind,
'-k', 'gthread',
'--threads', '10',
'-w', str(WORKERS),
'--max-requests', '4096',
'--access-logformat', log_format,
'-p', pid_file,
'--access-logfile', '-'
]
if DEBUG:
cmd.append('--reload')
return {'cmd': cmd, 'cwd': APPS_DIR}
def get_start_celery_ansible_kwargs():
print("\n- Start Celery as Distributed Task Queue")
return get_start_worker_kwargs('ansible', 4)
def get_start_celery_default_kwargs():
return get_start_worker_kwargs('celery', 2)
def get_start_worker_kwargs(queue, num):
# Todo: Must set this environment, otherwise not no ansible result return
os.environ.setdefault('PYTHONOPTIMIZE', '1')
if os.getuid() == 0:
os.environ.setdefault('C_FORCE_ROOT', '1')
cmd = [
'celery', 'worker',
'-A', 'ops',
'-l', 'INFO',
'-c', str(num),
'-Q', queue,
]
return {"cmd": cmd, "cwd": APPS_DIR}
def get_start_beat_kwargs():
print("\n- Start Beat as Periodic Task Scheduler")
os.environ.setdefault('PYTHONOPTIMIZE', '1')
if os.getuid() == 0:
os.environ.setdefault('C_FORCE_ROOT', '1')
scheduler = "django_celery_beat.schedulers:DatabaseScheduler"
cmd = [
'celery', 'beat',
'-A', 'ops',
'-l', 'INFO',
'--scheduler', scheduler,
'--max-interval', '60'
]
return {"cmd": cmd, 'cwd': APPS_DIR}
processes = {}
def watch_services():
max_retry = 3
signal.signal(signal.SIGTERM, lambda x, y: clean_up())
services_retry = defaultdict(int)
stopped_services = {}
def check_services():
for s, p in processes.items():
try:
p.wait(timeout=1)
stopped_services[s] = ''
except subprocess.TimeoutExpired:
stopped_services.pop(s, None)
services_retry.pop(s, None)
continue
def retry_start_stopped_services():
for s in stopped_services:
if services_retry[s] > max_retry:
print("\nService start failed, exit: ", s)
EXIT_EVENT.set()
break
print("\n> Find {} stopped, retry {}".format(
s, services_retry[s] + 1)
)
p = start_service(s)
processes[s] = p
services_retry[s] += 1
while not EXIT_EVENT.is_set():
try:
with LOCK:
check_services()
retry_start_stopped_services()
time.sleep(10)
except KeyboardInterrupt:
time.sleep(1)
break
clean_up()
def start_service(s):
services_kwargs = {
"gunicorn": get_start_gunicorn_kwargs,
"celery_ansible": get_start_celery_ansible_kwargs,
"celery_default": get_start_celery_default_kwargs,
"beat": get_start_beat_kwargs,
}
kwargs = services_kwargs.get(s)()
pid_file = get_pid_file_path(s)
if os.path.isfile(pid_file):
os.unlink(pid_file)
cmd = kwargs.pop('cmd')
to_stdout = False
if not DAEMON:
to_stdout = True
log_file = get_log_file_path(s)
_logger = LogPipe(s, log_file, to_stdout=to_stdout)
stderr = stdout = _logger
kwargs.update({"stderr": stderr, "stdout": stdout})
p = subprocess.Popen(cmd, **kwargs)
with open(pid_file, 'w') as f:
f.write(str(p.pid))
return p
def start_services_and_watch(s):
print(time.ctime())
print('Jumpserver version {}, more see https://www.jumpserver.org'.format(
__version__)
)
services_set = parse_service(s)
for i in services_set:
if is_running(i):
show_service_status(i)
continue
p = start_service(i)
time.sleep(2)
processes[i] = p
if not DAEMON:
watch_services()
else:
show_service_status(s)
global daemon_pid_file
daemon_pid_file = get_pid_file_path('jms')
context = daemon.DaemonContext(
pidfile=pidfile.TimeoutPIDLockFile(daemon_pid_file),
signal_map={
signal.SIGTERM: clean_up,
signal.SIGHUP: 'terminate',
},
)
with context:
watch_services()
def stop_service(s, sig=15):
services_set = parse_service(s)
for s in services_set:
if not is_running(s):
show_service_status(s)
continue
print("Stop service: {}".format(s))
pid = get_pid(s)
os.kill(pid, sig)
with LOCK:
processes.pop(s, None)
if s == "all":
pid = get_pid('jms')
os.kill(pid, sig)
def stop_multi_services(services):
for s in services:
stop_service(s, sig=9)
def stop_service_force(s):
stop_service(s, sig=9)
def clean_up():
if not EXIT_EVENT.is_set():
EXIT_EVENT.set()
processes_dump = {k: v for k, v in processes.items()}
for s1, p1 in processes_dump.items():
stop_service(s1)
p1.wait()
def show_service_status(s):
services_set = parse_service(s)
for ns in services_set:
if is_running(ns):
pid = get_pid(ns)
print("{} is running: {}".format(ns, pid))
else:
print("{} is stopped".format(ns))
if __name__ == '__main__':
parser = argparse.ArgumentParser(
description="""
Jumpserver service control tools;
Example: \r\n
%(prog)s start all -d;
"""
)
parser.add_argument(
'action', type=str,
choices=("start", "stop", "restart", "status"),
help="Action to run"
)
parser.add_argument(
"service", type=str, default="all", nargs="?",
choices=("all", "gunicorn", "celery", "beat", "celery,beat"),
help="The service to start",
)
parser.add_argument('-d', '--daemon', nargs="?", const=1)
parser.add_argument('-w', '--worker', type=int, nargs="?", const=4)
args = parser.parse_args()
if args.daemon:
DAEMON = True
if args.worker:
WORKERS = args.worker
action = args.action
srv = args.service
if action == "start":
start_services_and_watch(srv)
os._exit(0)
elif action == "stop":
stop_service(srv)
elif action == "restart":
DAEMON = True
stop_service(srv)
time.sleep(5)
start_services_and_watch(srv)
else:
show_service_status(srv)