Merge pull request #539 from erikwilson/update-vendor-kvsql-2

Update vendored kvsql
pull/540/head
Erik Wilson 2019-06-15 12:50:36 -07:00 committed by GitHub
commit 4ea110f746
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 159 additions and 61 deletions

View File

@ -142,7 +142,7 @@ import:
- package: github.com/hashicorp/golang-lru - package: github.com/hashicorp/golang-lru
version: v0.5.0 version: v0.5.0
- package: github.com/ibuildthecloud/kvsql - package: github.com/ibuildthecloud/kvsql
version: 1afc2d8ad7d7e263c1971b05cb37e83aa5562561 version: 79f1f6881e28b90976f070aad6edad8e259057c1
repo: https://github.com/erikwilson/rancher-kvsql.git repo: https://github.com/erikwilson/rancher-kvsql.git
- package: github.com/imdario/mergo - package: github.com/imdario/mergo
version: v0.3.5 version: v0.3.5

View File

@ -122,8 +122,7 @@ golang.org/x/oauth2 a6bd8cefa1811bd24b86f8902872e4e8225f74c4
golang.org/x/time f51c12702a4d776e4c1fa9b0fabab841babae631 golang.org/x/time f51c12702a4d776e4c1fa9b0fabab841babae631
gopkg.in/inf.v0 3887ee99ecf07df5b447e9b00d9c0b2adaa9f3e4 gopkg.in/inf.v0 3887ee99ecf07df5b447e9b00d9c0b2adaa9f3e4
gopkg.in/yaml.v2 v2.2.1 gopkg.in/yaml.v2 v2.2.1
#github.com/ibuildthecloud/kvsql 788464096f5af361d166858efccf26c12dc5b427 github.com/ibuildthecloud/kvsql 79f1f6881e28b90976f070aad6edad8e259057c1 https://github.com/erikwilson/rancher-kvsql.git
github.com/ibuildthecloud/kvsql 1afc2d8ad7d7e263c1971b05cb37e83aa5562561 https://github.com/erikwilson/rancher-kvsql.git
# rootless # rootless
github.com/rootless-containers/rootlesskit v0.4.1 github.com/rootless-containers/rootlesskit v0.4.1

View File

@ -18,6 +18,7 @@ import (
"crypto/tls" "crypto/tls"
"time" "time"
"github.com/coreos/etcd/pkg/transport"
"google.golang.org/grpc" "google.golang.org/grpc"
) )
@ -39,4 +40,6 @@ type Config struct {
DialTimeout time.Duration DialTimeout time.Duration
DialOptions []grpc.DialOption DialOptions []grpc.DialOption
TLSInfo *transport.TLSInfo
} }

View File

