1.新增eab列表

2.申请证书新增http代理、新增ca选择(zerossl、google)、新增证书算法选择
3.修复数据库连接内存泄漏
pull/117/head
zhangchenhao 2025-05-21 11:31:36 +08:00
parent eb302776a8
commit f64d2b2764
19 changed files with 489 additions and 210 deletions

View File

@ -135,6 +135,148 @@ func DelAccess(c *gin.Context) {
return return
} }
func GetAllEAB(c *gin.Context) {
var form struct {
CA string `form:"ca"`
}
err := c.Bind(&form)
if err != nil {
public.FailMsg(c, err.Error())
return
}
eabList, err := access.GetAllEAB(form.CA)
if err != nil {
public.FailMsg(c, err.Error())
return
}
public.SuccessData(c, eabList, 0)
return
}
func GetEABList(c *gin.Context) {
var form struct {
Search string `form:"search"`
Page int64 `form:"p"`
Limit int64 `form:"limit"`
}
err := c.Bind(&form)
if err != nil {
public.FailMsg(c, err.Error())
return
}
eabList, count, err := access.GetEABList(form.Search, form.Page, form.Limit)
if err != nil {
public.FailMsg(c, err.Error())
return
}
public.SuccessData(c, eabList, count)
return
}
func AddEAB(c *gin.Context) {
var form struct {
Name string `form:"name"`
Kid string `form:"Kid"`
HmacEncoded string `form:"HmacEncoded"`
CA string `form:"ca"`
}
err := c.Bind(&form)
if err != nil {
public.FailMsg(c, err.Error())
return
}
form.Name = strings.TrimSpace(form.Name)
form.Kid = strings.TrimSpace(form.Kid)
form.HmacEncoded = strings.TrimSpace(form.HmacEncoded)
form.CA = strings.TrimSpace(form.CA)
if form.Name == "" {
public.FailMsg(c, "名称不能为空")
return
}
if form.Kid == "" {
public.FailMsg(c, "ID不能为空")
return
}
if form.HmacEncoded == "" {
public.FailMsg(c, "HmacEncoded不能为空")
return
}
if form.CA == "" {
public.FailMsg(c, "CA不能为空")
return
}
err = access.AddEAB(form.Name, form.Kid, form.HmacEncoded, form.CA)
if err != nil {
public.FailMsg(c, err.Error())
}
public.SuccessMsg(c, "添加成功")
return
}
func UpdEAB(c *gin.Context) {
var form struct {
ID string `form:"id"`
Name string `form:"name"`
Kid string `form:"Kid"`
HmacEncoded string `form:"HmacEncoded"`
CA string `form:"ca"`
}
err := c.Bind(&form)
if err != nil {
public.FailMsg(c, err.Error())
return
}
form.Name = strings.TrimSpace(form.Name)
form.Kid = strings.TrimSpace(form.Kid)
form.HmacEncoded = strings.TrimSpace(form.HmacEncoded)
form.CA = strings.TrimSpace(form.CA)
if form.Name == "" {
public.FailMsg(c, "名称不能为空")
return
}
if form.Kid == "" {
public.FailMsg(c, "ID不能为空")
return
}
if form.HmacEncoded == "" {
public.FailMsg(c, "HmacEncoded不能为空")
return
}
if form.CA == "" {
public.FailMsg(c, "CA不能为空")
return
}
err = access.UpdEAB(form.ID, form.Name, form.Kid, form.HmacEncoded, form.CA)
if err != nil {
public.FailMsg(c, err.Error())
}
public.SuccessMsg(c, "修改成功")
return
}
func DelEAB(c *gin.Context) {
var form struct {
ID string `form:"id"`
}
err := c.Bind(&form)
if err != nil {
public.FailMsg(c, err.Error())
return
}
form.ID = strings.TrimSpace(form.ID)
if form.ID == "" {
public.FailMsg(c, "ID不能为空")
return
}
err = access.DelEAB(form.ID)
if err != nil {
public.FailMsg(c, err.Error())
return
}
public.SuccessMsg(c, "删除成功")
return
}
func TestAccess(c *gin.Context) { func TestAccess(c *gin.Context) {
var form struct { var form struct {
ID string `form:"id"` ID string `form:"id"`
@ -149,7 +291,7 @@ func TestAccess(c *gin.Context) {
public.FailMsg(c, "类型不能为空") public.FailMsg(c, "类型不能为空")
return return
} }
var result error var result error
switch form.Type { switch form.Type {
case "btwaf": case "btwaf":
@ -171,12 +313,12 @@ func TestAccess(c *gin.Context) {
default: default:
public.FailMsg(c, "不支持测试的提供商") public.FailMsg(c, "不支持测试的提供商")
} }
if result != nil { if result != nil {
public.FailMsg(c, result.Error()) public.FailMsg(c, result.Error())
return return
} }
public.SuccessMsg(c, "请求测试成功!") public.SuccessMsg(c, "请求测试成功!")
return return
} }

View File

@ -32,7 +32,6 @@ func Sign(c *gin.Context) {
public.FailMsg(c, err.Error()) public.FailMsg(c, err.Error())
return return
} }
s.Connect()
defer s.Close() defer s.Close()
s.TableName = "users" s.TableName = "users"
res, err := s.Where("username=?", []interface{}{form.Username}).Select() res, err := s.Where("username=?", []interface{}{form.Username}).Select()

View File

@ -12,7 +12,6 @@ func GetSqlite() (*public.Sqlite, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
s.Connect()
s.TableName = "access" s.TableName = "access"
return s, nil return s, nil
} }

View File

@ -9,7 +9,6 @@ func GetSqliteAT() (*public.Sqlite, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
s.Connect()
s.TableName = "access_type" s.TableName = "access_type"
return s, nil return s, nil
} }

View File

