package helper import ( "bufio" "database/sql" "errors" "fmt" "io" "strings" "time" "github.com/1Panel-dev/1Panel/backend/global" ) type sourceOption struct { dryRun bool mergeInsert int debug bool } type SourceOption func(*sourceOption) func WithDryRun() SourceOption { return func(o *sourceOption) { o.dryRun = true } } func WithMergeInsert(size int) SourceOption { return func(o *sourceOption) { o.mergeInsert = size } } func WithDebug() SourceOption { return func(o *sourceOption) { o.debug = true } } type dbWrapper struct { DB *sql.DB debug bool dryRun bool } func newDBWrapper(db *sql.DB, dryRun, debug bool) *dbWrapper { return &dbWrapper{ DB: db, dryRun: dryRun, debug: debug, } } func (db *dbWrapper) Exec(query string, args ...interface{}) (sql.Result, error) { if db.debug { global.LOG.Debugf("query %s", query) } if db.dryRun { return nil, nil } return db.DB.Exec(query, args...) } func Source(dns string, reader io.Reader, opts ...SourceOption) error { start := time.Now() global.LOG.Infof("source start at %s", start.Format("2006-01-02 15:04:05")) defer func() { end := time.Now() global.LOG.Infof("source end at %s, cost %s", end.Format("2006-01-02 15:04:05"), end.Sub(start)) }() var err error var db *sql.DB var o sourceOption for _, opt := range opts { opt(&o) } dbName, err := getDBNameFromDNS(dns) if err != nil { global.LOG.Errorf("get db name from dns failed, err: %v", err) return err } db, err = sql.Open("mysql", dns) if err != nil { global.LOG.Errorf("open mysql db failed, err: %v", err) return err } defer db.Close() dbWrapper := newDBWrapper(db, o.dryRun, o.debug) _, err = dbWrapper.Exec(fmt.Sprintf("USE `%s`;", dbName)) if err != nil { global.LOG.Errorf("exec `use %s` failed, err: %v", dbName, err) return err } db.SetConnMaxLifetime(3600) r := bufio.NewReader(reader) _, err = dbWrapper.Exec("SET autocommit=0;") if err != nil { global.LOG.Errorf("exec `set autocommit=0` failed, err: %v", err) return err } for { line, err := r.ReadString(';') if err != nil { if err == io.EOF { break } global.LOG.Errorf("read sql failed, err: %v", err) return err } ssql := string(line) ssql, err = trim(ssql) if err != nil { global.LOG.Errorf("trim sql failed, err: %v", err) return err } if o.mergeInsert > 1 && strings.HasPrefix(ssql, "INSERT INTO") { var insertSQLs []string insertSQLs = append(insertSQLs, ssql) for i := 0; i < o.mergeInsert-1; i++ { line, err := r.ReadString(';') if err != nil { if err == io.EOF { break } global.LOG.Errorf("read merge insert sql failed, err: %v", err) return err } ssql2 := string(line) ssql2, err = trim(ssql2) if err != nil { global.LOG.Errorf("trim merge insert sql failed, err: %v", err) return err } if strings.HasPrefix(ssql2, "INSERT INTO") { insertSQLs = append(insertSQLs, ssql2) continue } break } ssql, err = mergeInsert(insertSQLs) if err != nil { global.LOG.Errorf("do merge insert failed, err: %v", err) return err } } _, err = dbWrapper.Exec(ssql) if err != nil { global.LOG.Errorf("exec sql failed, err: %v", err) return err } } _, err = dbWrapper.Exec("COMMIT;") if err != nil { global.LOG.Errorf("exec `commit` failed, err: %v", err) return err } _, err = dbWrapper.Exec("SET autocommit=1;") if err != nil { global.LOG.Errorf("exec `autocommit=1` failed, err: %v", err) return err } return nil } func mergeInsert(insertSQLs []string) (string, error) { if len(insertSQLs) == 0 { return "", errors.New("no input provided") } builder := strings.Builder{} sql1 := insertSQLs[0] sql1 = strings.TrimSuffix(sql1, ";") builder.WriteString(sql1) for i, insertSQL := range insertSQLs[1:] { if i < len(insertSQLs)-1 { builder.WriteString(",") } valuesIdx := strings.Index(insertSQL, "VALUES") if valuesIdx == -1 { return "", errors.New("invalid SQL: missing VALUES keyword") } sqln := insertSQL[valuesIdx:] sqln = strings.TrimPrefix(sqln, "VALUES") sqln = strings.TrimSuffix(sqln, ";") builder.WriteString(sqln) } builder.WriteString(";") return builder.String(), nil } func trim(s string) (string, error) { s = strings.TrimLeft(s, "\n") s = strings.TrimSpace(s) return s, nil } func getDBNameFromDNS(dns string) (string, error) { ss1 := strings.Split(dns, "/") if len(ss1) == 2 { ss2 := strings.Split(ss1[1], "?") if len(ss2) == 2 { return ss2[0], nil } } return "", fmt.Errorf("dns error: %s", dns) }