@ -5,10 +5,16 @@ import (
"database/sql" "database/sql"
"strings" "strings"
"github.com/coreos/etcd/pkg/transport"
"github.com/go-sql-driver/mysql" "github.com/go-sql-driver/mysql"
"github.com/ibuildthecloud/kvsql/clientv3/driver" "github.com/ibuildthecloud/kvsql/clientv3/driver"
) )
const (
defaultUnixDSN = "root@unix(/var/run/mysqld/mysqld.sock)/"
defaultHostDSN = "root@tcp(127.0.0.1)/"
)
var ( var (
fieldList = "name, value, old_value, old_revision, create_revision, revision, ttl, version, del" fieldList = "name, value, old_value, old_revision, create_revision, revision, ttl, version, del"
baseList = ` baseList = `
@ -46,7 +52,7 @@ INSERT INTO key_value(` + fieldList + `)
} }
nameIdx = "create index name_idx on key_value (name(100))" nameIdx = "create index name_idx on key_value (name(100))"
revisionIdx = "create index revision_idx on key_value (revision)" revisionIdx = "create index revision_idx on key_value (revision)"
createDB = "create database if not exists kubernetes" createDB = "create database if not exists "
) )
func NewMySQL() *driver.Generic { func NewMySQL() *driver.Generic {
@ -65,31 +71,24 @@ func NewMySQL() *driver.Generic {
} }
} }
func Open(dataSourceName string, tlsConfig *tls.Config) (*sql.DB, error) { func Open(dataSourceName string, tlsInfo *transport.TLSInfo) (*sql.DB, error) {
if dataSourceName == "" { tlsConfig, err := tlsInfo.ClientConfig()
dataSourceName = "root@unix(/var/run/mysqld/mysqld.sock)/" if err != nil {
return nil, err
} }
// get database name tlsConfig.MinVersion = tls.VersionTLS11
dsList := strings.Split(dataSourceName, "/") if len(tlsInfo.CertFile) == 0 && len(tlsInfo.KeyFile) == 0 && len(tlsInfo.CAFile) == 0 {
databaseName := dsList[len(dsList)-1] tlsConfig = nil
if databaseName == "" { }
if err := createDBIfNotExist(dataSourceName); err != nil { parsedDSN, err := prepareDSN(dataSourceName, tlsConfig)
return nil, err if err != nil {
} return nil, err
dataSourceName = dataSourceName + "kubernetes" }
if err := createDBIfNotExist(parsedDSN); err != nil {
return nil, err
} }
// setting up tlsConfig db, err := sql.Open("mysql", parsedDSN)
if tlsConfig != nil {
mysql.RegisterTLSConfig("custom", tlsConfig)
if strings.Contains(dataSourceName, "?") {
dataSourceName = dataSourceName + ",tls=custom"
} else {
dataSourceName = dataSourceName + "?tls=custom"
}
}
db, err := sql.Open("mysql", dataSourceName)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -116,13 +115,30 @@ func Open(dataSourceName string, tlsConfig *tls.Config) (*sql.DB, error) {
} }
func createDBIfNotExist(dataSourceName string) error { func createDBIfNotExist(dataSourceName string) error {
config, err := mysql.ParseDSN(dataSourceName)
if err != nil {
return err
}
dbName := config.DBName
db, err := sql.Open("mysql", dataSourceName) db, err := sql.Open("mysql", dataSourceName)
if err != nil { if err != nil {
return err return err
} }
_, err = db.Exec(createDB) _, err = db.Exec(createDB + dbName)
if err != nil { if err != nil {
return err if mysqlError, ok := err.(*mysql.MySQLError); !ok || mysqlError.Number != 1049 {
return err
}
config.DBName = ""
db, err = sql.Open("mysql", config.FormatDSN())
if err != nil {
return err
}
_, err = db.Exec(createDB + dbName)
if err != nil {
return err
}
} }
return nil return nil
} }
@ -130,11 +146,35 @@ func createDBIfNotExist(dataSourceName string) error {
func createIndex(db *sql.DB, indexStmt string) error { func createIndex(db *sql.DB, indexStmt string) error {
_, err := db.Exec(indexStmt) _, err := db.Exec(indexStmt)
if err != nil { if err != nil {
// check if its a duplicate error if mysqlError, ok := err.(*mysql.MySQLError); !ok || mysqlError.Number != 1061 {
if err.(*mysql.MySQLError).Number == 1061 { return err
return nil
} }
return err
} }
return nil return nil
} }
func prepareDSN(dataSourceName string, tlsConfig *tls.Config) (string, error) {
if len(dataSourceName) == 0 {
dataSourceName = defaultUnixDSN
if tlsConfig != nil {
dataSourceName = defaultHostDSN
}
}
config, err := mysql.ParseDSN(dataSourceName)
if err != nil {
return "", err
}
// setting up tlsConfig
if tlsConfig != nil {
mysql.RegisterTLSConfig("custom", tlsConfig)
config.TLSConfig = "custom"
}
dbName := "kubernetes"
if len(config.DBName) > 0 {
dbName = config.DBName
}
config.DBName = dbName
parsedDSN := config.FormatDSN()
return parsedDSN, nil
}

View File

@ -2,14 +2,20 @@ package pgsql
import ( import (
"database/sql" "database/sql"
"net/url"
"regexp" "regexp"
"strconv" "strconv"
"strings" "strings"
"github.com/coreos/etcd/pkg/transport"
"github.com/ibuildthecloud/kvsql/clientv3/driver" "github.com/ibuildthecloud/kvsql/clientv3/driver"
"github.com/lib/pq" "github.com/lib/pq"
) )
const (
defaultDSN = "postgres://postgres:postgres@localhost/"
)
var ( var (
fieldList = "name, value, old_value, old_revision, create_revision, revision, ttl, version, del" fieldList = "name, value, old_value, old_revision, create_revision, revision, ttl, version, del"
baseList = ` baseList = `
@ -46,7 +52,7 @@ INSERT INTO key_value(` + fieldList + `)
`create index if not exists name_idx on key_value (name)`, `create index if not exists name_idx on key_value (name)`,
`create index if not exists revision_idx on key_value (revision)`, `create index if not exists revision_idx on key_value (revision)`,
} }
createDB = "create database kubernetes" createDB = "create database "
) )
func NewPGSQL() *driver.Generic { func NewPGSQL() *driver.Generic {
@ -65,22 +71,16 @@ func NewPGSQL() *driver.Generic {
} }
} }
func Open(dataSourceName string) (*sql.DB, error) { func Open(dataSourceName string, tlsInfo *transport.TLSInfo) (*sql.DB, error) {
if dataSourceName == "" { parsedDSN, err := prepareDSN(dataSourceName, tlsInfo)
dataSourceName = "postgres://postgres:postgres@localhost/" if err != nil {
} else { return nil, err
dataSourceName = "postgres://" + dataSourceName
} }
// get database name // get database name
dsList := strings.Split(dataSourceName, "/") if err := createDBIfNotExist(parsedDSN); err != nil {
databaseName := dsList[len(dsList)-1] return nil, err
if databaseName == "" {
if err := createDBIfNotExist(dataSourceName); err != nil {
return nil, err
}
dataSourceName = dataSourceName + "kubernetes"
} }
db, err := sql.Open("postgres", dataSourceName) db, err := sql.Open("postgres", parsedDSN)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -96,15 +96,35 @@ func Open(dataSourceName string) (*sql.DB, error) {
} }
func createDBIfNotExist(dataSourceName string) error { func createDBIfNotExist(dataSourceName string) error {
u, err := url.Parse(dataSourceName)
if err != nil {
return err
}
dbName := strings.SplitN(u.Path, "/", 2)[1]
db, err := sql.Open("postgres", dataSourceName) db, err := sql.Open("postgres", dataSourceName)
if err != nil { if err != nil {
return err return err
} }
_, err = db.Exec(createDB) err = db.Ping()
// check if database already exists // check if database already exists
if err != nil && err.(*pq.Error).Code != "42P04" { if _, ok := err.(*pq.Error); !ok {
return err return err
} }
if err := err.(*pq.Error); err.Code != "42P04" {
if err.Code != "3D000" {
return err
}
// database doesn't exit, will try to create it
u.Path = "/postgres"
db, err := sql.Open("postgres", u.String())
if err != nil {
return err
}
_, err = db.Exec(createDB + dbName + ";")
if err != nil {
return err
}
}
return nil return nil
} }
@ -117,3 +137,46 @@ func q(sql string) string {
return pref + strconv.Itoa(n) return pref + strconv.Itoa(n)
}) })
} }
func prepareDSN(dataSourceName string, tlsInfo *transport.TLSInfo) (string, error) {
if len(dataSourceName) == 0 {
dataSourceName = defaultDSN
} else {
dataSourceName = "postgres://" + dataSourceName
}
u, err := url.Parse(dataSourceName)
if err != nil {
return "", err
}
if len(u.Path) == 0 || u.Path == "/" {
u.Path = "/kubernetes"
}
queryMap, err := url.ParseQuery(u.RawQuery)
if err != nil {
return "", err
}
// set up tls dsn
params := url.Values{}
sslmode := "require"
if _, ok := queryMap["sslcert"]; tlsInfo.CertFile != "" && !ok {
params.Add("sslcert", tlsInfo.CertFile)
sslmode = "verify-full"
}
if _, ok := queryMap["sslkey"]; tlsInfo.KeyFile != "" && !ok {
params.Add("sslkey", tlsInfo.KeyFile)
sslmode = "verify-full"
}
if _, ok := queryMap["sslrootcert"]; tlsInfo.CAFile != "" && !ok {
params.Add("sslrootcert", tlsInfo.CAFile)
sslmode = "verify-full"
}
if _, ok := queryMap["sslmode"]; !ok {
params.Add("sslmode", sslmode)
}
for k, v := range queryMap {
params.Add(k, v[0])
}
u.RawQuery = params.Encode()
return u.String(), nil
}

