feat: 支持 piico 设备国密加密

pull/8800/head
Aaron3S 2022-08-23 17:40:01 +08:00
parent 60cb1f8136
commit 8772cd8c71
11 changed files with 450 additions and 2 deletions

View File

View File

@ -0,0 +1,7 @@
from .device import Device
def open_piico_device(driver_path) -> Device:
d = Device()
d.open(driver_path)
return d

View File

@ -0,0 +1,59 @@
cipher_alg_id = {
"sm4_ebc": 0x00000401,
"sm4_cbc": 0x00000402,
}
class ECCCipher:
def __init__(self, session, public_key, private_key):
self._session = session
self.public_key = public_key
self.private_key = private_key
def encrypt(self, plain_text):
return self._session.ecc_encrypt(self.public_key, plain_text, 0x00020800)
def decrypt(self, cipher_text):
return self._session.ecc_decrypt(self.private_key, cipher_text, 0x00020800)
class EBCCipher:
def __init__(self, session, key_val):
self._session = session
self._key = self.__get_key(key_val)
self._alg = "sm4_ebc"
self._iv = None
def __get_key(self, key_val):
key_val = self.__padding(key_val)
return self._session.import_key(key_val)
@staticmethod
def __padding(val):
# padding
val = bytes(val)
while len(val) % 16 != 0:
val += b'\0'
return val
def encrypt(self, plain_text):
plain_text = self.__padding(plain_text)
cipher_text = self._session.encrypt(plain_text, self._key, cipher_alg_id[self._alg], self._iv)
return bytes(cipher_text)
def decrypt(self, cipher_text):
plain_text = self._session.decrypt(cipher_text, self._key, cipher_alg_id[self._alg], self._iv)
return bytes(plain_text)
def destroy(self):
self._session.destroy_cipher_key(self._key)
self._session.close()
class CBCCipher(EBCCipher):
def __init__(self, session, key, iv):
super().__init__(session, key)
self._iv = iv
self._alg = "sm4_cbc"

View File

@ -0,0 +1,70 @@
from ctypes import *
from .exception import PiicoError
from .session import Session
from .cipher import *
from .digest import *
class Device:
_driver = None
__device = None
def open(self, driver_path="./libpiico_ccmu"):
# load driver
self.__load_driver(driver_path)
# open device
self.__open_device()
def close(self):
if self.__device is None:
raise Exception("device not turned on")
ret = self._driver.SDF_CloseDevice(self.__device)
if not ret == 0:
raise Exception("turn off device failed")
self.__device = None
def new_session(self):
session = c_void_p()
ret = self._driver.SDF_OpenSession(self.__device, pointer(session))
if not ret == 0:
raise Exception("create session failed")
return Session(self._driver, session)
def generate_ecc_key_pair(self):
session = self.new_session()
return session.generate_ecc_key_pair(alg_id=0x00020200)
def generate_random(self, length=64):
session = self.new_session()
return session.generate_random(length)
def new_sm2_ecc_cipher(self, public_key, private_key):
session = self.new_session()
return ECCCipher(session, public_key, private_key)
def new_sm4_ebc_cipher(self, key_val):
session = self.new_session()
return EBCCipher(session, key_val)
def new_sm4_cbc_cipher(self, key_val, iv):
session = self.new_session()
return CBCCipher(session, key_val, iv)
def new_digest(self, mode="sm3"):
session = self.new_session()
return Digest(session, mode)
def __load_driver(self, path):
# check driver status
if self._driver is not None:
raise Exception("already load driver")
# load driver
self._driver = cdll.LoadLibrary(path)
def __open_device(self):
device = c_void_p()
ret = self._driver.SDF_OpenDevice(pointer(device))
if not ret == 0:
raise PiicoError("open piico device failed", ret)
self.__device = device

View File

@ -0,0 +1,32 @@
hash_alg_id = {
"sm3": 0x00000001,
"sha1": 0x00000002,
"sha256": 0x00000004,
"sha512": 0x00000008,
}
class Digest:
def __init__(self, session, alg_name="sm3"):
if hash_alg_id[alg_name] is None:
raise Exception("unsupported hash alg {}".format(alg_name))
self._alg_name = alg_name
self._session = session
self.__init_hash()
def __init_hash(self):
self._session.hash_init(hash_alg_id[self._alg_name])
def update(self, data):
self._session.hash_update(data)
def final(self):
return self._session.hash_final()
def reset(self):
self.__init_hash()
def destroy(self):
self._session.close()

View File

