perf: 支持配置文件加密 (#8699)

* crypto

* perf: 暂存一下

* perf: 支持配置文件加密

* perf: 修改位置

* perf: 优化拆分出去

* stash

* perf: js 强制 key 最大 16

* pref: 修改语法

* fix: 修复启用 gm 后,又关闭导致的用户无法登录

Co-authored-by: ibuler <ibuler@qq.com>
pull/8705/head
fit2bot 2022-08-05 14:53:23 +08:00 committed by GitHub
parent b27b02eb9d
commit 4ecb0b760f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 104 additions and 50 deletions

View File

@ -49,7 +49,7 @@ class JMSBaseAuthBackend:
if not allow: if not allow:
info = 'User {} skip authentication backend {}, because it not in {}' info = 'User {} skip authentication backend {}, because it not in {}'
info = info.format(username, backend_name, ','.join(allowed_backend_names)) info = info.format(username, backend_name, ','.join(allowed_backend_names))
logger.debug(info) logger.info(info)
return allow return allow

View File

@ -1,7 +1,7 @@
import base64 import base64
import logging import logging
import re
from Cryptodome.Cipher import AES, PKCS1_v1_5 from Cryptodome.Cipher import AES, PKCS1_v1_5
from Cryptodome.Util.Padding import pad
from Cryptodome.Random import get_random_bytes from Cryptodome.Random import get_random_bytes
from Cryptodome.PublicKey import RSA from Cryptodome.PublicKey import RSA
from Cryptodome import Random from Cryptodome import Random
@ -11,21 +11,25 @@ from django.conf import settings
from django.core.exceptions import ImproperlyConfigured from django.core.exceptions import ImproperlyConfigured
def process_key(key): secret_pattern = re.compile(r'password|secret|key|token', re.IGNORECASE)
def padding_key(key, max_length=32):
""" """
返回32 bytes 的key 返回32 bytes 的key
""" """
if not isinstance(key, bytes): if not isinstance(key, bytes):
key = bytes(key, encoding='utf-8') key = bytes(key, encoding='utf-8')
if len(key) >= 32: if len(key) >= max_length:
return key[:32] return key[:max_length]
return pad(key, 32) while len(key) % 16 != 0:
key += b'\0'
return key
class BaseCrypto: class BaseCrypto:
def encrypt(self, text): def encrypt(self, text):
return base64.urlsafe_b64encode( return base64.urlsafe_b64encode(
self._encrypt(bytes(text, encoding='utf8')) self._encrypt(bytes(text, encoding='utf8'))
@ -45,7 +49,7 @@ class BaseCrypto:
class GMSM4EcbCrypto(BaseCrypto): class GMSM4EcbCrypto(BaseCrypto):
def __init__(self, key): def __init__(self, key):
self.key = process_key(key) self.key = padding_key(key, 16)
self.sm4_encryptor = CryptSM4() self.sm4_encryptor = CryptSM4()
self.sm4_encryptor.set_key(self.key, SM4_ENCRYPT) self.sm4_encryptor.set_key(self.key, SM4_ENCRYPT)
@ -70,9 +74,8 @@ class AESCrypto:
""" """
def __init__(self, key): def __init__(self, key):
if len(key) > 32: self.key = padding_key(key, 32)
key = key[:32] self.aes = AES.new(self.key, AES.MODE_ECB)
self.key = self.to_16(key)
@staticmethod @staticmethod
def to_16(key): def to_16(key):
@ -87,17 +90,15 @@ class AESCrypto:
return key # 返回bytes return key # 返回bytes
def aes(self): def aes(self):
return AES.new(self.key, AES.MODE_ECB) # 初始化加密器 return AES.new(self.key, AES.MODE_ECB)
def encrypt(self, text): def encrypt(self, text):
aes = self.aes() cipher = base64.encodebytes(self.aes.encrypt(self.to_16(text)))
cipher = base64.encodebytes(aes.encrypt(self.to_16(text)))
return str(cipher, encoding='utf8').replace('\n', '') # 加密 return str(cipher, encoding='utf8').replace('\n', '') # 加密
def decrypt(self, text): def decrypt(self, text):
aes = self.aes()
text_decoded = base64.decodebytes(bytes(text, encoding='utf8')) text_decoded = base64.decodebytes(bytes(text, encoding='utf8'))
return str(aes.decrypt(text_decoded).rstrip(b'\0').decode("utf8")) return str(self.aes.decrypt(text_decoded).rstrip(b'\0').decode("utf8"))
class AESCryptoGCM: class AESCryptoGCM:
@ -106,7 +107,7 @@ class AESCryptoGCM:
""" """
def __init__(self, key): def __init__(self, key):
self.key = process_key(key) self.key = padding_key(key)
def encrypt(self, text): def encrypt(self, text):
""" """
@ -133,7 +134,6 @@ class AESCryptoGCM:
nonce = base64.b64decode(metadata[24:48]) nonce = base64.b64decode(metadata[24:48])
tag = base64.b64decode(metadata[48:]) tag = base64.b64decode(metadata[48:])
ciphertext = base64.b64decode(text[72:]) ciphertext = base64.b64decode(text[72:])
cipher = AES.new(self.key, AES.MODE_GCM, nonce=nonce) cipher = AES.new(self.key, AES.MODE_GCM, nonce=nonce)
cipher.update(header) cipher.update(header)
@ -144,11 +144,10 @@ class AESCryptoGCM:
def get_aes_crypto(key=None, mode='GCM'): def get_aes_crypto(key=None, mode='GCM'):
if key is None: if key is None:
key = settings.SECRET_KEY key = settings.SECRET_KEY
if mode == 'ECB': if mode == 'GCM':
a = AESCrypto(key) return AESCryptoGCM(key)
elif mode == 'GCM': else:
a = AESCryptoGCM(key) return AESCrypto(key)
return a
def get_gm_sm4_ecb_crypto(key=None): def get_gm_sm4_ecb_crypto(key=None):

View File

@ -196,7 +196,8 @@ def encrypt_password(password, salt=None, algorithm='sha512'):
return des_crypt.hash(password, salt=salt[:2]) return des_crypt.hash(password, salt=salt[:2])
support_algorithm = { support_algorithm = {
'sha512': sha512, 'des': des 'sha512': sha512,
'des': des
} }
if isinstance(algorithm, str): if isinstance(algorithm, str):
@ -222,9 +223,6 @@ def ensure_last_char_is_ascii(data):
remain = '' remain = ''
secret_pattern = re.compile(r'password|secret|key', re.IGNORECASE)
def data_to_json(data, sort_keys=True, indent=2, cls=None): def data_to_json(data, sort_keys=True, indent=2, cls=None):
if cls is None: if cls is None:
cls = DjangoJSONEncoder cls = DjangoJSONEncoder

View File

@ -15,18 +15,23 @@ import errno
import json import json
import yaml import yaml
import copy import copy
import base64
import logging
from importlib import import_module from importlib import import_module
from urllib.parse import urljoin, urlparse from urllib.parse import urljoin, urlparse
from gmssl.sm4 import CryptSM4, SM4_ENCRYPT, SM4_DECRYPT
from django.urls import reverse_lazy from django.urls import reverse_lazy
from django.conf import settings
from django.utils.translation import ugettext_lazy as _ from django.utils.translation import ugettext_lazy as _
BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
PROJECT_DIR = os.path.dirname(BASE_DIR) PROJECT_DIR = os.path.dirname(BASE_DIR)
XPACK_DIR = os.path.join(BASE_DIR, 'xpack') XPACK_DIR = os.path.join(BASE_DIR, 'xpack')
HAS_XPACK = os.path.isdir(XPACK_DIR) HAS_XPACK = os.path.isdir(XPACK_DIR)
logger = logging.getLogger('jumpserver.conf')
def import_string(dotted_path): def import_string(dotted_path):
try: try:
@ -39,9 +44,9 @@ def import_string(dotted_path):
try: try:
return getattr(module, class_name) return getattr(module, class_name)
except AttributeError as err: except AttributeError as err:
raise ImportError('Module "%s" does not define a "%s" attribute/class' % ( raise ImportError(
module_path, class_name) 'Module "%s" does not define a "%s" attribute/class' %
) from err (module_path, class_name)) from err
def is_absolute_uri(uri): def is_absolute_uri(uri):
@ -80,6 +85,59 @@ class DoesNotExist(Exception):
pass pass
class ConfigCrypto:
secret_keys = [
'SECRET_KEY', 'DB_PASSWORD', 'REDIS_PASSWORD',
]
def __init__(self, key):
self.safe_key = self.process_key(key)
self.sm4_encryptor = CryptSM4()
self.sm4_encryptor.set_key(self.safe_key, SM4_ENCRYPT)
self.sm4_decryptor = CryptSM4()
self.sm4_decryptor.set_key(self.safe_key, SM4_DECRYPT)
@staticmethod
def process_key(secret_encrypt_key):
key = secret_encrypt_key.encode()
if len(key) >= 16:
key = key[:16]
else:
key += b'\0' * (16 - len(key))
return key
def encrypt(self, data):
data = bytes(data, encoding='utf8')
return base64.b64encode(self.sm4_encryptor.crypt_ecb(data)).decode('utf8')
def decrypt(self, data):
data = base64.urlsafe_b64decode(bytes(data, encoding='utf8'))
return self.sm4_decryptor.crypt_ecb(data).decode('utf8')
def decrypt_if_need(self, value, item):
if item not in self.secret_keys:
return value
try:
plaintext = self.decrypt(value)
if plaintext:
value = plaintext
except Exception as e:
logger.error('decrypt %s error: %s', item, e)
return value
@classmethod
def get_secret_encryptor(cls):
# 使用 SM4 加密配置文件敏感信息
# https://the-x.cn/cryptography/Sm4.aspx
secret_encrypt_key = os.environ.get('SECRET_ENCRYPT_KEY', '')
if not secret_encrypt_key:
return None
print('Info: Using SM4 to encrypt config secret value')
return cls(secret_encrypt_key)
class Config(dict): class Config(dict):
"""Works exactly like a dict but provides ways to fill it from files """Works exactly like a dict but provides ways to fill it from files
or special dictionaries. There are two common patterns to populate the or special dictionaries. There are two common patterns to populate the
@ -434,6 +492,10 @@ class Config(dict):
'HEALTH_CHECK_TOKEN': '', 'HEALTH_CHECK_TOKEN': '',
} }
def __init__(self, *args):
super().__init__(*args)
self.secret_encryptor = ConfigCrypto.get_secret_encryptor()
@staticmethod @staticmethod
def convert_keycloak_to_openid(keycloak_config): def convert_keycloak_to_openid(keycloak_config):
""" """
@ -445,7 +507,6 @@ class Config(dict):
""" """
openid_config = copy.deepcopy(keycloak_config) openid_config = copy.deepcopy(keycloak_config)
auth_openid = openid_config.get('AUTH_OPENID') auth_openid = openid_config.get('AUTH_OPENID')
auth_openid_realm_name = openid_config.get('AUTH_OPENID_REALM_NAME') auth_openid_realm_name = openid_config.get('AUTH_OPENID_REALM_NAME')
auth_openid_server_url = openid_config.get('AUTH_OPENID_SERVER_URL') auth_openid_server_url = openid_config.get('AUTH_OPENID_SERVER_URL')
@ -574,13 +635,12 @@ class Config(dict):
def get(self, item): def get(self, item):
# 再从配置文件中获取 # 再从配置文件中获取
value = self.get_from_config(item) value = self.get_from_config(item)
if value is not None: if value is None:
return value value = self.get_from_env(item)
# 其次从环境变量来 if value is None:
value = self.get_from_env(item) value = self.defaults.get(item)
if value is not None: if self.secret_encryptor:
return value value = self.secret_encryptor.decrypt_if_need(value, item)
value = self.defaults.get(item)
return value return value
def __getitem__(self, item): def __getitem__(self, item):

View File

@ -316,8 +316,11 @@ PASSWORD_HASHERS = [
GMSSL_ENABLED = CONFIG.GMSSL_ENABLED GMSSL_ENABLED = CONFIG.GMSSL_ENABLED
GM_HASHER = 'common.hashers.PBKDF2SM3PasswordHasher'
if GMSSL_ENABLED: if GMSSL_ENABLED:
PASSWORD_HASHERS.insert(0, 'common.hashers.PBKDF2SM3PasswordHasher') PASSWORD_HASHERS.insert(0, GM_HASHER)
else:
PASSWORD_HASHERS.append(GM_HASHER)
# For Debug toolbar # For Debug toolbar
INTERNAL_IPS = ["127.0.0.1"] INTERNAL_IPS = ["127.0.0.1"]

View File

@ -1504,17 +1504,11 @@ function getStatusIcon(status, mapping, title) {
function fillKey(key) { function fillKey(key) {
let keySize = 128 const KeyLength = 16
// 如果超过 key 16 , 最大取 32 需要更改填充 if (key.length > KeyLength) {
if (key.length > 16) { key = key.slice(0, KeyLength)
key = key.slice(0, 32)
keySize = keySize * 2
} }
const filledKeyLength = keySize / 8 const filledKey = Buffer.alloc(KeyLength)
if (key.length >= filledKeyLength) {
return key.slice(0, filledKeyLength)
}
const filledKey = Buffer.alloc(keySize / 8)
const keys = Buffer.from(key) const keys = Buffer.from(key)
for (let i = 0; i < keys.length; i++) { for (let i = 0; i < keys.length; i++) {
filledKey[i] = keys[i] filledKey[i] = keys[i]