diff --git a/backend/app/api/cert.go b/backend/app/api/cert.go index e435b2d..1d05588 100644 --- a/backend/app/api/cert.go +++ b/backend/app/api/cert.go @@ -41,7 +41,7 @@ func UploadCert(c *gin.Context) { } form.Key = strings.TrimSpace(form.Key) form.Cert = strings.TrimSpace(form.Cert) - + if form.Key == "" { public.FailMsg(c, "名称不能为空") return @@ -50,12 +50,12 @@ func UploadCert(c *gin.Context) { public.FailMsg(c, "类型不能为空") return } - err = cert.UploadCert(form.Key, form.Cert) + sha256, err := cert.UploadCert(form.Key, form.Cert) if err != nil { public.FailMsg(c, err.Error()) return } - public.SuccessMsg(c, "添加成功") + public.SuccessData(c, sha256, 0) return } @@ -83,7 +83,7 @@ func DelCert(c *gin.Context) { func DownloadCert(c *gin.Context) { ID := c.Query("id") - + if ID == "" { public.FailMsg(c, "ID不能为空") return @@ -93,11 +93,11 @@ func DownloadCert(c *gin.Context) { public.FailMsg(c, err.Error()) return } - + // 构建 zip 包(内存中) buf := new(bytes.Buffer) zipWriter := zip.NewWriter(buf) - + for filename, content := range certData { if filename == "cert" || filename == "key" { writer, err := zipWriter.Create(filename + ".pem") @@ -118,10 +118,10 @@ func DownloadCert(c *gin.Context) { return } // 设置响应头 - + zipName := strings.ReplaceAll(certData["domains"], ".", "_") zipName = strings.ReplaceAll(zipName, ",", "-") - + c.Header("Content-Type", "application/zip") c.Header("Content-Disposition", "attachment; filename="+zipName+".zip") c.Data(200, "application/zip", buf.Bytes()) diff --git a/backend/internal/cert/apply/apply.go b/backend/internal/cert/apply/apply.go index d4c22d4..60b7f68 100644 --- a/backend/internal/cert/apply/apply.go +++ b/backend/internal/cert/apply/apply.go @@ -39,18 +39,18 @@ func GetDNSProvider(providerName string, creds map[string]string) (challenge.Pro config.SecretID = creds["secret_id"] config.SecretKey = creds["secret_key"] return tencentcloud.NewDNSProviderConfig(config) - + // case "cloudflare": // config := cloudflare.NewDefaultConfig() // config.AuthToken = creds["CLOUDFLARE_API_TOKEN"] // return cloudflare.NewDNSProviderConfig(config) - + case "aliyun": config := alidns.NewDefaultConfig() config.APIKey = creds["access_key"] config.SecretKey = creds["access_secret"] return alidns.NewDNSProviderConfig(config) - + default: return nil, fmt.Errorf("不支持的 DNS Provider: %s", providerName) } @@ -62,7 +62,7 @@ func Apply(cfg map[string]any, logger *public.Logger) (map[string]any, error) { return nil, err } defer db.Close() - + email, ok := cfg["email"].(string) if !ok { return nil, fmt.Errorf("参数错误:email") @@ -84,7 +84,11 @@ func Apply(cfg map[string]any, logger *public.Logger) (map[string]any, error) { default: return nil, fmt.Errorf("参数错误:provider_id") } - + domainArr := strings.Split(domains, ",") + for i := range domainArr { + domainArr[i] = strings.TrimSpace(domainArr[i]) + } + // 获取上次申请的证书 runId, ok := cfg["_runId"].(string) if !ok { @@ -114,11 +118,17 @@ func Apply(cfg map[string]any, logger *public.Logger) (map[string]any, error) { var maxDays float64 var maxItem map[string]any for i := range certs { + if !public.ContainsAllIgnoreBRepeats(strings.Split(certs[i]["domains"].(string), ","), domainArr) { + continue + } endTimeStr, ok := certs[i]["end_time"].(string) if !ok { continue } - endTime, _ := time.Parse(layout, endTimeStr) + endTime, err := time.Parse(layout, endTimeStr) + if err != nil { + continue + } diff := endTime.Sub(time.Now()).Hours() / 24 if diff > maxDays { maxDays = diff @@ -131,10 +141,10 @@ func Apply(cfg map[string]any, logger *public.Logger) (map[string]any, error) { if !ok || cfgEnd <= 0 { cfgEnd = 30 } - + if int(maxDays) > cfgEnd { // 证书未过期,直接返回 - logger.Debug(fmt.Sprintf("上次证书申请成功,剩余天数:%d 大于%d天,已跳过申请复用此证书", int(maxDays), cfgEnd)) + logger.Debug(fmt.Sprintf("上次证书申请成功,域名:%s,剩余天数:%d 大于%d天,已跳过申请复用此证书", certObj["domains"], int(maxDays), cfgEnd)) return map[string]any{ "cert": certObj["cert"], "key": certObj["key"], @@ -145,7 +155,7 @@ func Apply(cfg map[string]any, logger *public.Logger) (map[string]any, error) { } } logger.Debug("正在申请证书,域名: " + domains) - + user, err := LoadUserFromDB(db, email) if err != nil { logger.Debug("acme账号不存在,注册新账号") @@ -154,10 +164,10 @@ func Apply(cfg map[string]any, logger *public.Logger) (map[string]any, error) { Email: email, key: privateKey, } - + config := lego.NewConfig(user) config.Certificate.KeyType = certcrypto.EC384 - + client, err := lego.NewClient(config) if err != nil { return nil, err @@ -168,14 +178,14 @@ func Apply(cfg map[string]any, logger *public.Logger) (map[string]any, error) { return nil, err } user.Registration = reg - + err = SaveUserToDB(db, user) if err != nil { return nil, err } logger.Debug("账号注册并保存成功") } - + // 初始化 ACME 客户端 client, err := lego.NewClient(lego.NewConfig(user)) if err != nil { @@ -196,13 +206,13 @@ func Apply(cfg map[string]any, logger *public.Logger) (map[string]any, error) { if err != nil { return nil, err } - + // DNS 验证 provider, err := GetDNSProvider(providerStr, providerConfig) if err != nil { return nil, fmt.Errorf("创建 DNS provider 失败: %v", err) } - + err = client.Challenge.SetDNS01Provider(provider, dns01.WrapPreCheck(func(domain, fqdn, value string, check dns01.PreCheckFunc) (bool, error) { // 跳过预检查 @@ -215,29 +225,29 @@ func Apply(cfg map[string]any, logger *public.Logger) (map[string]any, error) { if err != nil { return nil, err } - + // fmt.Println(strings.Split(domains, ",")) request := certificate.ObtainRequest{ - Domains: strings.Split(domains, ","), + Domains: domainArr, Bundle: true, } certObj, err := client.Certificate.Obtain(request) if err != nil { return nil, err } - + certStr := string(certObj.Certificate) keyStr := string(certObj.PrivateKey) issuerCertStr := string(certObj.IssuerCertificate) - + // 保存证书和私钥 data := map[string]any{ "cert": certStr, "key": keyStr, "issuerCert": issuerCertStr, } - - err = cert.SaveCert("workflow", keyStr, certStr, issuerCertStr, runId) + + _, err = cert.SaveCert("workflow", keyStr, certStr, issuerCertStr, runId) if err != nil { return nil, err } diff --git a/backend/internal/cert/cert.go b/backend/internal/cert/cert.go index 979bf3a..c9e0473 100644 --- a/backend/internal/cert/cert.go +++ b/backend/internal/cert/cert.go @@ -26,7 +26,7 @@ func GetList(search string, p, limit int64) ([]map[string]any, int, error) { return data, 0, err } defer s.Close() - + var limits []int64 if p >= 0 && limit >= 0 { limits = []int64{0, limit} @@ -35,7 +35,7 @@ func GetList(search string, p, limit int64) ([]map[string]any, int, error) { limits[1] = p * limit } } - + if search != "" { count, err = s.Where("domains like ?", []interface{}{"%" + search + "%"}).Count() data, err = s.Where("domains like ?", []interface{}{"%" + search + "%"}).Limit(limits).Order("create_time", "desc").Select() @@ -80,7 +80,7 @@ func AddCert(source, key, cert, issuer, issuerCert, domains, sha256, historyId, workflowId = wh[0]["workflow_id"].(string) } } - + now := time.Now().Format("2006-01-02 15:04:05") _, err = s.Insert(map[string]any{ "source": source, @@ -104,40 +104,40 @@ func AddCert(source, key, cert, issuer, issuerCert, domains, sha256, historyId, return nil } -func SaveCert(source, key, cert, issuerCert, historyId string) error { +func SaveCert(source, key, cert, issuerCert, historyId string) (string, error) { if err := public.ValidateSSLCertificate(cert, key); err != nil { - return err + return "", err } - + certObj, err := public.ParseCertificate([]byte(cert)) if err != nil { - return fmt.Errorf("解析证书失败: %v", err) + return "", fmt.Errorf("解析证书失败: %v", err) } // SHA256 sha256, err := public.GetSHA256(cert) if err != nil { - return fmt.Errorf("获取 SHA256 失败: %v", err) + return "", fmt.Errorf("获取 SHA256 失败: %v", err) } if d, _ := GetCert(sha256); d != nil { - return nil + return sha256, nil } - + domainSet := make(map[string]bool) - + if certObj.Subject.CommonName != "" { domainSet[certObj.Subject.CommonName] = true } for _, dns := range certObj.DNSNames { domainSet[dns] = true } - + // 转成切片并拼接成逗号分隔的字符串 var domains []string for domain := range domainSet { domains = append(domains, domain) } domainList := strings.Join(domains, ",") - + // 提取 CA 名称(Issuer 的组织名) caName := "UNKNOWN" if len(certObj.Issuer.Organization) > 0 { @@ -149,20 +149,20 @@ func SaveCert(source, key, cert, issuerCert, historyId string) error { startTime := certObj.NotBefore.Format("2006-01-02 15:04:05") endTime := certObj.NotAfter.Format("2006-01-02 15:04:05") endDay := fmt.Sprintf("%d", int(certObj.NotAfter.Sub(time.Now()).Hours()/24)) - + err = AddCert(source, key, cert, caName, issuerCert, domainList, sha256, historyId, startTime, endTime, endDay) if err != nil { - return fmt.Errorf("保存证书失败: %v", err) + return "", fmt.Errorf("保存证书失败: %v", err) } - return nil + return sha256, nil } -func UploadCert(key, cert string) error { - err := SaveCert("upload", key, cert, "", "") +func UploadCert(key, cert string) (string, error) { + sha256, err := SaveCert("upload", key, cert, "", "") if err != nil { - return fmt.Errorf("保存证书失败: %v", err) + return sha256, fmt.Errorf("保存证书失败: %v", err) } - return nil + return sha256, nil } func DelCert(id string) error { @@ -171,7 +171,7 @@ func DelCert(id string) error { return err } defer s.Close() - + _, err = s.Where("id=?", []interface{}{id}).Delete() if err != nil { return err @@ -185,7 +185,7 @@ func GetCert(id string) (map[string]string, error) { return nil, err } defer s.Close() - + res, err := s.Where("id=? or sha256=?", []interface{}{id, id}).Select() if err != nil { return nil, err @@ -193,13 +193,13 @@ func GetCert(id string) (map[string]string, error) { if len(res) == 0 { return nil, fmt.Errorf("证书不存在") } - + data := map[string]string{ "domains": res[0]["domains"].(string), "cert": res[0]["cert"].(string), "key": res[0]["key"].(string), } - + return data, nil } diff --git a/backend/internal/cert/deploy/ssh.go b/backend/internal/cert/deploy/ssh.go index 12ccda3..ad4e331 100644 --- a/backend/internal/cert/deploy/ssh.go +++ b/backend/internal/cert/deploy/ssh.go @@ -133,11 +133,11 @@ func DeploySSH(cfg map[string]any) error { } beforeCmd, ok := cfg["beforeCmd"].(string) if !ok { - return fmt.Errorf("参数错误:beforeCmd") + beforeCmd = "" } afterCmd, ok := cfg["afterCmd"].(string) if !ok { - return fmt.Errorf("参数错误:afterCmd") + afterCmd = "" } providerData, err := access.GetAccess(providerID) if err != nil { diff --git a/backend/internal/workflow/executor.go b/backend/internal/workflow/executor.go index 8633bb1..0c09c55 100644 --- a/backend/internal/workflow/executor.go +++ b/backend/internal/workflow/executor.go @@ -8,6 +8,7 @@ import ( "ALLinSSL/backend/public" "errors" "fmt" + "strconv" ) // var executors map[string]func(map[string]any) (any, error) @@ -33,7 +34,7 @@ func Executors(exec string, params map[string]any) (any, error) { func apply(params map[string]any) (any, error) { logger := params["logger"].(*public.Logger) - + logger.Info("=============申请证书=============") certificate, err := certApply.Apply(params, logger) if err != nil { @@ -67,28 +68,54 @@ func deploy(params map[string]any) (any, error) { func upload(params map[string]any) (any, error) { logger := params["logger"].(*public.Logger) logger.Info("=============上传证书=============") - - keyStr, ok := params["key"].(string) - if !ok { - logger.Error("上传的密钥有误") - logger.Info("=============上传失败=============") - return nil, errors.New("上传的密钥有误") + // 判断证书id走本地还是走旧上传,应在之后的迭代中移除旧代码 + if params["cert_id"] == nil { + keyStr, ok := params["key"].(string) + if !ok { + logger.Error("上传的密钥有误") + logger.Info("=============上传失败=============") + return nil, errors.New("上传的密钥有误") + } + certStr, ok := params["cert"].(string) + if !ok { + logger.Error("上传的证书有误") + logger.Info("=============上传失败=============") + return nil, errors.New("上传的证书有误") + } + _, err := cert.UploadCert(keyStr, certStr) + if err != nil { + logger.Error(err.Error()) + logger.Info("=============上传失败=============") + return nil, err + } + logger.Info("=============上传成功=============") + + return params, nil + } else { + certId := "" + switch v := params["cert_id"].(type) { + case float64: + certId = strconv.Itoa(int(v)) + case string: + certId = v + default: + logger.Info("=============上传证书获取失败=============") + return nil, errors.New("证书 ID 类型错误") + } + certObj, err := cert.GetCert(certId) + if err != nil { + logger.Error(err.Error()) + logger.Info("=============上传证书获取失败=============") + return nil, err + } + if certObj == nil { + logger.Error("证书不存在") + logger.Info("=============上传证书获取失败=============") + return nil, errors.New("证书不存在") + } + logger.Debug(fmt.Sprintf("证书 ID: %s", certId)) + return certObj, nil } - certStr, ok := params["cert"].(string) - if !ok { - logger.Error("上传的证书有误") - logger.Info("=============上传失败=============") - return nil, errors.New("上传的证书有误") - } - err := cert.UploadCert(keyStr, certStr) - if err != nil { - logger.Error(err.Error()) - logger.Info("=============上传失败=============") - return nil, err - } - logger.Info("=============上传成功=============") - - return params, nil } func notify(params map[string]any) (any, error) { diff --git a/backend/public/utils.go b/backend/public/utils.go index a60431e..811b86e 100644 --- a/backend/public/utils.go +++ b/backend/public/utils.go @@ -179,3 +179,24 @@ func GetPublicIP() (string, error) { return string(body), nil } + +func ContainsAllIgnoreBRepeats(a, b []string) bool { + // 构建 A 的集合 + setA := make(map[string]struct{}) + for _, item := range a { + setA[item] = struct{}{} + } + + // 遍历 B 的唯一元素,判断是否在 A 中 + seen := make(map[string]struct{}) + for _, item := range b { + if _, checked := seen[item]; checked { + continue + } + seen[item] = struct{}{} + if _, ok := setA[item]; !ok { + return false + } + } + return true +}