# -*- coding: utf-8 -*- # import json import asyncio from channels.generic.websocket import AsyncJsonWebsocketConsumer from django.core.cache import cache from django.conf import settings from django.utils.translation import gettext_lazy as _ from common.db.utils import close_old_connections from common.utils import get_logger from settings.serializers import ( LDAPTestConfigSerializer, LDAPTestLoginSerializer ) from orgs.models import Organization from orgs.utils import current_org from settings.tasks import sync_ldap_user from settings.utils import ( LDAPServerUtil, LDAPCacheUtil, LDAPImportUtil, LDAPSyncUtil, LDAP_USE_CACHE_FLAGS, LDAPTestUtil ) from .tools import ( verbose_ping, verbose_telnet, verbose_nmap, verbose_tcpdump, verbose_traceroute ) logger = get_logger(__name__) CACHE_KEY_LDAP_TEST_CONFIG_TASK_STATUS = 'CACHE_KEY_LDAP_TEST_CONFIG_TASK_STATUS' TASK_STATUS_IS_OVER = 'OVER' class ToolsWebsocket(AsyncJsonWebsocketConsumer): async def connect(self): user = self.scope["user"] if user.is_authenticated: await self.accept() else: await self.close() async def send_msg(self, msg=''): await self.send_json({'msg': f'{msg}\r\n'}) async def imitate_ping(self, dest_ips, timeout=3, count=5, psize=64): params = { 'dest_ips': dest_ips, 'timeout': timeout, 'count': count, 'psize': psize } logger.info(f'Receive request ping: {params}') await verbose_ping(display=self.send_msg, **params) async def imitate_telnet(self, dest_ips, dest_port=23, timeout=10): params = { 'dest_ips': dest_ips, 'dest_port': dest_port, 'timeout': timeout, } logger.info(f'Receive request telnet: {params}') await verbose_telnet(display=self.send_msg, **params) async def imitate_nmap(self, dest_ips, dest_ports=None, timeout=None): params = { 'dest_ips': dest_ips, 'dest_ports': dest_ports, 'timeout': timeout, } logger.info(f'Receive request nmap: {params}') await verbose_nmap(display=self.send_msg, **params) async def imitate_tcpdump( self, interfaces=None, src_ips='', src_ports='', dest_ips='', dest_ports='' ): params = { 'interfaces': interfaces, 'src_ips': src_ips, 'src_ports': src_ports, 'dest_ips': dest_ips, 'dest_ports': dest_ports } logger.info(f'Receive request tcpdump: {params}') await verbose_tcpdump(display=self.send_msg, **params) async def imitate_traceroute(self, dest_ips): params = {'dest_ips': dest_ips} await verbose_traceroute(display=self.send_msg, **params) async def receive(self, text_data=None, bytes_data=None, **kwargs): data = json.loads(text_data) tool_type = data.pop('tool_type', 'Ping') try: tool_func = getattr(self, f'imitate_{tool_type.lower()}') await tool_func(**data) except Exception as error: await self.send_msg('Exception: %s' % error) await self.send_msg() await self.close() async def disconnect(self, code): await self.close() close_old_connections() class LdapWebsocket(AsyncJsonWebsocketConsumer): async def connect(self): user = self.scope["user"] if user.is_authenticated: await self.accept() else: await self.close() async def receive(self, text_data=None, bytes_data=None, **kwargs): data = json.loads(text_data) msg_type = data.pop('msg_type', 'testing_config') try: tool_func = getattr(self, f'run_{msg_type.lower()}') ok, msg = await asyncio.to_thread(tool_func, data) await self.send_msg(ok, msg) except Exception as error: await self.send_msg(msg='Exception: %s' % error) async def send_msg(self, ok=True, msg=''): await self.send_json({'ok': ok, 'msg': f'{msg}'}) async def disconnect(self, code): await self.close() close_old_connections() @staticmethod def get_ldap_config(serializer): server_uri = serializer.validated_data["AUTH_LDAP_SERVER_URI"] bind_dn = serializer.validated_data["AUTH_LDAP_BIND_DN"] password = serializer.validated_data["AUTH_LDAP_BIND_PASSWORD"] use_ssl = serializer.validated_data.get("AUTH_LDAP_START_TLS", False) search_ou = serializer.validated_data["AUTH_LDAP_SEARCH_OU"] search_filter = serializer.validated_data["AUTH_LDAP_SEARCH_FILTER"] attr_map = serializer.validated_data["AUTH_LDAP_USER_ATTR_MAP"] auth_ldap = serializer.validated_data.get('AUTH_LDAP', False) if not password: password = settings.AUTH_LDAP_BIND_PASSWORD config = { 'server_uri': server_uri, 'bind_dn': bind_dn, 'password': password, 'use_ssl': use_ssl, 'search_ou': search_ou, 'search_filter': search_filter, 'attr_map': attr_map, 'auth_ldap': auth_ldap } return config @staticmethod def task_is_over(task_key): return cache.get(task_key) == TASK_STATUS_IS_OVER @staticmethod def set_task_status_over(task_key, ttl=120): cache.set(task_key, TASK_STATUS_IS_OVER, ttl) def run_testing_config(self, data): serializer = LDAPTestConfigSerializer(data=data) if not serializer.is_valid(): self.send_msg(msg=f'error: {str(serializer.errors)}') config = self.get_ldap_config(serializer) ok, msg = LDAPTestUtil(config).test_config() if ok: self.set_task_status_over(CACHE_KEY_LDAP_TEST_CONFIG_TASK_STATUS) return ok, msg def run_testing_login(self, data): serializer = LDAPTestLoginSerializer(data=data) if not serializer.is_valid(): self.send_msg(msg=f'error: {str(serializer.errors)}') username = serializer.validated_data['username'] password = serializer.validated_data['password'] ok, msg = LDAPTestUtil().test_login(username, password) return ok, msg @staticmethod def run_sync_user(data): sync_util = LDAPSyncUtil() sync_util.clear_cache() sync_ldap_user() msg = sync_util.get_task_error_msg() ok = False if msg else True return ok, msg def run_import_user(self, data): ok = False org_ids = data.get('org_ids') username_list = data.get('username_list', []) cache_police = data.get('cache_police', True) try: users = self.get_ldap_users(username_list, cache_police) if users is None: msg = _('Get ldap users is None') orgs = self.get_orgs(org_ids) new_users, error_msg = LDAPImportUtil().perform_import(users, orgs) if error_msg: msg = error_msg count = users if users is None else len(users) orgs_name = ', '.join([str(org) for org in orgs]) ok = True msg = _('Imported {} users successfully (Organization: {})').format(count, orgs_name) except Exception as e: msg = str(e) return ok, msg @staticmethod def get_orgs(org_ids): if org_ids: orgs = list(Organization.objects.filter(id__in=org_ids)) else: orgs = [current_org] return orgs @staticmethod def get_ldap_users(username_list, cache_police): if '*' in username_list: users = LDAPServerUtil().search() elif cache_police in LDAP_USE_CACHE_FLAGS: users = LDAPCacheUtil().search(search_users=username_list) else: users = LDAPServerUtil().search(search_users=username_list) return users