#!/usr/bin/env python3 # coding: utf-8 import os import subprocess import threading import logging import logging.handlers import time import argparse import sys import signal from collections import defaultdict import daemon from daemon import pidfile BASE_DIR = os.path.dirname(os.path.abspath(__file__)) sys.path.insert(0, BASE_DIR) try: from apps.jumpserver import const __version__ = const.VERSION except ImportError as e: print("Not found __version__: {}".format(e)) print("Sys path: {}".format(sys.path)) print("Python is: ") print(subprocess.call('which python', shell=True)) __version__ = 'Unknown' try: import apps print("List apps: {}".format(os.listdir('apps'))) print('apps is: {}'.format(apps)) except: pass try: from apps.jumpserver.conf import load_user_config CONFIG = load_user_config() except ImportError as e: print("Import error: {}".format(e)) print("Could not find config file, `cp config_example.yml config.yml`") sys.exit(1) os.environ["PYTHONIOENCODING"] = "UTF-8" APPS_DIR = os.path.join(BASE_DIR, 'apps') LOG_DIR = os.path.join(BASE_DIR, 'logs') TMP_DIR = os.path.join(BASE_DIR, 'tmp') HTTP_HOST = CONFIG.HTTP_BIND_HOST or '127.0.0.1' HTTP_PORT = CONFIG.HTTP_LISTEN_PORT or 8080 DEBUG = CONFIG.DEBUG or False LOG_LEVEL = CONFIG.LOG_LEVEL or 'INFO' START_TIMEOUT = 40 WORKERS = 4 DAEMON = False EXIT_EVENT = threading.Event() LOCK = threading.Lock() daemon_pid_file = '' try: os.makedirs(os.path.join(BASE_DIR, "data", "static")) os.makedirs(os.path.join(BASE_DIR, "data", "media")) except: pass class LogPipe(threading.Thread): def __init__(self, name, file_path, to_stdout=False): """Setup the object with a logger and a loglevel and start the thread """ threading.Thread.__init__(self) self.daemon = False self.name = name self.file_path = file_path self.to_stdout = to_stdout self.fdRead, self.fdWrite = os.pipe() self.pipeReader = os.fdopen(self.fdRead) self.logger = self.init_logger() self.start() def init_logger(self): _logger = logging.getLogger(self.name) _logger.setLevel(logging.INFO) _formatter = logging.Formatter('%(message)s') _handler = logging.handlers.RotatingFileHandler( self.file_path, mode='a', maxBytes=5*1024*1024, backupCount=5 ) _handler.setFormatter(_formatter) _handler.setLevel(logging.INFO) _logger.addHandler(_handler) if self.to_stdout: _console = logging.StreamHandler() _console.setLevel(logging.INFO) _console.setFormatter(_formatter) _logger.addHandler(_console) return _logger def fileno(self): """Return the write file descriptor of the pipe """ return self.fdWrite def run(self): """Run the thread, logging everything. """ for line in iter(self.pipeReader.readline, ''): self.logger.info(line.strip('\n')) self.pipeReader.close() def close(self): """Close the write end of the pipe. """ os.close(self.fdWrite) def check_database_connection(): os.chdir(os.path.join(BASE_DIR, 'apps')) for i in range(60): print("Check database connection ...") code = subprocess.call("python manage.py showmigrations users ", shell=True) if code == 0: print("Database connect success") return time.sleep(1) print("Connection database failed, exist") sys.exit(10) def make_migrations(): print("Check database structure change ...") os.chdir(os.path.join(BASE_DIR, 'apps')) print("Migrate model change to database ...") subprocess.call('python3 manage.py migrate', shell=True) def collect_static(): print("Collect static files") os.chdir(os.path.join(BASE_DIR, 'apps')) command = 'python3 manage.py collectstatic --no-input -c &> /dev/null ' \ '&& echo "Collect static file done"' subprocess.call(command, shell=True) def prepare(): check_database_connection() make_migrations() collect_static() def check_pid(pid): """ Check For the existence of a unix pid. """ try: os.kill(pid, 0) except OSError: return False else: return True def get_pid_file_path(s): return os.path.join('/tmp', '{}.pid'.format(s)) def get_log_file_path(s): return os.path.join(LOG_DIR, '{}.log'.format(s)) def get_pid_from_file(path): if os.path.isfile(path): with open(path) as f: try: return int(f.read().strip()) except ValueError: return 0 return 0 def get_pid(s): pid_file = get_pid_file_path(s) return get_pid_from_file(pid_file) def is_running(s, unlink=True): pid_file = get_pid_file_path(s) if os.path.isfile(pid_file): pid = get_pid(s) if pid == 0: return False elif check_pid(pid): return True if unlink: os.unlink(pid_file) return False def parse_service(s): all_services = ['gunicorn', 'celery_ansible', 'celery_default', 'beat'] if s == 'all': return all_services elif s == "celery": return ["celery_ansible", "celery_default"] elif "," in s: services = set() for i in s.split(','): services.update(parse_service(i)) return services else: return [s] def get_start_gunicorn_kwargs(): print("\n- Start Gunicorn WSGI HTTP Server") prepare() service = 'gunicorn' bind = '{}:{}'.format(HTTP_HOST, HTTP_PORT) log_format = '%(h)s %(t)s "%(r)s" %(s)s %(b)s ' pid_file = get_pid_file_path(service) cmd = [ 'gunicorn', 'jumpserver.wsgi', '-b', bind, '-k', 'gthread', '--threads', '10', '-w', str(WORKERS), '--max-requests', '4096', '--access-logformat', log_format, '-p', pid_file, '--access-logfile', '-' ] if DEBUG: cmd.append('--reload') return {'cmd': cmd, 'cwd': APPS_DIR} def get_start_celery_ansible_kwargs(): print("\n- Start Celery as Distributed Task Queue") return get_start_worker_kwargs('ansible', 4) def get_start_celery_default_kwargs(): return get_start_worker_kwargs('celery', 2) def get_start_worker_kwargs(queue, num): # Todo: Must set this environment, otherwise not no ansible result return os.environ.setdefault('PYTHONOPTIMIZE', '1') if os.getuid() == 0: os.environ.setdefault('C_FORCE_ROOT', '1') cmd = [ 'celery', 'worker', '-A', 'ops', '-l', 'INFO', '-c', str(num), '-Q', queue, ] return {"cmd": cmd, "cwd": APPS_DIR} def get_start_beat_kwargs(): print("\n- Start Beat as Periodic Task Scheduler") os.environ.setdefault('PYTHONOPTIMIZE', '1') if os.getuid() == 0: os.environ.setdefault('C_FORCE_ROOT', '1') scheduler = "django_celery_beat.schedulers:DatabaseScheduler" cmd = [ 'celery', 'beat', '-A', 'ops', '-l', 'INFO', '--scheduler', scheduler, '--max-interval', '60' ] return {"cmd": cmd, 'cwd': APPS_DIR} processes = {} def watch_services(): max_retry = 3 signal.signal(signal.SIGTERM, lambda x, y: clean_up()) services_retry = defaultdict(int) stopped_services = {} def check_services(): for s, p in processes.items(): try: p.wait(timeout=1) stopped_services[s] = '' except subprocess.TimeoutExpired: stopped_services.pop(s, None) services_retry.pop(s, None) continue def retry_start_stopped_services(): for s in stopped_services: if services_retry[s] > max_retry: print("\nService start failed, exit: ", s) EXIT_EVENT.set() break print("\n> Find {} stopped, retry {}".format( s, services_retry[s] + 1) ) p = start_service(s) processes[s] = p services_retry[s] += 1 while not EXIT_EVENT.is_set(): try: with LOCK: check_services() retry_start_stopped_services() time.sleep(10) except KeyboardInterrupt: time.sleep(1) break clean_up() def start_service(s): services_kwargs = { "gunicorn": get_start_gunicorn_kwargs, "celery_ansible": get_start_celery_ansible_kwargs, "celery_default": get_start_celery_default_kwargs, "beat": get_start_beat_kwargs, } kwargs = services_kwargs.get(s)() pid_file = get_pid_file_path(s) if os.path.isfile(pid_file): os.unlink(pid_file) cmd = kwargs.pop('cmd') to_stdout = False if not DAEMON: to_stdout = True log_file = get_log_file_path(s) _logger = LogPipe(s, log_file, to_stdout=to_stdout) stderr = stdout = _logger kwargs.update({"stderr": stderr, "stdout": stdout}) p = subprocess.Popen(cmd, **kwargs) with open(pid_file, 'w') as f: f.write(str(p.pid)) return p def start_services_and_watch(s): print(time.ctime()) print('Jumpserver version {}, more see https://www.jumpserver.org'.format( __version__) ) services_set = parse_service(s) for i in services_set: if is_running(i): show_service_status(i) continue p = start_service(i) time.sleep(2) processes[i] = p if not DAEMON: watch_services() else: show_service_status(s) global daemon_pid_file daemon_pid_file = get_pid_file_path('jms') context = daemon.DaemonContext( pidfile=pidfile.TimeoutPIDLockFile(daemon_pid_file), signal_map={ signal.SIGTERM: clean_up, signal.SIGHUP: 'terminate', }, ) with context: watch_services() def stop_service(s, sig=15): services_set = parse_service(s) for s in services_set: if not is_running(s): show_service_status(s) continue print("Stop service: {}".format(s)) pid = get_pid(s) os.kill(pid, sig) with LOCK: processes.pop(s, None) if s == "all": pid = get_pid('jms') os.kill(pid, sig) def stop_multi_services(services): for s in services: stop_service(s, sig=9) def stop_service_force(s): stop_service(s, sig=9) def clean_up(): if not EXIT_EVENT.is_set(): EXIT_EVENT.set() processes_dump = {k: v for k, v in processes.items()} for s1, p1 in processes_dump.items(): stop_service(s1) p1.wait() def show_service_status(s): services_set = parse_service(s) for ns in services_set: if is_running(ns): pid = get_pid(ns) print("{} is running: {}".format(ns, pid)) else: print("{} is stopped".format(ns)) if __name__ == '__main__': parser = argparse.ArgumentParser( description=""" Jumpserver service control tools; Example: \r\n %(prog)s start all -d; """ ) parser.add_argument( 'action', type=str, choices=("start", "stop", "restart", "status"), help="Action to run" ) parser.add_argument( "service", type=str, default="all", nargs="?", choices=("all", "gunicorn", "celery", "beat", "celery,beat"), help="The service to start", ) parser.add_argument('-d', '--daemon', nargs="?", const=1) parser.add_argument('-w', '--worker', type=int, nargs="?", const=4) args = parser.parse_args() if args.daemon: DAEMON = True if args.worker: WORKERS = args.worker action = args.action srv = args.service if action == "start": start_services_and_watch(srv) os._exit(0) elif action == "stop": stop_service(srv) elif action == "restart": DAEMON = True stop_service(srv) time.sleep(5) start_services_and_watch(srv) else: show_service_status(srv)