@ -0,0 +1,71 @@
from ctypes import *
ECCref_MAX_BITS = 512
ECCref_MAX_LEN = int((ECCref_MAX_BITS + 7) / 8)
class EncodeMixin:
def encode(self):
raise NotImplementedError
class ECCrefPublicKey(Structure, EncodeMixin):
_fields_ = [
('bits', c_uint),
('x', c_ubyte * ECCref_MAX_LEN),
('y', c_ubyte * ECCref_MAX_LEN),
]
def encode(self):
return bytes([0x04]) + bytes(self.x[32:]) + bytes(self.y[32:])
class ECCrefPrivateKey(Structure, EncodeMixin):
_fields_ = [
('bits', c_uint,),
('K', c_ubyte * ECCref_MAX_LEN),
]
def encode(self):
return bytes(self.K[32:])
class ECCCipherEncode(EncodeMixin):
def __init__(self):
self.x = None
self.y = None
self.M = None
self.C = None
self.L = None
def encode(self):
c1 = bytes(self.x[32:]) + bytes(self.y[32:])
c2 = bytes(self.C[:self.L])
c3 = bytes(self.M)
return bytes([0x04]) + c1 + c2 + c3
def new_ecc_cipher_cla(length):
_cache = {}
cla_name = "ECCCipher{}".format(length)
if _cache.__contains__(cla_name):
return _cache[cla_name]
else:
cla = type(cla_name, (Structure, ECCCipherEncode), {
"_fields_": [
('x', c_ubyte * ECCref_MAX_LEN),
('y', c_ubyte * ECCref_MAX_LEN),
('M', c_ubyte * 32),
('L', c_uint),
('C', c_ubyte * length)
]
})
_cache[cla_name] = cla
return cla
class ECCKeyPair:
def __init__(self, public_key, private_key):
self.public_key = public_key
self.private_key = private_key

View File

@ -0,0 +1,12 @@
class PiicoError(Exception):
def __init__(self, msg, ret):
super().__init__(self)
self.__ret = ret
self.__msg = msg
def __str__(self):
return "piico error: {} return code: {}".format(self.__msg, self.hex_ret(self.__ret))
@staticmethod
def hex_ret(ret):
return hex(ret & ((1 << 32) - 1))

View File

@ -0,0 +1,36 @@
from ctypes import *
from .ecc import ECCrefPublicKey, ECCrefPrivateKey, ECCKeyPair
from .exception import PiicoError
from .session_mixin import SM3Mixin, SM4Mixin, SM2Mixin
class Session(SM2Mixin, SM3Mixin, SM4Mixin):
def __init__(self, driver, session):
super().__init__()
self._session = session
self._driver = driver
def get_device_info(self):
pass
def generate_random(self, length=64):
random_data = (c_ubyte * length)()
ret = self._driver.SDF_GenerateRandom(self._session, c_int(length), random_data)
if not ret == 0:
raise PiicoError("generate random error", ret)
return bytes(random_data)
def generate_ecc_key_pair(self, alg_id):
public_key = ECCrefPublicKey()
private_key = ECCrefPrivateKey()
ret = self._driver.SDF_GenerateKeyPair_ECC(self._session, c_int(alg_id), c_int(256), pointer(public_key),
pointer(private_key))
if not ret == 0:
raise PiicoError("generate ecc key pair failed", ret)
return ECCKeyPair(public_key.encode(), private_key.encode())
def close(self):
ret = self._driver.SDF_CloseSession(self._session)
if not ret == 0:
raise PiicoError("close session failed", ret)

View File

