Merge pull request #5952 from jumpserver/dev

v2.9.0 rc2
pull/5980/head
Jiangjie.Bai 2021-04-13 19:19:43 +08:00 committed by GitHub
commit 4bf2371cf0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
19 changed files with 193 additions and 43 deletions

View File

@ -307,6 +307,15 @@ class NodeAllAssetsMappingMixin:
org_id = str(org_id)
cls.orgid_nodekey_assetsid_mapping.pop(org_id, None)
@classmethod
def expire_all_orgs_node_all_asset_ids_mapping_from_memory(cls):
orgs = Organization.objects.all()
org_ids = [str(org.id) for org in orgs]
org_ids.append(Organization.ROOT_ID)
for id in org_ids:
cls.expire_node_all_asset_ids_mapping_from_memory(id)
# get order: from memory -> (from cache -> to generate)
@classmethod
def get_node_all_asset_ids_mapping_from_cache_or_generate_to_cache(cls, org_id):

View File

@ -13,6 +13,7 @@ from common.signals import django_ready
from common.utils.connection import RedisPubSub
from common.utils import get_logger
from assets.models import Asset, Node
from orgs.models import Organization
logger = get_logger(__file__)
@ -36,13 +37,18 @@ node_assets_mapping_for_memory_pub_sub = NodeAssetsMappingForMemoryPubSub()
def expire_node_assets_mapping_for_memory(org_id):
# 所有进程清除(自己的 memory 数据)
org_id = str(org_id)
node_assets_mapping_for_memory_pub_sub.publish(org_id)
root_org_id = Organization.ROOT_ID
# 当前进程清除(cache 数据)
logger.debug(
"Expire node assets id mapping from cache of org={}, pid={}"
"".format(org_id, os.getpid())
)
Node.expire_node_all_asset_ids_mapping_from_cache(org_id)
Node.expire_node_all_asset_ids_mapping_from_cache(root_org_id)
node_assets_mapping_for_memory_pub_sub.publish(org_id)
node_assets_mapping_for_memory_pub_sub.publish(root_org_id)
@receiver(post_save, sender=Node)
@ -73,16 +79,22 @@ def subscribe_node_assets_mapping_expire(sender, **kwargs):
logger.debug("Start subscribe for expire node assets id mapping from memory")
def keep_subscribe():
subscribe = node_assets_mapping_for_memory_pub_sub.subscribe()
for message in subscribe.listen():
if message["type"] != "message":
continue
org_id = message['data'].decode()
Node.expire_node_all_asset_ids_mapping_from_memory(org_id)
logger.debug(
"Expire node assets id mapping from memory of org={}, pid={}"
"".format(str(org_id), os.getpid())
)
while True:
try:
subscribe = node_assets_mapping_for_memory_pub_sub.subscribe()
for message in subscribe.listen():
if message["type"] != "message":
continue
org_id = message['data'].decode()
Node.expire_node_all_asset_ids_mapping_from_memory(org_id)
logger.debug(
"Expire node assets id mapping from memory of org={}, pid={}"
"".format(str(org_id), os.getpid())
)
except Exception as e:
logger.exception(f'subscribe_node_assets_mapping_expire: {e}')
Node.expire_all_orgs_node_all_asset_ids_mapping_from_memory()
t = threading.Thread(target=keep_subscribe)
t.daemon = True
t.start()

View File

@ -44,6 +44,7 @@ class AuthBackendLabelMapping(LazyObject):
backend_label_mapping[backend] = source.label
backend_label_mapping[settings.AUTH_BACKEND_PUBKEY] = _('SSH Key')
backend_label_mapping[settings.AUTH_BACKEND_MODEL] = _('Password')
backend_label_mapping[settings.AUTH_BACKEND_SSO] = _('SSO')
return backend_label_mapping
def _setup(self):

View File

@ -16,6 +16,7 @@ reason_user_not_exist = 'user_not_exist'
reason_password_expired = 'password_expired'
reason_user_invalid = 'user_invalid'
reason_user_inactive = 'user_inactive'
reason_user_expired = 'user_expired'
reason_backend_not_match = 'backend_not_match'
reason_acl_not_allow = 'acl_not_allow'
@ -28,6 +29,7 @@ reason_choices = {
reason_password_expired: _("Password expired"),
reason_user_invalid: _('Disabled or expired'),
reason_user_inactive: _("This account is inactive."),
reason_user_expired: _("This account is expired"),
reason_backend_not_match: _("Auth backend not match"),
reason_acl_not_allow: _("ACL is not allowed"),
}

