mirror of https://github.com/hashicorp/consul
508 lines
11 KiB
Go
508 lines
11 KiB
Go
// +build !future
|
|
|
|
/*
|
|
Copyright 2018 SAP SE
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
you may not use this file except in compliance with the License.
|
|
You may obtain a copy of the License at
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
See the License for the specific language governing permissions and
|
|
limitations under the License.
|
|
*/
|
|
|
|
package driver
|
|
|
|
import (
|
|
"context"
|
|
"database/sql/driver"
|
|
"encoding/binary"
|
|
"fmt"
|
|
"io"
|
|
"regexp"
|
|
"sync"
|
|
|
|
"github.com/SAP/go-hdb/driver/sqltrace"
|
|
|
|
p "github.com/SAP/go-hdb/internal/protocol"
|
|
)
|
|
|
|
var reBulk = regexp.MustCompile("(?i)^(\\s)*(bulk +)(.*)")
|
|
|
|
func checkBulkInsert(sql string) (string, bool) {
|
|
if reBulk.MatchString(sql) {
|
|
return reBulk.ReplaceAllString(sql, "${3}"), true
|
|
}
|
|
return sql, false
|
|
}
|
|
|
|
var reCall = regexp.MustCompile("(?i)^(\\s)*(call +)(.*)")
|
|
|
|
func checkCallProcedure(sql string) bool {
|
|
return reCall.MatchString(sql)
|
|
}
|
|
|
|
// database connection
|
|
|
|
func (c *conn) PrepareContext(ctx context.Context, query string) (stmt driver.Stmt, err error) {
|
|
if c.session.IsBad() {
|
|
return nil, driver.ErrBadConn
|
|
}
|
|
|
|
done := make(chan struct{})
|
|
go func() {
|
|
prepareQuery, bulkInsert := checkBulkInsert(query)
|
|
var (
|
|
qt p.QueryType
|
|
id uint64
|
|
prmFieldSet *p.ParameterFieldSet
|
|
resultFieldSet *p.ResultFieldSet
|
|
)
|
|
qt, id, prmFieldSet, resultFieldSet, err = c.session.Prepare(prepareQuery)
|
|
if err != nil {
|
|
goto done
|
|
}
|
|
select {
|
|
default:
|
|
case <-ctx.Done():
|
|
return
|
|
}
|
|
if bulkInsert {
|
|
stmt, err = newBulkInsertStmt(c.session, prepareQuery, id, prmFieldSet)
|
|
} else {
|
|
stmt, err = newStmt(qt, c.session, prepareQuery, id, prmFieldSet, resultFieldSet)
|
|
}
|
|
done:
|
|
close(done)
|
|
}()
|
|
|
|
select {
|
|
case <-ctx.Done():
|
|
return nil, ctx.Err()
|
|
case <-done:
|
|
return stmt, err
|
|
}
|
|
}
|
|
|
|
// QueryContext implements the database/sql/driver/QueryerContext interface.
|
|
func (c *conn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (rows driver.Rows, err error) {
|
|
if c.session.IsBad() {
|
|
return nil, driver.ErrBadConn
|
|
}
|
|
|
|
if len(args) != 0 {
|
|
return nil, driver.ErrSkip //fast path not possible (prepare needed)
|
|
}
|
|
|
|
// direct execution of call procedure
|
|
// - returns no parameter metadata (sps 82) but only field values
|
|
// --> let's take the 'prepare way' for stored procedures
|
|
if checkCallProcedure(query) {
|
|
return nil, driver.ErrSkip
|
|
}
|
|
|
|
sqltrace.Traceln(query)
|
|
|
|
id, idx, ok := decodeTableQuery(query)
|
|
if ok {
|
|
r := procedureCallResultStore.get(id)
|
|
if r == nil {
|
|
return nil, fmt.Errorf("invalid procedure table query %s", query)
|
|
}
|
|
return r.tableRows(int(idx))
|
|
}
|
|
|
|
done := make(chan struct{})
|
|
go func() {
|
|
var (
|
|
id uint64
|
|
resultFieldSet *p.ResultFieldSet
|
|
fieldValues *p.FieldValues
|
|
attributes p.PartAttributes
|
|
)
|
|
id, resultFieldSet, fieldValues, attributes, err = c.session.QueryDirect(query)
|
|
if err != nil {
|
|
goto done
|
|
}
|
|
select {
|
|
default:
|
|
case <-ctx.Done():
|
|
return
|
|
}
|
|
if id == 0 { // non select query
|
|
rows = noResult
|
|
} else {
|
|
rows, err = newQueryResult(c.session, id, resultFieldSet, fieldValues, attributes)
|
|
}
|
|
done:
|
|
close(done)
|
|
}()
|
|
|
|
select {
|
|
case <-ctx.Done():
|
|
return nil, ctx.Err()
|
|
case <-done:
|
|
return rows, err
|
|
}
|
|
}
|
|
|
|
//statement
|
|
|
|
func (s *stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (rows driver.Rows, err error) {
|
|
|
|
if s.session.IsBad() {
|
|
return nil, driver.ErrBadConn
|
|
}
|
|
|
|
done := make(chan struct{})
|
|
go func() {
|
|
switch s.qt {
|
|
default:
|
|
rows, err = s.defaultQuery(ctx, args)
|
|
case p.QtProcedureCall:
|
|
rows, err = s.procedureCall(ctx, args)
|
|
}
|
|
close(done)
|
|
}()
|
|
|
|
select {
|
|
case <-ctx.Done():
|
|
return nil, ctx.Err()
|
|
case <-done:
|
|
return rows, err
|
|
}
|
|
}
|
|
|
|
func (s *stmt) defaultQuery(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
|
|
|
|
sqltrace.Tracef("%s %v", s.query, args)
|
|
|
|
rid, values, attributes, err := s.session.Query(s.id, s.prmFieldSet, s.resultFieldSet, args)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
select {
|
|
default:
|
|
case <-ctx.Done():
|
|
return nil, ctx.Err()
|
|
}
|
|
|
|
if rid == 0 { // non select query
|
|
return noResult, nil
|
|
}
|
|
return newQueryResult(s.session, rid, s.resultFieldSet, values, attributes)
|
|
}
|
|
|
|
func (s *stmt) procedureCall(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
|
|
|
|
sqltrace.Tracef("%s %v", s.query, args)
|
|
|
|
fieldValues, tableResults, err := s.session.Call(s.id, s.prmFieldSet, args)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
select {
|
|
default:
|
|
case <-ctx.Done():
|
|
return nil, ctx.Err()
|
|
}
|
|
|
|
return newProcedureCallResult(s.session, s.prmFieldSet, fieldValues, tableResults)
|
|
}
|
|
|
|
// bulk insert statement
|
|
|
|
// check if bulkInsertStmt implements all required interfaces
|
|
var (
|
|
_ driver.Stmt = (*bulkInsertStmt)(nil)
|
|
_ driver.StmtExecContext = (*bulkInsertStmt)(nil)
|
|
_ driver.StmtQueryContext = (*bulkInsertStmt)(nil)
|
|
_ driver.NamedValueChecker = (*bulkInsertStmt)(nil)
|
|
)
|
|
|
|
type bulkInsertStmt struct {
|
|
session *p.Session
|
|
query string
|
|
id uint64
|
|
prmFieldSet *p.ParameterFieldSet
|
|
numArg int
|
|
args []driver.NamedValue
|
|
}
|
|
|
|
func newBulkInsertStmt(session *p.Session, query string, id uint64, prmFieldSet *p.ParameterFieldSet) (*bulkInsertStmt, error) {
|
|
return &bulkInsertStmt{session: session, query: query, id: id, prmFieldSet: prmFieldSet, args: make([]driver.NamedValue, 0)}, nil
|
|
}
|
|
|
|
func (s *bulkInsertStmt) Close() error {
|
|
return s.session.DropStatementID(s.id)
|
|
}
|
|
|
|
func (s *bulkInsertStmt) NumInput() int {
|
|
return -1
|
|
}
|
|
|
|
func (s *bulkInsertStmt) Exec(args []driver.Value) (driver.Result, error) {
|
|
panic("deprecated")
|
|
}
|
|
|
|
func (s *bulkInsertStmt) ExecContext(ctx context.Context, args []driver.NamedValue) (r driver.Result, err error) {
|
|
|
|
if s.session.IsBad() {
|
|
return nil, driver.ErrBadConn
|
|
}
|
|
|
|
sqltrace.Tracef("%s %v", s.query, args)
|
|
|
|
done := make(chan struct{})
|
|
go func() {
|
|
if args == nil || len(args) == 0 {
|
|
r, err = s.execFlush()
|
|
} else {
|
|
r, err = s.execBuffer(args)
|
|
}
|
|
close(done)
|
|
}()
|
|
|
|
select {
|
|
case <-ctx.Done():
|
|
return nil, ctx.Err()
|
|
case <-done:
|
|
return r, err
|
|
}
|
|
}
|
|
|
|
func (s *bulkInsertStmt) execFlush() (driver.Result, error) {
|
|
|
|
if s.numArg == 0 {
|
|
return driver.ResultNoRows, nil
|
|
}
|
|
|
|
sqltrace.Traceln("execFlush")
|
|
|
|
result, err := s.session.Exec(s.id, s.prmFieldSet, s.args)
|
|
s.args = s.args[:0]
|
|
s.numArg = 0
|
|
return result, err
|
|
}
|
|
|
|
func (s *bulkInsertStmt) execBuffer(args []driver.NamedValue) (driver.Result, error) {
|
|
|
|
numField := s.prmFieldSet.NumInputField()
|
|
if len(args) != numField {
|
|
return nil, fmt.Errorf("invalid number of arguments %d - %d expected", len(args), numField)
|
|
}
|
|
|
|
var result driver.Result = driver.ResultNoRows
|
|
var err error
|
|
|
|
if s.numArg == maxSmallint { // TODO: check why bigArgument count does not work
|
|
result, err = s.execFlush()
|
|
}
|
|
|
|
s.args = append(s.args, args...)
|
|
s.numArg++
|
|
|
|
return result, err
|
|
}
|
|
|
|
func (s *bulkInsertStmt) Query(args []driver.Value) (driver.Rows, error) {
|
|
panic("deprecated")
|
|
}
|
|
|
|
func (s *bulkInsertStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
|
|
return nil, fmt.Errorf("query not allowed in context of bulk insert statement %s", s.query)
|
|
}
|
|
|
|
// Deprecated: see NamedValueChecker.
|
|
//func (s *bulkInsertStmt) ColumnConverter(idx int) driver.ValueConverter {
|
|
//}
|
|
|
|
// CheckNamedValue implements NamedValueChecker interface.
|
|
func (s *bulkInsertStmt) CheckNamedValue(nv *driver.NamedValue) error {
|
|
return checkNamedValue(s.prmFieldSet, nv)
|
|
}
|
|
|
|
//call result store
|
|
type callResultStore struct {
|
|
mu sync.RWMutex
|
|
store map[uint64]*procedureCallResult
|
|
cnt uint64
|
|
free []uint64
|
|
}
|
|
|
|
func (s *callResultStore) get(k uint64) *procedureCallResult {
|
|
s.mu.RLock()
|
|
defer s.mu.RUnlock()
|
|
|
|
if r, ok := s.store[k]; ok {
|
|
return r
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (s *callResultStore) add(v *procedureCallResult) uint64 {
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
|
|
var k uint64
|
|
|
|
if s.free == nil || len(s.free) == 0 {
|
|
s.cnt++
|
|
k = s.cnt
|
|
} else {
|
|
size := len(s.free)
|
|
k = s.free[size-1]
|
|
s.free = s.free[:size-1]
|
|
}
|
|
|
|
if s.store == nil {
|
|
s.store = make(map[uint64]*procedureCallResult)
|
|
}
|
|
|
|
s.store[k] = v
|
|
|
|
return k
|
|
}
|
|
|
|
func (s *callResultStore) del(k uint64) {
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
|
|
delete(s.store, k)
|
|
|
|
if s.free == nil {
|
|
s.free = []uint64{k}
|
|
} else {
|
|
s.free = append(s.free, k)
|
|
}
|
|
}
|
|
|
|
var procedureCallResultStore = new(callResultStore)
|
|
|
|
//procedure call result
|
|
|
|
// check if procedureCallResult implements all required interfaces
|
|
var _ driver.Rows = (*procedureCallResult)(nil)
|
|
|
|
type procedureCallResult struct {
|
|
id uint64
|
|
session *p.Session
|
|
prmFieldSet *p.ParameterFieldSet
|
|
fieldValues *p.FieldValues
|
|
_tableRows []driver.Rows
|
|
columns []string
|
|
eof error
|
|
}
|
|
|
|
func newProcedureCallResult(session *p.Session, prmFieldSet *p.ParameterFieldSet, fieldValues *p.FieldValues, tableResults []*p.TableResult) (driver.Rows, error) {
|
|
|
|
fieldIdx := prmFieldSet.NumOutputField()
|
|
columns := make([]string, fieldIdx+len(tableResults))
|
|
|
|
for i := 0; i < fieldIdx; i++ {
|
|
columns[i] = prmFieldSet.OutputField(i).Name()
|
|
}
|
|
|
|
tableRows := make([]driver.Rows, len(tableResults))
|
|
for i, tableResult := range tableResults {
|
|
var err error
|
|
|
|
if tableRows[i], err = newQueryResult(session, tableResult.ID(), tableResult.FieldSet(), tableResult.FieldValues(), tableResult.Attrs()); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
columns[fieldIdx] = fmt.Sprintf("table %d", i)
|
|
|
|
fieldIdx++
|
|
|
|
}
|
|
|
|
result := &procedureCallResult{
|
|
session: session,
|
|
prmFieldSet: prmFieldSet,
|
|
fieldValues: fieldValues,
|
|
_tableRows: tableRows,
|
|
columns: columns,
|
|
}
|
|
id := procedureCallResultStore.add(result)
|
|
result.id = id
|
|
return result, nil
|
|
}
|
|
|
|
func (r *procedureCallResult) Columns() []string {
|
|
return r.columns
|
|
}
|
|
|
|
func (r *procedureCallResult) Close() error {
|
|
procedureCallResultStore.del(r.id)
|
|
return nil
|
|
}
|
|
|
|
func (r *procedureCallResult) Next(dest []driver.Value) error {
|
|
if r.session.IsBad() {
|
|
return driver.ErrBadConn
|
|
}
|
|
|
|
if r.eof != nil {
|
|
return r.eof
|
|
}
|
|
|
|
if r.fieldValues.NumRow() == 0 && len(r._tableRows) == 0 {
|
|
r.eof = io.EOF
|
|
return r.eof
|
|
}
|
|
|
|
if r.fieldValues.NumRow() != 0 {
|
|
r.fieldValues.Row(0, dest)
|
|
}
|
|
|
|
i := r.prmFieldSet.NumOutputField()
|
|
for j := range r._tableRows {
|
|
dest[i] = encodeTableQuery(r.id, uint64(j))
|
|
i++
|
|
}
|
|
|
|
r.eof = io.EOF
|
|
return nil
|
|
}
|
|
|
|
func (r *procedureCallResult) tableRows(idx int) (driver.Rows, error) {
|
|
if idx >= len(r._tableRows) {
|
|
return nil, fmt.Errorf("table row index %d exceeds maximun %d", idx, len(r._tableRows)-1)
|
|
}
|
|
return r._tableRows[idx], nil
|
|
}
|
|
|
|
// helper
|
|
const tableQueryPrefix = "@tq"
|
|
|
|
func encodeTableQuery(id, idx uint64) string {
|
|
start := len(tableQueryPrefix)
|
|
b := make([]byte, start+8+8)
|
|
copy(b, tableQueryPrefix)
|
|
binary.LittleEndian.PutUint64(b[start:start+8], id)
|
|
binary.LittleEndian.PutUint64(b[start+8:start+8+8], idx)
|
|
return string(b)
|
|
}
|
|
|
|
func decodeTableQuery(query string) (uint64, uint64, bool) {
|
|
size := len(query)
|
|
start := len(tableQueryPrefix)
|
|
if size != start+8+8 {
|
|
return 0, 0, false
|
|
}
|
|
if query[:start] != tableQueryPrefix {
|
|
return 0, 0, false
|
|
}
|
|
id := binary.LittleEndian.Uint64([]byte(query[start : start+8]))
|
|
idx := binary.LittleEndian.Uint64([]byte(query[start+8 : start+8+8]))
|
|
return id, idx, true
|
|
}
|