[v3] perf: migrate gateway to asset (#8928)

* perf: migrate gateway to asset

* perf: asset discriminate gateway

Co-authored-by: feng626 <1304903146@qq.com>
pull/9115/head
fit2bot 2022-11-22 17:33:09 +08:00 committed by GitHub
parent 873b81e639
commit 1a204618f7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 258 additions and 151 deletions

View File

@ -6,13 +6,11 @@ from rest_framework.decorators import action
from rest_framework.response import Response
from assets import serializers
from assets.models import Asset
from assets.filters import IpInFilterBackend, LabelFilterBackend, NodeFilterBackend
from assets.models import Asset, Gateway
from assets.tasks import (
push_accounts_to_assets,
test_assets_connectivity_manual,
update_assets_hardware_info_manual,
verify_accounts_connectivity,
push_accounts_to_assets, test_assets_connectivity_manual,
update_assets_hardware_info_manual, verify_accounts_connectivity,
)
from common.drf.filters import BaseFilterSet
from common.mixins.api import SuggestionMixin
@ -74,7 +72,7 @@ class AssetViewSet(SuggestionMixin, NodeFilterMixin, OrgBulkModelViewSet):
def gateways(self, *args, **kwargs):
asset = self.get_object()
if not asset.domain:
gateways = Gateway.objects.none()
gateways = Asset.objects.none()
else:
gateways = asset.domain.gateways.filter(protocol="ssh")
return self.get_paginated_response_from_queryset(gateways)

View File

@ -1,4 +1,3 @@
from assets.models import Host
from assets.serializers import HostSerializer
from .asset import AssetViewSet

View File

@ -7,21 +7,20 @@ from rest_framework.serializers import ValidationError
from common.utils import get_logger
from orgs.mixins.api import OrgBulkModelViewSet
from ..models import Domain, Gateway
from ..models import Domain, Host
from .. import serializers
logger = get_logger(__file__)
__all__ = ['DomainViewSet', 'GatewayViewSet', "GatewayTestConnectionApi"]
class DomainViewSet(OrgBulkModelViewSet):
model = Domain
filterset_fields = ("name", )
filterset_fields = ("name",)
search_fields = filterset_fields
serializer_class = serializers.DomainSerializer
ordering_fields = ('name',)
ordering = ('name', )
ordering = ('name',)
def get_serializer_class(self):
if self.request.query_params.get('gateway'):
@ -30,21 +29,26 @@ class DomainViewSet(OrgBulkModelViewSet):
class GatewayViewSet(OrgBulkModelViewSet):
model = Gateway
filterset_fields = ("domain__name", "name", "username", "domain")
search_fields = ("domain__name", "name", "username", )
filterset_fields = ("domain__name", "name", "domain")
search_fields = ("domain__name",)
serializer_class = serializers.GatewaySerializer
def get_queryset(self):
queryset = Host.get_gateway_queryset()
return queryset
class GatewayTestConnectionApi(SingleObjectMixin, APIView):
queryset = Gateway.objects.all()
object = None
rbac_perms = {
'POST': 'assets.test_gateway'
}
def get_queryset(self):
queryset = Host.get_gateway_queryset()
return queryset
def post(self, request, *args, **kwargs):
self.object = self.get_object(Gateway.objects.all())
self.object = self.get_object()
local_port = self.request.data.get('port') or self.object.port
try:
local_port = int(local_port)

View File

@ -1,3 +1,5 @@
from .base import *
from .host import *
from .types import *
from .account import *
from .protocol import *

View File

@ -1,5 +1,7 @@
from .base import BaseType
GATEWAY_NAME = 'Gateway'
class HostTypes(BaseType):
LINUX = 'linux', 'Linux'
@ -67,7 +69,7 @@ class HostTypes(BaseType):
return {
cls.LINUX: [
{'name': 'Linux'},
{'name': 'Gateway'}
{'name': GATEWAY_NAME}
],
cls.UNIX: [
{'name': 'Unix'},

View File

@ -0,0 +1,73 @@
# Generated by Django 3.2.13 on 2022-09-29 11:03
from django.db import migrations
from assets.const.host import GATEWAY_NAME
def _create_account_obj(secret, secret_type, gateway, asset, account_model):
return account_model(
asset=asset,
secret=secret,
org_id=gateway.org_id,
secret_type=secret_type,
username=gateway.username,
name=f'{gateway.name}-{secret_type}-{GATEWAY_NAME.lower()}',
)
def migrate_gateway_to_asset(apps, schema_editor):
db_alias = schema_editor.connection.alias
gateway_model = apps.get_model('assets', 'Gateway')
platform_model = apps.get_model('assets', 'Platform')
gateway_platform = platform_model.objects.using(db_alias).get(name=GATEWAY_NAME)
print('>>> migrate gateway to asset')
asset_dict = {}
host_model = apps.get_model('assets', 'Host')
asset_model = apps.get_model('assets', 'Asset')
protocol_model = apps.get_model('assets', 'Protocol')
gateways = gateway_model.objects.all()
for gateway in gateways:
comment = gateway.comment if gateway.comment else ''
data = {
'comment': comment,
'name': f'{gateway.name}-{GATEWAY_NAME.lower()}',
'address': gateway.ip,
'domain': gateway.domain,
'org_id': gateway.org_id,
'is_active': gateway.is_active,
'platform': gateway_platform,
}
asset = asset_model.objects.using(db_alias).create(**data)
asset_dict[gateway.id] = asset
protocol_model.objects.using(db_alias).create(name='ssh', port=gateway.port, asset=asset)
hosts = [host_model(asset_ptr=asset) for asset in asset_dict.values()]
host_model.objects.using(db_alias).bulk_create(hosts, ignore_conflicts=True)
print('>>> migrate gateway to account')
accounts = []
account_model = apps.get_model('assets', 'Account')
for gateway in gateways:
password = gateway.password
private_key = gateway.private_key
asset = asset_dict[gateway.id]
if password:
accounts.append(_create_account_obj(
password, 'password', gateway, asset, account_model
))
if private_key:
accounts.append(_create_account_obj(
private_key, 'ssh_key', gateway, asset, account_model
))
account_model.objects.using(db_alias).bulk_create(accounts)
class Migration(migrations.Migration):
dependencies = [
('assets', '0111_alter_automationexecution_status'),
]
operations = [
migrations.RunPython(migrate_gateway_to_asset),
]

View File

@ -2,8 +2,8 @@
# -*- coding: utf-8 -*-
#
import logging
import uuid
import logging
from collections import defaultdict
from django.db import models

View File

@ -1,6 +1,13 @@
from assets.const import Category
from assets.const import GATEWAY_NAME
from .common import Asset
class Host(Asset):
pass
@classmethod
def get_gateway_queryset(cls):
queryset = cls.objects.filter(
platform__name=GATEWAY_NAME
)
return queryset

View File

@ -6,10 +6,10 @@ import sshpubkeys
from hashlib import md5
from django.db import models
from django.utils import timezone
from django.utils.translation import ugettext_lazy as _
from django.conf import settings
from django.utils import timezone
from django.db.models import QuerySet
from django.utils.translation import ugettext_lazy as _
from common.utils import (
ssh_key_string_to_obj, ssh_key_gen, get_logger,

View File

@ -1,22 +1,24 @@
# -*- coding: utf-8 -*-
#
import socket
import uuid
import socket
import random
from django.core.cache import cache
import paramiko
from django.db import models
from django.core.cache import cache
from django.db.models.query import QuerySet
from django.utils.translation import ugettext_lazy as _
from common.utils import get_logger, lazyproperty
from common.db import fields
from common.utils import get_logger, lazyproperty
from orgs.mixins.models import OrgModelMixin
from .base import BaseAccount
from ..const import SecretType, GATEWAY_NAME
logger = get_logger(__file__)
__all__ = ['Domain', 'Gateway']
__all__ = ['Domain', 'GatewayMixin']
class Domain(OrgModelMixin):
@ -33,12 +35,9 @@ class Domain(OrgModelMixin):
def __str__(self):
return self.name
def has_gateway(self):
return self.gateway_set.filter(is_active=True).exists()
@lazyproperty
def gateways(self):
return self.gateway_set.filter(is_active=True)
return self.assets.filter(platform__name=GATEWAY_NAME, is_active=True)
def select_gateway(self):
return self.random_gateway()
@ -53,18 +52,141 @@ class Domain(OrgModelMixin):
return random.choice(self.gateways)
class Gateway(BaseAccount):
UNCONNECTIVE_KEY_TMPL = 'asset_unconnective_gateway_{}'
UNCONNECTIVE_SILENCE_PERIOD_KEY_TMPL = 'asset_unconnective_gateway_silence_period_{}'
UNCONNECTIVE_SILENCE_PERIOD_BEGIN_VALUE = 60 * 5
class GatewayMixin:
id: uuid.UUID
port: int
address: str
accounts: QuerySet
private_key_path: str
private_key_obj: paramiko.RSAKey
UNCONNECTED_KEY_TMPL = 'asset_unconnective_gateway_{}'
UNCONNECTED_SILENCE_PERIOD_KEY_TMPL = 'asset_unconnective_gateway_silence_period_{}'
UNCONNECTED_SILENCE_PERIOD_BEGIN_VALUE = 60 * 5
def set_unconnected(self):
unconnected_key = self.UNCONNECTED_KEY_TMPL.format(self.id)
unconnected_silence_period_key = self.UNCONNECTED_SILENCE_PERIOD_KEY_TMPL.format(self.id)
unconnected_silence_period = cache.get(
unconnected_silence_period_key, self.UNCONNECTED_SILENCE_PERIOD_BEGIN_VALUE
)
cache.set(unconnected_silence_period_key, unconnected_silence_period * 2)
cache.set(unconnected_key, unconnected_silence_period, unconnected_silence_period)
def set_connective(self):
unconnected_key = self.UNCONNECTED_KEY_TMPL.format(self.id)
unconnected_silence_period_key = self.UNCONNECTED_SILENCE_PERIOD_KEY_TMPL.format(self.id)
cache.delete(unconnected_key)
cache.delete(unconnected_silence_period_key)
def get_is_unconnected(self):
unconnected_key = self.UNCONNECTED_KEY_TMPL.format(self.id)
return cache.get(unconnected_key, False)
@property
def is_connective(self):
return not self.get_is_unconnected()
@is_connective.setter
def is_connective(self, value):
if value:
self.set_connective()
else:
self.set_unconnected()
def test_connective(self, local_port=None):
# TODO 走ansible runner
if local_port is None:
local_port = self.port
client = paramiko.SSHClient()
client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
proxy = paramiko.SSHClient()
proxy.set_missing_host_key_policy(paramiko.AutoAddPolicy())
try:
proxy.connect(self.address, port=self.port,
username=self.username,
password=self.password,
pkey=self.private_key_obj)
except(paramiko.AuthenticationException,
paramiko.BadAuthenticationType,
paramiko.SSHException,
paramiko.ChannelException,
paramiko.ssh_exception.NoValidConnectionsError,
socket.gaierror) as e:
err = str(e)
if err.startswith('[Errno None] Unable to connect to port'):
err = _('Unable to connect to port {port} on {address}')
err = err.format(port=self.port, ip=self.address)
elif err == 'Authentication failed.':
err = _('Authentication failed')
elif err == 'Connect failed':
err = _('Connect failed')
self.is_connective = False
return False, err
try:
sock = proxy.get_transport().open_channel(
'direct-tcpip', ('127.0.0.1', local_port), ('127.0.0.1', 0)
)
client.connect("127.0.0.1", port=local_port,
username=self.username,
password=self.password,
key_filename=self.private_key_path,
sock=sock,
timeout=5)
except (paramiko.SSHException,
paramiko.ssh_exception.SSHException,
paramiko.ChannelException,
paramiko.AuthenticationException,
TimeoutError) as e:
err = getattr(e, 'text', str(e))
if err == 'Connect failed':
err = _('Connect failed')
self.is_connective = False
return False, err
finally:
client.close()
self.is_connective = True
return True, None
@lazyproperty
def username(self):
account = self.accounts.all().first()
if account:
return account.username
logger.error(f'Gateway {self} has no account')
return ''
def get_secret(self, secret_type):
account = self.accounts.filter(secret_type=secret_type).first()
if account:
return account.secret
logger.error(f'Gateway {self} has no {secret_type} account')
@lazyproperty
def password(self):
secret_type = SecretType.PASSWORD
return self.get_secret(secret_type)
@lazyproperty
def private_key(self):
secret_type = SecretType.SSH_KEY
return self.get_secret(secret_type)
class Gateway(BaseAccount):
class Protocol(models.TextChoices):
ssh = 'ssh', 'SSH'
name = models.CharField(max_length=128, verbose_name='Name')
ip = models.CharField(max_length=128, verbose_name=_('IP'), db_index=True)
port = models.IntegerField(default=22, verbose_name=_('Port'))
protocol = models.CharField(choices=Protocol.choices, max_length=16, default=Protocol.ssh, verbose_name=_("Protocol"))
protocol = models.CharField(
choices=Protocol.choices, max_length=16, default=Protocol.ssh, verbose_name=_("Protocol")
)
domain = models.ForeignKey(Domain, on_delete=models.CASCADE, verbose_name=_("Domain"))
comment = models.CharField(max_length=128, blank=True, null=True, verbose_name=_("Comment"))
is_active = models.BooleanField(default=True, verbose_name=_("Is active"))
@ -85,91 +207,3 @@ class Gateway(BaseAccount):
permissions = [
('test_gateway', _('Test gateway'))
]
def set_unconnective(self):
unconnective_key = self.UNCONNECTIVE_KEY_TMPL.format(self.id)
unconnective_silence_period_key = self.UNCONNECTIVE_SILENCE_PERIOD_KEY_TMPL.format(self.id)
unconnective_silence_period = cache.get(unconnective_silence_period_key,
self.UNCONNECTIVE_SILENCE_PERIOD_BEGIN_VALUE)
cache.set(unconnective_silence_period_key, unconnective_silence_period * 2)
cache.set(unconnective_key, unconnective_silence_period, unconnective_silence_period)
def set_connective(self):
unconnective_key = self.UNCONNECTIVE_KEY_TMPL.format(self.id)
unconnective_silence_period_key = self.UNCONNECTIVE_SILENCE_PERIOD_KEY_TMPL.format(self.id)
cache.delete(unconnective_key)
cache.delete(unconnective_silence_period_key)
def get_is_unconnective(self):
unconnective_key = self.UNCONNECTIVE_KEY_TMPL.format(self.id)
return cache.get(unconnective_key, False)
@property
def is_connective(self):
return not self.get_is_unconnective()
@is_connective.setter
def is_connective(self, value):
if value:
self.set_connective()
else:
self.set_unconnective()
def test_connective(self, local_port=None):
if local_port is None:
local_port = self.port
client = paramiko.SSHClient()
client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
proxy = paramiko.SSHClient()
proxy.set_missing_host_key_policy(paramiko.AutoAddPolicy())
try:
proxy.connect(self.ip, port=self.port,
username=self.username,
password=self.password,
pkey=self.private_key_obj)
except(paramiko.AuthenticationException,
paramiko.BadAuthenticationType,
paramiko.SSHException,
paramiko.ChannelException,
paramiko.ssh_exception.NoValidConnectionsError,
socket.gaierror) as e:
err = str(e)
if err.startswith('[Errno None] Unable to connect to port'):
err = _('Unable to connect to port {port} on {address}')
err = err.format(port=self.port, ip=self.ip)
elif err == 'Authentication failed.':
err = _('Authentication failed')
elif err == 'Connect failed':
err = _('Connect failed')
self.is_connective = False
return False, err
try:
sock = proxy.get_transport().open_channel(
'direct-tcpip', ('127.0.0.1', local_port), ('127.0.0.1', 0)
)
client.connect("127.0.0.1", port=local_port,
username=self.username,
password=self.password,
key_filename=self.private_key_file,
sock=sock,
timeout=5)
except (paramiko.SSHException,
paramiko.ssh_exception.SSHException,
paramiko.ChannelException,
paramiko.AuthenticationException,
TimeoutError) as e:
err = getattr(e, 'text', str(e))
if err == 'Connect failed':
err = _('Connect failed')
self.is_connective = False
return False, err
finally:
client.close()
self.is_connective = True
return True, None

View File

@ -3,11 +3,9 @@
from rest_framework import serializers
from django.utils.translation import ugettext_lazy as _
from common.validators import alphanumeric
from orgs.mixins.serializers import BulkOrgResourceModelSerializer
from common.drf.serializers import SecretReadableMixin
from ..models import Domain, Gateway
from .base import AuthValidateMixin
from ..models import Domain, Asset
class DomainSerializer(BulkOrgResourceModelSerializer):
@ -35,32 +33,23 @@ class DomainSerializer(BulkOrgResourceModelSerializer):
@staticmethod
def get_gateway_count(obj):
return obj.gateway_set.all().count()
return obj.gateways.count()
class GatewaySerializer(AuthValidateMixin, BulkOrgResourceModelSerializer):
class GatewaySerializer(BulkOrgResourceModelSerializer):
is_connective = serializers.BooleanField(required=False, label=_('Connectivity'))
class Meta:
model = Gateway
fields_mini = ['id', 'username']
fields_write_only = [
'password', 'private_key', 'public_key', 'passphrase'
]
fields_small = fields_mini + fields_write_only + [
'ip', 'port', 'protocol',
model = Asset
fields_mini = ['id']
fields_small = fields_mini + [
'address', 'port', 'protocol',
'is_active', 'is_connective',
'date_created', 'date_updated',
'created_by', 'comment',
]
fields_fk = ['domain']
fields = fields_small + fields_fk
extra_kwargs = {
'username': {"validators": [alphanumeric]},
'password': {'write_only': True},
'private_key': {"write_only": True},
'public_key': {"write_only": True},
}
class GatewayWithAuthSerializer(SecretReadableMixin, GatewaySerializer):

View File

@ -1,7 +1,7 @@
from django.utils.translation import ugettext_lazy as _
from rest_framework import serializers
from assets.models import Asset, Gateway, Domain, CommandFilterRule, Account, Platform
from assets.models import Asset, Domain, CommandFilterRule, Account, Platform
from authentication.models import ConnectionToken
from common.utils import pretty_string
from common.utils.random import random_string
@ -130,8 +130,8 @@ class ConnectionTokenGatewaySerializer(serializers.ModelSerializer):
""" Gateway """
class Meta:
model = Gateway
fields = ['id', 'ip', 'port', 'username', 'password', 'private_key']
model = Asset
fields = ['id', 'address', 'port', 'username', 'password', 'private_key']
class ConnectionTokenDomainSerializer(serializers.ModelSerializer):

View File

@ -14,7 +14,7 @@ from .serializers import (
)
from users.models import User, UserGroup
from assets.models import (
Asset, Domain, Label, Node, Gateway,
Asset, Domain, Label, Node,
CommandFilter, CommandFilterRule, GatheredUser
)
from perms.models import AssetPermission
@ -27,7 +27,7 @@ logger = get_logger(__file__)
# 部分 org 相关的 model需要清空这些数据之后才能删除该组织
org_related_models = [
User, UserGroup, Asset, Label, Domain, Gateway, Node, Label,
User, UserGroup, Asset, Label, Domain, Node, Label,
CommandFilter, CommandFilterRule, GatheredUser,
AssetPermission,
]

View File

@ -6,7 +6,7 @@ from orgs.utils import current_org, tmp_to_org
from common.cache import Cache, IntegerField
from common.utils import get_logger
from users.models import UserGroup, User
from assets.models import Node, Domain, Gateway, Asset, Account
from assets.models import Node, Domain, Asset, Account
from terminal.models import Session
from perms.models import AssetPermission
@ -54,7 +54,7 @@ class OrgResourceStatisticsCache(OrgRelatedCache):
nodes_amount = IntegerField(queryset=Node.objects)
accounts_amount = IntegerField(queryset=Account.objects)
domains_amount = IntegerField(queryset=Domain.objects)
gateways_amount = IntegerField(queryset=Gateway.objects)
# gateways_amount = IntegerField(queryset=Gateway.objects)
asset_perms_amount = IntegerField(queryset=AssetPermission.objects)
total_count_online_users = IntegerField()

View File

@ -8,7 +8,7 @@ from users.models import UserGroup, User
from users.signals import pre_user_leave_org
from terminal.models import Session
from rbac.models import OrgRoleBinding, SystemRoleBinding, RoleBinding
from assets.models import Asset, Domain, Gateway
from assets.models import Asset, Domain
from orgs.caches import OrgResourceStatisticsCache
from orgs.utils import current_org
from common.utils import get_logger
@ -75,7 +75,6 @@ def on_user_delete_refresh_cache(sender, instance, **kwargs):
class OrgResourceStatisticsRefreshUtil:
model_cache_field_mapper = {
AssetPermission: ['asset_perms_amount'],
Gateway: ['gateways_amount'],
Domain: ['domains_amount'],
Node: ['nodes_amount'],
Asset: ['assets_amount'],