View File

@ -115,15 +115,17 @@ func newKV(cfg Config) (*kv, error) {
} }
driver = sqlite.NewSQLite() driver = sqlite.NewSQLite()
case "mysql": case "mysql":
if db, err = mysql.Open(parts[1], cfg.TLS); err != nil { if db, err = mysql.Open(parts[1], cfg.TLSInfo); err != nil {
return nil, err return nil, err
} }
driver = mysql.NewMySQL() driver = mysql.NewMySQL()
case "postgres": case "postgres":
if db, err = pgsql.Open(parts[1]); err != nil { if db, err = pgsql.Open(parts[1], cfg.TLSInfo); err != nil {
return nil, err return nil, err
} }
driver = pgsql.NewPGSQL() driver = pgsql.NewPGSQL()
default:
return nil, fmt.Errorf("unknown driver type [%s]", parts[0])
} }
if err := driver.Start(context.TODO(), db); err != nil { if err := driver.Start(context.TODO(), db); err != nil {

View File

@ -18,7 +18,6 @@ package factory
import ( import (
"context" "context"
"crypto/tls"
"fmt" "fmt"
"sync/atomic" "sync/atomic"
"time" "time"
@ -67,22 +66,14 @@ func NewKVSQLHealthCheck(c storagebackend.Config) (func() error, error) {
} }
func newETCD3Client(c storagebackend.Config) (*clientv3.Client, error) { func newETCD3Client(c storagebackend.Config) (*clientv3.Client, error) {
tlsInfo := transport.TLSInfo{ tlsInfo := &transport.TLSInfo{
CertFile: c.Transport.CertFile, CertFile: c.Transport.CertFile,
KeyFile: c.Transport.KeyFile, KeyFile: c.Transport.KeyFile,
CAFile: c.Transport.CAFile, CAFile: c.Transport.CAFile,
} }
tlsConfig, err := tlsInfo.ClientConfig()
if err != nil {
return nil, err
}
tlsConfig.MinVersion = tls.VersionTLS11
if len(c.Transport.CertFile) == 0 && len(c.Transport.KeyFile) == 0 && len(c.Transport.CAFile) == 0 {
tlsConfig = nil
}
cfg := clientv3.Config{ cfg := clientv3.Config{
Endpoints: c.Transport.ServerList, Endpoints: c.Transport.ServerList,
TLS: tlsConfig, TLSInfo: tlsInfo,
} }
if len(cfg.Endpoints) == 0 { if len(cfg.Endpoints) == 0 {