diff --git a/spug_api/apps/host/extend.py b/spug_api/apps/host/extend.py
index 45d3a46..6172176 100644
--- a/spug_api/apps/host/extend.py
+++ b/spug_api/apps/host/extend.py
@@ -20,8 +20,8 @@ class ExtendView(View):
return json_response(error='未找到指定主机')
if not host.is_verified:
return json_response(error='该主机还未验证')
- ssh = host.get_ssh()
- response = fetch_host_extend(ssh)
+ with host.get_ssh() as ssh:
+ response = fetch_host_extend(ssh)
return json_response(response)
return json_response(error=error)
diff --git a/spug_api/apps/host/utils.py b/spug_api/apps/host/utils.py
index feab7d5..cab03fe 100644
--- a/spug_api/apps/host/utils.py
+++ b/spug_api/apps/host/utils.py
@@ -186,68 +186,67 @@ def fetch_host_extend(ssh):
public_ip_address = set()
private_ip_address = set()
response = {'disk': []}
- with ssh:
- code, out = ssh.exec_command_raw('nproc')
- if code != 0:
- code, out = ssh.exec_command_raw("grep -c '^processor' /proc/cpuinfo")
- if code == 0:
- response['cpu'] = int(out.strip())
+ code, out = ssh.exec_command_raw('nproc')
+ if code != 0:
+ code, out = ssh.exec_command_raw("grep -c '^processor' /proc/cpuinfo")
+ if code == 0:
+ response['cpu'] = int(out.strip())
- code, out = ssh.exec_command_raw("cat /etc/os-release | grep PRETTY_NAME | awk -F \\\" '{print $2}'")
- if '/etc/os-release' in out:
- code, out = ssh.exec_command_raw("cat /etc/issue | head -1 | awk '{print $1,$2,$3}'")
- if code == 0:
- response['os_name'] = out.strip()[:50]
+ code, out = ssh.exec_command_raw("cat /etc/os-release | grep PRETTY_NAME | awk -F \\\" '{print $2}'")
+ if '/etc/os-release' in out:
+ code, out = ssh.exec_command_raw("cat /etc/issue | head -1 | awk '{print $1,$2,$3}'")
+ if code == 0:
+ response['os_name'] = out.strip()[:50]
- code, out = ssh.exec_command_raw('hostname -I')
- if code == 0:
- for ip in out.strip().split():
- if len(ip) > 15: # ignore ipv6
- continue
- if ipaddress.ip_address(ip).is_global:
- if len(public_ip_address) < 10:
- public_ip_address.add(ip)
- elif len(private_ip_address) < 10:
- private_ip_address.add(ip)
+ code, out = ssh.exec_command_raw('hostname -I')
+ if code == 0:
+ for ip in out.strip().split():
+ if len(ip) > 15: # ignore ipv6
+ continue
+ if ipaddress.ip_address(ip).is_global:
+ if len(public_ip_address) < 10:
+ public_ip_address.add(ip)
+ elif len(private_ip_address) < 10:
+ private_ip_address.add(ip)
- ssh_hostname = ssh.arguments.get('hostname')
- if ip_validator(ssh_hostname):
- if ipaddress.ip_address(ssh_hostname).is_global:
- if ssh_hostname in public_ip_address:
- public_ip_address.remove(ssh_hostname)
- public_ip_address = [ssh_hostname] + list(public_ip_address)
+ ssh_hostname = ssh.arguments.get('hostname')
+ if ip_validator(ssh_hostname):
+ if ipaddress.ip_address(ssh_hostname).is_global:
+ if ssh_hostname in public_ip_address:
+ public_ip_address.remove(ssh_hostname)
+ public_ip_address = [ssh_hostname] + list(public_ip_address)
+ else:
+ if ssh_hostname in private_ip_address:
+ private_ip_address.remove(ssh_hostname)
+ private_ip_address = [ssh_hostname] + list(private_ip_address)
+
+ code, out = ssh.exec_command_raw('lsblk -dbn -o SIZE -e 11 2> /dev/null')
+ if code == 0:
+ disks = []
+ for item in out.strip().splitlines():
+ item = item.strip()
+ size = math.ceil(int(item) / 1024 / 1024 / 1024)
+ if size > 10:
+ disks.append(size)
+ response['disk'] = disks[:10]
+
+ code, out = ssh.exec_command_raw("dmidecode -t 17 | grep -E 'Size: [0-9]+' | awk '{s+=$2} END {print s,$3}'")
+ if code == 0:
+ fields = out.strip().split()
+ if len(fields) == 2 and fields[1] in ('GB', 'MB'):
+ size, unit = out.strip().split()
+ if unit == 'GB':
+ response['memory'] = size
else:
- if ssh_hostname in private_ip_address:
- private_ip_address.remove(ssh_hostname)
- private_ip_address = [ssh_hostname] + list(private_ip_address)
-
- code, out = ssh.exec_command_raw('lsblk -dbn -o SIZE -e 11 2> /dev/null')
+ response['memory'] = round(int(size) / 1024, 0)
+ if 'memory' not in response:
+ code, out = ssh.exec_command_raw("free -m | awk 'NR==2{print $2}'")
if code == 0:
- disks = []
- for item in out.strip().splitlines():
- item = item.strip()
- size = math.ceil(int(item) / 1024 / 1024 / 1024)
- if size > 10:
- disks.append(size)
- response['disk'] = disks[:10]
+ response['memory'] = math.ceil(int(out) / 1024)
- code, out = ssh.exec_command_raw("dmidecode -t 17 | grep -E 'Size: [0-9]+' | awk '{s+=$2} END {print s,$3}'")
- if code == 0:
- fields = out.strip().split()
- if len(fields) == 2 and fields[1] in ('GB', 'MB'):
- size, unit = out.strip().split()
- if unit == 'GB':
- response['memory'] = size
- else:
- response['memory'] = round(int(size) / 1024, 0)
- if 'memory' not in response:
- code, out = ssh.exec_command_raw("free -m | awk 'NR==2{print $2}'")
- if code == 0:
- response['memory'] = math.ceil(int(out) / 1024)
-
- response['public_ip_address'] = list(public_ip_address)
- response['private_ip_address'] = list(private_ip_address)
- return response
+ response['public_ip_address'] = list(public_ip_address)
+ response['private_ip_address'] = list(private_ip_address)
+ return response
def batch_sync_host(token, hosts, password=None):
@@ -275,7 +274,8 @@ def batch_sync_host(token, hosts, password=None):
def _sync_host_extend(host, private_key=None, public_key=None, password=None, ssh=None):
if not ssh:
kwargs = host.to_dict(selects=('hostname', 'port', 'username'))
- ssh = _get_ssh(kwargs, host.pkey, private_key, public_key, password)
+ with _get_ssh(kwargs, host.pkey, private_key, public_key, password) as ssh:
+ return _sync_host_extend(host, ssh=ssh)
form = AttrDict(fetch_host_extend(ssh))
form.disk = json.dumps(form.disk)
form.public_ip_address = json.dumps(form.public_ip_address)
diff --git a/spug_api/apps/host/views.py b/spug_api/apps/host/views.py
index 8fb6e36..98c640b 100644
--- a/spug_api/apps/host/views.py
+++ b/spug_api/apps/host/views.py
@@ -16,6 +16,7 @@ from libs.ssh import SSH, AuthenticationException
from paramiko.ssh_exception import BadAuthenticationType
from openpyxl import load_workbook
from threading import Thread
+import socket
import uuid
@@ -43,22 +44,7 @@ class HostView(View):
Argument('password', required=False),
).parse(request.body)
if error is None:
- password = form.pop('password')
- private_key, public_key = AppSetting.get_ssh_key()
- try:
- if form.pkey:
- private_key = form.pkey
- elif password:
- with SSH(form.hostname, form.port, form.username, password=password) as ssh:
- ssh.add_public_key(public_key)
-
- with SSH(form.hostname, form.port, form.username, private_key) as ssh:
- ssh.ping()
- except BadAuthenticationType:
- return json_response(error='该主机不支持密钥认证,请参考官方文档,错误代码:E01')
- except AuthenticationException:
- if password:
- return json_response(error='密钥认证失败,请参考官方文档,错误代码:E02')
+ if not _do_host_verify(form):
return json_response('auth fail')
group_ids = form.pop('group_ids')
@@ -70,13 +56,23 @@ class HostView(View):
host = Host.objects.get(pk=form.id)
else:
host = Host.objects.create(created_by=request.user, is_verified=True, **form)
- _sync_host_extend(host, ssh=ssh)
host.groups.set(group_ids)
response = host.to_view()
response['group_ids'] = group_ids
return json_response(response)
return json_response(error=error)
+ @auth('host.host.add|host.host.edit')
+ def put(self, request):
+ form, error = JsonParser(
+ Argument('id', type=int, help='参数错误')
+ ).parse(request.body)
+ if error is None:
+ host = Host.objects.get(pk=form.id)
+ with host.get_ssh() as ssh:
+ _sync_host_extend(host, ssh=ssh)
+ return json_response(error=error)
+
@auth('admin')
def patch(self, request):
form, error = JsonParser(
@@ -115,7 +111,7 @@ class HostView(View):
.annotate(app_name=F('app__name'), env_name=F('env__name')).first()
if deploy:
return json_response(error=f'应用【{deploy.app_name}】在【{deploy.env_name}】的发布配置关联了该主机,请解除关联后再尝试删除该主机')
- task = Task.objects.filter(targets__regex=fr'[^0-9]{form.id}[^0-9]').first()
+ task = Task.objects.filter(targets__regex=regex).first()
if task:
return json_response(error=f'任务计划中的任务【{task.name}】关联了该主机,请解除关联后再尝试删除该主机')
detection = Detection.objects.filter(type__in=('3', '4'), targets__regex=regex).first()
@@ -190,3 +186,43 @@ def batch_valid(request):
Thread(target=batch_sync_host, args=(token, hosts, form.password)).start()
return json_response({'token': token, 'hosts': {x.id: {'name': x.name} for x in hosts}})
return json_response(error=error)
+
+
+def _do_host_verify(form):
+ if form.pkey:
+ try:
+ with SSH(form.hostname, form.port, form.username, form.pkey) as ssh:
+ ssh.ping()
+ return True
+ except BadAuthenticationType:
+ raise Exception('该主机不支持密钥认证,请参考官方文档,错误代码:E01')
+ except AuthenticationException:
+ raise Exception('上传的独立密钥认证失败,请检查该密钥是否能正常连接主机(推荐使用全局密钥)')
+ except socket.timeout:
+ raise Exception('连接主机超时,请检查网络')
+
+ private_key, public_key = AppSetting.get_ssh_key()
+ password = form.pop('password')
+ if password:
+ try:
+ with SSH(form.hostname, form.port, form.username, password=password) as ssh:
+ ssh.add_public_key(public_key)
+ except BadAuthenticationType:
+ raise Exception('该主机不支持密钥认证,请参考官方文档,错误代码:E01')
+ except AuthenticationException:
+ raise Exception('密码连接认证失败,请检查密码是否正确')
+ except socket.timeout:
+ raise Exception('连接主机超时,请检查网络')
+
+ try:
+ with SSH(form.hostname, form.port, form.username, private_key) as ssh:
+ ssh.ping()
+ except BadAuthenticationType:
+ raise Exception('该主机不支持密钥认证,请参考官方文档,错误代码:E01')
+ except AuthenticationException:
+ if password:
+ raise Exception('密钥认证失败,请参考官方文档,错误代码:E02')
+ return False
+ except socket.timeout:
+ raise Exception('连接主机超时,请检查网络')
+ return True
diff --git a/spug_web/src/pages/host/Form.js b/spug_web/src/pages/host/Form.js
index 531a091..3a8a9a3 100644
--- a/spug_web/src/pages/host/Form.js
+++ b/spug_web/src/pages/host/Form.js
@@ -48,17 +48,20 @@ export default observer(function () {
message.success('验证成功');
store.formVisible = false;
store.fetchRecords();
+ store.fetchExtend(res.id)
}
}, () => setLoading(false))
}
function handleConfirm(formData) {
if (formData.password) {
- return http.post('/api/host/', formData).then(res => {
- message.success('验证成功');
- store.formVisible = false;
- store.fetchRecords();
- })
+ return http.post('/api/host/', formData)
+ .then(res => {
+ message.success('验证成功');
+ store.formVisible = false;
+ store.fetchRecords();
+ store.fetchExtend(res.id)
+ })
}
message.error('请输入授权密码')
}
@@ -134,7 +137,7 @@ export default observer(function () {