View File

@ -171,7 +171,7 @@ class AuthMixin:
if not user:
self.raise_credential_error(errors.reason_password_failed)
elif user.is_expired:
self.raise_credential_error(errors.reason_user_inactive)
self.raise_credential_error(errors.reason_user_expired)
elif not user.is_active:
self.raise_credential_error(errors.reason_user_inactive)
return user

View File

@ -69,6 +69,7 @@ class UserLoginView(mixins.AuthMixin, FormView):
new_form = form_cls(data=form.data)
new_form._errors = form.errors
context = self.get_context_data(form=new_form)
self.request.session.set_test_cookie()
return self.render_to_response(context)
except (errors.PasswdTooSimple, errors.PasswordRequireResetError) as e:
return redirect(e.url)

View File

@ -1,3 +1,5 @@
import time
from django.core.cache import cache
from django.utils import timezone
from django.utils.timesince import timesince
@ -6,6 +8,8 @@ from django.http.response import JsonResponse, HttpResponse
from rest_framework.views import APIView
from rest_framework.permissions import AllowAny
from collections import Counter
from django.conf import settings
from rest_framework.response import Response
from users.models import User
from assets.models import Asset
@ -307,7 +311,68 @@ class IndexApi(TotalCountMixin, DatesLoginMetricMixin, APIView):
return JsonResponse(data, status=200)
class PrometheusMetricsApi(APIView):
class HealthApiMixin(APIView):
def is_token_right(self):
token = self.request.query_params.get('token')
ok_token = settings.HEALTH_CHECK_TOKEN
if ok_token and token != ok_token:
return False
return True
def check_permissions(self, request):
if not self.is_token_right():
msg = 'Health check token error, ' \
'Please set query param in url and same with setting HEALTH_CHECK_TOKEN. ' \
'eg: $PATH/?token=$HEALTH_CHECK_TOKEN'
self.permission_denied(request, message={'error': msg}, code=403)
class HealthCheckView(HealthApiMixin):
permission_classes = (AllowAny,)
@staticmethod
def get_db_status():
t1 = time.time()
try:
User.objects.first()
t2 = time.time()
return True, t2 - t1
except:
t2 = time.time()
return False, t2 - t1
def get_redis_status(self):
key = 'HEALTH_CHECK'
t1 = time.time()
try:
value = '1'
cache.set(key, '1', 10)
got = cache.get(key)
t2 = time.time()
if value == got:
return True, t2 -t1
return False, t2 -t1
except:
t2 = time.time()
return False, t2 - t1
def get(self, request):
redis_status, redis_time = self.get_redis_status()
db_status, db_time = self.get_db_status()
status = all([redis_status, db_status])
data = {
'status': status,
'db_status': db_status,
'db_time': db_time,
'redis_status': redis_status,
'redis_time': redis_time,
'time': int(time.time())
}
return Response(data)
class PrometheusMetricsApi(HealthApiMixin):
permission_classes = (AllowAny,)
def get(self, request, *args, **kwargs):

View File

@ -289,6 +289,7 @@ class Config(dict):
'SESSION_SAVE_EVERY_REQUEST': True,
'SESSION_EXPIRE_AT_BROWSER_CLOSE_FORCE': False,
'FORGOT_PASSWORD_URL': '',
'HEALTH_CHECK_TOKEN': ''
}
def compatible_auth_openid_of_key(self):

View File

@ -123,3 +123,4 @@ FORGOT_PASSWORD_URL = CONFIG.FORGOT_PASSWORD_URL
# 自定义默认组织名
GLOBAL_ORG_DISPLAY_NAME = CONFIG.GLOBAL_ORG_DISPLAY_NAME
HEALTH_CHECK_TOKEN = CONFIG.HEALTH_CHECK_TOKEN

View File