@ -29,7 +29,7 @@ func (u *MyUser) GetPrivateKey() crypto.PrivateKey {
return u.key return u.key
} }
func SaveUserToDB(db *public.Sqlite, user *MyUser) error { func SaveUserToDB(db *public.Sqlite, user *MyUser, Type string) error {
keyBytes, err := x509.MarshalPKCS8PrivateKey(user.key) keyBytes, err := x509.MarshalPKCS8PrivateKey(user.key)
if err != nil { if err != nil {
return err return err
@ -53,13 +53,13 @@ func SaveUserToDB(db *public.Sqlite, user *MyUser) error {
"reg": regBytes, "reg": regBytes,
"create_time": now, "create_time": now,
"update_time": now, "update_time": now,
"type": "Let's Encrypt", "type": Type,
}) })
return err return err
} }
func LoadUserFromDB(db *public.Sqlite, email string) (*MyUser, error) { func LoadUserFromDB(db *public.Sqlite, email string, Type string) (*MyUser, error) {
data, err := db.Where(`email=?`, []interface{}{email}).Select() data, err := db.Where(`email=? and type=?`, []interface{}{email, Type}).Select()
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -22,17 +22,33 @@ import (
"github.com/go-acme/lego/v4/providers/dns/volcengine" "github.com/go-acme/lego/v4/providers/dns/volcengine"
"github.com/go-acme/lego/v4/providers/dns/westcn" "github.com/go-acme/lego/v4/providers/dns/westcn"
"github.com/go-acme/lego/v4/registration" "github.com/go-acme/lego/v4/registration"
"net/http"
"net/url"
"strconv" "strconv"
"strings" "strings"
"time" "time"
) )
var AlgorithmMap = map[string]certcrypto.KeyType{
"RSA2048": certcrypto.RSA2048,
"RSA3072": certcrypto.RSA3072,
"RSA4096": certcrypto.RSA4096,
"RSA8192": certcrypto.RSA8192,
"EC256": certcrypto.EC256,
"EC384": certcrypto.EC384,
}
var CADirURLMap = map[string]string{
"Let's Encrypt": "https://acme-v02.api.letsencrypt.org/directory",
"zerossl": "https://acme.zerossl.com/v2/DV90",
"google": "https://dv.acme-v02.api.pki.goog/directory",
}
func GetSqlite() (*public.Sqlite, error) { func GetSqlite() (*public.Sqlite, error) {
s, err := public.NewSqlite("data/data.db", "") s, err := public.NewSqlite("data/data.db", "")
if err != nil { if err != nil {
return nil, err return nil, err
} }
s.Connect()
s.TableName = "_accounts" s.TableName = "_accounts"
return s, nil return s, nil
} }
@ -77,17 +93,179 @@ func GetDNSProvider(providerName string, creds map[string]string) (challenge.Pro
config.SecretKey = creds["secret_key"] config.SecretKey = creds["secret_key"]
return volcengine.NewDNSProviderConfig(config) return volcengine.NewDNSProviderConfig(config)
// case "godaddy":
// config := godaddy.NewDefaultConfig()
// config.APIKey = creds["api_key"]
// config.APISecret = creds["api_secret"]
// return godaddy.NewDNSProviderConfig(config)
default: default:
return nil, fmt.Errorf("不支持的 DNS Provider: %s", providerName) return nil, fmt.Errorf("不支持的 DNS Provider: %s", providerName)
} }
} }
func GetAcmeClient(db *public.Sqlite, email, algorithm, ca, proxy, eabId string, logger *public.Logger) (*lego.Client, error) {
user, err := LoadUserFromDB(db, email, ca)
if err != nil {
logger.Debug("acme账号不存在注册新账号")
privateKey, _ := ecdsa.GenerateKey(elliptic.P384(), rand.Reader)
user = &MyUser{
Email: email,
key: privateKey,
}
config := lego.NewConfig(user)
config.Certificate.KeyType = AlgorithmMap[algorithm]
config.CADirURL = CADirURLMap[ca]
if proxy != "" {
// 构建代理 HTTP 客户端
proxyURL, err := url.Parse(proxy) // 替换为你的代理地址
if err != nil {
return nil, fmt.Errorf("无效的代理地址: %v", err)
}
httpClient := &http.Client{
Transport: &http.Transport{
Proxy: http.ProxyURL(proxyURL),
},
Timeout: 30 * time.Second,
}
config.HTTPClient = httpClient
}
client, err := lego.NewClient(config)
if err != nil {
return nil, err
}
logger.Debug("正在注册账号:" + email)
var reg *registration.Resource
switch ca {
case "Let's Encrypt":
reg, err = client.Registration.Register(registration.RegisterOptions{TermsOfServiceAgreed: true})
case "zerossl", "google":
// 获取EAB参数
var eabData map[string]any
if eabId == "" {
data, err := access.GetAllEAB(ca)
if err != nil {
return nil, err
}
if len(data) <= 0 {
return nil, fmt.Errorf("未找到EAB信息")
}
eabData = data[0]
} else {
eabData, err = access.GetEAB(eabId)
if err != nil {
return nil, err
}
if eabData == nil {
return nil, fmt.Errorf("未找到EAB信息")
}
}
Kid := eabData["kid"].(string)
HmacEncoded := eabData["HmacEncoded"].(string)
reg, err = client.Registration.RegisterWithExternalAccountBinding(registration.RegisterEABOptions{
TermsOfServiceAgreed: true,
Kid: Kid,
HmacEncoded: HmacEncoded,
})
default:
reg, err = client.Registration.Register(registration.RegisterOptions{TermsOfServiceAgreed: true})
}
if err != nil {
return nil, err
}
user.Registration = reg
err = SaveUserToDB(db, user, ca)
if err != nil {
return nil, err
}
logger.Debug("acme账号注册并保存成功")
return client, nil
} else {
config := lego.NewConfig(user)
config.Certificate.KeyType = AlgorithmMap[algorithm]
config.CADirURL = CADirURLMap[ca]
if proxy != "" {
// 构建代理 HTTP 客户端
proxyURL, err := url.Parse(proxy) // 替换为你的代理地址
if err != nil {
return nil, fmt.Errorf("无效的代理地址: %v", err)
}
httpClient := &http.Client{
Transport: &http.Transport{
Proxy: http.ProxyURL(proxyURL),
},
Timeout: 30 * time.Second,
}
config.HTTPClient = httpClient
}
// 初始化 ACME 客户端
client, err := lego.NewClient(config)
if err != nil {
return nil, err
}
return client, nil
}
}
func GetCert(runId string, domainArr []string, endDay int, logger *public.Logger) (map[string]any, error) {
if runId == "" {
return nil, fmt.Errorf("参数错误_runId")
}
s, err := public.NewSqlite("data/data.db", "")
if err != nil {
return nil, err
}
s.TableName = "workflow_history"
defer s.Close()
// 查询 workflowId
wh, err := s.Where("id=?", []interface{}{runId}).Select()
if err != nil {
return nil, err
}
if len(wh) <= 0 {
return nil, fmt.Errorf("未获取到对应的workflowId")
}
s.TableName = "cert"
certs, err := s.Where("workflow_id=?", []interface{}{wh[0]["workflow_id"]}).Select()
if err != nil {
return nil, err
}
if len(certs) <= 0 {
return nil, fmt.Errorf("未获取到当前工作流下的证书")
}
layout := "2006-01-02 15:04:05"
var maxDays float64
var maxItem map[string]any
for i := range certs {
if !public.ContainsAllIgnoreBRepeats(strings.Split(certs[i]["domains"].(string), ","), domainArr) {
continue
}
endTimeStr, ok := certs[i]["end_time"].(string)
if !ok {
continue
}
endTime, err := time.Parse(layout, endTimeStr)
if err != nil {
continue
}
diff := endTime.Sub(time.Now()).Hours() / 24
if diff > maxDays {
maxDays = diff
maxItem = certs[i]
}
}
if maxItem == nil {
return nil, fmt.Errorf("未获取到对应的证书")
}
if int(maxDays) <= endDay {
return nil, fmt.Errorf("证书已过期或即将过期,剩余天数:%d 小于%d天", int(maxDays), endDay)
}
// 证书未过期,直接返回
logger.Debug(fmt.Sprintf("上次证书申请成功,域名:%s剩余天数%d 大于%d天已跳过申请复用此证书", maxItem["domains"], int(maxDays), endDay))
return map[string]any{
"cert": maxItem["cert"],
"key": maxItem["key"],
"issuerCert": maxItem["issuer_cert"],
}, nil
}
func Apply(cfg map[string]any, logger *public.Logger) (map[string]any, error) { func Apply(cfg map[string]any, logger *public.Logger) (map[string]any, error) {
db, err := GetSqlite() db, err := GetSqlite()
if err != nil { if err != nil {
@ -107,6 +285,44 @@ func Apply(cfg map[string]any, logger *public.Logger) (map[string]any, error) {
if !ok { if !ok {
return nil, fmt.Errorf("参数错误provider") return nil, fmt.Errorf("参数错误provider")
} }
endDay := 30
switch v := cfg["end_day"].(type) {
case float64:
endDay = int(v)
case int:
endDay = v
case string:
if v != "" {
endDay, err = strconv.Atoi(v)
if err != nil {
return nil, fmt.Errorf("参数错误end_day")
}
}
case int64:
endDay = int(v)
}
algorithm, ok := cfg["algorithm"].(string)
if !ok {
algorithm = "RSA2048"
}
ca, ok := cfg["ca"].(string)
if !ok {
ca = "Let's Encrypt"
}
proxy, ok := cfg["proxy"].(string)
if !ok {
proxy = ""
}
var eabId string
switch v := cfg["eabId"].(type) {
case float64:
eabId = strconv.Itoa(int(v))
case string:
eabId = v
default:
eabId = ""
}
var providerID string var providerID string
switch v := cfg["provider_id"].(type) { switch v := cfg["provider_id"].(type) {
case float64: case float64:
@ -178,100 +394,15 @@ func Apply(cfg map[string]any, logger *public.Logger) (map[string]any, error) {
if !ok { if !ok {
return nil, fmt.Errorf("参数错误_runId") return nil, fmt.Errorf("参数错误_runId")
} }
if runId != "" { certData, err := GetCert(runId, domainArr, endDay, logger)
s, err := public.NewSqlite("data/data.db", "") if err != nil {
if err != nil { logger.Debug("未获取到符合条件的本地证书:" + err.Error())
return nil, err } else {
} return certData, nil
s.Connect()
s.TableName = "workflow_history"
defer s.Close()
// 查询 workflowId
wh, err := s.Where("id=?", []interface{}{runId}).Select()
if err != nil {
return nil, err
}
if len(wh) > 0 {
s.TableName = "cert"
certs, err := s.Where("workflow_id=?", []interface{}{wh[0]["workflow_id"]}).Select()
if err != nil {
return nil, err
}
if len(certs) > 0 {
layout := "2006-01-02 15:04:05"
var maxDays float64
var maxItem map[string]any
for i := range certs {
if !public.ContainsAllIgnoreBRepeats(strings.Split(certs[i]["domains"].(string), ","), domainArr) {
continue
}
endTimeStr, ok := certs[i]["end_time"].(string)
if !ok {
continue
}
endTime, err := time.Parse(layout, endTimeStr)
if err != nil {
continue
}
diff := endTime.Sub(time.Now()).Hours() / 24
if diff > maxDays {
maxDays = diff
maxItem = certs[i]
}
}
certObj := maxItem
// 判断证书是否过期
cfgEnd, ok := cfg["end_day"].(int)
if !ok || cfgEnd <= 0 {
cfgEnd = 30
}
if int(maxDays) > cfgEnd {
// 证书未过期,直接返回
logger.Debug(fmt.Sprintf("上次证书申请成功,域名:%s剩余天数%d 大于%d天已跳过申请复用此证书", certObj["domains"], int(maxDays), cfgEnd))
return map[string]any{
"cert": certObj["cert"],
"key": certObj["key"],
"issuerCert": certObj["issuer_cert"],
}, nil
}
}
}
} }
logger.Debug("正在申请证书,域名: " + domains) logger.Debug("正在申请证书,域名: " + domains)
// 创建 ACME 客户端
user, err := LoadUserFromDB(db, email) client, err := GetAcmeClient(db, email, algorithm, ca, proxy, eabId, logger)
if err != nil {
logger.Debug("acme账号不存在注册新账号")
privateKey, _ := ecdsa.GenerateKey(elliptic.P384(), rand.Reader)
user = &MyUser{
Email: email,
key: privateKey,
}
config := lego.NewConfig(user)
config.Certificate.KeyType = certcrypto.EC384
client, err := lego.NewClient(config)
if err != nil {
return nil, err
}
logger.Debug("正在注册账号:" + email)
reg, err := client.Registration.Register(registration.RegisterOptions{TermsOfServiceAgreed: true})
if err != nil {
return nil, err
}
user.Registration = reg
err = SaveUserToDB(db, user)
if err != nil {
return nil, err
}
logger.Debug("账号注册并保存成功")
}
// 初始化 ACME 客户端
client, err := lego.NewClient(lego.NewConfig(user))
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -13,7 +13,6 @@ func GetSqlite() (*public.Sqlite, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
s.Connect()
s.TableName = "cert" s.TableName = "cert"
return s, nil return s, nil
} }
@ -26,7 +25,7 @@ func GetList(search string, p, limit int64) ([]map[string]any, int, error) {
return data, 0, err return data, 0, err
} }
defer s.Close() defer s.Close()
var limits []int64 var limits []int64
if p >= 0 && limit >= 0 { if p >= 0 && limit >= 0 {
limits = []int64{0, limit} limits = []int64{0, limit}
@ -35,7 +34,7 @@ func GetList(search string, p, limit int64) ([]map[string]any, int, error) {
limits[1] = p * limit limits[1] = p * limit
} }
} }
if search != "" { if search != "" {
count, err = s.Where("domains like ?", []interface{}{"%" + search + "%"}).Count() count, err = s.Where("domains like ?", []interface{}{"%" + search + "%"}).Count()
data, err = s.Where("domains like ?", []interface{}{"%" + search + "%"}).Limit(limits).Order("create_time", "desc").Select() data, err = s.Where("domains like ?", []interface{}{"%" + search + "%"}).Limit(limits).Order("create_time", "desc").Select()
@ -68,7 +67,6 @@ func AddCert(source, key, cert, issuer, issuerCert, domains, sha256, historyId,
if err != nil { if err != nil {
return err return err
} }
s.Connect()
s.TableName = "workflow_history" s.TableName = "workflow_history"
defer s.Close() defer s.Close()
// 查询 workflowId // 查询 workflowId
@ -80,7 +78,7 @@ func AddCert(source, key, cert, issuer, issuerCert, domains, sha256, historyId,
workflowId = wh[0]["workflow_id"].(string) workflowId = wh[0]["workflow_id"].(string)
} }
} }
now := time.Now().Format("2006-01-02 15:04:05") now := time.Now().Format("2006-01-02 15:04:05")
_, err = s.Insert(map[string]any{ _, err = s.Insert(map[string]any{
"source": source, "source": source,
@ -108,7 +106,7 @@ func SaveCert(source, key, cert, issuerCert, historyId string) (string, error) {
if err := public.ValidateSSLCertificate(cert, key); err != nil { if err := public.ValidateSSLCertificate(cert, key); err != nil {
return "", err return "", err
} }
certObj, err := public.ParseCertificate([]byte(cert)) certObj, err := public.ParseCertificate([]byte(cert))
if err != nil { if err != nil {
return "", fmt.Errorf("解析证书失败: %v", err) return "", fmt.Errorf("解析证书失败: %v", err)
@ -121,23 +119,23 @@ func SaveCert(source, key, cert, issuerCert, historyId string) (string, error) {
if d, _ := GetCert(sha256); d != nil { if d, _ := GetCert(sha256); d != nil {
return sha256, nil return sha256, nil
} }
domainSet := make(map[string]bool) domainSet := make(map[string]bool)
if certObj.Subject.CommonName != "" { if certObj.Subject.CommonName != "" {
domainSet[certObj.Subject.CommonName] = true domainSet[certObj.Subject.CommonName] = true
} }
for _, dns := range certObj.DNSNames { for _, dns := range certObj.DNSNames {
domainSet[dns] = true domainSet[dns] = true
} }
// 转成切片并拼接成逗号分隔的字符串 // 转成切片并拼接成逗号分隔的字符串
var domains []string var domains []string
for domain := range domainSet { for domain := range domainSet {
domains = append(domains, domain) domains = append(domains, domain)
} }
domainList := strings.Join(domains, ",") domainList := strings.Join(domains, ",")
// 提取 CA 名称Issuer 的组织名) // 提取 CA 名称Issuer 的组织名)
caName := "UNKNOWN" caName := "UNKNOWN"
if len(certObj.Issuer.Organization) > 0 { if len(certObj.Issuer.Organization) > 0 {
@ -149,7 +147,7 @@ func SaveCert(source, key, cert, issuerCert, historyId string) (string, error) {
startTime := certObj.NotBefore.Format("2006-01-02 15:04:05") startTime := certObj.NotBefore.Format("2006-01-02 15:04:05")
endTime := certObj.NotAfter.Format("2006-01-02 15:04:05") endTime := certObj.NotAfter.Format("2006-01-02 15:04:05")
endDay := fmt.Sprintf("%d", int(certObj.NotAfter.Sub(time.Now()).Hours()/24)) endDay := fmt.Sprintf("%d", int(certObj.NotAfter.Sub(time.Now()).Hours()/24))
err = AddCert(source, key, cert, caName, issuerCert, domainList, sha256, historyId, startTime, endTime, endDay) err = AddCert(source, key, cert, caName, issuerCert, domainList, sha256, historyId, startTime, endTime, endDay)
if err != nil { if err != nil {
return "", fmt.Errorf("保存证书失败: %v", err) return "", fmt.Errorf("保存证书失败: %v", err)
@ -171,7 +169,7 @@ func DelCert(id string) error {
return err return err
} }
defer s.Close() defer s.Close()
_, err = s.Where("id=?", []interface{}{id}).Delete() _, err = s.Where("id=?", []interface{}{id}).Delete()
if err != nil { if err != nil {
return err return err
@ -185,7 +183,7 @@ func GetCert(id string) (map[string]string, error) {
return nil, err return nil, err
} }
defer s.Close() defer s.Close()
res, err := s.Where("id=? or sha256=?", []interface{}{id, id}).Select() res, err := s.Where("id=? or sha256=?", []interface{}{id, id}).Select()
if err != nil { if err != nil {
return nil, err return nil, err
@ -193,13 +191,13 @@ func GetCert(id string) (map[string]string, error) {
if len(res) == 0 { if len(res) == 0 {
return nil, fmt.Errorf("证书不存在") return nil, fmt.Errorf("证书不存在")
} }
data := map[string]string{ data := map[string]string{
"domains": res[0]["domains"].(string), "domains": res[0]["domains"].(string),
"cert": res[0]["cert"].(string), "cert": res[0]["cert"].(string),
"key": res[0]["key"].(string), "key": res[0]["key"].(string),
} }
return data, nil return data, nil
} }

View File

@ -12,7 +12,6 @@ func GetWorkflowCount() (map[string]any, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
s.Connect()
defer s.Close() defer s.Close()
workflow, err := s.Query(`select count(*) as count, workflow, err := s.Query(`select count(*) as count,
count(case when exec_type='auto' then 1 end ) as active, count(case when exec_type='auto' then 1 end ) as active,
@ -71,7 +70,6 @@ func GetSiteMonitorCount() (map[string]any, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
s.Connect()
defer s.Close() defer s.Close()
cert, err := s.Query(`select count(*) as count, cert, err := s.Query(`select count(*) as count,
count(case when state='' then 1 end ) as exception count(case when state='' then 1 end ) as exception

View File

@ -16,7 +16,6 @@ func GetSqlite() (*public.Sqlite, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
s.Connect()
s.TableName = "report" s.TableName = "report"
return s, nil return s, nil
} }
@ -29,7 +28,7 @@ func GetList(search string, p, limit int64) ([]map[string]any, int, error) {
return data, 0, err return data, 0, err
} }
defer s.Close() defer s.Close()
var limits []int64 var limits []int64
if p >= 0 && limit >= 0 { if p >= 0 && limit >= 0 {
limits = []int64{0, limit} limits = []int64{0, limit}
@ -38,7 +37,7 @@ func GetList(search string, p, limit int64) ([]map[string]any, int, error) {
limits[1] = p * limit limits[1] = p * limit
} }
} }
if search != "" { if search != "" {
count, err = s.Where("name like ?", []interface{}{"%" + search + "%"}).Count() count, err = s.Where("name like ?", []interface{}{"%" + search + "%"}).Count()
data, err = s.Where("name like ?", []interface{}{"%" + search + "%"}).Limit(limits).Order("update_time", "desc").Select() data, err = s.Where("name like ?", []interface{}{"%" + search + "%"}).Limit(limits).Order("update_time", "desc").Select()
@ -66,7 +65,7 @@ func GetReport(id string) (map[string]any, error) {
return nil, fmt.Errorf("没有找到此通知配置") return nil, fmt.Errorf("没有找到此通知配置")
} }
return data[0], nil return data[0], nil
} }
func AddReport(Type, config, name string) error { func AddReport(Type, config, name string) error {
@ -148,7 +147,7 @@ func Notify(params map[string]any) error {
} }
func NotifyMail(params map[string]any) error { func NotifyMail(params map[string]any) error {
if params == nil { if params == nil {
return fmt.Errorf("缺少参数") return fmt.Errorf("缺少参数")
} }
@ -164,18 +163,18 @@ func NotifyMail(params map[string]any) error {
if err != nil { if err != nil {
return fmt.Errorf("解析配置失败: %v", err) return fmt.Errorf("解析配置失败: %v", err)
} }
e := email.NewEmail() e := email.NewEmail()
e.From = config["sender"] e.From = config["sender"]
e.To = []string{config["receiver"]} e.To = []string{config["receiver"]}
e.Subject = params["subject"].(string) e.Subject = params["subject"].(string)
e.Text = []byte(params["body"].(string)) e.Text = []byte(params["body"].(string))
addr := fmt.Sprintf("%s:%s", config["smtpHost"], config["smtpPort"]) addr := fmt.Sprintf("%s:%s", config["smtpHost"], config["smtpPort"])
auth := smtp.PlainAuth("", config["sender"], config["password"], config["smtpHost"]) auth := smtp.PlainAuth("", config["sender"], config["password"], config["smtpHost"])
// 使用 SSL通常是 465 // 使用 SSL通常是 465
if config["smtpPort"] == "465" { if config["smtpPort"] == "465" {
tlsConfig := &tls.Config{ tlsConfig := &tls.Config{
@ -192,7 +191,7 @@ func NotifyMail(params map[string]any) error {
} }
return nil return nil
} }
// 普通明文发送25端口非推荐 // 普通明文发送25端口非推荐
err = e.Send(addr, auth) err = e.Send(addr, auth)
if err != nil { if err != nil {

View File

@ -29,7 +29,6 @@ func GetSqlite() (*public.Sqlite, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
s.Connect()
s.TableName = "site_monitor" s.TableName = "site_monitor"
return s, nil return s, nil
} }

View File

@ -13,7 +13,6 @@ func GetSqlite() (*public.Sqlite, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
s.Connect()
s.TableName = "workflow" s.TableName = "workflow"
return s, nil return s, nil
} }
@ -26,7 +25,7 @@ func GetList(search string, p, limit int64) ([]map[string]any, int, error) {
return data, 0, err return data, 0, err
} }
defer s.Close() defer s.Close()
var limits []int64 var limits []int64
if p >= 0 && limit >= 0 { if p >= 0 && limit >= 0 {
limits = []int64{0, limit} limits = []int64{0, limit}
@ -35,7 +34,7 @@ func GetList(search string, p, limit int64) ([]map[string]any, int, error) {
limits[1] = p * limit limits[1] = p * limit
} }
} }
if search != "" { if search != "" {
count, err = s.Where("name like ?", []interface{}{"%" + search + "%"}).Count() count, err = s.Where("name like ?", []interface{}{"%" + search + "%"}).Count()
data, err = s.Where("name like ?", []interface{}{"%" + search + "%"}).Order("update_time", "desc").Limit(limits).Select() data, err = s.Where("name like ?", []interface{}{"%" + search + "%"}).Order("update_time", "desc").Limit(limits).Select()
@ -55,7 +54,7 @@ func AddWorkflow(name, content, execType, active, execTime string) error {
if err != nil { if err != nil {
return fmt.Errorf("检测到工作流配置有问题:%v", err) return fmt.Errorf("检测到工作流配置有问题:%v", err)
} }
s, err := GetSqlite() s, err := GetSqlite()
if err != nil { if err != nil {
return err return err
@ -160,7 +159,7 @@ func ExecuteWorkflow(id string) error {
return fmt.Errorf("工作流正在执行中") return fmt.Errorf("工作流正在执行中")
} }
content := data[0]["content"].(string) content := data[0]["content"].(string)
go func(id, c string) { go func(id, c string) {
// defer wg.Done() // defer wg.Done()
// WorkflowID := strconv.FormatInt(id, 10) // WorkflowID := strconv.FormatInt(id, 10)
@ -219,10 +218,10 @@ func RunNode(node *WorkflowNode, ctx *ExecutionContext) error {
node.Config["_runId"] = ctx.RunID node.Config["_runId"] = ctx.RunID
node.Config["logger"] = ctx.Logger node.Config["logger"] = ctx.Logger
node.Config["NodeId"] = node.Id node.Config["NodeId"] = node.Id
// 执行当前节点 // 执行当前节点
result, err := Executors(node.Type, node.Config) result, err := Executors(node.Type, node.Config)
var status ExecutionStatus var status ExecutionStatus
if err != nil { if err != nil {
status = StatusFailed status = StatusFailed
@ -232,9 +231,9 @@ func RunNode(node *WorkflowNode, ctx *ExecutionContext) error {
} else { } else {
status = StatusSuccess status = StatusSuccess
} }
ctx.SetOutput(node.Id, result, status) ctx.SetOutput(node.Id, result, status)
// 普通的并行 // 普通的并行
if node.Type == "branch" { if node.Type == "branch" {
if len(node.ConditionNodes) > 0 { if len(node.ConditionNodes) > 0 {
@ -270,7 +269,7 @@ func RunNode(node *WorkflowNode, ctx *ExecutionContext) error {
} }
} }
} }
if node.ChildNode != nil { if node.ChildNode != nil {
return RunNode(node.ChildNode, ctx) return RunNode(node.ChildNode, ctx)
} }

View File

@ -13,7 +13,6 @@ func GetSqliteObjWH() (*public.Sqlite, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
s.Connect()
s.TableName = "workflow_history" s.TableName = "workflow_history"
return s, nil return s, nil
} }
@ -27,7 +26,7 @@ func GetListWH(id string, p, limit int64) ([]map[string]any, int, error) {
return data, 0, err return data, 0, err
} }
defer s.Close() defer s.Close()
var limits []int64 var limits []int64
if p >= 0 && limit >= 0 { if p >= 0 && limit >= 0 {
limits = []int64{0, limit} limits = []int64{0, limit}
@ -43,7 +42,7 @@ func GetListWH(id string, p, limit int64) ([]map[string]any, int, error) {
count, err = s.Where("workflow_id=?", []interface{}{id}).Count() count, err = s.Where("workflow_id=?", []interface{}{id}).Count()
data, err = s.Where("workflow_id=?", []interface{}{id}).Limit(limits).Order("create_time", "desc").Select() data, err = s.Where("workflow_id=?", []interface{}{id}).Limit(limits).Order("create_time", "desc").Select()
} }
if err != nil { if err != nil {
return data, 0, err return data, 0, err
} }

View File

@ -45,7 +45,7 @@ func SessionAuthMiddleware() gin.HandlerFunc {
} }
// 返回登录页 // 返回登录页
c.Redirect(http.StatusFound, "/login") c.Redirect(http.StatusFound, "/login")
// c.Abort() c.Abort()
return return
} else { } else {
if session.Get("secure") == nil || last == nil { if session.Get("secure") == nil || last == nil {

View File

@ -28,9 +28,9 @@ func init() {
fmt.Fprintf(os.Stderr, "切换目录失败: %v\n", err) fmt.Fprintf(os.Stderr, "切换目录失败: %v\n", err)
os.Exit(1) os.Exit(1)
} }
os.MkdirAll("data", os.ModePerm) os.MkdirAll("data", os.ModePerm)
dbPath := "data/data.db" dbPath := "data/data.db"
_, _ = filepath.Abs(dbPath) _, _ = filepath.Abs(dbPath)
// fmt.Println("数据库路径:", absPath) // fmt.Println("数据库路径:", absPath)
@ -167,7 +167,7 @@ func init() {
workflow_id TEXT not null workflow_id TEXT not null
); );
create table workflow_deploy create table IF NOT EXISTS workflow_deploy
( (
id TEXT, id TEXT,
workflow_id TEXT, workflow_id TEXT,
@ -177,6 +177,19 @@ func init() {
primary key (id, workflow_id) primary key (id, workflow_id)
); );
create table IF NOT EXISTS _eab
(
id integer not null
constraint _eab_pk
primary key autoincrement,
name TEXT,
Kid TEXT not null,
HmacEncoded TEXT not null,
ca TEXT not null,
create_time TEXT,
update_time TEXT
);
`) `)
insertDefaultData(db, "users", "INSERT INTO users (id, username, password, salt) VALUES (1, 'admin', 'xxxxxxx', '&*ghs^&%dag');") insertDefaultData(db, "users", "INSERT INTO users (id, username, password, salt) VALUES (1, 'admin', 'xxxxxxx', '&*ghs^&%dag');")
insertDefaultData(db, "access_type", ` insertDefaultData(db, "access_type", `
@ -187,15 +200,15 @@ func init() {
INSERT INTO access_type (name, type) VALUES ('ssh', 'host'); INSERT INTO access_type (name, type) VALUES ('ssh', 'host');
INSERT INTO access_type (name, type) VALUES ('btpanel', 'host'); INSERT INTO access_type (name, type) VALUES ('btpanel', 'host');
INSERT INTO access_type (name, type) VALUES ('1panel', 'host');`) INSERT INTO access_type (name, type) VALUES ('1panel', 'host');`)
uuidStr := public.GenerateUUID() uuidStr := public.GenerateUUID()
randomStr := public.RandomString(8) randomStr := public.RandomString(8)
port, err := public.GetFreePort() port, err := public.GetFreePort()
if err != nil { if err != nil {
port = 20773 port = 20773
} }
Isql := fmt.Sprintf( Isql := fmt.Sprintf(
`INSERT INTO settings (key, value, create_time, update_time, active, type) VALUES ('log_path', 'logs/ALLinSSL.log', '2025-04-15 15:58', '2025-04-15 15:58', 1, null); `INSERT INTO settings (key, value, create_time, update_time, active, type) VALUES ('log_path', 'logs/ALLinSSL.log', '2025-04-15 15:58', '2025-04-15 15:58', 1, null);
INSERT INTO settings (key, value, create_time, update_time, active, type) VALUES ( 'workflow_log_path', 'logs/workflows/', '2025-04-15 15:58', '2025-04-15 15:58', 1, null); INSERT INTO settings (key, value, create_time, update_time, active, type) VALUES ( 'workflow_log_path', 'logs/workflows/', '2025-04-15 15:58', '2025-04-15 15:58', 1, null);
@ -204,26 +217,26 @@ INSERT INTO settings (key, value, create_time, update_time, active, type) VALUES
INSERT INTO settings (key, value, create_time, update_time, active, type) VALUES ('session_key', '%s', '2025-04-15 15:58', '2025-04-15 15:58', 1, null); INSERT INTO settings (key, value, create_time, update_time, active, type) VALUES ('session_key', '%s', '2025-04-15 15:58', '2025-04-15 15:58', 1, null);
INSERT INTO settings (key, value, create_time, update_time, active, type) VALUES ('secure', '/%s', '2025-04-15 15:58', '2025-04-15 15:58', 1, null); INSERT INTO settings (key, value, create_time, update_time, active, type) VALUES ('secure', '/%s', '2025-04-15 15:58', '2025-04-15 15:58', 1, null);
INSERT INTO settings (key, value, create_time, update_time, active, type) VALUES ('port', '%d', '2025-04-15 15:58', '2025-04-15 15:58', 1, null);`, uuidStr, randomStr, port) INSERT INTO settings (key, value, create_time, update_time, active, type) VALUES ('port', '%d', '2025-04-15 15:58', '2025-04-15 15:58', 1, null);`, uuidStr, randomStr, port)
insertDefaultData(db, "settings", Isql) insertDefaultData(db, "settings", Isql)
InsertIfNotExists(db, "access_type", map[string]any{"name": "cloudflare", "type": "host"}, []string{"name", "type"}, []any{"cloudflare", "host"}) InsertIfNotExists(db, "access_type", map[string]any{"name": "cloudflare", "type": "host"}, []string{"name", "type"}, []any{"cloudflare", "host"})
InsertIfNotExists(db, "access_type", map[string]any{"name": "cloudflare", "type": "dns"}, []string{"name", "type"}, []any{"cloudflare", "dns"}) InsertIfNotExists(db, "access_type", map[string]any{"name": "cloudflare", "type": "dns"}, []string{"name", "type"}, []any{"cloudflare", "dns"})
InsertIfNotExists(db, "access_type", map[string]any{"name": "huaweicloud", "type": "host"}, []string{"name", "type"}, []any{"huaweicloud", "host"}) InsertIfNotExists(db, "access_type", map[string]any{"name": "huaweicloud", "type": "host"}, []string{"name", "type"}, []any{"huaweicloud", "host"})
InsertIfNotExists(db, "access_type", map[string]any{"name": "huaweicloud", "type": "dns"}, []string{"name", "type"}, []any{"huaweicloud", "dns"}) InsertIfNotExists(db, "access_type", map[string]any{"name": "huaweicloud", "type": "dns"}, []string{"name", "type"}, []any{"huaweicloud", "dns"})
InsertIfNotExists(db, "access_type", map[string]any{"name": "baidu", "type": "host"}, []string{"name", "type"}, []any{"baidu", "host"}) InsertIfNotExists(db, "access_type", map[string]any{"name": "baidu", "type": "host"}, []string{"name", "type"}, []any{"baidu", "host"})
InsertIfNotExists(db, "access_type", map[string]any{"name": "baidu", "type": "dns"}, []string{"name", "type"}, []any{"baidu", "dns"}) InsertIfNotExists(db, "access_type", map[string]any{"name": "baidu", "type": "dns"}, []string{"name", "type"}, []any{"baidu", "dns"})
InsertIfNotExists(db, "access_type", map[string]any{"name": "btwaf", "type": "host"}, []string{"name", "type"}, []any{"btwaf", "host"}) InsertIfNotExists(db, "access_type", map[string]any{"name": "btwaf", "type": "host"}, []string{"name", "type"}, []any{"btwaf", "host"})
// 雷池 // 雷池
InsertIfNotExists(db, "access_type", map[string]any{"name": "safeline", "type": "host"}, []string{"name", "type"}, []any{"safeline", "host"}) InsertIfNotExists(db, "access_type", map[string]any{"name": "safeline", "type": "host"}, []string{"name", "type"}, []any{"safeline", "host"})
// 西部数码 // 西部数码
InsertIfNotExists(db, "access_type", map[string]any{"name": "westcn", "type": "dns"}, []string{"name", "type"}, []any{"westcn", "dns"}) InsertIfNotExists(db, "access_type", map[string]any{"name": "westcn", "type": "dns"}, []string{"name", "type"}, []any{"westcn", "dns"})
// 火山引擎 // 火山引擎
InsertIfNotExists(db, "access_type", map[string]any{"name": "volcengine", "type": "dns"}, []string{"name", "type"}, []any{"volcengine", "dns"}) InsertIfNotExists(db, "access_type", map[string]any{"name": "volcengine", "type": "dns"}, []string{"name", "type"}, []any{"volcengine", "dns"})
err = sqlite_migrate.EnsureDatabaseWithTables( err = sqlite_migrate.EnsureDatabaseWithTables(
"data/site_monitor.db", "data/site_monitor.db",
"data/data.db", "data/data.db",
@ -232,7 +245,7 @@ INSERT INTO settings (key, value, create_time, update_time, active, type) VALUES
if err != nil { if err != nil {
fmt.Println("错误:", err) fmt.Println("错误:", err)
} }
db1, err := sql.Open("sqlite", "data/site_monitor.db") db1, err := sql.Open("sqlite", "data/site_monitor.db")
if err != nil { if err != nil {
// fmt.Println("创建数据库失败:", err) // fmt.Println("创建数据库失败:", err)
@ -275,7 +288,7 @@ func insertDefaultData(db *sql.DB, table, insertSQL string) {
// fmt.Println("检查数据行数失败:", err) // fmt.Println("检查数据行数失败:", err)
return return
} }
// 如果表为空,则插入默认数据 // 如果表为空,则插入默认数据
if count == 0 { if count == 0 {
// fmt.Println("表为空,插入默认数据...") // fmt.Println("表为空,插入默认数据...")
@ -309,7 +322,7 @@ func InsertIfNotExists(
whereArgs = append(whereArgs, val) whereArgs = append(whereArgs, val)
i++ i++
} }
// 2. 判断是否存在 // 2. 判断是否存在
query := fmt.Sprintf("SELECT EXISTS(SELECT 1 FROM %s WHERE %s)", table, whereClause) query := fmt.Sprintf("SELECT EXISTS(SELECT 1 FROM %s WHERE %s)", table, whereClause)
var exists bool var exists bool
@ -320,7 +333,7 @@ func InsertIfNotExists(
if exists { if exists {
return nil // 已存在 return nil // 已存在
} }
// 3. 构建 INSERT 语句 // 3. 构建 INSERT 语句
columnList := "" columnList := ""
placeholderList := "" placeholderList := ""
@ -333,11 +346,11 @@ func InsertIfNotExists(
placeholderList += "?" placeholderList += "?"
} }
insertSQL := fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s)", table, columnList, placeholderList) insertSQL := fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s)", table, columnList, placeholderList)
_, err = db.Exec(insertSQL, insertValues...) _, err = db.Exec(insertSQL, insertValues...)
if err != nil { if err != nil {
return fmt.Errorf("insert failed: %w", err) return fmt.Errorf("insert failed: %w", err)
} }
return nil return nil
} }

View File

@ -4,6 +4,7 @@ import (
"crypto" "crypto"
"crypto/ecdsa" "crypto/ecdsa"
"crypto/ed25519" "crypto/ed25519"
"crypto/rand"
"crypto/rsa" "crypto/rsa"
"crypto/sha256" "crypto/sha256"
"crypto/x509" "crypto/x509"
@ -67,7 +68,7 @@ func VerifyCertificateAndKey(cert *x509.Certificate, privateKey crypto.PrivateKe
case *rsa.PrivateKey: case *rsa.PrivateKey:
signature, err = rsa.SignPKCS1v15(nil, key, crypto.SHA256, message) signature, err = rsa.SignPKCS1v15(nil, key, crypto.SHA256, message)
case *ecdsa.PrivateKey: case *ecdsa.PrivateKey:
signature, err = key.Sign(nil, message, crypto.SHA256) signature, err = key.Sign(rand.Reader, message, crypto.SHA256)
case ed25519.PrivateKey: case ed25519.PrivateKey:
signature = ed25519.Sign(key, message) signature = ed25519.Sign(key, message)
default: default:

View File

@ -4,7 +4,7 @@ import (
"database/sql" "database/sql"
"fmt" "fmt"
"os" "os"
_ "modernc.org/sqlite" // 使用 pure Go 实现的 SQLite 驱动 _ "modernc.org/sqlite" // 使用 pure Go 实现的 SQLite 驱动
) )
@ -14,50 +14,50 @@ func EnsureDatabaseWithTables(targetDBPath string, baseDBPath string, tables []s
// fmt.Printf("数据库 %s 已存在,跳过迁移。\n", targetDBPath) // fmt.Printf("数据库 %s 已存在,跳过迁移。\n", targetDBPath)
return nil return nil
} }
// fmt.Printf("数据库 %s 不存在,开始从基础数据库迁移表...\n", targetDBPath) // fmt.Printf("数据库 %s 不存在,开始从基础数据库迁移表...\n", targetDBPath)
// 2. 打开源数据库(只读)和目标数据库(新建) // 2. 打开源数据库(只读)和目标数据库(新建)
baseDB, err := sql.Open("sqlite", baseDBPath) baseDB, err := sql.Open("sqlite", baseDBPath)
if err != nil { if err != nil {
return fmt.Errorf("打开基础数据库失败: %v", err) return fmt.Errorf("打开基础数据库失败: %v", err)
} }
defer baseDB.Close() defer baseDB.Close()
targetDB, err := sql.Open("sqlite", targetDBPath) targetDB, err := sql.Open("sqlite", targetDBPath)
if err != nil { if err != nil {
return fmt.Errorf("创建目标数据库失败: %v", err) return fmt.Errorf("创建目标数据库失败: %v", err)
} }
defer targetDB.Close() defer targetDB.Close()
for _, table := range tables { for _, table := range tables {
// 2.1 获取建表语句 // 2.1 获取建表语句
var createSQL string var createSQL string
query := "SELECT sql FROM sqlite_master WHERE type='table' AND name=?" query := "SELECT sql FROM sqlite_master WHERE type='table' AND name=?"
err = baseDB.QueryRow(query, table).Scan(&createSQL) err = baseDB.QueryRow(query, table).Scan(&createSQL)
if err != nil { if err != nil {
return fmt.Errorf("获取表 %s 的结构失败: %v", table, err) return nil
} }
// 2.2 在目标库中创建表 // 2.2 在目标库中创建表
_, err = targetDB.Exec(createSQL) _, err = targetDB.Exec(createSQL)
if err != nil { if err != nil {
return fmt.Errorf("创建表 %s 失败: %v", table, err) return fmt.Errorf("创建表 %s 失败: %v", table, err)
} }
// 2.3 从基础库读取数据并插入目标库 // 2.3 从基础库读取数据并插入目标库
rows, err := baseDB.Query(fmt.Sprintf("SELECT * FROM %s", table)) rows, err := baseDB.Query(fmt.Sprintf("SELECT * FROM %s", table))
if err != nil { if err != nil {
return fmt.Errorf("读取表 %s 数据失败: %v", table, err) return fmt.Errorf("读取表 %s 数据失败: %v", table, err)
} }
cols, _ := rows.Columns() cols, _ := rows.Columns()
values := make([]interface{}, len(cols)) values := make([]interface{}, len(cols))
valuePtrs := make([]interface{}, len(cols)) valuePtrs := make([]interface{}, len(cols))
tx, _ := targetDB.Begin() tx, _ := targetDB.Begin()
stmt, _ := tx.Prepare(buildInsertSQL(table, len(cols))) stmt, _ := tx.Prepare(buildInsertSQL(table, len(cols)))
for rows.Next() { for rows.Next() {
for i := range values { for i := range values {
valuePtrs[i] = &values[i] valuePtrs[i] = &values[i]
@ -65,12 +65,12 @@ func EnsureDatabaseWithTables(targetDBPath string, baseDBPath string, tables []s
rows.Scan(valuePtrs...) rows.Scan(valuePtrs...)
stmt.Exec(values...) stmt.Exec(values...)
} }
stmt.Close() stmt.Close()
tx.Commit() tx.Commit()
rows.Close() rows.Close()
} }
// fmt.Println("迁移完成。") // fmt.Println("迁移完成。")
return nil return nil
} }

View File

@ -22,7 +22,6 @@ func GetSettingIgnoreError(key string) string {
if err != nil { if err != nil {
return "" return ""
} }
s.Connect()
defer s.Close() defer s.Close()
s.TableName = "settings" s.TableName = "settings"
res, err := s.Where("key=?", []interface{}{key}).Select() res, err := s.Where("key=?", []interface{}{key}).Select()
@ -44,7 +43,6 @@ func UpdateSetting(key, val string) error {
if err != nil { if err != nil {
return err return err
} }
s.Connect()
defer s.Close() defer s.Close()
s.TableName = "settings" s.TableName = "settings"
_, err = s.Where("key=?", []interface{}{key}).Update(map[string]any{"value": val}) _, err = s.Where("key=?", []interface{}{key}).Update(map[string]any{"value": val})
@ -60,14 +58,13 @@ func GetSettingsFromType(typ string) ([]map[string]any, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
s.Connect()
defer s.Close() defer s.Close()
s.TableName = "settings" s.TableName = "settings"
res, err := s.Where("type=?", []interface{}{typ}).Select() res, err := s.Where("type=?", []interface{}{typ}).Select()
if err != nil { if err != nil {
return nil, err return nil, err
} }
return res, nil return res, nil
} }
@ -79,14 +76,14 @@ func GetFreePort() (int, error) {
return 0, err return 0, err
} }
defer ln.Close() defer ln.Close()
addr := ln.Addr().String() addr := ln.Addr().String()
// 提取端口号 // 提取端口号
parts := strings.Split(addr, ":") parts := strings.Split(addr, ":")
if len(parts) < 2 { if len(parts) < 2 {
return 0, fmt.Errorf("invalid address: %s", addr) return 0, fmt.Errorf("invalid address: %s", addr)
} }
var port int var port int
fmt.Sscanf(parts[len(parts)-1], "%d", &port) fmt.Sscanf(parts[len(parts)-1], "%d", &port)
return port, nil return port, nil
@ -105,7 +102,7 @@ func RandomString(length int) string {
func RandomStringWithCharset(length int, charset string) (string, error) { func RandomStringWithCharset(length int, charset string) (string, error) {
result := make([]byte, length) result := make([]byte, length)
charsetLen := big.NewInt(int64(len(charset))) charsetLen := big.NewInt(int64(len(charset)))
for i := 0; i < length; i++ { for i := 0; i < length; i++ {
num, err := rand.Int(rand.Reader, charsetLen) num, err := rand.Int(rand.Reader, charsetLen)
if err != nil { if err != nil {
@ -113,7 +110,7 @@ func RandomStringWithCharset(length int, charset string) (string, error) {
} }
result[i] = charset[num.Int64()] result[i] = charset[num.Int64()]
} }
return string(result), nil return string(result), nil
} }
@ -121,7 +118,7 @@ func RandomStringWithCharset(length int, charset string) (string, error) {
func GenerateUUID() string { func GenerateUUID() string {
// 生成一个新的 UUID // 生成一个新的 UUID
uuidStr := strings.ReplaceAll(uuid.New().String(), "-", "") uuidStr := strings.ReplaceAll(uuid.New().String(), "-", "")
// 返回 UUID 的字符串表示 // 返回 UUID 的字符串表示
return uuidStr return uuidStr
} }
@ -131,7 +128,7 @@ func GetLocalIP() (string, error) {
if err != nil { if err != nil {
return "", err return "", err
} }
for _, iface := range interfaces { for _, iface := range interfaces {
if iface.Flags&net.FlagUp == 0 { if iface.Flags&net.FlagUp == 0 {
continue // 接口未启用 continue // 接口未启用
@ -139,12 +136,12 @@ func GetLocalIP() (string, error) {
if iface.Flags&net.FlagLoopback != 0 { if iface.Flags&net.FlagLoopback != 0 {
continue // 忽略回环地址 continue // 忽略回环地址
} }
addrs, err := iface.Addrs() addrs, err := iface.Addrs()
if err != nil { if err != nil {
continue continue
} }
for _, addr := range addrs { for _, addr := range addrs {
var ip net.IP var ip net.IP
switch v := addr.(type) { switch v := addr.(type) {
@ -153,14 +150,14 @@ func GetLocalIP() (string, error) {
case *net.IPAddr: case *net.IPAddr:
ip = v.IP ip = v.IP
} }
// 只返回 IPv4 内网地址 // 只返回 IPv4 内网地址
if ip != nil && ip.To4() != nil && !ip.IsLoopback() { if ip != nil && ip.To4() != nil && !ip.IsLoopback() {
return ip.String(), nil return ip.String(), nil
} }
} }
} }
return "", fmt.Errorf("没有找到内网 IP") return "", fmt.Errorf("没有找到内网 IP")
} }
@ -170,16 +167,16 @@ func GetPublicIP() (string, error) {
return "", fmt.Errorf("请求失败: %v", err) return "", fmt.Errorf("请求失败: %v", err)
} }
defer resp.Body.Close() defer resp.Body.Close()
if resp.StatusCode != http.StatusOK { if resp.StatusCode != http.StatusOK {
return "", fmt.Errorf("HTTP状态错误: %v", resp.Status) return "", fmt.Errorf("HTTP状态错误: %v", resp.Status)
} }
body, err := io.ReadAll(resp.Body) body, err := io.ReadAll(resp.Body)
if err != nil { if err != nil {
return "", fmt.Errorf("读取响应失败: %v", err) return "", fmt.Errorf("读取响应失败: %v", err)
} }
return string(body), nil return string(body), nil
} }
@ -189,7 +186,7 @@ func ContainsAllIgnoreBRepeats(a, b []string) bool {
for _, item := range a { for _, item := range a {
setA[item] = struct{}{} setA[item] = struct{}{}
} }
// 遍历 B 的唯一元素,判断是否在 A 中 // 遍历 B 的唯一元素,判断是否在 A 中
seen := make(map[string]struct{}) seen := make(map[string]struct{})
for _, item := range b { for _, item := range b {
@ -207,19 +204,19 @@ func ContainsAllIgnoreBRepeats(a, b []string) bool {
// ExecCommand 执行系统命令,并返回 stdout、stderr 和错误 // ExecCommand 执行系统命令,并返回 stdout、stderr 和错误
func ExecCommand(command string) (string, string, error) { func ExecCommand(command string) (string, string, error) {
var cmd *exec.Cmd var cmd *exec.Cmd
// 根据操作系统选择解释器 // 根据操作系统选择解释器
if runtime.GOOS == "windows" { if runtime.GOOS == "windows" {
cmd = exec.Command("cmd", "/C", command) cmd = exec.Command("cmd", "/C", command)
} else { } else {
cmd = exec.Command("bash", "-c", command) cmd = exec.Command("bash", "-c", command)
} }
var stdout, stderr bytes.Buffer var stdout, stderr bytes.Buffer
cmd.Stdout = &stdout cmd.Stdout = &stdout
cmd.Stderr = &stderr cmd.Stderr = &stderr
err := cmd.Run() err := cmd.Run()
return stdout.String(), stderr.String(), err return stdout.String(), stderr.String(), err
} }

View File

@ -8,7 +8,7 @@ import (
func Register(r *gin.Engine) { func Register(r *gin.Engine) {
v1 := r.Group("/v1") v1 := r.Group("/v1")
login := v1.Group("/login") login := v1.Group("/login")
{ {
login.POST("/sign", api.Sign) login.POST("/sign", api.Sign)
@ -44,6 +44,12 @@ func Register(r *gin.Engine) {
access.POST("/upd_access", api.UpdateAccess) access.POST("/upd_access", api.UpdateAccess)
access.POST("/get_all", api.GetAllAccess) access.POST("/get_all", api.GetAllAccess)
access.POST("/test_access", api.TestAccess) access.POST("/test_access", api.TestAccess)
access.POST("/get_eab_list", api.GetEABList)
access.POST("/add_eab", api.AddEAB)
access.POST("/del_eab", api.DelEAB)
access.POST("/upd_eab", api.UpdEAB)
access.POST("/get_all_eab", api.GetAllEAB)
} }
cert := v1.Group("/cert") cert := v1.Group("/cert")
{ {
@ -71,7 +77,7 @@ func Register(r *gin.Engine) {
{ {
overview.POST("/get_overviews", api.GetOverview) overview.POST("/get_overviews", api.GetOverview)
} }
// 1. 提供静态文件服务 // 1. 提供静态文件服务
r.StaticFS("/static", http.Dir("./frontend/static")) // 静态资源路径 r.StaticFS("/static", http.Dir("./frontend/static")) // 静态资源路径
r.StaticFS("/auto-deploy/static", http.Dir("./frontend/static")) // 静态资源路径 r.StaticFS("/auto-deploy/static", http.Dir("./frontend/static")) // 静态资源路径
@ -79,7 +85,7 @@ func Register(r *gin.Engine) {
r.GET("/favicon.ico", func(c *gin.Context) { r.GET("/favicon.ico", func(c *gin.Context) {
c.File("./frontend/favicon.ico") c.File("./frontend/favicon.ico")
}) })
// 3. 前端路由托管:匹配所有其他路由并返回 index.html // 3. 前端路由托管:匹配所有其他路由并返回 index.html
r.NoRoute(func(c *gin.Context) { r.NoRoute(func(c *gin.Context) {
c.File("./frontend/index.html") c.File("./frontend/index.html")

View File

@ -83,8 +83,8 @@ func SiteMonitor() {
os.Remove(path) os.Remove(path)
} }
}() }()
wg.Wait()
} }
} }
} }
wg.Wait()
} }