mirror of https://github.com/k3s-io/k3s
173 lines
3.7 KiB
Go
173 lines
3.7 KiB
Go
package pgsql
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"net/url"
|
|
"regexp"
|
|
"strconv"
|
|
"strings"
|
|
|
|
"github.com/lib/pq"
|
|
"github.com/rancher/kine/pkg/drivers/generic"
|
|
"github.com/rancher/kine/pkg/logstructured"
|
|
"github.com/rancher/kine/pkg/logstructured/sqllog"
|
|
"github.com/rancher/kine/pkg/server"
|
|
"github.com/rancher/kine/pkg/tls"
|
|
)
|
|
|
|
const (
|
|
defaultDSN = "postgres://postgres:postgres@localhost/"
|
|
)
|
|
|
|
var (
|
|
schema = []string{
|
|
`create table if not exists kine
|
|
(
|
|
id SERIAL PRIMARY KEY,
|
|
name VARCHAR(630),
|
|
created INTEGER,
|
|
deleted INTEGER,
|
|
create_revision INTEGER,
|
|
prev_revision INTEGER,
|
|
lease INTEGER,
|
|
value bytea,
|
|
old_value bytea
|
|
);`,
|
|
`CREATE INDEX IF NOT EXISTS kine_name_index ON kine (name)`,
|
|
`CREATE INDEX IF NOT EXISTS kine_name_id_index ON kine (name,id)`,
|
|
`CREATE UNIQUE INDEX IF NOT EXISTS kine_name_prev_revision_uindex ON kine (name, prev_revision)`,
|
|
}
|
|
createDB = "create database "
|
|
)
|
|
|
|
func New(ctx context.Context, dataSourceName string, tlsInfo tls.Config) (server.Backend, error) {
|
|
parsedDSN, err := prepareDSN(dataSourceName, tlsInfo)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if err := createDBIfNotExist(parsedDSN); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
dialect, err := generic.Open(ctx, "postgres", parsedDSN, "$", true)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
dialect.TranslateErr = func(err error) error {
|
|
if err, ok := err.(*pq.Error); ok && err.Code == "23505" {
|
|
return server.ErrKeyExists
|
|
}
|
|
return err
|
|
}
|
|
|
|
if err := setup(dialect.DB); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
dialect.Migrate(context.Background())
|
|
return logstructured.New(sqllog.New(dialect)), nil
|
|
}
|
|
|
|
func setup(db *sql.DB) error {
|
|
for _, stmt := range schema {
|
|
_, err := db.Exec(stmt)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func createDBIfNotExist(dataSourceName string) error {
|
|
u, err := url.Parse(dataSourceName)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
dbName := strings.SplitN(u.Path, "/", 2)[1]
|
|
db, err := sql.Open("postgres", dataSourceName)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer db.Close()
|
|
|
|
err = db.Ping()
|
|
// check if database already exists
|
|
if _, ok := err.(*pq.Error); !ok {
|
|
return err
|
|
}
|
|
if err := err.(*pq.Error); err.Code != "42P04" {
|
|
if err.Code != "3D000" {
|
|
return err
|
|
}
|
|
// database doesn't exit, will try to create it
|
|
u.Path = "/postgres"
|
|
db, err := sql.Open("postgres", u.String())
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer db.Close()
|
|
_, err = db.Exec(createDB + dbName + ";")
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func q(sql string) string {
|
|
regex := regexp.MustCompile(`\?`)
|
|
pref := "$"
|
|
n := 0
|
|
return regex.ReplaceAllStringFunc(sql, func(string) string {
|
|
n++
|
|
return pref + strconv.Itoa(n)
|
|
})
|
|
}
|
|
|
|
func prepareDSN(dataSourceName string, tlsInfo tls.Config) (string, error) {
|
|
if len(dataSourceName) == 0 {
|
|
dataSourceName = defaultDSN
|
|
} else {
|
|
dataSourceName = "postgres://" + dataSourceName
|
|
}
|
|
u, err := url.Parse(dataSourceName)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
if len(u.Path) == 0 || u.Path == "/" {
|
|
u.Path = "/kubernetes"
|
|
}
|
|
|
|
queryMap, err := url.ParseQuery(u.RawQuery)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
// set up tls dsn
|
|
params := url.Values{}
|
|
sslmode := ""
|
|
if _, ok := queryMap["sslcert"]; tlsInfo.CertFile != "" && !ok {
|
|
params.Add("sslcert", tlsInfo.CertFile)
|
|
sslmode = "verify-full"
|
|
}
|
|
if _, ok := queryMap["sslkey"]; tlsInfo.KeyFile != "" && !ok {
|
|
params.Add("sslkey", tlsInfo.KeyFile)
|
|
sslmode = "verify-full"
|
|
}
|
|
if _, ok := queryMap["sslrootcert"]; tlsInfo.CAFile != "" && !ok {
|
|
params.Add("sslrootcert", tlsInfo.CAFile)
|
|
sslmode = "verify-full"
|
|
}
|
|
if _, ok := queryMap["sslmode"]; !ok && sslmode != "" {
|
|
params.Add("sslmode", sslmode)
|
|
}
|
|
for k, v := range queryMap {
|
|
params.Add(k, v[0])
|
|
}
|
|
u.RawQuery = params.Encode()
|
|
return u.String(), nil
|
|
}
|