@ -48,7 +48,8 @@ urlpatterns = [
path('', views.IndexView.as_view(), name='index'),
path('api/v1/', include(api_v1)),
re_path('api/(?P<app>\w+)/(?P<version>v\d)/.*', views.redirect_format_api),
path('api/health/', views.HealthCheckView.as_view(), name="health"),
path('api/health/', api.HealthCheckView.as_view(), name="health"),
path('api/v1/health/', api.HealthCheckView.as_view(), name="health_v1"),
# External apps url
path('core/auth/captcha/', include('captcha.urls')),
path('core/', include(app_view_patterns)),

View File

@ -17,7 +17,7 @@ from common.http import HttpResponseTemporaryRedirect
__all__ = [
'LunaView', 'I18NView', 'KokoView', 'WsView', 'HealthCheckView',
'LunaView', 'I18NView', 'KokoView', 'WsView',
'redirect_format_api', 'redirect_old_apps_view', 'UIView'
]
@ -64,13 +64,6 @@ def redirect_old_apps_view(request, *args, **kwargs):
return HttpResponseTemporaryRedirect(new_path)
class HealthCheckView(APIView):
permission_classes = (AllowAny,)
def get(self, request):
return JsonResponse({"status": 1, "time": int(time.time())})
class WsView(APIView):
ws_port = settings.HTTP_LISTEN_PORT + 1

View File

@ -46,12 +46,19 @@ def subscribe_orgs_mapping_expire(sender, **kwargs):
logger.debug("Start subscribe for expire orgs mapping from memory")
def keep_subscribe():
subscribe = orgs_mapping_for_memory_pub_sub.subscribe()
for message in subscribe.listen():
if message['type'] != 'message':
continue
Organization.expire_orgs_mapping()
logger.debug('Expire orgs mapping')
while True:
try:
subscribe = orgs_mapping_for_memory_pub_sub.subscribe()
for message in subscribe.listen():
if message['type'] != 'message':
continue
if message['data'] == b'error':
raise ValueError
Organization.expire_orgs_mapping()
logger.debug('Expire orgs mapping')
except Exception as e:
logger.exception(f'subscribe_orgs_mapping_expire: {e}')
Organization.expire_orgs_mapping()
t = threading.Thread(target=keep_subscribe)
t.daemon = True

View File

@ -6,6 +6,7 @@ import threading
from django.dispatch import receiver
from django.db.models.signals import post_save, pre_save
from django.utils.functional import LazyObject
from django.db import close_old_connections
from jumpserver.utils import current_request
from common.decorator import on_transaction_commit
@ -71,13 +72,21 @@ def subscribe_settings_change(sender, **kwargs):
logger.debug("Start subscribe setting change")
def keep_subscribe():
sub = setting_pub_sub.subscribe()
for msg in sub.listen():
if msg["type"] != "message":
continue
item = msg['data'].decode()
logger.debug("Found setting change: {}".format(str(item)))
Setting.refresh_item(item)
while True:
try:
sub = setting_pub_sub.subscribe()
for msg in sub.listen():
close_old_connections()
if msg["type"] != "message":
continue
item = msg['data'].decode()
logger.debug("Found setting change: {}".format(str(item)))
Setting.refresh_item(item)
except Exception as e:
logger.exception(f'subscribe_settings_change: {e}')
close_old_connections()
Setting.refresh_all_settings()
t = threading.Thread(target=keep_subscribe)
t.daemon = True
t.start()

View File

@ -11,6 +11,7 @@ from rest_framework.response import Response
from rest_framework.decorators import action
from django.template import loader
from common.http import is_true
from terminal.models import CommandStorage, Command
from terminal.filters import CommandFilter
from orgs.utils import current_org
@ -140,7 +141,21 @@ class CommandViewSet(viewsets.ModelViewSet):
if session_id and not command_storage_id:
# 会话里的命令列表肯定会提供 session_id这里防止 merge 的时候取全量的数据
return self.merge_all_storage_list(request, *args, **kwargs)
return super().list(request, *args, **kwargs)
queryset = self.filter_queryset(self.get_queryset())
page = self.paginate_queryset(queryset)
if page is not None:
serializer = self.get_serializer(page, many=True)
return self.get_paginated_response(serializer.data)
query_all = self.request.query_params.get('all', False)
if is_true(query_all):
# 适配像 ES 这种没有指定分页只返回少量数据的情况
queryset = queryset[:]
serializer = self.get_serializer(queryset, many=True)
return Response(serializer.data)
def get_queryset(self):
command_storage_id = self.request.query_params.get('command_storage_id')

View File

@ -10,6 +10,7 @@ import inspect
from django.db.models import QuerySet as DJQuerySet
from elasticsearch import Elasticsearch
from elasticsearch.helpers import bulk
from elasticsearch.exceptions import RequestError
from common.utils.common import lazyproperty
from common.utils import get_logger
@ -31,6 +32,15 @@ class CommandStore():
kwargs['verify_certs'] = None
self.es = Elasticsearch(hosts=hosts, max_retries=0, **kwargs)
def pre_use_check(self):
self._ensure_index_exists()
def _ensure_index_exists(self):
try:
self.es.indices.create(self.index)
except RequestError:
pass
@staticmethod
def make_data(command):
data = dict(
@ -234,6 +244,7 @@ class QuerySet(DJQuerySet):
uqs = QuerySet(self._command_store_config)
uqs._method_calls = self._method_calls.copy()
uqs._slice = self._slice
uqs.model = self.model
return uqs
def count(self, limit_to_max_result_window=True):

View File

@ -76,6 +76,15 @@ class CommandStorage(CommonModelMixin):
qs.model = Command
return qs
def save(self, force_insert=False, force_update=False, using=None,
update_fields=None):
super().save()
if self.type in TYPE_ENGINE_MAPPING:
engine_mod = import_module(TYPE_ENGINE_MAPPING[self.type])
backend = engine_mod.CommandStore(self.config)
backend.pre_use_check()
class ReplayStorage(CommonModelMixin):
name = models.CharField(max_length=128, verbose_name=_("Name"), unique=True)

View File

@ -7,6 +7,7 @@ from django.conf import settings
from common.utils import get_logger
from users.models import User
from orgs.utils import tmp_to_root_org
from .status import Status
from .. import const
from ..const import ComponentStatusChoices as StatusChoice
@ -112,7 +113,6 @@ class Terminal(StorageMixin, TerminalStatusMixin, models.Model):
date_created = models.DateTimeField(auto_now_add=True)
comment = models.TextField(blank=True, verbose_name=_('Comment'))
@property
def is_active(self):
if self.user and self.user.is_active:
@ -126,7 +126,8 @@ class Terminal(StorageMixin, TerminalStatusMixin, models.Model):
self.user.save()
def get_online_sessions(self):
return Session.objects.filter(terminal=self, is_finished=False)
with tmp_to_root_org():
return Session.objects.filter(terminal=self, is_finished=False)
def get_online_session_count(self):
return self.get_online_sessions().count()

View File

@ -1,8 +1,8 @@
# ~*~ coding: utf-8 ~*~
from django.core.cache import cache
from collections import defaultdict
from django.utils.translation import ugettext as _
from rest_framework.decorators import action
from django.conf import settings
from rest_framework import generics
from rest_framework.response import Response
from rest_framework_bulk import BulkModelViewSet
@ -155,10 +155,17 @@ class UserViewSet(CommonApiMixin, UserQuerysetMixin, BulkModelViewSet):
serializer = serializer_cls(data=data, many=True)
serializer.is_valid(raise_exception=True)
validated_data = serializer.validated_data
users_by_role = defaultdict(list)
for i in validated_data:
i['org_id'] = current_org.org_id()
relations = [OrganizationMember(**i) for i in validated_data]
OrganizationMember.objects.bulk_create(relations, ignore_conflicts=True)
users_by_role[i['role']].append(i['user'])
OrganizationMember.objects.add_users_by_role(
current_org,
users=users_by_role[ORG_ROLE.USER],
admins=users_by_role[ORG_ROLE.ADMIN],
auditors=users_by_role[ORG_ROLE.AUDITOR]
)
return Response(serializer.data, status=201)
@action(methods=['post'], detail=True, permission_classes=(IsOrgAdmin,))

View File

@ -667,6 +667,11 @@ class User(AuthMixin, TokenMixin, RoleMixin, MFAMixin, AbstractUser):
else:
return user_default
def unblock_login(self):
from users.utils import LoginBlockUtil, MFABlockUtils
LoginBlockUtil.unblock_user(self.username)
MFABlockUtils.unblock_user(self.username)
@property
def login_blocked(self):
from users.utils import LoginBlockUtil, MFABlockUtils