@ -0,0 +1,129 @@
from .ecc import *
from .exception import PiicoError
class BaseMixin:
def __init__(self):
self._driver = None
self._session = None
class SM2Mixin(BaseMixin):
def ecc_encrypt(self, public_key, plain_text, alg_id):
pos = 1
k1 = bytes([0] * 32) + bytes(public_key[pos:pos + 32])
k1 = (c_ubyte * len(k1))(*k1)
pos += 32
k2 = bytes([0] * 32) + bytes(public_key[pos:pos + 32])
pk = ECCrefPublicKey(c_uint(0x40), (c_ubyte * len(k1))(*k1), (c_ubyte * len(k2))(*k2))
plain_text = (c_ubyte * len(plain_text))(*plain_text)
ecc_data = new_ecc_cipher_cla(len(plain_text))()
ret = self._driver.SDF_ExternalEncrypt_ECC(self._session, c_int(alg_id), pointer(pk), plain_text,
c_int(len(plain_text)), pointer(ecc_data))
if not ret == 0:
raise Exception("ecc encrypt failed", ret)
return ecc_data.encode()
def ecc_decrypt(self, private_key, cipher_text, alg_id):
k = bytes([0] * 32) + bytes(private_key[:32])
vk = ECCrefPrivateKey(c_uint(0x40), (c_ubyte * len(k))(*k))
pos = 1
# c1
x = bytes([0] * 32) + bytes(cipher_text[pos:pos + 32])
pos += 32
y = bytes([0] * 32) + bytes(cipher_text[pos:pos + 32])
pos += 32
# c2
c = bytes(cipher_text[pos:-32])
l = len(c)
# c3
m = bytes(cipher_text[-32:])
ecc_data = new_ecc_cipher_cla(l)(
(c_ubyte * 64)(*x),
(c_ubyte * 64)(*y),
(c_ubyte * 32)(*m),
c_uint(l),
(c_ubyte * l)(*c),
)
temp_data = (c_ubyte * l)()
temp_data_length = c_int()
ret = self._driver.SDF_ExternalDecrypt_ECC(self._session, c_int(alg_id), pointer(vk),
pointer(ecc_data),
temp_data, pointer(temp_data_length))
if not ret == 0:
raise Exception("ecc decrypt failed", ret)
return bytes(temp_data[:temp_data_length.value])
class SM3Mixin(BaseMixin):
def hash_init(self, alg_id):
ret = self._driver.SDF_HashInit(self._session, c_int(alg_id), None, None, c_int(0))
if not ret == 0:
raise PiicoError("hash init failed,alg id is {}".format(alg_id), ret)
def hash_update(self, data):
data = (c_ubyte * len(data))(*data)
ret = self._driver.SDF_HashUpdate(self._session, data, c_int(len(data)))
if not ret == 0:
raise PiicoError("hash update failed", ret)
def hash_final(self):
result_data = (c_ubyte * 32)()
result_length = c_int()
ret = self._driver.SDF_HashFinal(self._session, result_data, pointer(result_length))
if not ret == 0:
raise PiicoError("hash final failed", ret)
return bytes(result_data[:result_length.value])
class SM4Mixin(BaseMixin):
def import_key(self, key_val):
# to c lang
key_val = (c_ubyte * len(key_val))(*key_val)
key = c_void_p()
ret = self._driver.SDF_ImportKey(self._session, key_val, c_int(len(key_val)), pointer(key))
if not ret == 0:
raise PiicoError("import key failed", ret)
return key
def destroy_cipher_key(self, key):
ret = self._driver.SDF_DestroyKey(self._session, key)
if not ret == 0:
raise Exception("destroy key failed")
def encrypt(self, plain_text, key, alg, iv=None):
return self.__do_cipher_action(plain_text, key, alg, iv, True)
def decrypt(self, cipher_text, key, alg, iv=None):
return self.__do_cipher_action(cipher_text, key, alg, iv, False)
def __do_cipher_action(self, text, key, alg, iv=None, encrypt=True):
text = (c_ubyte * len(text))(*text)
if iv is not None:
iv = (c_ubyte * len(iv))(*iv)
temp_data = (c_ubyte * len(text))()
temp_data_length = c_int()
if encrypt:
ret = self._driver.SDF_Encrypt(self._session, key, c_int(alg), iv, text, c_int(len(text)), temp_data,
pointer(temp_data_length))
if not ret == 0:
raise PiicoError("encrypt failed", ret)
else:
ret = self._driver.SDF_Decrypt(self._session, key, c_int(alg), iv, text, c_int(len(text)), temp_data,
pointer(temp_data_length))
if not ret == 0:
raise PiicoError("decrypt failed", ret)
return temp_data[:temp_data_length.value]

View File

View File

@ -1,6 +1,7 @@
import base64
import logging
import re
from Cryptodome.Cipher import AES, PKCS1_v1_5
from Cryptodome.Random import get_random_bytes
from Cryptodome.PublicKey import RSA
@ -11,6 +12,7 @@ from gmssl.sm4 import CryptSM4, SM4_ENCRYPT, SM4_DECRYPT
from django.conf import settings
from django.core.exceptions import ImproperlyConfigured
from common.sdk.gm import piico
secret_pattern = re.compile(r'password|secret|key|token', re.IGNORECASE)
@ -64,6 +66,25 @@ class GMSM4EcbCrypto(BaseCrypto):
return self.sm4_decryptor.crypt_ecb(data)
class PiicoSM4EcbCrypto(BaseCrypto):
@staticmethod
def to_16(key):
while len(key) % 16 != 0:
key += b'\0'
return key # 返回bytes
def __init__(self, key, device: piico.Device):
key = padding_key(key, 16)
self.cipher = device.new_sm4_ebc_cipher(key)
def _encrypt(self, data: bytes) -> bytes:
return self.cipher.encrypt(self.to_16(data))
def _decrypt(self, data: bytes) -> bytes:
return self.cipher.decrypt(data)
class AESCrypto:
"""
AES
@ -164,6 +185,11 @@ def get_gm_sm4_ecb_crypto(key=None):
return GMSM4EcbCrypto(key)
def get_piico_gm_sm4_ecb_crypto(device, key=None):
key = key or settings.SECRET_KEY
return PiicoSM4EcbCrypto(key, device)
aes_ecb_crypto = get_aes_crypto(mode='ECB')
aes_crypto = get_aes_crypto(mode='GCM')
gm_sm4_ecb_crypto = get_gm_sm4_ecb_crypto()
@ -183,10 +209,16 @@ class Crypto:
crypt_algo = settings.SECURITY_DATA_CRYPTO_ALGO
if not crypt_algo:
if settings.GMSSL_ENABLED:
if settings.PIICO_DEVICE_ENABLE:
piico_driver_path = settings.PIICO_DRIVER_PATH if settings.PIICO_DRIVER_PATH \
else "./lib/libpiico_ccmu.so"
device = piico.open_piico_device(piico_driver_path)
self.cryptor_map["piico_gm"] = get_piico_gm_sm4_ecb_crypto(device)
crypt_algo = 'piico_gm'
else:
crypt_algo = 'gm'
else:
crypt_algo = 'aes'
cryptor = self.cryptor_map.get(crypt_algo, None)
if cryptor is None:
raise ImproperlyConfigured(