mirror of https://github.com/jumpserver/jumpserver
perf: 支持配置文件加密 (#8699)
* crypto * perf: 暂存一下 * perf: 支持配置文件加密 * perf: 修改位置 * perf: 优化拆分出去 * stash * perf: js 强制 key 最大 16 * pref: 修改语法 * fix: 修复启用 gm 后,又关闭导致的用户无法登录 Co-authored-by: ibuler <ibuler@qq.com>pull/8705/head
parent
b27b02eb9d
commit
4ecb0b760f
|
@ -49,7 +49,7 @@ class JMSBaseAuthBackend:
|
|||
if not allow:
|
||||
info = 'User {} skip authentication backend {}, because it not in {}'
|
||||
info = info.format(username, backend_name, ','.join(allowed_backend_names))
|
||||
logger.debug(info)
|
||||
logger.info(info)
|
||||
return allow
|
||||
|
||||
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
import base64
|
||||
import logging
|
||||
import re
|
||||
from Cryptodome.Cipher import AES, PKCS1_v1_5
|
||||
from Cryptodome.Util.Padding import pad
|
||||
from Cryptodome.Random import get_random_bytes
|
||||
from Cryptodome.PublicKey import RSA
|
||||
from Cryptodome import Random
|
||||
|
@ -11,21 +11,25 @@ from django.conf import settings
|
|||
from django.core.exceptions import ImproperlyConfigured
|
||||
|
||||
|
||||
def process_key(key):
|
||||
secret_pattern = re.compile(r'password|secret|key|token', re.IGNORECASE)
|
||||
|
||||
|
||||
def padding_key(key, max_length=32):
|
||||
"""
|
||||
返回32 bytes 的key
|
||||
"""
|
||||
if not isinstance(key, bytes):
|
||||
key = bytes(key, encoding='utf-8')
|
||||
|
||||
if len(key) >= 32:
|
||||
return key[:32]
|
||||
if len(key) >= max_length:
|
||||
return key[:max_length]
|
||||
|
||||
return pad(key, 32)
|
||||
while len(key) % 16 != 0:
|
||||
key += b'\0'
|
||||
return key
|
||||
|
||||
|
||||
class BaseCrypto:
|
||||
|
||||
def encrypt(self, text):
|
||||
return base64.urlsafe_b64encode(
|
||||
self._encrypt(bytes(text, encoding='utf8'))
|
||||
|
@ -45,7 +49,7 @@ class BaseCrypto:
|
|||
|
||||
class GMSM4EcbCrypto(BaseCrypto):
|
||||
def __init__(self, key):
|
||||
self.key = process_key(key)
|
||||
self.key = padding_key(key, 16)
|
||||
self.sm4_encryptor = CryptSM4()
|
||||
self.sm4_encryptor.set_key(self.key, SM4_ENCRYPT)
|
||||
|
||||
|
@ -70,9 +74,8 @@ class AESCrypto:
|
|||
"""
|
||||
|
||||
def __init__(self, key):
|
||||
if len(key) > 32:
|
||||
key = key[:32]
|
||||
self.key = self.to_16(key)
|
||||
self.key = padding_key(key, 32)
|
||||
self.aes = AES.new(self.key, AES.MODE_ECB)
|
||||
|
||||
@staticmethod
|
||||
def to_16(key):
|
||||
|
@ -87,17 +90,15 @@ class AESCrypto:
|
|||
return key # 返回bytes
|
||||
|
||||
def aes(self):
|
||||
return AES.new(self.key, AES.MODE_ECB) # 初始化加密器
|
||||
return AES.new(self.key, AES.MODE_ECB)
|
||||
|
||||
def encrypt(self, text):
|
||||
aes = self.aes()
|
||||
cipher = base64.encodebytes(aes.encrypt(self.to_16(text)))
|
||||
cipher = base64.encodebytes(self.aes.encrypt(self.to_16(text)))
|
||||
return str(cipher, encoding='utf8').replace('\n', '') # 加密
|
||||
|
||||
def decrypt(self, text):
|
||||
aes = self.aes()
|
||||
text_decoded = base64.decodebytes(bytes(text, encoding='utf8'))
|
||||
return str(aes.decrypt(text_decoded).rstrip(b'\0').decode("utf8"))
|
||||
return str(self.aes.decrypt(text_decoded).rstrip(b'\0').decode("utf8"))
|
||||
|
||||
|
||||
class AESCryptoGCM:
|
||||
|
@ -106,7 +107,7 @@ class AESCryptoGCM:
|
|||
"""
|
||||
|
||||
def __init__(self, key):
|
||||
self.key = process_key(key)
|
||||
self.key = padding_key(key)
|
||||
|
||||
def encrypt(self, text):
|
||||
"""
|
||||
|
@ -133,7 +134,6 @@ class AESCryptoGCM:
|
|||
nonce = base64.b64decode(metadata[24:48])
|
||||
tag = base64.b64decode(metadata[48:])
|
||||
ciphertext = base64.b64decode(text[72:])
|
||||
|
||||
cipher = AES.new(self.key, AES.MODE_GCM, nonce=nonce)
|
||||
|
||||
cipher.update(header)
|
||||
|
@ -144,11 +144,10 @@ class AESCryptoGCM:
|
|||
def get_aes_crypto(key=None, mode='GCM'):
|
||||
if key is None:
|
||||
key = settings.SECRET_KEY
|
||||
if mode == 'ECB':
|
||||
a = AESCrypto(key)
|
||||
elif mode == 'GCM':
|
||||
a = AESCryptoGCM(key)
|
||||
return a
|
||||
if mode == 'GCM':
|
||||
return AESCryptoGCM(key)
|
||||
else:
|
||||
return AESCrypto(key)
|
||||
|
||||
|
||||
def get_gm_sm4_ecb_crypto(key=None):
|
||||
|
|
|
@ -196,7 +196,8 @@ def encrypt_password(password, salt=None, algorithm='sha512'):
|
|||
return des_crypt.hash(password, salt=salt[:2])
|
||||
|
||||
support_algorithm = {
|
||||
'sha512': sha512, 'des': des
|
||||
'sha512': sha512,
|
||||
'des': des
|
||||
}
|
||||
|
||||
if isinstance(algorithm, str):
|
||||
|
@ -222,9 +223,6 @@ def ensure_last_char_is_ascii(data):
|
|||
remain = ''
|
||||
|
||||
|
||||
secret_pattern = re.compile(r'password|secret|key', re.IGNORECASE)
|
||||
|
||||
|
||||
def data_to_json(data, sort_keys=True, indent=2, cls=None):
|
||||
if cls is None:
|
||||
cls = DjangoJSONEncoder
|
||||
|
|
|
@ -15,18 +15,23 @@ import errno
|
|||
import json
|
||||
import yaml
|
||||
import copy
|
||||
import base64
|
||||
import logging
|
||||
from importlib import import_module
|
||||
from urllib.parse import urljoin, urlparse
|
||||
from gmssl.sm4 import CryptSM4, SM4_ENCRYPT, SM4_DECRYPT
|
||||
|
||||
from django.urls import reverse_lazy
|
||||
from django.conf import settings
|
||||
from django.utils.translation import ugettext_lazy as _
|
||||
|
||||
|
||||
BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
PROJECT_DIR = os.path.dirname(BASE_DIR)
|
||||
XPACK_DIR = os.path.join(BASE_DIR, 'xpack')
|
||||
HAS_XPACK = os.path.isdir(XPACK_DIR)
|
||||
|
||||
logger = logging.getLogger('jumpserver.conf')
|
||||
|
||||
|
||||
def import_string(dotted_path):
|
||||
try:
|
||||
|
@ -39,9 +44,9 @@ def import_string(dotted_path):
|
|||
try:
|
||||
return getattr(module, class_name)
|
||||
except AttributeError as err:
|
||||
raise ImportError('Module "%s" does not define a "%s" attribute/class' % (
|
||||
module_path, class_name)
|
||||
) from err
|
||||
raise ImportError(
|
||||
'Module "%s" does not define a "%s" attribute/class' %
|
||||
(module_path, class_name)) from err
|
||||
|
||||
|
||||
def is_absolute_uri(uri):
|
||||
|
@ -80,6 +85,59 @@ class DoesNotExist(Exception):
|
|||
pass
|
||||
|
||||
|
||||
class ConfigCrypto:
|
||||
secret_keys = [
|
||||
'SECRET_KEY', 'DB_PASSWORD', 'REDIS_PASSWORD',
|
||||
]
|
||||
|
||||
def __init__(self, key):
|
||||
self.safe_key = self.process_key(key)
|
||||
self.sm4_encryptor = CryptSM4()
|
||||
self.sm4_encryptor.set_key(self.safe_key, SM4_ENCRYPT)
|
||||
|
||||
self.sm4_decryptor = CryptSM4()
|
||||
self.sm4_decryptor.set_key(self.safe_key, SM4_DECRYPT)
|
||||
|
||||
@staticmethod
|
||||
def process_key(secret_encrypt_key):
|
||||
key = secret_encrypt_key.encode()
|
||||
if len(key) >= 16:
|
||||
key = key[:16]
|
||||
else:
|
||||
key += b'\0' * (16 - len(key))
|
||||
return key
|
||||
|
||||
def encrypt(self, data):
|
||||
data = bytes(data, encoding='utf8')
|
||||
return base64.b64encode(self.sm4_encryptor.crypt_ecb(data)).decode('utf8')
|
||||
|
||||
def decrypt(self, data):
|
||||
data = base64.urlsafe_b64decode(bytes(data, encoding='utf8'))
|
||||
return self.sm4_decryptor.crypt_ecb(data).decode('utf8')
|
||||
|
||||
def decrypt_if_need(self, value, item):
|
||||
if item not in self.secret_keys:
|
||||
return value
|
||||
|
||||
try:
|
||||
plaintext = self.decrypt(value)
|
||||
if plaintext:
|
||||
value = plaintext
|
||||
except Exception as e:
|
||||
logger.error('decrypt %s error: %s', item, e)
|
||||
return value
|
||||
|
||||
@classmethod
|
||||
def get_secret_encryptor(cls):
|
||||
# 使用 SM4 加密配置文件敏感信息
|
||||
# https://the-x.cn/cryptography/Sm4.aspx
|
||||
secret_encrypt_key = os.environ.get('SECRET_ENCRYPT_KEY', '')
|
||||
if not secret_encrypt_key:
|
||||
return None
|
||||
print('Info: Using SM4 to encrypt config secret value')
|
||||
return cls(secret_encrypt_key)
|
||||
|
||||
|
||||
class Config(dict):
|
||||
"""Works exactly like a dict but provides ways to fill it from files
|
||||
or special dictionaries. There are two common patterns to populate the
|
||||
|
@ -434,6 +492,10 @@ class Config(dict):
|
|||
'HEALTH_CHECK_TOKEN': '',
|
||||
}
|
||||
|
||||
def __init__(self, *args):
|
||||
super().__init__(*args)
|
||||
self.secret_encryptor = ConfigCrypto.get_secret_encryptor()
|
||||
|
||||
@staticmethod
|
||||
def convert_keycloak_to_openid(keycloak_config):
|
||||
"""
|
||||
|
@ -445,7 +507,6 @@ class Config(dict):
|
|||
"""
|
||||
|
||||
openid_config = copy.deepcopy(keycloak_config)
|
||||
|
||||
auth_openid = openid_config.get('AUTH_OPENID')
|
||||
auth_openid_realm_name = openid_config.get('AUTH_OPENID_REALM_NAME')
|
||||
auth_openid_server_url = openid_config.get('AUTH_OPENID_SERVER_URL')
|
||||
|
@ -574,13 +635,12 @@ class Config(dict):
|
|||
def get(self, item):
|
||||
# 再从配置文件中获取
|
||||
value = self.get_from_config(item)
|
||||
if value is not None:
|
||||
return value
|
||||
# 其次从环境变量来
|
||||
value = self.get_from_env(item)
|
||||
if value is not None:
|
||||
return value
|
||||
value = self.defaults.get(item)
|
||||
if value is None:
|
||||
value = self.get_from_env(item)
|
||||
if value is None:
|
||||
value = self.defaults.get(item)
|
||||
if self.secret_encryptor:
|
||||
value = self.secret_encryptor.decrypt_if_need(value, item)
|
||||
return value
|
||||
|
||||
def __getitem__(self, item):
|
||||
|
|
|
@ -316,8 +316,11 @@ PASSWORD_HASHERS = [
|
|||
|
||||
|
||||
GMSSL_ENABLED = CONFIG.GMSSL_ENABLED
|
||||
GM_HASHER = 'common.hashers.PBKDF2SM3PasswordHasher'
|
||||
if GMSSL_ENABLED:
|
||||
PASSWORD_HASHERS.insert(0, 'common.hashers.PBKDF2SM3PasswordHasher')
|
||||
PASSWORD_HASHERS.insert(0, GM_HASHER)
|
||||
else:
|
||||
PASSWORD_HASHERS.append(GM_HASHER)
|
||||
|
||||
# For Debug toolbar
|
||||
INTERNAL_IPS = ["127.0.0.1"]
|
||||
|
|
|
@ -1504,17 +1504,11 @@ function getStatusIcon(status, mapping, title) {
|
|||
|
||||
|
||||
function fillKey(key) {
|
||||
let keySize = 128
|
||||
// 如果超过 key 16 位, 最大取 32 位,需要更改填充
|
||||
if (key.length > 16) {
|
||||
key = key.slice(0, 32)
|
||||
keySize = keySize * 2
|
||||
const KeyLength = 16
|
||||
if (key.length > KeyLength) {
|
||||
key = key.slice(0, KeyLength)
|
||||
}
|
||||
const filledKeyLength = keySize / 8
|
||||
if (key.length >= filledKeyLength) {
|
||||
return key.slice(0, filledKeyLength)
|
||||
}
|
||||
const filledKey = Buffer.alloc(keySize / 8)
|
||||
const filledKey = Buffer.alloc(KeyLength)
|
||||
const keys = Buffer.from(key)
|
||||
for (let i = 0; i < keys.length; i++) {
|
||||
filledKey[i] = keys[i]
|
||||
|
|
Loading…
Reference in New Issue