diff --git a/apps/assets/api/asset/asset.py b/apps/assets/api/asset/asset.py index ad04966e3..f6cc509b3 100644 --- a/apps/assets/api/asset/asset.py +++ b/apps/assets/api/asset/asset.py @@ -6,13 +6,11 @@ from rest_framework.decorators import action from rest_framework.response import Response from assets import serializers +from assets.models import Asset from assets.filters import IpInFilterBackend, LabelFilterBackend, NodeFilterBackend -from assets.models import Asset, Gateway from assets.tasks import ( - push_accounts_to_assets, - test_assets_connectivity_manual, - update_assets_hardware_info_manual, - verify_accounts_connectivity, + push_accounts_to_assets, test_assets_connectivity_manual, + update_assets_hardware_info_manual, verify_accounts_connectivity, ) from common.drf.filters import BaseFilterSet from common.mixins.api import SuggestionMixin @@ -74,7 +72,7 @@ class AssetViewSet(SuggestionMixin, NodeFilterMixin, OrgBulkModelViewSet): def gateways(self, *args, **kwargs): asset = self.get_object() if not asset.domain: - gateways = Gateway.objects.none() + gateways = Asset.objects.none() else: gateways = asset.domain.gateways.filter(protocol="ssh") return self.get_paginated_response_from_queryset(gateways) diff --git a/apps/assets/api/asset/host.py b/apps/assets/api/asset/host.py index 3094e15a8..fbc2e997c 100644 --- a/apps/assets/api/asset/host.py +++ b/apps/assets/api/asset/host.py @@ -1,4 +1,3 @@ - from assets.models import Host from assets.serializers import HostSerializer from .asset import AssetViewSet diff --git a/apps/assets/api/domain.py b/apps/assets/api/domain.py index 39f2b44f9..bb705322c 100644 --- a/apps/assets/api/domain.py +++ b/apps/assets/api/domain.py @@ -7,21 +7,20 @@ from rest_framework.serializers import ValidationError from common.utils import get_logger from orgs.mixins.api import OrgBulkModelViewSet -from ..models import Domain, Gateway +from ..models import Domain, Host from .. import serializers - logger = get_logger(__file__) __all__ = ['DomainViewSet', 'GatewayViewSet', "GatewayTestConnectionApi"] class DomainViewSet(OrgBulkModelViewSet): model = Domain - filterset_fields = ("name", ) + filterset_fields = ("name",) search_fields = filterset_fields serializer_class = serializers.DomainSerializer ordering_fields = ('name',) - ordering = ('name', ) + ordering = ('name',) def get_serializer_class(self): if self.request.query_params.get('gateway'): @@ -30,21 +29,26 @@ class DomainViewSet(OrgBulkModelViewSet): class GatewayViewSet(OrgBulkModelViewSet): - model = Gateway - filterset_fields = ("domain__name", "name", "username", "domain") - search_fields = ("domain__name", "name", "username", ) + filterset_fields = ("domain__name", "name", "domain") + search_fields = ("domain__name",) serializer_class = serializers.GatewaySerializer + def get_queryset(self): + queryset = Host.get_gateway_queryset() + return queryset + class GatewayTestConnectionApi(SingleObjectMixin, APIView): - queryset = Gateway.objects.all() - object = None rbac_perms = { 'POST': 'assets.test_gateway' } + def get_queryset(self): + queryset = Host.get_gateway_queryset() + return queryset + def post(self, request, *args, **kwargs): - self.object = self.get_object(Gateway.objects.all()) + self.object = self.get_object() local_port = self.request.data.get('port') or self.object.port try: local_port = int(local_port) diff --git a/apps/assets/const/__init__.py b/apps/assets/const/__init__.py index 81115b412..bc30f388d 100644 --- a/apps/assets/const/__init__.py +++ b/apps/assets/const/__init__.py @@ -1,3 +1,5 @@ +from .base import * +from .host import * from .types import * from .account import * from .protocol import * diff --git a/apps/assets/const/host.py b/apps/assets/const/host.py index 371ab3688..8be44db6f 100644 --- a/apps/assets/const/host.py +++ b/apps/assets/const/host.py @@ -1,5 +1,7 @@ from .base import BaseType +GATEWAY_NAME = 'Gateway' + class HostTypes(BaseType): LINUX = 'linux', 'Linux' @@ -67,7 +69,7 @@ class HostTypes(BaseType): return { cls.LINUX: [ {'name': 'Linux'}, - {'name': 'Gateway'} + {'name': GATEWAY_NAME} ], cls.UNIX: [ {'name': 'Unix'}, diff --git a/apps/assets/migrations/0112_gateway_to_asset.py b/apps/assets/migrations/0112_gateway_to_asset.py new file mode 100644 index 000000000..da43b3a84 --- /dev/null +++ b/apps/assets/migrations/0112_gateway_to_asset.py @@ -0,0 +1,73 @@ +# Generated by Django 3.2.13 on 2022-09-29 11:03 + +from django.db import migrations +from assets.const.host import GATEWAY_NAME + + +def _create_account_obj(secret, secret_type, gateway, asset, account_model): + return account_model( + asset=asset, + secret=secret, + org_id=gateway.org_id, + secret_type=secret_type, + username=gateway.username, + name=f'{gateway.name}-{secret_type}-{GATEWAY_NAME.lower()}', + ) + + +def migrate_gateway_to_asset(apps, schema_editor): + db_alias = schema_editor.connection.alias + gateway_model = apps.get_model('assets', 'Gateway') + platform_model = apps.get_model('assets', 'Platform') + gateway_platform = platform_model.objects.using(db_alias).get(name=GATEWAY_NAME) + + print('>>> migrate gateway to asset') + asset_dict = {} + host_model = apps.get_model('assets', 'Host') + asset_model = apps.get_model('assets', 'Asset') + protocol_model = apps.get_model('assets', 'Protocol') + gateways = gateway_model.objects.all() + for gateway in gateways: + comment = gateway.comment if gateway.comment else '' + data = { + 'comment': comment, + 'name': f'{gateway.name}-{GATEWAY_NAME.lower()}', + 'address': gateway.ip, + 'domain': gateway.domain, + 'org_id': gateway.org_id, + 'is_active': gateway.is_active, + 'platform': gateway_platform, + } + asset = asset_model.objects.using(db_alias).create(**data) + asset_dict[gateway.id] = asset + protocol_model.objects.using(db_alias).create(name='ssh', port=gateway.port, asset=asset) + hosts = [host_model(asset_ptr=asset) for asset in asset_dict.values()] + host_model.objects.using(db_alias).bulk_create(hosts, ignore_conflicts=True) + + print('>>> migrate gateway to account') + accounts = [] + account_model = apps.get_model('assets', 'Account') + for gateway in gateways: + password = gateway.password + private_key = gateway.private_key + asset = asset_dict[gateway.id] + if password: + accounts.append(_create_account_obj( + password, 'password', gateway, asset, account_model + )) + + if private_key: + accounts.append(_create_account_obj( + private_key, 'ssh_key', gateway, asset, account_model + )) + account_model.objects.using(db_alias).bulk_create(accounts) + + +class Migration(migrations.Migration): + dependencies = [ + ('assets', '0111_alter_automationexecution_status'), + ] + + operations = [ + migrations.RunPython(migrate_gateway_to_asset), + ] diff --git a/apps/assets/models/asset/common.py b/apps/assets/models/asset/common.py index 8ea75bc2a..c9baf8818 100644 --- a/apps/assets/models/asset/common.py +++ b/apps/assets/models/asset/common.py @@ -2,8 +2,8 @@ # -*- coding: utf-8 -*- # -import logging import uuid +import logging from collections import defaultdict from django.db import models diff --git a/apps/assets/models/asset/host.py b/apps/assets/models/asset/host.py index 4ce4be5c9..46aeed4f3 100644 --- a/apps/assets/models/asset/host.py +++ b/apps/assets/models/asset/host.py @@ -1,6 +1,13 @@ -from assets.const import Category +from assets.const import GATEWAY_NAME from .common import Asset class Host(Asset): pass + + @classmethod + def get_gateway_queryset(cls): + queryset = cls.objects.filter( + platform__name=GATEWAY_NAME + ) + return queryset diff --git a/apps/assets/models/base.py b/apps/assets/models/base.py index 7920d3798..90fb384e6 100644 --- a/apps/assets/models/base.py +++ b/apps/assets/models/base.py @@ -6,10 +6,10 @@ import sshpubkeys from hashlib import md5 from django.db import models -from django.utils import timezone -from django.utils.translation import ugettext_lazy as _ from django.conf import settings +from django.utils import timezone from django.db.models import QuerySet +from django.utils.translation import ugettext_lazy as _ from common.utils import ( ssh_key_string_to_obj, ssh_key_gen, get_logger, diff --git a/apps/assets/models/domain.py b/apps/assets/models/domain.py index 4abe8aa68..bf33caed6 100644 --- a/apps/assets/models/domain.py +++ b/apps/assets/models/domain.py @@ -1,22 +1,24 @@ # -*- coding: utf-8 -*- # -import socket import uuid +import socket import random - -from django.core.cache import cache import paramiko + from django.db import models +from django.core.cache import cache +from django.db.models.query import QuerySet from django.utils.translation import ugettext_lazy as _ -from common.utils import get_logger, lazyproperty from common.db import fields +from common.utils import get_logger, lazyproperty from orgs.mixins.models import OrgModelMixin from .base import BaseAccount +from ..const import SecretType, GATEWAY_NAME logger = get_logger(__file__) -__all__ = ['Domain', 'Gateway'] +__all__ = ['Domain', 'GatewayMixin'] class Domain(OrgModelMixin): @@ -33,12 +35,9 @@ class Domain(OrgModelMixin): def __str__(self): return self.name - def has_gateway(self): - return self.gateway_set.filter(is_active=True).exists() - @lazyproperty def gateways(self): - return self.gateway_set.filter(is_active=True) + return self.assets.filter(platform__name=GATEWAY_NAME, is_active=True) def select_gateway(self): return self.random_gateway() @@ -53,18 +52,141 @@ class Domain(OrgModelMixin): return random.choice(self.gateways) -class Gateway(BaseAccount): - UNCONNECTIVE_KEY_TMPL = 'asset_unconnective_gateway_{}' - UNCONNECTIVE_SILENCE_PERIOD_KEY_TMPL = 'asset_unconnective_gateway_silence_period_{}' - UNCONNECTIVE_SILENCE_PERIOD_BEGIN_VALUE = 60 * 5 +class GatewayMixin: + id: uuid.UUID + port: int + address: str + accounts: QuerySet + private_key_path: str + private_key_obj: paramiko.RSAKey + UNCONNECTED_KEY_TMPL = 'asset_unconnective_gateway_{}' + UNCONNECTED_SILENCE_PERIOD_KEY_TMPL = 'asset_unconnective_gateway_silence_period_{}' + UNCONNECTED_SILENCE_PERIOD_BEGIN_VALUE = 60 * 5 + def set_unconnected(self): + unconnected_key = self.UNCONNECTED_KEY_TMPL.format(self.id) + unconnected_silence_period_key = self.UNCONNECTED_SILENCE_PERIOD_KEY_TMPL.format(self.id) + unconnected_silence_period = cache.get( + unconnected_silence_period_key, self.UNCONNECTED_SILENCE_PERIOD_BEGIN_VALUE + ) + cache.set(unconnected_silence_period_key, unconnected_silence_period * 2) + cache.set(unconnected_key, unconnected_silence_period, unconnected_silence_period) + + def set_connective(self): + unconnected_key = self.UNCONNECTED_KEY_TMPL.format(self.id) + unconnected_silence_period_key = self.UNCONNECTED_SILENCE_PERIOD_KEY_TMPL.format(self.id) + + cache.delete(unconnected_key) + cache.delete(unconnected_silence_period_key) + + def get_is_unconnected(self): + unconnected_key = self.UNCONNECTED_KEY_TMPL.format(self.id) + return cache.get(unconnected_key, False) + + @property + def is_connective(self): + return not self.get_is_unconnected() + + @is_connective.setter + def is_connective(self, value): + if value: + self.set_connective() + else: + self.set_unconnected() + + def test_connective(self, local_port=None): + # TODO 走ansible runner + if local_port is None: + local_port = self.port + + client = paramiko.SSHClient() + client.set_missing_host_key_policy(paramiko.AutoAddPolicy()) + proxy = paramiko.SSHClient() + proxy.set_missing_host_key_policy(paramiko.AutoAddPolicy()) + + try: + proxy.connect(self.address, port=self.port, + username=self.username, + password=self.password, + pkey=self.private_key_obj) + except(paramiko.AuthenticationException, + paramiko.BadAuthenticationType, + paramiko.SSHException, + paramiko.ChannelException, + paramiko.ssh_exception.NoValidConnectionsError, + socket.gaierror) as e: + err = str(e) + if err.startswith('[Errno None] Unable to connect to port'): + err = _('Unable to connect to port {port} on {address}') + err = err.format(port=self.port, ip=self.address) + elif err == 'Authentication failed.': + err = _('Authentication failed') + elif err == 'Connect failed': + err = _('Connect failed') + self.is_connective = False + return False, err + + try: + sock = proxy.get_transport().open_channel( + 'direct-tcpip', ('127.0.0.1', local_port), ('127.0.0.1', 0) + ) + client.connect("127.0.0.1", port=local_port, + username=self.username, + password=self.password, + key_filename=self.private_key_path, + sock=sock, + timeout=5) + except (paramiko.SSHException, + paramiko.ssh_exception.SSHException, + paramiko.ChannelException, + paramiko.AuthenticationException, + TimeoutError) as e: + + err = getattr(e, 'text', str(e)) + if err == 'Connect failed': + err = _('Connect failed') + self.is_connective = False + return False, err + finally: + client.close() + self.is_connective = True + return True, None + + @lazyproperty + def username(self): + account = self.accounts.all().first() + if account: + return account.username + logger.error(f'Gateway {self} has no account') + return '' + + def get_secret(self, secret_type): + account = self.accounts.filter(secret_type=secret_type).first() + if account: + return account.secret + logger.error(f'Gateway {self} has no {secret_type} account') + + @lazyproperty + def password(self): + secret_type = SecretType.PASSWORD + return self.get_secret(secret_type) + + @lazyproperty + def private_key(self): + secret_type = SecretType.SSH_KEY + return self.get_secret(secret_type) + + +class Gateway(BaseAccount): class Protocol(models.TextChoices): ssh = 'ssh', 'SSH' name = models.CharField(max_length=128, verbose_name='Name') ip = models.CharField(max_length=128, verbose_name=_('IP'), db_index=True) port = models.IntegerField(default=22, verbose_name=_('Port')) - protocol = models.CharField(choices=Protocol.choices, max_length=16, default=Protocol.ssh, verbose_name=_("Protocol")) + protocol = models.CharField( + choices=Protocol.choices, max_length=16, default=Protocol.ssh, verbose_name=_("Protocol") + ) domain = models.ForeignKey(Domain, on_delete=models.CASCADE, verbose_name=_("Domain")) comment = models.CharField(max_length=128, blank=True, null=True, verbose_name=_("Comment")) is_active = models.BooleanField(default=True, verbose_name=_("Is active")) @@ -85,91 +207,3 @@ class Gateway(BaseAccount): permissions = [ ('test_gateway', _('Test gateway')) ] - - def set_unconnective(self): - unconnective_key = self.UNCONNECTIVE_KEY_TMPL.format(self.id) - unconnective_silence_period_key = self.UNCONNECTIVE_SILENCE_PERIOD_KEY_TMPL.format(self.id) - - unconnective_silence_period = cache.get(unconnective_silence_period_key, - self.UNCONNECTIVE_SILENCE_PERIOD_BEGIN_VALUE) - cache.set(unconnective_silence_period_key, unconnective_silence_period * 2) - cache.set(unconnective_key, unconnective_silence_period, unconnective_silence_period) - - def set_connective(self): - unconnective_key = self.UNCONNECTIVE_KEY_TMPL.format(self.id) - unconnective_silence_period_key = self.UNCONNECTIVE_SILENCE_PERIOD_KEY_TMPL.format(self.id) - - cache.delete(unconnective_key) - cache.delete(unconnective_silence_period_key) - - def get_is_unconnective(self): - unconnective_key = self.UNCONNECTIVE_KEY_TMPL.format(self.id) - return cache.get(unconnective_key, False) - - @property - def is_connective(self): - return not self.get_is_unconnective() - - @is_connective.setter - def is_connective(self, value): - if value: - self.set_connective() - else: - self.set_unconnective() - - def test_connective(self, local_port=None): - if local_port is None: - local_port = self.port - - client = paramiko.SSHClient() - client.set_missing_host_key_policy(paramiko.AutoAddPolicy()) - proxy = paramiko.SSHClient() - proxy.set_missing_host_key_policy(paramiko.AutoAddPolicy()) - - try: - proxy.connect(self.ip, port=self.port, - username=self.username, - password=self.password, - pkey=self.private_key_obj) - except(paramiko.AuthenticationException, - paramiko.BadAuthenticationType, - paramiko.SSHException, - paramiko.ChannelException, - paramiko.ssh_exception.NoValidConnectionsError, - socket.gaierror) as e: - err = str(e) - if err.startswith('[Errno None] Unable to connect to port'): - err = _('Unable to connect to port {port} on {address}') - err = err.format(port=self.port, ip=self.ip) - elif err == 'Authentication failed.': - err = _('Authentication failed') - elif err == 'Connect failed': - err = _('Connect failed') - self.is_connective = False - return False, err - - try: - sock = proxy.get_transport().open_channel( - 'direct-tcpip', ('127.0.0.1', local_port), ('127.0.0.1', 0) - ) - client.connect("127.0.0.1", port=local_port, - username=self.username, - password=self.password, - key_filename=self.private_key_file, - sock=sock, - timeout=5) - except (paramiko.SSHException, - paramiko.ssh_exception.SSHException, - paramiko.ChannelException, - paramiko.AuthenticationException, - TimeoutError) as e: - - err = getattr(e, 'text', str(e)) - if err == 'Connect failed': - err = _('Connect failed') - self.is_connective = False - return False, err - finally: - client.close() - self.is_connective = True - return True, None diff --git a/apps/assets/serializers/domain.py b/apps/assets/serializers/domain.py index 82fd9433f..37d17c814 100644 --- a/apps/assets/serializers/domain.py +++ b/apps/assets/serializers/domain.py @@ -3,11 +3,9 @@ from rest_framework import serializers from django.utils.translation import ugettext_lazy as _ -from common.validators import alphanumeric from orgs.mixins.serializers import BulkOrgResourceModelSerializer from common.drf.serializers import SecretReadableMixin -from ..models import Domain, Gateway -from .base import AuthValidateMixin +from ..models import Domain, Asset class DomainSerializer(BulkOrgResourceModelSerializer): @@ -35,32 +33,23 @@ class DomainSerializer(BulkOrgResourceModelSerializer): @staticmethod def get_gateway_count(obj): - return obj.gateway_set.all().count() + return obj.gateways.count() -class GatewaySerializer(AuthValidateMixin, BulkOrgResourceModelSerializer): +class GatewaySerializer(BulkOrgResourceModelSerializer): is_connective = serializers.BooleanField(required=False, label=_('Connectivity')) class Meta: - model = Gateway - fields_mini = ['id', 'username'] - fields_write_only = [ - 'password', 'private_key', 'public_key', 'passphrase' - ] - fields_small = fields_mini + fields_write_only + [ - 'ip', 'port', 'protocol', + model = Asset + fields_mini = ['id'] + fields_small = fields_mini + [ + 'address', 'port', 'protocol', 'is_active', 'is_connective', 'date_created', 'date_updated', 'created_by', 'comment', ] fields_fk = ['domain'] fields = fields_small + fields_fk - extra_kwargs = { - 'username': {"validators": [alphanumeric]}, - 'password': {'write_only': True}, - 'private_key': {"write_only": True}, - 'public_key': {"write_only": True}, - } class GatewayWithAuthSerializer(SecretReadableMixin, GatewaySerializer): diff --git a/apps/authentication/serializers/connection_token.py b/apps/authentication/serializers/connection_token.py index 5011d73a6..fd3895921 100644 --- a/apps/authentication/serializers/connection_token.py +++ b/apps/authentication/serializers/connection_token.py @@ -1,7 +1,7 @@ from django.utils.translation import ugettext_lazy as _ from rest_framework import serializers -from assets.models import Asset, Gateway, Domain, CommandFilterRule, Account, Platform +from assets.models import Asset, Domain, CommandFilterRule, Account, Platform from authentication.models import ConnectionToken from common.utils import pretty_string from common.utils.random import random_string @@ -130,8 +130,8 @@ class ConnectionTokenGatewaySerializer(serializers.ModelSerializer): """ Gateway """ class Meta: - model = Gateway - fields = ['id', 'ip', 'port', 'username', 'password', 'private_key'] + model = Asset + fields = ['id', 'address', 'port', 'username', 'password', 'private_key'] class ConnectionTokenDomainSerializer(serializers.ModelSerializer): diff --git a/apps/orgs/api.py b/apps/orgs/api.py index 16fde4d69..da9b1530f 100644 --- a/apps/orgs/api.py +++ b/apps/orgs/api.py @@ -14,7 +14,7 @@ from .serializers import ( ) from users.models import User, UserGroup from assets.models import ( - Asset, Domain, Label, Node, Gateway, + Asset, Domain, Label, Node, CommandFilter, CommandFilterRule, GatheredUser ) from perms.models import AssetPermission @@ -27,7 +27,7 @@ logger = get_logger(__file__) # 部分 org 相关的 model,需要清空这些数据之后才能删除该组织 org_related_models = [ - User, UserGroup, Asset, Label, Domain, Gateway, Node, Label, + User, UserGroup, Asset, Label, Domain, Node, Label, CommandFilter, CommandFilterRule, GatheredUser, AssetPermission, ] diff --git a/apps/orgs/caches.py b/apps/orgs/caches.py index 5df387c91..a17cd832b 100644 --- a/apps/orgs/caches.py +++ b/apps/orgs/caches.py @@ -6,7 +6,7 @@ from orgs.utils import current_org, tmp_to_org from common.cache import Cache, IntegerField from common.utils import get_logger from users.models import UserGroup, User -from assets.models import Node, Domain, Gateway, Asset, Account +from assets.models import Node, Domain, Asset, Account from terminal.models import Session from perms.models import AssetPermission @@ -54,7 +54,7 @@ class OrgResourceStatisticsCache(OrgRelatedCache): nodes_amount = IntegerField(queryset=Node.objects) accounts_amount = IntegerField(queryset=Account.objects) domains_amount = IntegerField(queryset=Domain.objects) - gateways_amount = IntegerField(queryset=Gateway.objects) + # gateways_amount = IntegerField(queryset=Gateway.objects) asset_perms_amount = IntegerField(queryset=AssetPermission.objects) total_count_online_users = IntegerField() diff --git a/apps/orgs/signal_handlers/cache.py b/apps/orgs/signal_handlers/cache.py index 1e3343931..b3e06362d 100644 --- a/apps/orgs/signal_handlers/cache.py +++ b/apps/orgs/signal_handlers/cache.py @@ -8,7 +8,7 @@ from users.models import UserGroup, User from users.signals import pre_user_leave_org from terminal.models import Session from rbac.models import OrgRoleBinding, SystemRoleBinding, RoleBinding -from assets.models import Asset, Domain, Gateway +from assets.models import Asset, Domain from orgs.caches import OrgResourceStatisticsCache from orgs.utils import current_org from common.utils import get_logger @@ -75,7 +75,6 @@ def on_user_delete_refresh_cache(sender, instance, **kwargs): class OrgResourceStatisticsRefreshUtil: model_cache_field_mapper = { AssetPermission: ['asset_perms_amount'], - Gateway: ['gateways_amount'], Domain: ['domains_amount'], Node: ['nodes_amount'], Asset: ['assets_amount'],