mirror of https://github.com/hashicorp/consul
Vendor the vault api
parent
0b5d7277f9
commit
01f82717b4
|
@ -187,7 +187,7 @@
|
|||
same "printed page" as the copyright notice for easier
|
||||
identification within third-party archives.
|
||||
|
||||
Copyright 2014 Google Inc.
|
||||
Copyright [yyyy] [name of copyright owner]
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
|
|
|
@ -0,0 +1,277 @@
|
|||
// Copyright 2016 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
// Package civil implements types for civil time, a time-zone-independent
|
||||
// representation of time that follows the rules of the proleptic
|
||||
// Gregorian calendar with exactly 24-hour days, 60-minute hours, and 60-second
|
||||
// minutes.
|
||||
//
|
||||
// Because they lack location information, these types do not represent unique
|
||||
// moments or intervals of time. Use time.Time for that purpose.
|
||||
package civil
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
)
|
||||
|
||||
// A Date represents a date (year, month, day).
|
||||
//
|
||||
// This type does not include location information, and therefore does not
|
||||
// describe a unique 24-hour timespan.
|
||||
type Date struct {
|
||||
Year int // Year (e.g., 2014).
|
||||
Month time.Month // Month of the year (January = 1, ...).
|
||||
Day int // Day of the month, starting at 1.
|
||||
}
|
||||
|
||||
// DateOf returns the Date in which a time occurs in that time's location.
|
||||
func DateOf(t time.Time) Date {
|
||||
var d Date
|
||||
d.Year, d.Month, d.Day = t.Date()
|
||||
return d
|
||||
}
|
||||
|
||||
// ParseDate parses a string in RFC3339 full-date format and returns the date value it represents.
|
||||
func ParseDate(s string) (Date, error) {
|
||||
t, err := time.Parse("2006-01-02", s)
|
||||
if err != nil {
|
||||
return Date{}, err
|
||||
}
|
||||
return DateOf(t), nil
|
||||
}
|
||||
|
||||
// String returns the date in RFC3339 full-date format.
|
||||
func (d Date) String() string {
|
||||
return fmt.Sprintf("%04d-%02d-%02d", d.Year, d.Month, d.Day)
|
||||
}
|
||||
|
||||
// IsValid reports whether the date is valid.
|
||||
func (d Date) IsValid() bool {
|
||||
return DateOf(d.In(time.UTC)) == d
|
||||
}
|
||||
|
||||
// In returns the time corresponding to time 00:00:00 of the date in the location.
|
||||
//
|
||||
// In is always consistent with time.Date, even when time.Date returns a time
|
||||
// on a different day. For example, if loc is America/Indiana/Vincennes, then both
|
||||
// time.Date(1955, time.May, 1, 0, 0, 0, 0, loc)
|
||||
// and
|
||||
// civil.Date{Year: 1955, Month: time.May, Day: 1}.In(loc)
|
||||
// return 23:00:00 on April 30, 1955.
|
||||
//
|
||||
// In panics if loc is nil.
|
||||
func (d Date) In(loc *time.Location) time.Time {
|
||||
return time.Date(d.Year, d.Month, d.Day, 0, 0, 0, 0, loc)
|
||||
}
|
||||
|
||||
// AddDays returns the date that is n days in the future.
|
||||
// n can also be negative to go into the past.
|
||||
func (d Date) AddDays(n int) Date {
|
||||
return DateOf(d.In(time.UTC).AddDate(0, 0, n))
|
||||
}
|
||||
|
||||
// DaysSince returns the signed number of days between the date and s, not including the end day.
|
||||
// This is the inverse operation to AddDays.
|
||||
func (d Date) DaysSince(s Date) (days int) {
|
||||
// We convert to Unix time so we do not have to worry about leap seconds:
|
||||
// Unix time increases by exactly 86400 seconds per day.
|
||||
deltaUnix := d.In(time.UTC).Unix() - s.In(time.UTC).Unix()
|
||||
return int(deltaUnix / 86400)
|
||||
}
|
||||
|
||||
// Before reports whether d1 occurs before d2.
|
||||
func (d1 Date) Before(d2 Date) bool {
|
||||
if d1.Year != d2.Year {
|
||||
return d1.Year < d2.Year
|
||||
}
|
||||
if d1.Month != d2.Month {
|
||||
return d1.Month < d2.Month
|
||||
}
|
||||
return d1.Day < d2.Day
|
||||
}
|
||||
|
||||
// After reports whether d1 occurs after d2.
|
||||
func (d1 Date) After(d2 Date) bool {
|
||||
return d2.Before(d1)
|
||||
}
|
||||
|
||||
// MarshalText implements the encoding.TextMarshaler interface.
|
||||
// The output is the result of d.String().
|
||||
func (d Date) MarshalText() ([]byte, error) {
|
||||
return []byte(d.String()), nil
|
||||
}
|
||||
|
||||
// UnmarshalText implements the encoding.TextUnmarshaler interface.
|
||||
// The date is expected to be a string in a format accepted by ParseDate.
|
||||
func (d *Date) UnmarshalText(data []byte) error {
|
||||
var err error
|
||||
*d, err = ParseDate(string(data))
|
||||
return err
|
||||
}
|
||||
|
||||
// A Time represents a time with nanosecond precision.
|
||||
//
|
||||
// This type does not include location information, and therefore does not
|
||||
// describe a unique moment in time.
|
||||
//
|
||||
// This type exists to represent the TIME type in storage-based APIs like BigQuery.
|
||||
// Most operations on Times are unlikely to be meaningful. Prefer the DateTime type.
|
||||
type Time struct {
|
||||
Hour int // The hour of the day in 24-hour format; range [0-23]
|
||||
Minute int // The minute of the hour; range [0-59]
|
||||
Second int // The second of the minute; range [0-59]
|
||||
Nanosecond int // The nanosecond of the second; range [0-999999999]
|
||||
}
|
||||
|
||||
// TimeOf returns the Time representing the time of day in which a time occurs
|
||||
// in that time's location. It ignores the date.
|
||||
func TimeOf(t time.Time) Time {
|
||||
var tm Time
|
||||
tm.Hour, tm.Minute, tm.Second = t.Clock()
|
||||
tm.Nanosecond = t.Nanosecond()
|
||||
return tm
|
||||
}
|
||||
|
||||
// ParseTime parses a string and returns the time value it represents.
|
||||
// ParseTime accepts an extended form of the RFC3339 partial-time format. After
|
||||
// the HH:MM:SS part of the string, an optional fractional part may appear,
|
||||
// consisting of a decimal point followed by one to nine decimal digits.
|
||||
// (RFC3339 admits only one digit after the decimal point).
|
||||
func ParseTime(s string) (Time, error) {
|
||||
t, err := time.Parse("15:04:05.999999999", s)
|
||||
if err != nil {
|
||||
return Time{}, err
|
||||
}
|
||||
return TimeOf(t), nil
|
||||
}
|
||||
|
||||
// String returns the date in the format described in ParseTime. If Nanoseconds
|
||||
// is zero, no fractional part will be generated. Otherwise, the result will
|
||||
// end with a fractional part consisting of a decimal point and nine digits.
|
||||
func (t Time) String() string {
|
||||
s := fmt.Sprintf("%02d:%02d:%02d", t.Hour, t.Minute, t.Second)
|
||||
if t.Nanosecond == 0 {
|
||||
return s
|
||||
}
|
||||
return s + fmt.Sprintf(".%09d", t.Nanosecond)
|
||||
}
|
||||
|
||||
// IsValid reports whether the time is valid.
|
||||
func (t Time) IsValid() bool {
|
||||
// Construct a non-zero time.
|
||||
tm := time.Date(2, 2, 2, t.Hour, t.Minute, t.Second, t.Nanosecond, time.UTC)
|
||||
return TimeOf(tm) == t
|
||||
}
|
||||
|
||||
// MarshalText implements the encoding.TextMarshaler interface.
|
||||
// The output is the result of t.String().
|
||||
func (t Time) MarshalText() ([]byte, error) {
|
||||
return []byte(t.String()), nil
|
||||
}
|
||||
|
||||
// UnmarshalText implements the encoding.TextUnmarshaler interface.
|
||||
// The time is expected to be a string in a format accepted by ParseTime.
|
||||
func (t *Time) UnmarshalText(data []byte) error {
|
||||
var err error
|
||||
*t, err = ParseTime(string(data))
|
||||
return err
|
||||
}
|
||||
|
||||
// A DateTime represents a date and time.
|
||||
//
|
||||
// This type does not include location information, and therefore does not
|
||||
// describe a unique moment in time.
|
||||
type DateTime struct {
|
||||
Date Date
|
||||
Time Time
|
||||
}
|
||||
|
||||
// Note: We deliberately do not embed Date into DateTime, to avoid promoting AddDays and Sub.
|
||||
|
||||
// DateTimeOf returns the DateTime in which a time occurs in that time's location.
|
||||
func DateTimeOf(t time.Time) DateTime {
|
||||
return DateTime{
|
||||
Date: DateOf(t),
|
||||
Time: TimeOf(t),
|
||||
}
|
||||
}
|
||||
|
||||
// ParseDateTime parses a string and returns the DateTime it represents.
|
||||
// ParseDateTime accepts a variant of the RFC3339 date-time format that omits
|
||||
// the time offset but includes an optional fractional time, as described in
|
||||
// ParseTime. Informally, the accepted format is
|
||||
// YYYY-MM-DDTHH:MM:SS[.FFFFFFFFF]
|
||||
// where the 'T' may be a lower-case 't'.
|
||||
func ParseDateTime(s string) (DateTime, error) {
|
||||
t, err := time.Parse("2006-01-02T15:04:05.999999999", s)
|
||||
if err != nil {
|
||||
t, err = time.Parse("2006-01-02t15:04:05.999999999", s)
|
||||
if err != nil {
|
||||
return DateTime{}, err
|
||||
}
|
||||
}
|
||||
return DateTimeOf(t), nil
|
||||
}
|
||||
|
||||
// String returns the date in the format described in ParseDate.
|
||||
func (dt DateTime) String() string {
|
||||
return dt.Date.String() + "T" + dt.Time.String()
|
||||
}
|
||||
|
||||
// IsValid reports whether the datetime is valid.
|
||||
func (dt DateTime) IsValid() bool {
|
||||
return dt.Date.IsValid() && dt.Time.IsValid()
|
||||
}
|
||||
|
||||
// In returns the time corresponding to the DateTime in the given location.
|
||||
//
|
||||
// If the time is missing or ambigous at the location, In returns the same
|
||||
// result as time.Date. For example, if loc is America/Indiana/Vincennes, then
|
||||
// both
|
||||
// time.Date(1955, time.May, 1, 0, 30, 0, 0, loc)
|
||||
// and
|
||||
// civil.DateTime{
|
||||
// civil.Date{Year: 1955, Month: time.May, Day: 1}},
|
||||
// civil.Time{Minute: 30}}.In(loc)
|
||||
// return 23:30:00 on April 30, 1955.
|
||||
//
|
||||
// In panics if loc is nil.
|
||||
func (dt DateTime) In(loc *time.Location) time.Time {
|
||||
return time.Date(dt.Date.Year, dt.Date.Month, dt.Date.Day, dt.Time.Hour, dt.Time.Minute, dt.Time.Second, dt.Time.Nanosecond, loc)
|
||||
}
|
||||
|
||||
// Before reports whether dt1 occurs before dt2.
|
||||
func (dt1 DateTime) Before(dt2 DateTime) bool {
|
||||
return dt1.In(time.UTC).Before(dt2.In(time.UTC))
|
||||
}
|
||||
|
||||
// After reports whether dt1 occurs after dt2.
|
||||
func (dt1 DateTime) After(dt2 DateTime) bool {
|
||||
return dt2.Before(dt1)
|
||||
}
|
||||
|
||||
// MarshalText implements the encoding.TextMarshaler interface.
|
||||
// The output is the result of dt.String().
|
||||
func (dt DateTime) MarshalText() ([]byte, error) {
|
||||
return []byte(dt.String()), nil
|
||||
}
|
||||
|
||||
// UnmarshalText implements the encoding.TextUnmarshaler interface.
|
||||
// The datetime is expected to be a string in a format accepted by ParseDateTime
|
||||
func (dt *DateTime) UnmarshalText(data []byte) error {
|
||||
var err error
|
||||
*dt, err = ParseDateTime(string(data))
|
||||
return err
|
||||
}
|
|
@ -0,0 +1,19 @@
|
|||
Copyright (c) 2014 Ashley Jeffs
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in
|
||||
all copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
THE SOFTWARE.
|
|
@ -0,0 +1,315 @@
|
|||
![Gabs](gabs_logo.png "Gabs")
|
||||
|
||||
Gabs is a small utility for dealing with dynamic or unknown JSON structures in
|
||||
golang. It's pretty much just a helpful wrapper around the golang
|
||||
`json.Marshal/json.Unmarshal` behaviour and `map[string]interface{}` objects.
|
||||
It does nothing spectacular except for being fabulous.
|
||||
|
||||
https://godoc.org/github.com/Jeffail/gabs
|
||||
|
||||
## How to install:
|
||||
|
||||
``` bash
|
||||
go get github.com/Jeffail/gabs
|
||||
```
|
||||
|
||||
## How to use
|
||||
|
||||
### Parsing and searching JSON
|
||||
|
||||
``` go
|
||||
...
|
||||
|
||||
import "github.com/Jeffail/gabs"
|
||||
|
||||
jsonParsed, err := gabs.ParseJSON([]byte(`{
|
||||
"outter":{
|
||||
"inner":{
|
||||
"value1":10,
|
||||
"value2":22
|
||||
},
|
||||
"alsoInner":{
|
||||
"value1":20
|
||||
}
|
||||
}
|
||||
}`))
|
||||
|
||||
var value float64
|
||||
var ok bool
|
||||
|
||||
value, ok = jsonParsed.Path("outter.inner.value1").Data().(float64)
|
||||
// value == 10.0, ok == true
|
||||
|
||||
value, ok = jsonParsed.Search("outter", "inner", "value1").Data().(float64)
|
||||
// value == 10.0, ok == true
|
||||
|
||||
value, ok = jsonParsed.Path("does.not.exist").Data().(float64)
|
||||
// value == 0.0, ok == false
|
||||
|
||||
exists := jsonParsed.Exists("outter", "inner", "value1")
|
||||
// exists == true
|
||||
|
||||
exists := jsonParsed.Exists("does", "not", "exist")
|
||||
// exists == false
|
||||
|
||||
exists := jsonParsed.ExistsP("does.not.exist")
|
||||
// exists == false
|
||||
|
||||
...
|
||||
```
|
||||
|
||||
### Iterating objects
|
||||
|
||||
``` go
|
||||
...
|
||||
|
||||
jsonParsed, _ := gabs.ParseJSON([]byte(`{"object":{ "first": 1, "second": 2, "third": 3 }}`))
|
||||
|
||||
// S is shorthand for Search
|
||||
children, _ := jsonParsed.S("object").ChildrenMap()
|
||||
for key, child := range children {
|
||||
fmt.Printf("key: %v, value: %v\n", key, child.Data().(string))
|
||||
}
|
||||
|
||||
...
|
||||
```
|
||||
|
||||
### Iterating arrays
|
||||
|
||||
``` go
|
||||
...
|
||||
|
||||
jsonParsed, _ := gabs.ParseJSON([]byte(`{"array":[ "first", "second", "third" ]}`))
|
||||
|
||||
// S is shorthand for Search
|
||||
children, _ := jsonParsed.S("array").Children()
|
||||
for _, child := range children {
|
||||
fmt.Println(child.Data().(string))
|
||||
}
|
||||
|
||||
...
|
||||
```
|
||||
|
||||
Will print:
|
||||
|
||||
```
|
||||
first
|
||||
second
|
||||
third
|
||||
```
|
||||
|
||||
Children() will return all children of an array in order. This also works on
|
||||
objects, however, the children will be returned in a random order.
|
||||
|
||||
### Searching through arrays
|
||||
|
||||
If your JSON structure contains arrays you can still search the fields of the
|
||||
objects within the array, this returns a JSON array containing the results for
|
||||
each element.
|
||||
|
||||
``` go
|
||||
...
|
||||
|
||||
jsonParsed, _ := gabs.ParseJSON([]byte(`{"array":[ {"value":1}, {"value":2}, {"value":3} ]}`))
|
||||
fmt.Println(jsonParsed.Path("array.value").String())
|
||||
|
||||
...
|
||||
```
|
||||
|
||||
Will print:
|
||||
|
||||
```
|
||||
[1,2,3]
|
||||
```
|
||||
|
||||
### Generating JSON
|
||||
|
||||
``` go
|
||||
...
|
||||
|
||||
jsonObj := gabs.New()
|
||||
// or gabs.Consume(jsonObject) to work on an existing map[string]interface{}
|
||||
|
||||
jsonObj.Set(10, "outter", "inner", "value")
|
||||
jsonObj.SetP(20, "outter.inner.value2")
|
||||
jsonObj.Set(30, "outter", "inner2", "value3")
|
||||
|
||||
fmt.Println(jsonObj.String())
|
||||
|
||||
...
|
||||
```
|
||||
|
||||
Will print:
|
||||
|
||||
```
|
||||
{"outter":{"inner":{"value":10,"value2":20},"inner2":{"value3":30}}}
|
||||
```
|
||||
|
||||
To pretty-print:
|
||||
|
||||
``` go
|
||||
...
|
||||
|
||||
fmt.Println(jsonObj.StringIndent("", " "))
|
||||
|
||||
...
|
||||
```
|
||||
|
||||
Will print:
|
||||
|
||||
```
|
||||
{
|
||||
"outter": {
|
||||
"inner": {
|
||||
"value": 10,
|
||||
"value2": 20
|
||||
},
|
||||
"inner2": {
|
||||
"value3": 30
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Generating Arrays
|
||||
|
||||
``` go
|
||||
...
|
||||
|
||||
jsonObj := gabs.New()
|
||||
|
||||
jsonObj.Array("foo", "array")
|
||||
// Or .ArrayP("foo.array")
|
||||
|
||||
jsonObj.ArrayAppend(10, "foo", "array")
|
||||
jsonObj.ArrayAppend(20, "foo", "array")
|
||||
jsonObj.ArrayAppend(30, "foo", "array")
|
||||
|
||||
fmt.Println(jsonObj.String())
|
||||
|
||||
...
|
||||
```
|
||||
|
||||
Will print:
|
||||
|
||||
```
|
||||
{"foo":{"array":[10,20,30]}}
|
||||
```
|
||||
|
||||
Working with arrays by index:
|
||||
|
||||
``` go
|
||||
...
|
||||
|
||||
jsonObj := gabs.New()
|
||||
|
||||
// Create an array with the length of 3
|
||||
jsonObj.ArrayOfSize(3, "foo")
|
||||
|
||||
jsonObj.S("foo").SetIndex("test1", 0)
|
||||
jsonObj.S("foo").SetIndex("test2", 1)
|
||||
|
||||
// Create an embedded array with the length of 3
|
||||
jsonObj.S("foo").ArrayOfSizeI(3, 2)
|
||||
|
||||
jsonObj.S("foo").Index(2).SetIndex(1, 0)
|
||||
jsonObj.S("foo").Index(2).SetIndex(2, 1)
|
||||
jsonObj.S("foo").Index(2).SetIndex(3, 2)
|
||||
|
||||
fmt.Println(jsonObj.String())
|
||||
|
||||
...
|
||||
```
|
||||
|
||||
Will print:
|
||||
|
||||
```
|
||||
{"foo":["test1","test2",[1,2,3]]}
|
||||
```
|
||||
|
||||
### Converting back to JSON
|
||||
|
||||
This is the easiest part:
|
||||
|
||||
``` go
|
||||
...
|
||||
|
||||
jsonParsedObj, _ := gabs.ParseJSON([]byte(`{
|
||||
"outter":{
|
||||
"values":{
|
||||
"first":10,
|
||||
"second":11
|
||||
}
|
||||
},
|
||||
"outter2":"hello world"
|
||||
}`))
|
||||
|
||||
jsonOutput := jsonParsedObj.String()
|
||||
// Becomes `{"outter":{"values":{"first":10,"second":11}},"outter2":"hello world"}`
|
||||
|
||||
...
|
||||
```
|
||||
|
||||
And to serialize a specific segment is as simple as:
|
||||
|
||||
``` go
|
||||
...
|
||||
|
||||
jsonParsedObj := gabs.ParseJSON([]byte(`{
|
||||
"outter":{
|
||||
"values":{
|
||||
"first":10,
|
||||
"second":11
|
||||
}
|
||||
},
|
||||
"outter2":"hello world"
|
||||
}`))
|
||||
|
||||
jsonOutput := jsonParsedObj.Search("outter").String()
|
||||
// Becomes `{"values":{"first":10,"second":11}}`
|
||||
|
||||
...
|
||||
```
|
||||
|
||||
### Merge two containers
|
||||
|
||||
You can merge a JSON structure into an existing one, where collisions will be
|
||||
converted into a JSON array.
|
||||
|
||||
``` go
|
||||
jsonParsed1, _ := ParseJSON([]byte(`{"outter": {"value1": "one"}}`))
|
||||
jsonParsed2, _ := ParseJSON([]byte(`{"outter": {"inner": {"value3": "three"}}, "outter2": {"value2": "two"}}`))
|
||||
|
||||
jsonParsed1.Merge(jsonParsed2)
|
||||
// Becomes `{"outter":{"inner":{"value3":"three"},"value1":"one"},"outter2":{"value2":"two"}}`
|
||||
```
|
||||
|
||||
Arrays are merged:
|
||||
|
||||
``` go
|
||||
jsonParsed1, _ := ParseJSON([]byte(`{"array": ["one"]}`))
|
||||
jsonParsed2, _ := ParseJSON([]byte(`{"array": ["two"]}`))
|
||||
|
||||
jsonParsed1.Merge(jsonParsed2)
|
||||
// Becomes `{"array":["one", "two"]}`
|
||||
```
|
||||
|
||||
### Parsing Numbers
|
||||
|
||||
Gabs uses the `json` package under the bonnet, which by default will parse all
|
||||
number values into `float64`. If you need to parse `Int` values then you should
|
||||
use a `json.Decoder` (https://golang.org/pkg/encoding/json/#Decoder):
|
||||
|
||||
``` go
|
||||
sample := []byte(`{"test":{"int":10, "float":6.66}}`)
|
||||
dec := json.NewDecoder(bytes.NewReader(sample))
|
||||
dec.UseNumber()
|
||||
|
||||
val, err := gabs.ParseJSONDecoder(dec)
|
||||
if err != nil {
|
||||
t.Errorf("Failed to parse: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
intValue, err := val.Path("test.int").Data().(json.Number).Int64()
|
||||
```
|
|
@ -0,0 +1,579 @@
|
|||
/*
|
||||
Copyright (c) 2014 Ashley Jeffs
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in
|
||||
all copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
THE SOFTWARE.
|
||||
*/
|
||||
|
||||
// Package gabs implements a simplified wrapper around creating and parsing JSON.
|
||||
package gabs
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"strings"
|
||||
)
|
||||
|
||||
//--------------------------------------------------------------------------------------------------
|
||||
|
||||
var (
|
||||
// ErrOutOfBounds - Index out of bounds.
|
||||
ErrOutOfBounds = errors.New("out of bounds")
|
||||
|
||||
// ErrNotObjOrArray - The target is not an object or array type.
|
||||
ErrNotObjOrArray = errors.New("not an object or array")
|
||||
|
||||
// ErrNotObj - The target is not an object type.
|
||||
ErrNotObj = errors.New("not an object")
|
||||
|
||||
// ErrNotArray - The target is not an array type.
|
||||
ErrNotArray = errors.New("not an array")
|
||||
|
||||
// ErrPathCollision - Creating a path failed because an element collided with an existing value.
|
||||
ErrPathCollision = errors.New("encountered value collision whilst building path")
|
||||
|
||||
// ErrInvalidInputObj - The input value was not a map[string]interface{}.
|
||||
ErrInvalidInputObj = errors.New("invalid input object")
|
||||
|
||||
// ErrInvalidInputText - The input data could not be parsed.
|
||||
ErrInvalidInputText = errors.New("input text could not be parsed")
|
||||
|
||||
// ErrInvalidPath - The filepath was not valid.
|
||||
ErrInvalidPath = errors.New("invalid file path")
|
||||
|
||||
// ErrInvalidBuffer - The input buffer contained an invalid JSON string
|
||||
ErrInvalidBuffer = errors.New("input buffer contained invalid JSON")
|
||||
)
|
||||
|
||||
//--------------------------------------------------------------------------------------------------
|
||||
|
||||
// Container - an internal structure that holds a reference to the core interface map of the parsed
|
||||
// json. Use this container to move context.
|
||||
type Container struct {
|
||||
object interface{}
|
||||
}
|
||||
|
||||
// Data - Return the contained data as an interface{}.
|
||||
func (g *Container) Data() interface{} {
|
||||
if g == nil {
|
||||
return nil
|
||||
}
|
||||
return g.object
|
||||
}
|
||||
|
||||
//--------------------------------------------------------------------------------------------------
|
||||
|
||||
// Path - Search for a value using dot notation.
|
||||
func (g *Container) Path(path string) *Container {
|
||||
return g.Search(strings.Split(path, ".")...)
|
||||
}
|
||||
|
||||
// Search - Attempt to find and return an object within the JSON structure by specifying the
|
||||
// hierarchy of field names to locate the target. If the search encounters an array and has not
|
||||
// reached the end target then it will iterate each object of the array for the target and return
|
||||
// all of the results in a JSON array.
|
||||
func (g *Container) Search(hierarchy ...string) *Container {
|
||||
var object interface{}
|
||||
|
||||
object = g.Data()
|
||||
for target := 0; target < len(hierarchy); target++ {
|
||||
if mmap, ok := object.(map[string]interface{}); ok {
|
||||
object, ok = mmap[hierarchy[target]]
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
} else if marray, ok := object.([]interface{}); ok {
|
||||
tmpArray := []interface{}{}
|
||||
for _, val := range marray {
|
||||
tmpGabs := &Container{val}
|
||||
res := tmpGabs.Search(hierarchy[target:]...)
|
||||
if res != nil {
|
||||
tmpArray = append(tmpArray, res.Data())
|
||||
}
|
||||
}
|
||||
if len(tmpArray) == 0 {
|
||||
return nil
|
||||
}
|
||||
return &Container{tmpArray}
|
||||
} else {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
return &Container{object}
|
||||
}
|
||||
|
||||
// S - Shorthand method, does the same thing as Search.
|
||||
func (g *Container) S(hierarchy ...string) *Container {
|
||||
return g.Search(hierarchy...)
|
||||
}
|
||||
|
||||
// Exists - Checks whether a path exists.
|
||||
func (g *Container) Exists(hierarchy ...string) bool {
|
||||
return g.Search(hierarchy...) != nil
|
||||
}
|
||||
|
||||
// ExistsP - Checks whether a dot notation path exists.
|
||||
func (g *Container) ExistsP(path string) bool {
|
||||
return g.Exists(strings.Split(path, ".")...)
|
||||
}
|
||||
|
||||
// Index - Attempt to find and return an object within a JSON array by index.
|
||||
func (g *Container) Index(index int) *Container {
|
||||
if array, ok := g.Data().([]interface{}); ok {
|
||||
if index >= len(array) {
|
||||
return &Container{nil}
|
||||
}
|
||||
return &Container{array[index]}
|
||||
}
|
||||
return &Container{nil}
|
||||
}
|
||||
|
||||
// Children - Return a slice of all the children of the array. This also works for objects, however,
|
||||
// the children returned for an object will NOT be in order and you lose the names of the returned
|
||||
// objects this way.
|
||||
func (g *Container) Children() ([]*Container, error) {
|
||||
if array, ok := g.Data().([]interface{}); ok {
|
||||
children := make([]*Container, len(array))
|
||||
for i := 0; i < len(array); i++ {
|
||||
children[i] = &Container{array[i]}
|
||||
}
|
||||
return children, nil
|
||||
}
|
||||
if mmap, ok := g.Data().(map[string]interface{}); ok {
|
||||
children := []*Container{}
|
||||
for _, obj := range mmap {
|
||||
children = append(children, &Container{obj})
|
||||
}
|
||||
return children, nil
|
||||
}
|
||||
return nil, ErrNotObjOrArray
|
||||
}
|
||||
|
||||
// ChildrenMap - Return a map of all the children of an object.
|
||||
func (g *Container) ChildrenMap() (map[string]*Container, error) {
|
||||
if mmap, ok := g.Data().(map[string]interface{}); ok {
|
||||
children := map[string]*Container{}
|
||||
for name, obj := range mmap {
|
||||
children[name] = &Container{obj}
|
||||
}
|
||||
return children, nil
|
||||
}
|
||||
return nil, ErrNotObj
|
||||
}
|
||||
|
||||
//--------------------------------------------------------------------------------------------------
|
||||
|
||||
// Set - Set the value of a field at a JSON path, any parts of the path that do not exist will be
|
||||
// constructed, and if a collision occurs with a non object type whilst iterating the path an error
|
||||
// is returned.
|
||||
func (g *Container) Set(value interface{}, path ...string) (*Container, error) {
|
||||
if len(path) == 0 {
|
||||
g.object = value
|
||||
return g, nil
|
||||
}
|
||||
var object interface{}
|
||||
if g.object == nil {
|
||||
g.object = map[string]interface{}{}
|
||||
}
|
||||
object = g.object
|
||||
for target := 0; target < len(path); target++ {
|
||||
if mmap, ok := object.(map[string]interface{}); ok {
|
||||
if target == len(path)-1 {
|
||||
mmap[path[target]] = value
|
||||
} else if mmap[path[target]] == nil {
|
||||
mmap[path[target]] = map[string]interface{}{}
|
||||
}
|
||||
object = mmap[path[target]]
|
||||
} else {
|
||||
return &Container{nil}, ErrPathCollision
|
||||
}
|
||||
}
|
||||
return &Container{object}, nil
|
||||
}
|
||||
|
||||
// SetP - Does the same as Set, but using a dot notation JSON path.
|
||||
func (g *Container) SetP(value interface{}, path string) (*Container, error) {
|
||||
return g.Set(value, strings.Split(path, ".")...)
|
||||
}
|
||||
|
||||
// SetIndex - Set a value of an array element based on the index.
|
||||
func (g *Container) SetIndex(value interface{}, index int) (*Container, error) {
|
||||
if array, ok := g.Data().([]interface{}); ok {
|
||||
if index >= len(array) {
|
||||
return &Container{nil}, ErrOutOfBounds
|
||||
}
|
||||
array[index] = value
|
||||
return &Container{array[index]}, nil
|
||||
}
|
||||
return &Container{nil}, ErrNotArray
|
||||
}
|
||||
|
||||
// Object - Create a new JSON object at a path. Returns an error if the path contains a collision
|
||||
// with a non object type.
|
||||
func (g *Container) Object(path ...string) (*Container, error) {
|
||||
return g.Set(map[string]interface{}{}, path...)
|
||||
}
|
||||
|
||||
// ObjectP - Does the same as Object, but using a dot notation JSON path.
|
||||
func (g *Container) ObjectP(path string) (*Container, error) {
|
||||
return g.Object(strings.Split(path, ".")...)
|
||||
}
|
||||
|
||||
// ObjectI - Create a new JSON object at an array index. Returns an error if the object is not an
|
||||
// array or the index is out of bounds.
|
||||
func (g *Container) ObjectI(index int) (*Container, error) {
|
||||
return g.SetIndex(map[string]interface{}{}, index)
|
||||
}
|
||||
|
||||
// Array - Create a new JSON array at a path. Returns an error if the path contains a collision with
|
||||
// a non object type.
|
||||
func (g *Container) Array(path ...string) (*Container, error) {
|
||||
return g.Set([]interface{}{}, path...)
|
||||
}
|
||||
|
||||
// ArrayP - Does the same as Array, but using a dot notation JSON path.
|
||||
func (g *Container) ArrayP(path string) (*Container, error) {
|
||||
return g.Array(strings.Split(path, ".")...)
|
||||
}
|
||||
|
||||
// ArrayI - Create a new JSON array at an array index. Returns an error if the object is not an
|
||||
// array or the index is out of bounds.
|
||||
func (g *Container) ArrayI(index int) (*Container, error) {
|
||||
return g.SetIndex([]interface{}{}, index)
|
||||
}
|
||||
|
||||
// ArrayOfSize - Create a new JSON array of a particular size at a path. Returns an error if the
|
||||
// path contains a collision with a non object type.
|
||||
func (g *Container) ArrayOfSize(size int, path ...string) (*Container, error) {
|
||||
a := make([]interface{}, size)
|
||||
return g.Set(a, path...)
|
||||
}
|
||||
|
||||
// ArrayOfSizeP - Does the same as ArrayOfSize, but using a dot notation JSON path.
|
||||
func (g *Container) ArrayOfSizeP(size int, path string) (*Container, error) {
|
||||
return g.ArrayOfSize(size, strings.Split(path, ".")...)
|
||||
}
|
||||
|
||||
// ArrayOfSizeI - Create a new JSON array of a particular size at an array index. Returns an error
|
||||
// if the object is not an array or the index is out of bounds.
|
||||
func (g *Container) ArrayOfSizeI(size, index int) (*Container, error) {
|
||||
a := make([]interface{}, size)
|
||||
return g.SetIndex(a, index)
|
||||
}
|
||||
|
||||
// Delete - Delete an element at a JSON path, an error is returned if the element does not exist.
|
||||
func (g *Container) Delete(path ...string) error {
|
||||
var object interface{}
|
||||
|
||||
if g.object == nil {
|
||||
return ErrNotObj
|
||||
}
|
||||
object = g.object
|
||||
for target := 0; target < len(path); target++ {
|
||||
if mmap, ok := object.(map[string]interface{}); ok {
|
||||
if target == len(path)-1 {
|
||||
if _, ok := mmap[path[target]]; ok {
|
||||
delete(mmap, path[target])
|
||||
} else {
|
||||
return ErrNotObj
|
||||
}
|
||||
}
|
||||
object = mmap[path[target]]
|
||||
} else {
|
||||
return ErrNotObj
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteP - Does the same as Delete, but using a dot notation JSON path.
|
||||
func (g *Container) DeleteP(path string) error {
|
||||
return g.Delete(strings.Split(path, ".")...)
|
||||
}
|
||||
|
||||
// Merge - Merges two gabs-containers
|
||||
func (g *Container) Merge(toMerge *Container) error {
|
||||
var recursiveFnc func(map[string]interface{}, []string) error
|
||||
recursiveFnc = func(mmap map[string]interface{}, path []string) error {
|
||||
for key, value := range mmap {
|
||||
newPath := append(path, key)
|
||||
if g.Exists(newPath...) {
|
||||
target := g.Search(newPath...)
|
||||
switch t := value.(type) {
|
||||
case map[string]interface{}:
|
||||
switch targetV := target.Data().(type) {
|
||||
case map[string]interface{}:
|
||||
if err := recursiveFnc(t, newPath); err != nil {
|
||||
return err
|
||||
}
|
||||
case []interface{}:
|
||||
g.Set(append(targetV, t), newPath...)
|
||||
default:
|
||||
newSlice := append([]interface{}{}, targetV)
|
||||
g.Set(append(newSlice, t), newPath...)
|
||||
}
|
||||
case []interface{}:
|
||||
for _, valueOfSlice := range t {
|
||||
if err := g.ArrayAppend(valueOfSlice, newPath...); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
default:
|
||||
switch targetV := target.Data().(type) {
|
||||
case []interface{}:
|
||||
g.Set(append(targetV, t), newPath...)
|
||||
default:
|
||||
newSlice := append([]interface{}{}, targetV)
|
||||
g.Set(append(newSlice, t), newPath...)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// path doesn't exist. So set the value
|
||||
if _, err := g.Set(value, newPath...); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
if mmap, ok := toMerge.Data().(map[string]interface{}); ok {
|
||||
return recursiveFnc(mmap, []string{})
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
//--------------------------------------------------------------------------------------------------
|
||||
|
||||
/*
|
||||
Array modification/search - Keeping these options simple right now, no need for anything more
|
||||
complicated since you can just cast to []interface{}, modify and then reassign with Set.
|
||||
*/
|
||||
|
||||
// ArrayAppend - Append a value onto a JSON array. If the target is not a JSON array then it will be
|
||||
// converted into one, with its contents as the first element of the array.
|
||||
func (g *Container) ArrayAppend(value interface{}, path ...string) error {
|
||||
if array, ok := g.Search(path...).Data().([]interface{}); ok {
|
||||
array = append(array, value)
|
||||
_, err := g.Set(array, path...)
|
||||
return err
|
||||
}
|
||||
|
||||
newArray := []interface{}{}
|
||||
newArray = append(newArray, g.Search(path...).Data())
|
||||
newArray = append(newArray, value)
|
||||
|
||||
_, err := g.Set(newArray, path...)
|
||||
return err
|
||||
}
|
||||
|
||||
// ArrayAppendP - Append a value onto a JSON array using a dot notation JSON path.
|
||||
func (g *Container) ArrayAppendP(value interface{}, path string) error {
|
||||
return g.ArrayAppend(value, strings.Split(path, ".")...)
|
||||
}
|
||||
|
||||
// ArrayRemove - Remove an element from a JSON array.
|
||||
func (g *Container) ArrayRemove(index int, path ...string) error {
|
||||
if index < 0 {
|
||||
return ErrOutOfBounds
|
||||
}
|
||||
array, ok := g.Search(path...).Data().([]interface{})
|
||||
if !ok {
|
||||
return ErrNotArray
|
||||
}
|
||||
if index < len(array) {
|
||||
array = append(array[:index], array[index+1:]...)
|
||||
} else {
|
||||
return ErrOutOfBounds
|
||||
}
|
||||
_, err := g.Set(array, path...)
|
||||
return err
|
||||
}
|
||||
|
||||
// ArrayRemoveP - Remove an element from a JSON array using a dot notation JSON path.
|
||||
func (g *Container) ArrayRemoveP(index int, path string) error {
|
||||
return g.ArrayRemove(index, strings.Split(path, ".")...)
|
||||
}
|
||||
|
||||
// ArrayElement - Access an element from a JSON array.
|
||||
func (g *Container) ArrayElement(index int, path ...string) (*Container, error) {
|
||||
if index < 0 {
|
||||
return &Container{nil}, ErrOutOfBounds
|
||||
}
|
||||
array, ok := g.Search(path...).Data().([]interface{})
|
||||
if !ok {
|
||||
return &Container{nil}, ErrNotArray
|
||||
}
|
||||
if index < len(array) {
|
||||
return &Container{array[index]}, nil
|
||||
}
|
||||
return &Container{nil}, ErrOutOfBounds
|
||||
}
|
||||
|
||||
// ArrayElementP - Access an element from a JSON array using a dot notation JSON path.
|
||||
func (g *Container) ArrayElementP(index int, path string) (*Container, error) {
|
||||
return g.ArrayElement(index, strings.Split(path, ".")...)
|
||||
}
|
||||
|
||||
// ArrayCount - Count the number of elements in a JSON array.
|
||||
func (g *Container) ArrayCount(path ...string) (int, error) {
|
||||
if array, ok := g.Search(path...).Data().([]interface{}); ok {
|
||||
return len(array), nil
|
||||
}
|
||||
return 0, ErrNotArray
|
||||
}
|
||||
|
||||
// ArrayCountP - Count the number of elements in a JSON array using a dot notation JSON path.
|
||||
func (g *Container) ArrayCountP(path string) (int, error) {
|
||||
return g.ArrayCount(strings.Split(path, ".")...)
|
||||
}
|
||||
|
||||
//--------------------------------------------------------------------------------------------------
|
||||
|
||||
// Bytes - Converts the contained object back to a JSON []byte blob.
|
||||
func (g *Container) Bytes() []byte {
|
||||
if g.Data() != nil {
|
||||
if bytes, err := json.Marshal(g.object); err == nil {
|
||||
return bytes
|
||||
}
|
||||
}
|
||||
return []byte("{}")
|
||||
}
|
||||
|
||||
// BytesIndent - Converts the contained object to a JSON []byte blob formatted with prefix, indent.
|
||||
func (g *Container) BytesIndent(prefix string, indent string) []byte {
|
||||
if g.object != nil {
|
||||
if bytes, err := json.MarshalIndent(g.object, prefix, indent); err == nil {
|
||||
return bytes
|
||||
}
|
||||
}
|
||||
return []byte("{}")
|
||||
}
|
||||
|
||||
// String - Converts the contained object to a JSON formatted string.
|
||||
func (g *Container) String() string {
|
||||
return string(g.Bytes())
|
||||
}
|
||||
|
||||
// StringIndent - Converts the contained object back to a JSON formatted string with prefix, indent.
|
||||
func (g *Container) StringIndent(prefix string, indent string) string {
|
||||
return string(g.BytesIndent(prefix, indent))
|
||||
}
|
||||
|
||||
// EncodeOpt is a functional option for the EncodeJSON method.
|
||||
type EncodeOpt func(e *json.Encoder)
|
||||
|
||||
// EncodeOptHTMLEscape sets the encoder to escape the JSON for html.
|
||||
func EncodeOptHTMLEscape(doEscape bool) EncodeOpt {
|
||||
return func(e *json.Encoder) {
|
||||
e.SetEscapeHTML(doEscape)
|
||||
}
|
||||
}
|
||||
|
||||
// EncodeOptIndent sets the encoder to indent the JSON output.
|
||||
func EncodeOptIndent(prefix string, indent string) EncodeOpt {
|
||||
return func(e *json.Encoder) {
|
||||
e.SetIndent(prefix, indent)
|
||||
}
|
||||
}
|
||||
|
||||
// EncodeJSON - Encodes the contained object back to a JSON formatted []byte
|
||||
// using a variant list of modifier functions for the encoder being used.
|
||||
// Functions for modifying the output are prefixed with EncodeOpt, e.g.
|
||||
// EncodeOptHTMLEscape.
|
||||
func (g *Container) EncodeJSON(encodeOpts ...EncodeOpt) []byte {
|
||||
var b bytes.Buffer
|
||||
encoder := json.NewEncoder(&b)
|
||||
encoder.SetEscapeHTML(false) // Do not escape by default.
|
||||
for _, opt := range encodeOpts {
|
||||
opt(encoder)
|
||||
}
|
||||
if err := encoder.Encode(g.object); err != nil {
|
||||
return []byte("{}")
|
||||
}
|
||||
result := b.Bytes()
|
||||
if len(result) > 0 {
|
||||
result = result[:len(result)-1]
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// New - Create a new gabs JSON object.
|
||||
func New() *Container {
|
||||
return &Container{map[string]interface{}{}}
|
||||
}
|
||||
|
||||
// Consume - Gobble up an already converted JSON object, or a fresh map[string]interface{} object.
|
||||
func Consume(root interface{}) (*Container, error) {
|
||||
return &Container{root}, nil
|
||||
}
|
||||
|
||||
// ParseJSON - Convert a string into a representation of the parsed JSON.
|
||||
func ParseJSON(sample []byte) (*Container, error) {
|
||||
var gabs Container
|
||||
|
||||
if err := json.Unmarshal(sample, &gabs.object); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &gabs, nil
|
||||
}
|
||||
|
||||
// ParseJSONDecoder - Convert a json.Decoder into a representation of the parsed JSON.
|
||||
func ParseJSONDecoder(decoder *json.Decoder) (*Container, error) {
|
||||
var gabs Container
|
||||
|
||||
if err := decoder.Decode(&gabs.object); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &gabs, nil
|
||||
}
|
||||
|
||||
// ParseJSONFile - Read a file and convert into a representation of the parsed JSON.
|
||||
func ParseJSONFile(path string) (*Container, error) {
|
||||
if len(path) > 0 {
|
||||
cBytes, err := ioutil.ReadFile(path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
container, err := ParseJSON(cBytes)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return container, nil
|
||||
}
|
||||
return nil, ErrInvalidPath
|
||||
}
|
||||
|
||||
// ParseJSONBuffer - Read the contents of a buffer into a representation of the parsed JSON.
|
||||
func ParseJSONBuffer(buffer io.Reader) (*Container, error) {
|
||||
var gabs Container
|
||||
jsonDecoder := json.NewDecoder(buffer)
|
||||
if err := jsonDecoder.Decode(&gabs.object); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &gabs, nil
|
||||
}
|
||||
|
||||
//--------------------------------------------------------------------------------------------------
|
Binary file not shown.
After Width: | Height: | Size: 164 KiB |
|
@ -0,0 +1,201 @@
|
|||
Apache License
|
||||
Version 2.0, January 2004
|
||||
http://www.apache.org/licenses/
|
||||
|
||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||
|
||||
1. Definitions.
|
||||
|
||||
"License" shall mean the terms and conditions for use, reproduction,
|
||||
and distribution as defined by Sections 1 through 9 of this document.
|
||||
|
||||
"Licensor" shall mean the copyright owner or entity authorized by
|
||||
the copyright owner that is granting the License.
|
||||
|
||||
"Legal Entity" shall mean the union of the acting entity and all
|
||||
other entities that control, are controlled by, or are under common
|
||||
control with that entity. For the purposes of this definition,
|
||||
"control" means (i) the power, direct or indirect, to cause the
|
||||
direction or management of such entity, whether by contract or
|
||||
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||
|
||||
"You" (or "Your") shall mean an individual or Legal Entity
|
||||
exercising permissions granted by this License.
|
||||
|
||||
"Source" form shall mean the preferred form for making modifications,
|
||||
including but not limited to software source code, documentation
|
||||
source, and configuration files.
|
||||
|
||||
"Object" form shall mean any form resulting from mechanical
|
||||
transformation or translation of a Source form, including but
|
||||
not limited to compiled object code, generated documentation,
|
||||
and conversions to other media types.
|
||||
|
||||
"Work" shall mean the work of authorship, whether in Source or
|
||||
Object form, made available under the License, as indicated by a
|
||||
copyright notice that is included in or attached to the work
|
||||
(an example is provided in the Appendix below).
|
||||
|
||||
"Derivative Works" shall mean any work, whether in Source or Object
|
||||
form, that is based on (or derived from) the Work and for which the
|
||||
editorial revisions, annotations, elaborations, or other modifications
|
||||
represent, as a whole, an original work of authorship. For the purposes
|
||||
of this License, Derivative Works shall not include works that remain
|
||||
separable from, or merely link (or bind by name) to the interfaces of,
|
||||
the Work and Derivative Works thereof.
|
||||
|
||||
"Contribution" shall mean any work of authorship, including
|
||||
the original version of the Work and any modifications or additions
|
||||
to that Work or Derivative Works thereof, that is intentionally
|
||||
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||
or by an individual or Legal Entity authorized to submit on behalf of
|
||||
the copyright owner. For the purposes of this definition, "submitted"
|
||||
means any form of electronic, verbal, or written communication sent
|
||||
to the Licensor or its representatives, including but not limited to
|
||||
communication on electronic mailing lists, source code control systems,
|
||||
and issue tracking systems that are managed by, or on behalf of, the
|
||||
Licensor for the purpose of discussing and improving the Work, but
|
||||
excluding communication that is conspicuously marked or otherwise
|
||||
designated in writing by the copyright owner as "Not a Contribution."
|
||||
|
||||
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||
on behalf of whom a Contribution has been received by Licensor and
|
||||
subsequently incorporated within the Work.
|
||||
|
||||
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
copyright license to reproduce, prepare Derivative Works of,
|
||||
publicly display, publicly perform, sublicense, and distribute the
|
||||
Work and such Derivative Works in Source or Object form.
|
||||
|
||||
3. Grant of Patent License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
(except as stated in this section) patent license to make, have made,
|
||||
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||
where such license applies only to those patent claims licensable
|
||||
by such Contributor that are necessarily infringed by their
|
||||
Contribution(s) alone or by combination of their Contribution(s)
|
||||
with the Work to which such Contribution(s) was submitted. If You
|
||||
institute patent litigation against any entity (including a
|
||||
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||
or a Contribution incorporated within the Work constitutes direct
|
||||
or contributory patent infringement, then any patent licenses
|
||||
granted to You under this License for that Work shall terminate
|
||||
as of the date such litigation is filed.
|
||||
|
||||
4. Redistribution. You may reproduce and distribute copies of the
|
||||
Work or Derivative Works thereof in any medium, with or without
|
||||
modifications, and in Source or Object form, provided that You
|
||||
meet the following conditions:
|
||||
|
||||
(a) You must give any other recipients of the Work or
|
||||
Derivative Works a copy of this License; and
|
||||
|
||||
(b) You must cause any modified files to carry prominent notices
|
||||
stating that You changed the files; and
|
||||
|
||||
(c) You must retain, in the Source form of any Derivative Works
|
||||
that You distribute, all copyright, patent, trademark, and
|
||||
attribution notices from the Source form of the Work,
|
||||
excluding those notices that do not pertain to any part of
|
||||
the Derivative Works; and
|
||||
|
||||
(d) If the Work includes a "NOTICE" text file as part of its
|
||||
distribution, then any Derivative Works that You distribute must
|
||||
include a readable copy of the attribution notices contained
|
||||
within such NOTICE file, excluding those notices that do not
|
||||
pertain to any part of the Derivative Works, in at least one
|
||||
of the following places: within a NOTICE text file distributed
|
||||
as part of the Derivative Works; within the Source form or
|
||||
documentation, if provided along with the Derivative Works; or,
|
||||
within a display generated by the Derivative Works, if and
|
||||
wherever such third-party notices normally appear. The contents
|
||||
of the NOTICE file are for informational purposes only and
|
||||
do not modify the License. You may add Your own attribution
|
||||
notices within Derivative Works that You distribute, alongside
|
||||
or as an addendum to the NOTICE text from the Work, provided
|
||||
that such additional attribution notices cannot be construed
|
||||
as modifying the License.
|
||||
|
||||
You may add Your own copyright statement to Your modifications and
|
||||
may provide additional or different license terms and conditions
|
||||
for use, reproduction, or distribution of Your modifications, or
|
||||
for any such Derivative Works as a whole, provided Your use,
|
||||
reproduction, and distribution of the Work otherwise complies with
|
||||
the conditions stated in this License.
|
||||
|
||||
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||
any Contribution intentionally submitted for inclusion in the Work
|
||||
by You to the Licensor shall be under the terms and conditions of
|
||||
this License, without any additional terms or conditions.
|
||||
Notwithstanding the above, nothing herein shall supersede or modify
|
||||
the terms of any separate license agreement you may have executed
|
||||
with Licensor regarding such Contributions.
|
||||
|
||||
6. Trademarks. This License does not grant permission to use the trade
|
||||
names, trademarks, service marks, or product names of the Licensor,
|
||||
except as required for reasonable and customary use in describing the
|
||||
origin of the Work and reproducing the content of the NOTICE file.
|
||||
|
||||
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||
agreed to in writing, Licensor provides the Work (and each
|
||||
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
implied, including, without limitation, any warranties or conditions
|
||||
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||
appropriateness of using or redistributing the Work and assume any
|
||||
risks associated with Your exercise of permissions under this License.
|
||||
|
||||
8. Limitation of Liability. In no event and under no legal theory,
|
||||
whether in tort (including negligence), contract, or otherwise,
|
||||
unless required by applicable law (such as deliberate and grossly
|
||||
negligent acts) or agreed to in writing, shall any Contributor be
|
||||
liable to You for damages, including any direct, indirect, special,
|
||||
incidental, or consequential damages of any character arising as a
|
||||
result of this License or out of the use or inability to use the
|
||||
Work (including but not limited to damages for loss of goodwill,
|
||||
work stoppage, computer failure or malfunction, or any and all
|
||||
other commercial damages or losses), even if such Contributor
|
||||
has been advised of the possibility of such damages.
|
||||
|
||||
9. Accepting Warranty or Additional Liability. While redistributing
|
||||
the Work or Derivative Works thereof, You may choose to offer,
|
||||
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||
or other liability obligations and/or rights consistent with this
|
||||
License. However, in accepting such obligations, You may act only
|
||||
on Your own behalf and on Your sole responsibility, not on behalf
|
||||
of any other Contributor, and only if You agree to indemnify,
|
||||
defend, and hold each Contributor harmless for any liability
|
||||
incurred by, or claims asserted against, such Contributor by reason
|
||||
of your accepting any such warranty or additional liability.
|
||||
|
||||
END OF TERMS AND CONDITIONS
|
||||
|
||||
APPENDIX: How to apply the Apache License to your work.
|
||||
|
||||
To apply the Apache License to your work, attach the following
|
||||
boilerplate notice, with the fields enclosed by brackets "{}"
|
||||
replaced with your own identifying information. (Don't include
|
||||
the brackets!) The text should be enclosed in the appropriate
|
||||
comment syntax for the file format. We also recommend that a
|
||||
file or class name and description of purpose be included on the
|
||||
same "printed page" as the copyright notice for easier
|
||||
identification within third-party archives.
|
||||
|
||||
Copyright {yyyy} {name of copyright owner}
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
|
@ -0,0 +1,5 @@
|
|||
SAP HANA Database driver for the Go Programming Language
|
||||
Copyright 2014 SAP SE
|
||||
|
||||
This product includes software developed at
|
||||
SAP SE (http://www.sap.com).
|
|
@ -0,0 +1,43 @@
|
|||
/*
|
||||
Copyright 2017 SAP SE
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package driver
|
||||
|
||||
import (
|
||||
"database/sql/driver"
|
||||
)
|
||||
|
||||
// NullBytes represents an []byte that may be null.
|
||||
// NullBytes implements the Scanner interface so
|
||||
// it can be used as a scan destination, similar to NullString.
|
||||
type NullBytes struct {
|
||||
Bytes []byte
|
||||
Valid bool // Valid is true if Bytes is not NULL
|
||||
}
|
||||
|
||||
// Scan implements the Scanner interface.
|
||||
func (n *NullBytes) Scan(value interface{}) error {
|
||||
n.Bytes, n.Valid = value.([]byte)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Value implements the driver Valuer interface.
|
||||
func (n NullBytes) Value() (driver.Value, error) {
|
||||
if !n.Valid {
|
||||
return nil, nil
|
||||
}
|
||||
return n.Bytes, nil
|
||||
}
|
|
@ -0,0 +1,287 @@
|
|||
/*
|
||||
Copyright 2014 SAP SE
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package driver
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"database/sql/driver"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"sync"
|
||||
)
|
||||
|
||||
/*
|
||||
A Connector represents a hdb driver in a fixed configuration.
|
||||
A Connector can be passed to sql.OpenDB (starting from go 1.10) allowing users to bypass a string based data source name.
|
||||
*/
|
||||
type Connector struct {
|
||||
mu sync.RWMutex
|
||||
host, username, password string
|
||||
locale string
|
||||
bufferSize, fetchSize, timeout int
|
||||
tlsConfig *tls.Config
|
||||
}
|
||||
|
||||
func newConnector() *Connector {
|
||||
return &Connector{
|
||||
fetchSize: DefaultFetchSize,
|
||||
timeout: DefaultTimeout,
|
||||
}
|
||||
}
|
||||
|
||||
// NewBasicAuthConnector creates a connector for basic authentication.
|
||||
func NewBasicAuthConnector(host, username, password string) *Connector {
|
||||
c := newConnector()
|
||||
c.host = host
|
||||
c.username = username
|
||||
c.password = password
|
||||
return c
|
||||
}
|
||||
|
||||
// NewDSNConnector creates a connector from a data source name.
|
||||
func NewDSNConnector(dsn string) (*Connector, error) {
|
||||
c := newConnector()
|
||||
|
||||
url, err := url.Parse(dsn)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
c.host = url.Host
|
||||
|
||||
if url.User != nil {
|
||||
c.username = url.User.Username()
|
||||
c.password, _ = url.User.Password()
|
||||
}
|
||||
|
||||
var certPool *x509.CertPool
|
||||
|
||||
for k, v := range url.Query() {
|
||||
switch k {
|
||||
|
||||
default:
|
||||
return nil, fmt.Errorf("URL parameter %s is not supported", k)
|
||||
|
||||
case DSNFetchSize:
|
||||
if len(v) == 0 {
|
||||
continue
|
||||
}
|
||||
fetchSize, err := strconv.Atoi(v[0])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse fetchSize: %s", v[0])
|
||||
}
|
||||
if fetchSize < minFetchSize {
|
||||
c.fetchSize = minFetchSize
|
||||
} else {
|
||||
c.fetchSize = fetchSize
|
||||
}
|
||||
|
||||
case DSNTimeout:
|
||||
if len(v) == 0 {
|
||||
continue
|
||||
}
|
||||
timeout, err := strconv.Atoi(v[0])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse timeout: %s", v[0])
|
||||
}
|
||||
if timeout < minTimeout {
|
||||
c.timeout = minTimeout
|
||||
} else {
|
||||
c.timeout = timeout
|
||||
}
|
||||
|
||||
case DSNLocale:
|
||||
if len(v) == 0 {
|
||||
continue
|
||||
}
|
||||
c.locale = v[0]
|
||||
|
||||
case DSNTLSServerName:
|
||||
if len(v) == 0 {
|
||||
continue
|
||||
}
|
||||
if c.tlsConfig == nil {
|
||||
c.tlsConfig = &tls.Config{}
|
||||
}
|
||||
c.tlsConfig.ServerName = v[0]
|
||||
|
||||
case DSNTLSInsecureSkipVerify:
|
||||
if len(v) == 0 {
|
||||
continue
|
||||
}
|
||||
var err error
|
||||
b := true
|
||||
if v[0] != "" {
|
||||
b, err = strconv.ParseBool(v[0])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse InsecureSkipVerify (bool): %s", v[0])
|
||||
}
|
||||
}
|
||||
if c.tlsConfig == nil {
|
||||
c.tlsConfig = &tls.Config{}
|
||||
}
|
||||
c.tlsConfig.InsecureSkipVerify = b
|
||||
|
||||
case DSNTLSRootCAFile:
|
||||
for _, fn := range v {
|
||||
rootPEM, err := ioutil.ReadFile(fn)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if certPool == nil {
|
||||
certPool = x509.NewCertPool()
|
||||
}
|
||||
if ok := certPool.AppendCertsFromPEM(rootPEM); !ok {
|
||||
return nil, fmt.Errorf("failed to parse root certificate - filename: %s", fn)
|
||||
}
|
||||
}
|
||||
if certPool != nil {
|
||||
if c.tlsConfig == nil {
|
||||
c.tlsConfig = &tls.Config{}
|
||||
}
|
||||
c.tlsConfig.RootCAs = certPool
|
||||
}
|
||||
}
|
||||
}
|
||||
return c, nil
|
||||
}
|
||||
|
||||
// Host returns the host of the connector.
|
||||
func (c *Connector) Host() string {
|
||||
return c.host
|
||||
}
|
||||
|
||||
// Username returns the username of the connector.
|
||||
func (c *Connector) Username() string {
|
||||
return c.username
|
||||
}
|
||||
|
||||
// Password returns the password of the connector.
|
||||
func (c *Connector) Password() string {
|
||||
return c.password
|
||||
}
|
||||
|
||||
// Locale returns the locale of the connector.
|
||||
func (c *Connector) Locale() string {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
return c.locale
|
||||
}
|
||||
|
||||
/*
|
||||
SetLocale sets the locale of the connector.
|
||||
|
||||
For more information please see DSNLocale.
|
||||
*/
|
||||
func (c *Connector) SetLocale(locale string) {
|
||||
c.mu.Lock()
|
||||
c.locale = locale
|
||||
c.mu.Unlock()
|
||||
}
|
||||
|
||||
// FetchSize returns the fetchSize of the connector.
|
||||
func (c *Connector) FetchSize() int {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
return c.fetchSize
|
||||
}
|
||||
|
||||
/*
|
||||
SetFetchSize sets the fetchSize of the connector.
|
||||
|
||||
For more information please see DSNFetchSize.
|
||||
*/
|
||||
func (c *Connector) SetFetchSize(fetchSize int) error {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
if fetchSize < minFetchSize {
|
||||
fetchSize = minFetchSize
|
||||
}
|
||||
c.fetchSize = fetchSize
|
||||
return nil
|
||||
}
|
||||
|
||||
// Timeout returns the timeout of the connector.
|
||||
func (c *Connector) Timeout() int {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
return c.timeout
|
||||
}
|
||||
|
||||
/*
|
||||
SetTimeout sets the timeout of the connector.
|
||||
|
||||
For more information please see DSNTimeout.
|
||||
*/
|
||||
func (c *Connector) SetTimeout(timeout int) error {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
if timeout < minTimeout {
|
||||
timeout = minTimeout
|
||||
}
|
||||
c.timeout = timeout
|
||||
return nil
|
||||
}
|
||||
|
||||
// TLSConfig returns the TLS configuration of the connector.
|
||||
func (c *Connector) TLSConfig() *tls.Config {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
return c.tlsConfig
|
||||
}
|
||||
|
||||
// SetTLSConfig sets the TLS configuration of the connector.
|
||||
func (c *Connector) SetTLSConfig(tlsConfig *tls.Config) error {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
c.tlsConfig = tlsConfig
|
||||
return nil
|
||||
}
|
||||
|
||||
// BasicAuthDSN return the connector DSN for basic authentication.
|
||||
func (c *Connector) BasicAuthDSN() string {
|
||||
values := url.Values{}
|
||||
if c.locale != "" {
|
||||
values.Set(DSNLocale, c.locale)
|
||||
}
|
||||
if c.fetchSize != 0 {
|
||||
values.Set(DSNFetchSize, fmt.Sprintf("%d", c.fetchSize))
|
||||
}
|
||||
if c.timeout != 0 {
|
||||
values.Set(DSNTimeout, fmt.Sprintf("%d", c.timeout))
|
||||
}
|
||||
return (&url.URL{
|
||||
Scheme: DriverName,
|
||||
User: url.UserPassword(c.username, c.password),
|
||||
Host: c.host,
|
||||
RawQuery: values.Encode(),
|
||||
}).String()
|
||||
}
|
||||
|
||||
// Connect implements the database/sql/driver/Connector interface.
|
||||
func (c *Connector) Connect(ctx context.Context) (driver.Conn, error) {
|
||||
return newConn(ctx, c)
|
||||
}
|
||||
|
||||
// Driver implements the database/sql/driver/Connector interface.
|
||||
func (c *Connector) Driver() driver.Driver {
|
||||
return drv
|
||||
}
|
|
@ -0,0 +1,363 @@
|
|||
/*
|
||||
Copyright 2014 SAP SE
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package driver
|
||||
|
||||
import (
|
||||
"database/sql/driver"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math"
|
||||
"reflect"
|
||||
"time"
|
||||
|
||||
p "github.com/SAP/go-hdb/internal/protocol"
|
||||
)
|
||||
|
||||
const (
|
||||
minTinyint = 0
|
||||
maxTinyint = math.MaxUint8
|
||||
minSmallint = math.MinInt16
|
||||
maxSmallint = math.MaxInt16
|
||||
minInteger = math.MinInt32
|
||||
maxInteger = math.MaxInt32
|
||||
minBigint = math.MinInt64
|
||||
maxBigint = math.MaxInt64
|
||||
maxReal = math.MaxFloat32
|
||||
maxDouble = math.MaxFloat64
|
||||
)
|
||||
|
||||
// ErrIntegerOutOfRange means that an integer exceeds the size of the hdb integer field.
|
||||
var ErrIntegerOutOfRange = errors.New("integer out of range error")
|
||||
|
||||
// ErrFloatOutOfRange means that a float exceeds the size of the hdb float field.
|
||||
var ErrFloatOutOfRange = errors.New("float out of range error")
|
||||
|
||||
var typeOfTime = reflect.TypeOf((*time.Time)(nil)).Elem()
|
||||
var typeOfBytes = reflect.TypeOf((*[]byte)(nil)).Elem()
|
||||
|
||||
func checkNamedValue(prmFieldSet *p.ParameterFieldSet, nv *driver.NamedValue) error {
|
||||
idx := nv.Ordinal - 1
|
||||
|
||||
if idx >= prmFieldSet.NumInputField() {
|
||||
return nil
|
||||
}
|
||||
|
||||
f := prmFieldSet.Field(idx)
|
||||
dt := f.TypeCode().DataType()
|
||||
|
||||
value, err := convertNamedValue(idx, f, dt, nv.Value)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
nv.Value = value
|
||||
return nil
|
||||
}
|
||||
|
||||
func convertNamedValue(idx int, f *p.ParameterField, dt p.DataType, v driver.Value) (driver.Value, error) {
|
||||
var err error
|
||||
|
||||
// let fields with own Value converter convert themselves first (e.g. NullInt64, ...)
|
||||
if _, ok := v.(driver.Valuer); ok {
|
||||
if v, err = driver.DefaultParameterConverter.ConvertValue(v); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
switch dt {
|
||||
|
||||
default:
|
||||
return nil, fmt.Errorf("convert named value datatype error: %[1]d - %[1]s", dt)
|
||||
|
||||
case p.DtTinyint:
|
||||
return convertNvInteger(v, minTinyint, maxTinyint)
|
||||
|
||||
case p.DtSmallint:
|
||||
return convertNvInteger(v, minSmallint, maxSmallint)
|
||||
|
||||
case p.DtInteger:
|
||||
return convertNvInteger(v, minInteger, maxInteger)
|
||||
|
||||
case p.DtBigint:
|
||||
return convertNvInteger(v, minBigint, maxBigint)
|
||||
|
||||
case p.DtReal:
|
||||
return convertNvFloat(v, maxReal)
|
||||
|
||||
case p.DtDouble:
|
||||
return convertNvFloat(v, maxDouble)
|
||||
|
||||
case p.DtTime:
|
||||
return convertNvTime(v)
|
||||
|
||||
case p.DtDecimal:
|
||||
return convertNvDecimal(v)
|
||||
|
||||
case p.DtString:
|
||||
return convertNvString(v)
|
||||
|
||||
case p.DtBytes:
|
||||
return convertNvBytes(v)
|
||||
|
||||
case p.DtLob:
|
||||
return convertNvLob(idx, f, v)
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
// integer types
|
||||
func convertNvInteger(v interface{}, min, max int64) (driver.Value, error) {
|
||||
|
||||
if v == nil {
|
||||
return v, nil
|
||||
}
|
||||
|
||||
rv := reflect.ValueOf(v)
|
||||
switch rv.Kind() {
|
||||
|
||||
// bool is represented in HDB as tinyint
|
||||
case reflect.Bool:
|
||||
return rv.Bool(), nil
|
||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
||||
i64 := rv.Int()
|
||||
if i64 > max || i64 < min {
|
||||
return nil, ErrIntegerOutOfRange
|
||||
}
|
||||
return i64, nil
|
||||
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
|
||||
u64 := rv.Uint()
|
||||
if u64 > uint64(max) {
|
||||
return nil, ErrIntegerOutOfRange
|
||||
}
|
||||
return int64(u64), nil
|
||||
case reflect.Ptr:
|
||||
// indirect pointers
|
||||
if rv.IsNil() {
|
||||
return nil, nil
|
||||
}
|
||||
return convertNvInteger(rv.Elem().Interface(), min, max)
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("unsupported integer conversion type error %[1]T %[1]v", v)
|
||||
}
|
||||
|
||||
// float types
|
||||
func convertNvFloat(v interface{}, max float64) (driver.Value, error) {
|
||||
|
||||
if v == nil {
|
||||
return v, nil
|
||||
}
|
||||
|
||||
rv := reflect.ValueOf(v)
|
||||
switch rv.Kind() {
|
||||
|
||||
case reflect.Float32, reflect.Float64:
|
||||
f64 := rv.Float()
|
||||
if math.Abs(f64) > max {
|
||||
return nil, ErrFloatOutOfRange
|
||||
}
|
||||
return f64, nil
|
||||
case reflect.Ptr:
|
||||
// indirect pointers
|
||||
if rv.IsNil() {
|
||||
return nil, nil
|
||||
}
|
||||
return convertNvFloat(rv.Elem().Interface(), max)
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("unsupported float conversion type error %[1]T %[1]v", v)
|
||||
}
|
||||
|
||||
// time
|
||||
func convertNvTime(v interface{}) (driver.Value, error) {
|
||||
|
||||
if v == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
switch v := v.(type) {
|
||||
|
||||
case time.Time:
|
||||
return v, nil
|
||||
}
|
||||
|
||||
rv := reflect.ValueOf(v)
|
||||
|
||||
switch rv.Kind() {
|
||||
|
||||
case reflect.Ptr:
|
||||
// indirect pointers
|
||||
if rv.IsNil() {
|
||||
return nil, nil
|
||||
}
|
||||
return convertNvTime(rv.Elem().Interface())
|
||||
}
|
||||
|
||||
if rv.Type().ConvertibleTo(typeOfTime) {
|
||||
tv := rv.Convert(typeOfTime)
|
||||
return tv.Interface().(time.Time), nil
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("unsupported time conversion type error %[1]T %[1]v", v)
|
||||
}
|
||||
|
||||
// decimal
|
||||
func convertNvDecimal(v interface{}) (driver.Value, error) {
|
||||
|
||||
if v == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
if v, ok := v.([]byte); ok {
|
||||
return v, nil
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("unsupported decimal conversion type error %[1]T %[1]v", v)
|
||||
}
|
||||
|
||||
// string
|
||||
func convertNvString(v interface{}) (driver.Value, error) {
|
||||
|
||||
if v == nil {
|
||||
return v, nil
|
||||
}
|
||||
|
||||
switch v := v.(type) {
|
||||
|
||||
case string, []byte:
|
||||
return v, nil
|
||||
}
|
||||
|
||||
rv := reflect.ValueOf(v)
|
||||
|
||||
switch rv.Kind() {
|
||||
|
||||
case reflect.String:
|
||||
return rv.String(), nil
|
||||
|
||||
case reflect.Slice:
|
||||
if rv.Type() == typeOfBytes {
|
||||
return rv.Bytes(), nil
|
||||
}
|
||||
|
||||
case reflect.Ptr:
|
||||
// indirect pointers
|
||||
if rv.IsNil() {
|
||||
return nil, nil
|
||||
}
|
||||
return convertNvString(rv.Elem().Interface())
|
||||
}
|
||||
|
||||
if rv.Type().ConvertibleTo(typeOfBytes) {
|
||||
bv := rv.Convert(typeOfBytes)
|
||||
return bv.Interface().([]byte), nil
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("unsupported character conversion type error %[1]T %[1]v", v)
|
||||
}
|
||||
|
||||
// bytes
|
||||
func convertNvBytes(v interface{}) (driver.Value, error) {
|
||||
|
||||
if v == nil {
|
||||
return v, nil
|
||||
}
|
||||
|
||||
if v, ok := v.([]byte); ok {
|
||||
return v, nil
|
||||
}
|
||||
|
||||
rv := reflect.ValueOf(v)
|
||||
|
||||
switch rv.Kind() {
|
||||
|
||||
case reflect.Slice:
|
||||
if rv.Type() == typeOfBytes {
|
||||
return rv.Bytes(), nil
|
||||
}
|
||||
|
||||
case reflect.Ptr:
|
||||
// indirect pointers
|
||||
if rv.IsNil() {
|
||||
return nil, nil
|
||||
}
|
||||
return convertNvBytes(rv.Elem().Interface())
|
||||
}
|
||||
|
||||
if rv.Type().ConvertibleTo(typeOfBytes) {
|
||||
bv := rv.Convert(typeOfBytes)
|
||||
return bv.Interface().([]byte), nil
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("unsupported bytes conversion type error %[1]T %[1]v", v)
|
||||
}
|
||||
|
||||
// Lob
|
||||
func convertNvLob(idx int, f *p.ParameterField, v interface{}) (driver.Value, error) {
|
||||
|
||||
if v == nil {
|
||||
return v, nil
|
||||
}
|
||||
|
||||
switch v := v.(type) {
|
||||
case Lob:
|
||||
if v.rd == nil {
|
||||
return nil, fmt.Errorf("lob error: initial reader %[1]T %[1]v", v)
|
||||
}
|
||||
f.SetLobReader(v.rd)
|
||||
return fmt.Sprintf("<lob %d", idx), nil
|
||||
case *Lob:
|
||||
if v.rd == nil {
|
||||
return nil, fmt.Errorf("lob error: initial reader %[1]T %[1]v", v)
|
||||
}
|
||||
f.SetLobReader(v.rd)
|
||||
return fmt.Sprintf("<lob %d", idx), nil
|
||||
case NullLob:
|
||||
if !v.Valid {
|
||||
return nil, nil
|
||||
}
|
||||
if v.Lob.rd == nil {
|
||||
return nil, fmt.Errorf("lob error: initial reader %[1]T %[1]v", v)
|
||||
}
|
||||
f.SetLobReader(v.Lob.rd)
|
||||
return fmt.Sprintf("<lob %d", idx), nil
|
||||
case *NullLob:
|
||||
if !v.Valid {
|
||||
return nil, nil
|
||||
}
|
||||
if v.Lob.rd == nil {
|
||||
return nil, fmt.Errorf("lob error: initial reader %[1]T %[1]v", v)
|
||||
}
|
||||
f.SetLobReader(v.Lob.rd)
|
||||
return fmt.Sprintf("<lob %d", idx), nil
|
||||
}
|
||||
|
||||
rv := reflect.ValueOf(v)
|
||||
|
||||
switch rv.Kind() {
|
||||
|
||||
case reflect.Ptr:
|
||||
// indirect pointers
|
||||
if rv.IsNil() {
|
||||
return nil, nil
|
||||
}
|
||||
return convertNvLob(idx, f, rv.Elem().Interface())
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("unsupported lob conversion type error %[1]T %[1]v", v)
|
||||
}
|
|
@ -0,0 +1,377 @@
|
|||
/*
|
||||
Copyright 2014 SAP SE
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package driver
|
||||
|
||||
import (
|
||||
"database/sql/driver"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math"
|
||||
"math/big"
|
||||
"sync"
|
||||
)
|
||||
|
||||
//bigint word size (*--> src/pkg/math/big/arith.go)
|
||||
const (
|
||||
// Compute the size _S of a Word in bytes.
|
||||
_m = ^big.Word(0)
|
||||
_logS = _m>>8&1 + _m>>16&1 + _m>>32&1
|
||||
_S = 1 << _logS
|
||||
)
|
||||
|
||||
const (
|
||||
// http://en.wikipedia.org/wiki/Decimal128_floating-point_format
|
||||
dec128Digits = 34
|
||||
dec128Bias = 6176
|
||||
dec128MinExp = -6176
|
||||
dec128MaxExp = 6111
|
||||
)
|
||||
|
||||
const (
|
||||
decimalSize = 16 //number of bytes
|
||||
)
|
||||
|
||||
var natZero = big.NewInt(0)
|
||||
var natOne = big.NewInt(1)
|
||||
var natTen = big.NewInt(10)
|
||||
|
||||
var nat = []*big.Int{
|
||||
natOne, //10^0
|
||||
natTen, //10^1
|
||||
big.NewInt(100), //10^2
|
||||
big.NewInt(1000), //10^3
|
||||
big.NewInt(10000), //10^4
|
||||
big.NewInt(100000), //10^5
|
||||
big.NewInt(1000000), //10^6
|
||||
big.NewInt(10000000), //10^7
|
||||
big.NewInt(100000000), //10^8
|
||||
big.NewInt(1000000000), //10^9
|
||||
big.NewInt(10000000000), //10^10
|
||||
}
|
||||
|
||||
const lg10 = math.Ln10 / math.Ln2 // ~log2(10)
|
||||
|
||||
var maxDecimal = new(big.Int).SetBytes([]byte{0x01, 0xED, 0x09, 0xBE, 0xAD, 0x87, 0xC0, 0x37, 0x8D, 0x8E, 0x63, 0xFF, 0xFF, 0xFF, 0xFF})
|
||||
|
||||
type decFlags byte
|
||||
|
||||
const (
|
||||
dfNotExact decFlags = 1 << iota
|
||||
dfOverflow
|
||||
dfUnderflow
|
||||
)
|
||||
|
||||
// ErrDecimalOutOfRange means that a big.Rat exceeds the size of hdb decimal fields.
|
||||
var ErrDecimalOutOfRange = errors.New("decimal out of range error")
|
||||
|
||||
// big.Int free list
|
||||
var bigIntFree = sync.Pool{
|
||||
New: func() interface{} { return new(big.Int) },
|
||||
}
|
||||
|
||||
// big.Rat free list
|
||||
var bigRatFree = sync.Pool{
|
||||
New: func() interface{} { return new(big.Rat) },
|
||||
}
|
||||
|
||||
// A Decimal is the driver representation of a database decimal field value as big.Rat.
|
||||
type Decimal big.Rat
|
||||
|
||||
// Scan implements the database/sql/Scanner interface.
|
||||
func (d *Decimal) Scan(src interface{}) error {
|
||||
|
||||
b, ok := src.([]byte)
|
||||
if !ok {
|
||||
return fmt.Errorf("decimal: invalid data type %T", src)
|
||||
}
|
||||
|
||||
if len(b) != decimalSize {
|
||||
return fmt.Errorf("decimal: invalid size %d of %v - %d expected", len(b), b, decimalSize)
|
||||
}
|
||||
|
||||
if (b[15] & 0x60) == 0x60 {
|
||||
return fmt.Errorf("decimal: format (infinity, nan, ...) not supported : %v", b)
|
||||
}
|
||||
|
||||
v := (*big.Rat)(d)
|
||||
p := v.Num()
|
||||
q := v.Denom()
|
||||
|
||||
neg, exp := decodeDecimal(b, p)
|
||||
|
||||
switch {
|
||||
case exp < 0:
|
||||
q.Set(exp10(exp * -1))
|
||||
case exp == 0:
|
||||
q.Set(natOne)
|
||||
case exp > 0:
|
||||
p.Mul(p, exp10(exp))
|
||||
q.Set(natOne)
|
||||
}
|
||||
|
||||
if neg {
|
||||
v.Neg(v)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Value implements the database/sql/Valuer interface.
|
||||
func (d Decimal) Value() (driver.Value, error) {
|
||||
m := bigIntFree.Get().(*big.Int)
|
||||
neg, exp, df := convertRatToDecimal((*big.Rat)(&d), m, dec128Digits, dec128MinExp, dec128MaxExp)
|
||||
|
||||
var v driver.Value
|
||||
var err error
|
||||
|
||||
switch {
|
||||
default:
|
||||
v, err = encodeDecimal(m, neg, exp)
|
||||
case df&dfUnderflow != 0: // set to zero
|
||||
m.Set(natZero)
|
||||
v, err = encodeDecimal(m, false, 0)
|
||||
case df&dfOverflow != 0:
|
||||
err = ErrDecimalOutOfRange
|
||||
}
|
||||
|
||||
// performance (avoid expensive defer)
|
||||
bigIntFree.Put(m)
|
||||
|
||||
return v, err
|
||||
}
|
||||
|
||||
func convertRatToDecimal(x *big.Rat, m *big.Int, digits, minExp, maxExp int) (bool, int, decFlags) {
|
||||
|
||||
neg := x.Sign() < 0 //store sign
|
||||
|
||||
if x.Num().Cmp(natZero) == 0 { // zero
|
||||
m.Set(natZero)
|
||||
return neg, 0, 0
|
||||
}
|
||||
|
||||
c := bigRatFree.Get().(*big.Rat).Abs(x) // copy && abs
|
||||
a := c.Num()
|
||||
b := c.Denom()
|
||||
|
||||
exp, shift := 0, 0
|
||||
|
||||
if c.IsInt() {
|
||||
exp = digits10(a) - 1
|
||||
} else {
|
||||
shift = digits10(a) - digits10(b)
|
||||
switch {
|
||||
case shift < 0:
|
||||
a.Mul(a, exp10(shift*-1))
|
||||
case shift > 0:
|
||||
b.Mul(b, exp10(shift))
|
||||
}
|
||||
if a.Cmp(b) == -1 {
|
||||
exp = shift - 1
|
||||
} else {
|
||||
exp = shift
|
||||
}
|
||||
}
|
||||
|
||||
var df decFlags
|
||||
|
||||
switch {
|
||||
default:
|
||||
exp = max(exp-digits+1, minExp)
|
||||
case exp < minExp:
|
||||
df |= dfUnderflow
|
||||
exp = exp - digits + 1
|
||||
}
|
||||
|
||||
if exp > maxExp {
|
||||
df |= dfOverflow
|
||||
}
|
||||
|
||||
shift = exp - shift
|
||||
switch {
|
||||
case shift < 0:
|
||||
a.Mul(a, exp10(shift*-1))
|
||||
case exp > 0:
|
||||
b.Mul(b, exp10(shift))
|
||||
}
|
||||
|
||||
m.QuoRem(a, b, a) // reuse a as rest
|
||||
if a.Cmp(natZero) != 0 {
|
||||
// round (business >= 0.5 up)
|
||||
df |= dfNotExact
|
||||
if a.Add(a, a).Cmp(b) >= 0 {
|
||||
m.Add(m, natOne)
|
||||
if m.Cmp(exp10(digits)) == 0 {
|
||||
shift := min(digits, maxExp-exp)
|
||||
if shift < 1 { // overflow -> shift one at minimum
|
||||
df |= dfOverflow
|
||||
shift = 1
|
||||
}
|
||||
m.Set(exp10(digits - shift))
|
||||
exp += shift
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// norm
|
||||
for exp < maxExp {
|
||||
a.QuoRem(m, natTen, b) // reuse a, b
|
||||
if b.Cmp(natZero) != 0 {
|
||||
break
|
||||
}
|
||||
m.Set(a)
|
||||
exp++
|
||||
}
|
||||
|
||||
// performance (avoid expensive defer)
|
||||
bigRatFree.Put(c)
|
||||
|
||||
return neg, exp, df
|
||||
}
|
||||
|
||||
func min(a, b int) int {
|
||||
if a < b {
|
||||
return a
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
func max(a, b int) int {
|
||||
if a > b {
|
||||
return a
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
// performance: tested with reference work variable
|
||||
// - but int.Set is expensive, so let's live with big.Int creation for n >= len(nat)
|
||||
func exp10(n int) *big.Int {
|
||||
if n < len(nat) {
|
||||
return nat[n]
|
||||
}
|
||||
r := big.NewInt(int64(n))
|
||||
return r.Exp(natTen, r, nil)
|
||||
}
|
||||
|
||||
func digits10(p *big.Int) int {
|
||||
k := p.BitLen() // 2^k <= p < 2^(k+1) - 1
|
||||
//i := int(float64(k) / lg10) //minimal digits base 10
|
||||
//i := int(float64(k) / lg10) //minimal digits base 10
|
||||
i := k * 100 / 332
|
||||
if i < 1 {
|
||||
i = 1
|
||||
}
|
||||
|
||||
for ; ; i++ {
|
||||
if p.Cmp(exp10(i)) < 0 {
|
||||
return i
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func decodeDecimal(b []byte, m *big.Int) (bool, int) {
|
||||
|
||||
neg := (b[15] & 0x80) != 0
|
||||
exp := int((((uint16(b[15])<<8)|uint16(b[14]))<<1)>>2) - dec128Bias
|
||||
|
||||
b14 := b[14] // save b[14]
|
||||
b[14] &= 0x01 // keep the mantissa bit (rest: sign and exp)
|
||||
|
||||
//most significand byte
|
||||
msb := 14
|
||||
for msb > 0 {
|
||||
if b[msb] != 0 {
|
||||
break
|
||||
}
|
||||
msb--
|
||||
}
|
||||
|
||||
//calc number of words
|
||||
numWords := (msb / _S) + 1
|
||||
w := make([]big.Word, numWords)
|
||||
|
||||
k := numWords - 1
|
||||
d := big.Word(0)
|
||||
for i := msb; i >= 0; i-- {
|
||||
d |= big.Word(b[i])
|
||||
if k*_S == i {
|
||||
w[k] = d
|
||||
k--
|
||||
d = 0
|
||||
}
|
||||
d <<= 8
|
||||
}
|
||||
b[14] = b14 // restore b[14]
|
||||
m.SetBits(w)
|
||||
return neg, exp
|
||||
}
|
||||
|
||||
func encodeDecimal(m *big.Int, neg bool, exp int) (driver.Value, error) {
|
||||
|
||||
b := make([]byte, decimalSize)
|
||||
|
||||
// little endian bigint words (significand) -> little endian db decimal format
|
||||
j := 0
|
||||
for _, d := range m.Bits() {
|
||||
for i := 0; i < 8; i++ {
|
||||
b[j] = byte(d)
|
||||
d >>= 8
|
||||
j++
|
||||
}
|
||||
}
|
||||
|
||||
exp += dec128Bias
|
||||
b[14] |= (byte(exp) << 1)
|
||||
b[15] = byte(uint16(exp) >> 7)
|
||||
|
||||
if neg {
|
||||
b[15] |= 0x80
|
||||
}
|
||||
|
||||
return b, nil
|
||||
}
|
||||
|
||||
// NullDecimal represents an Decimal that may be null.
|
||||
// NullDecimal implements the Scanner interface so
|
||||
// it can be used as a scan destination, similar to NullString.
|
||||
type NullDecimal struct {
|
||||
Decimal *Decimal
|
||||
Valid bool // Valid is true if Decimal is not NULL
|
||||
}
|
||||
|
||||
// Scan implements the Scanner interface.
|
||||
func (n *NullDecimal) Scan(value interface{}) error {
|
||||
var b []byte
|
||||
|
||||
b, n.Valid = value.([]byte)
|
||||
if !n.Valid {
|
||||
return nil
|
||||
}
|
||||
if n.Decimal == nil {
|
||||
return fmt.Errorf("invalid decimal value %v", n.Decimal)
|
||||
}
|
||||
return n.Decimal.Scan(b)
|
||||
}
|
||||
|
||||
// Value implements the driver Valuer interface.
|
||||
func (n NullDecimal) Value() (driver.Value, error) {
|
||||
if !n.Valid {
|
||||
return nil, nil
|
||||
}
|
||||
if n.Decimal == nil {
|
||||
return nil, fmt.Errorf("invalid decimal value %v", n.Decimal)
|
||||
}
|
||||
return n.Decimal.Value()
|
||||
}
|
|
@ -0,0 +1,18 @@
|
|||
/*
|
||||
Copyright 2014 SAP SE
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
// Package driver is a native Go SAP HANA driver implementation for the database/sql package.
|
||||
package driver
|
|
@ -0,0 +1,542 @@
|
|||
/*
|
||||
Copyright 2014 SAP SE
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package driver
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"reflect"
|
||||
"time"
|
||||
|
||||
"github.com/SAP/go-hdb/driver/sqltrace"
|
||||
|
||||
p "github.com/SAP/go-hdb/internal/protocol"
|
||||
)
|
||||
|
||||
// DriverVersion is the version number of the hdb driver.
|
||||
const DriverVersion = "0.12.0"
|
||||
|
||||
// DriverName is the driver name to use with sql.Open for hdb databases.
|
||||
const DriverName = "hdb"
|
||||
|
||||
// Transaction isolation levels supported by hdb.
|
||||
const (
|
||||
LevelReadCommitted = "READ COMMITTED"
|
||||
LevelRepeatableRead = "REPEATABLE READ"
|
||||
LevelSerializable = "SERIALIZABLE"
|
||||
)
|
||||
|
||||
// Access modes supported by hdb.
|
||||
const (
|
||||
modeReadOnly = "READ ONLY"
|
||||
modeReadWrite = "READ WRITE"
|
||||
)
|
||||
|
||||
// map sql isolation level to hdb isolation level.
|
||||
var isolationLevel = map[driver.IsolationLevel]string{
|
||||
driver.IsolationLevel(sql.LevelDefault): LevelReadCommitted,
|
||||
driver.IsolationLevel(sql.LevelReadCommitted): LevelReadCommitted,
|
||||
driver.IsolationLevel(sql.LevelRepeatableRead): LevelRepeatableRead,
|
||||
driver.IsolationLevel(sql.LevelSerializable): LevelSerializable,
|
||||
}
|
||||
|
||||
// map sql read only flag to hdb access mode.
|
||||
var readOnly = map[bool]string{
|
||||
true: modeReadOnly,
|
||||
false: modeReadWrite,
|
||||
}
|
||||
|
||||
// ErrUnsupportedIsolationLevel is the error raised if a transaction is started with a not supported isolation level.
|
||||
var ErrUnsupportedIsolationLevel = errors.New("Unsupported isolation level")
|
||||
|
||||
// ErrNestedTransaction is the error raised if a tranasction is created within a transaction as this is not supported by hdb.
|
||||
var ErrNestedTransaction = errors.New("Nested transactions are not supported")
|
||||
|
||||
// needed for testing
|
||||
const driverDataFormatVersion = 1
|
||||
|
||||
// queries
|
||||
const (
|
||||
pingQuery = "select 1 from dummy"
|
||||
isolationLevelStmt = "set transaction isolation level %s"
|
||||
accessModeStmt = "set transaction %s"
|
||||
)
|
||||
|
||||
// bulk statement
|
||||
const noFlush = "$nf"
|
||||
|
||||
// NoFlush is to be used as parameter in bulk inserts.
|
||||
var NoFlush = sql.Named(noFlush, nil)
|
||||
|
||||
var drv = &hdbDrv{}
|
||||
|
||||
func init() {
|
||||
sql.Register(DriverName, drv)
|
||||
}
|
||||
|
||||
// driver
|
||||
|
||||
// check if driver implements all required interfaces
|
||||
var (
|
||||
_ driver.Driver = (*hdbDrv)(nil)
|
||||
)
|
||||
|
||||
type hdbDrv struct{}
|
||||
|
||||
func (d *hdbDrv) Open(dsn string) (driver.Conn, error) {
|
||||
connector, err := NewDSNConnector(dsn)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return connector.Connect(context.Background())
|
||||
}
|
||||
|
||||
// database connection
|
||||
|
||||
// check if conn implements all required interfaces
|
||||
var (
|
||||
_ driver.Conn = (*conn)(nil)
|
||||
_ driver.ConnPrepareContext = (*conn)(nil)
|
||||
_ driver.Pinger = (*conn)(nil)
|
||||
_ driver.ConnBeginTx = (*conn)(nil)
|
||||
_ driver.ExecerContext = (*conn)(nil)
|
||||
//go 1.9 issue (ExecerContext is only called if Execer is implemented)
|
||||
_ driver.Execer = (*conn)(nil)
|
||||
_ driver.QueryerContext = (*conn)(nil)
|
||||
//go 1.9 issue (QueryerContext is only called if Queryer is implemented)
|
||||
// QueryContext is needed for stored procedures with table output parameters.
|
||||
_ driver.Queryer = (*conn)(nil)
|
||||
_ driver.NamedValueChecker = (*conn)(nil)
|
||||
)
|
||||
|
||||
type conn struct {
|
||||
session *p.Session
|
||||
}
|
||||
|
||||
func newConn(ctx context.Context, c *Connector) (driver.Conn, error) {
|
||||
session, err := p.NewSession(ctx, c)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &conn{session: session}, nil
|
||||
}
|
||||
|
||||
func (c *conn) Prepare(query string) (driver.Stmt, error) {
|
||||
panic("deprecated")
|
||||
}
|
||||
|
||||
func (c *conn) Close() error {
|
||||
c.session.Close()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *conn) Begin() (driver.Tx, error) {
|
||||
panic("deprecated")
|
||||
}
|
||||
|
||||
func (c *conn) BeginTx(ctx context.Context, opts driver.TxOptions) (tx driver.Tx, err error) {
|
||||
|
||||
if c.session.IsBad() {
|
||||
return nil, driver.ErrBadConn
|
||||
}
|
||||
|
||||
if c.session.InTx() {
|
||||
return nil, ErrNestedTransaction
|
||||
}
|
||||
|
||||
level, ok := isolationLevel[opts.Isolation]
|
||||
if !ok {
|
||||
return nil, ErrUnsupportedIsolationLevel
|
||||
}
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
// set isolation level
|
||||
if _, err = c.ExecContext(ctx, fmt.Sprintf(isolationLevelStmt, level), nil); err != nil {
|
||||
goto done
|
||||
}
|
||||
// set access mode
|
||||
if _, err = c.ExecContext(ctx, fmt.Sprintf(accessModeStmt, readOnly[opts.ReadOnly]), nil); err != nil {
|
||||
goto done
|
||||
}
|
||||
c.session.SetInTx(true)
|
||||
tx = newTx(c.session)
|
||||
done:
|
||||
close(done)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
case <-done:
|
||||
return tx, err
|
||||
}
|
||||
}
|
||||
|
||||
// Exec implements the database/sql/driver/Execer interface.
|
||||
// delete after go 1.9 compatibility is given up.
|
||||
func (c *conn) Exec(query string, args []driver.Value) (driver.Result, error) {
|
||||
panic("deprecated")
|
||||
}
|
||||
|
||||
// ExecContext implements the database/sql/driver/ExecerContext interface.
|
||||
func (c *conn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (r driver.Result, err error) {
|
||||
if c.session.IsBad() {
|
||||
return nil, driver.ErrBadConn
|
||||
}
|
||||
|
||||
if len(args) != 0 {
|
||||
return nil, driver.ErrSkip //fast path not possible (prepare needed)
|
||||
}
|
||||
|
||||
sqltrace.Traceln(query)
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
r, err = c.session.ExecDirect(query)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
case <-done:
|
||||
return r, err
|
||||
}
|
||||
}
|
||||
|
||||
// Queryer implements the database/sql/driver/Queryer interface.
|
||||
// delete after go 1.9 compatibility is given up.
|
||||
func (c *conn) Query(query string, args []driver.Value) (driver.Rows, error) {
|
||||
panic("deprecated")
|
||||
}
|
||||
|
||||
func (c *conn) Ping(ctx context.Context) (err error) {
|
||||
if c.session.IsBad() {
|
||||
return driver.ErrBadConn
|
||||
}
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
_, err = c.QueryContext(ctx, pingQuery, nil)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-done:
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// CheckNamedValue implements NamedValueChecker interface.
|
||||
// implemented for conn:
|
||||
// if querier or execer is called, sql checks parameters before in case of
|
||||
// parameters the method can be 'skipped' and force the prepare path
|
||||
// --> guarantee that a valid driver value is returned
|
||||
// --> if not implemented, Lob need to have a pseudo Value method to return a valid driver value
|
||||
func (c *conn) CheckNamedValue(nv *driver.NamedValue) error {
|
||||
switch nv.Value.(type) {
|
||||
case Lob, *Lob:
|
||||
nv.Value = nil
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
//transaction
|
||||
|
||||
// check if tx implements all required interfaces
|
||||
var (
|
||||
_ driver.Tx = (*tx)(nil)
|
||||
)
|
||||
|
||||
type tx struct {
|
||||
session *p.Session
|
||||
}
|
||||
|
||||
func newTx(session *p.Session) *tx {
|
||||
return &tx{
|
||||
session: session,
|
||||
}
|
||||
}
|
||||
|
||||
func (t *tx) Commit() error {
|
||||
if t.session.IsBad() {
|
||||
return driver.ErrBadConn
|
||||
}
|
||||
|
||||
return t.session.Commit()
|
||||
}
|
||||
|
||||
func (t *tx) Rollback() error {
|
||||
if t.session.IsBad() {
|
||||
return driver.ErrBadConn
|
||||
}
|
||||
|
||||
return t.session.Rollback()
|
||||
}
|
||||
|
||||
//statement
|
||||
|
||||
// check if stmt implements all required interfaces
|
||||
var (
|
||||
_ driver.Stmt = (*stmt)(nil)
|
||||
_ driver.StmtExecContext = (*stmt)(nil)
|
||||
_ driver.StmtQueryContext = (*stmt)(nil)
|
||||
_ driver.NamedValueChecker = (*stmt)(nil)
|
||||
)
|
||||
|
||||
type stmt struct {
|
||||
qt p.QueryType
|
||||
session *p.Session
|
||||
query string
|
||||
id uint64
|
||||
prmFieldSet *p.ParameterFieldSet
|
||||
resultFieldSet *p.ResultFieldSet
|
||||
}
|
||||
|
||||
func newStmt(qt p.QueryType, session *p.Session, query string, id uint64, prmFieldSet *p.ParameterFieldSet, resultFieldSet *p.ResultFieldSet) (*stmt, error) {
|
||||
return &stmt{qt: qt, session: session, query: query, id: id, prmFieldSet: prmFieldSet, resultFieldSet: resultFieldSet}, nil
|
||||
}
|
||||
|
||||
func (s *stmt) Close() error {
|
||||
return s.session.DropStatementID(s.id)
|
||||
}
|
||||
|
||||
func (s *stmt) NumInput() int {
|
||||
return s.prmFieldSet.NumInputField()
|
||||
}
|
||||
|
||||
func (s *stmt) Exec(args []driver.Value) (driver.Result, error) {
|
||||
panic("deprecated")
|
||||
}
|
||||
|
||||
func (s *stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (r driver.Result, err error) {
|
||||
if s.session.IsBad() {
|
||||
return nil, driver.ErrBadConn
|
||||
}
|
||||
|
||||
numField := s.prmFieldSet.NumInputField()
|
||||
if len(args) != numField {
|
||||
return nil, fmt.Errorf("invalid number of arguments %d - %d expected", len(args), numField)
|
||||
}
|
||||
|
||||
sqltrace.Tracef("%s %v", s.query, args)
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
r, err = s.session.Exec(s.id, s.prmFieldSet, args)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
case <-done:
|
||||
return r, err
|
||||
}
|
||||
}
|
||||
|
||||
func (s *stmt) Query(args []driver.Value) (rows driver.Rows, err error) {
|
||||
panic("deprecated")
|
||||
}
|
||||
|
||||
// Deprecated: see NamedValueChecker.
|
||||
//func (s *stmt) ColumnConverter(idx int) driver.ValueConverter {
|
||||
//}
|
||||
|
||||
// CheckNamedValue implements NamedValueChecker interface.
|
||||
func (s *stmt) CheckNamedValue(nv *driver.NamedValue) error {
|
||||
if nv.Name == noFlush {
|
||||
//...
|
||||
|
||||
print("remove variable")
|
||||
|
||||
return driver.ErrRemoveArgument
|
||||
}
|
||||
return checkNamedValue(s.prmFieldSet, nv)
|
||||
}
|
||||
|
||||
// driver.Rows drop-in replacement if driver Query or QueryRow is used for statements that doesn't return rows
|
||||
var noColumns = []string{}
|
||||
var noResult = new(noResultType)
|
||||
|
||||
// check if noResultType implements all required interfaces
|
||||
var (
|
||||
_ driver.Rows = (*noResultType)(nil)
|
||||
)
|
||||
|
||||
type noResultType struct{}
|
||||
|
||||
func (r *noResultType) Columns() []string { return noColumns }
|
||||
func (r *noResultType) Close() error { return nil }
|
||||
func (r *noResultType) Next(dest []driver.Value) error { return io.EOF }
|
||||
|
||||
// rows
|
||||
type rows struct {
|
||||
}
|
||||
|
||||
// query result
|
||||
|
||||
// check if queryResult implements all required interfaces
|
||||
var (
|
||||
_ driver.Rows = (*queryResult)(nil)
|
||||
_ driver.RowsColumnTypeDatabaseTypeName = (*queryResult)(nil) // go 1.8
|
||||
_ driver.RowsColumnTypeLength = (*queryResult)(nil) // go 1.8
|
||||
_ driver.RowsColumnTypeNullable = (*queryResult)(nil) // go 1.8
|
||||
_ driver.RowsColumnTypePrecisionScale = (*queryResult)(nil) // go 1.8
|
||||
_ driver.RowsColumnTypeScanType = (*queryResult)(nil) // go 1.8
|
||||
)
|
||||
|
||||
type queryResult struct {
|
||||
session *p.Session
|
||||
id uint64
|
||||
resultFieldSet *p.ResultFieldSet
|
||||
fieldValues *p.FieldValues
|
||||
pos int
|
||||
attrs p.PartAttributes
|
||||
columns []string
|
||||
lastErr error
|
||||
}
|
||||
|
||||
func newQueryResult(session *p.Session, id uint64, resultFieldSet *p.ResultFieldSet, fieldValues *p.FieldValues, attrs p.PartAttributes) (driver.Rows, error) {
|
||||
columns := make([]string, resultFieldSet.NumField())
|
||||
for i := 0; i < len(columns); i++ {
|
||||
columns[i] = resultFieldSet.Field(i).Name()
|
||||
}
|
||||
|
||||
return &queryResult{
|
||||
session: session,
|
||||
id: id,
|
||||
resultFieldSet: resultFieldSet,
|
||||
fieldValues: fieldValues,
|
||||
attrs: attrs,
|
||||
columns: columns,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (r *queryResult) Columns() []string {
|
||||
return r.columns
|
||||
}
|
||||
|
||||
func (r *queryResult) Close() error {
|
||||
// if lastError is set, attrs are nil
|
||||
if r.lastErr != nil {
|
||||
return r.lastErr
|
||||
}
|
||||
|
||||
if !r.attrs.ResultsetClosed() {
|
||||
return r.session.CloseResultsetID(r.id)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *queryResult) Next(dest []driver.Value) error {
|
||||
if r.session.IsBad() {
|
||||
return driver.ErrBadConn
|
||||
}
|
||||
|
||||
if r.pos >= r.fieldValues.NumRow() {
|
||||
if r.attrs.LastPacket() {
|
||||
return io.EOF
|
||||
}
|
||||
|
||||
var err error
|
||||
|
||||
if r.attrs, err = r.session.FetchNext(r.id, r.resultFieldSet, r.fieldValues); err != nil {
|
||||
r.lastErr = err //fieldValues and attrs are nil
|
||||
return err
|
||||
}
|
||||
|
||||
if r.attrs.NoRows() {
|
||||
return io.EOF
|
||||
}
|
||||
|
||||
r.pos = 0
|
||||
|
||||
}
|
||||
|
||||
r.fieldValues.Row(r.pos, dest)
|
||||
r.pos++
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *queryResult) ColumnTypeDatabaseTypeName(idx int) string {
|
||||
return r.resultFieldSet.Field(idx).TypeCode().TypeName()
|
||||
}
|
||||
|
||||
func (r *queryResult) ColumnTypeLength(idx int) (int64, bool) {
|
||||
return r.resultFieldSet.Field(idx).TypeLength()
|
||||
}
|
||||
|
||||
func (r *queryResult) ColumnTypePrecisionScale(idx int) (int64, int64, bool) {
|
||||
return r.resultFieldSet.Field(idx).TypePrecisionScale()
|
||||
}
|
||||
|
||||
func (r *queryResult) ColumnTypeNullable(idx int) (bool, bool) {
|
||||
return r.resultFieldSet.Field(idx).Nullable(), true
|
||||
}
|
||||
|
||||
var (
|
||||
scanTypeUnknown = reflect.TypeOf(new(interface{})).Elem()
|
||||
scanTypeTinyint = reflect.TypeOf(uint8(0))
|
||||
scanTypeSmallint = reflect.TypeOf(int16(0))
|
||||
scanTypeInteger = reflect.TypeOf(int32(0))
|
||||
scanTypeBigint = reflect.TypeOf(int64(0))
|
||||
scanTypeReal = reflect.TypeOf(float32(0.0))
|
||||
scanTypeDouble = reflect.TypeOf(float64(0.0))
|
||||
scanTypeTime = reflect.TypeOf(time.Time{})
|
||||
scanTypeString = reflect.TypeOf(string(""))
|
||||
scanTypeBytes = reflect.TypeOf([]byte{})
|
||||
scanTypeDecimal = reflect.TypeOf(Decimal{})
|
||||
scanTypeLob = reflect.TypeOf(Lob{})
|
||||
)
|
||||
|
||||
func (r *queryResult) ColumnTypeScanType(idx int) reflect.Type {
|
||||
switch r.resultFieldSet.Field(idx).TypeCode().DataType() {
|
||||
default:
|
||||
return scanTypeUnknown
|
||||
case p.DtTinyint:
|
||||
return scanTypeTinyint
|
||||
case p.DtSmallint:
|
||||
return scanTypeSmallint
|
||||
case p.DtInteger:
|
||||
return scanTypeInteger
|
||||
case p.DtBigint:
|
||||
return scanTypeBigint
|
||||
case p.DtReal:
|
||||
return scanTypeReal
|
||||
case p.DtDouble:
|
||||
return scanTypeDouble
|
||||
case p.DtTime:
|
||||
return scanTypeTime
|
||||
case p.DtDecimal:
|
||||
return scanTypeDecimal
|
||||
case p.DtString:
|
||||
return scanTypeString
|
||||
case p.DtBytes:
|
||||
return scanTypeBytes
|
||||
case p.DtLob:
|
||||
return scanTypeLob
|
||||
}
|
||||
}
|
|
@ -0,0 +1,162 @@
|
|||
// +build future
|
||||
|
||||
/*
|
||||
Copyright 2018 SAP SE
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package driver
|
||||
|
||||
import (
|
||||
"context"
|
||||
//"database/sql"
|
||||
"database/sql/driver"
|
||||
|
||||
"github.com/SAP/go-hdb/driver/sqltrace"
|
||||
|
||||
p "github.com/SAP/go-hdb/internal/protocol"
|
||||
)
|
||||
|
||||
// database connection
|
||||
|
||||
func (c *conn) PrepareContext(ctx context.Context, query string) (stmt driver.Stmt, err error) {
|
||||
if c.session.IsBad() {
|
||||
return nil, driver.ErrBadConn
|
||||
}
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
var (
|
||||
qt p.QueryType
|
||||
id uint64
|
||||
prmFieldSet *p.ParameterFieldSet
|
||||
resultFieldSet *p.ResultFieldSet
|
||||
)
|
||||
qt, id, prmFieldSet, resultFieldSet, err = c.session.Prepare(query)
|
||||
if err != nil {
|
||||
goto done
|
||||
}
|
||||
select {
|
||||
default:
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
stmt, err = newStmt(qt, c.session, query, id, prmFieldSet, resultFieldSet)
|
||||
done:
|
||||
close(done)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
case <-done:
|
||||
return stmt, err
|
||||
}
|
||||
}
|
||||
|
||||
// QueryContext implements the database/sql/driver/QueryerContext interface.
|
||||
func (c *conn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (rows driver.Rows, err error) {
|
||||
if c.session.IsBad() {
|
||||
return nil, driver.ErrBadConn
|
||||
}
|
||||
|
||||
if len(args) != 0 {
|
||||
return nil, driver.ErrSkip //fast path not possible (prepare needed)
|
||||
}
|
||||
|
||||
// direct execution of call procedure
|
||||
// - returns no parameter metadata (sps 82) but only field values
|
||||
// --> let's take the 'prepare way' for stored procedures
|
||||
// if checkCallProcedure(query) {
|
||||
// return nil, driver.ErrSkip
|
||||
// }
|
||||
|
||||
sqltrace.Traceln(query)
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
var (
|
||||
id uint64
|
||||
resultFieldSet *p.ResultFieldSet
|
||||
fieldValues *p.FieldValues
|
||||
attributes p.PartAttributes
|
||||
)
|
||||
id, resultFieldSet, fieldValues, attributes, err = c.session.QueryDirect(query)
|
||||
if err != nil {
|
||||
goto done
|
||||
}
|
||||
select {
|
||||
default:
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
if id == 0 { // non select query
|
||||
rows = noResult
|
||||
} else {
|
||||
rows, err = newQueryResult(c.session, id, resultFieldSet, fieldValues, attributes)
|
||||
}
|
||||
done:
|
||||
close(done)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
case <-done:
|
||||
return rows, err
|
||||
}
|
||||
}
|
||||
|
||||
//statement
|
||||
|
||||
func (s *stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (rows driver.Rows, err error) {
|
||||
|
||||
if s.session.IsBad() {
|
||||
return nil, driver.ErrBadConn
|
||||
}
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
rows, err = s.defaultQuery(ctx, args)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
case <-done:
|
||||
return rows, err
|
||||
}
|
||||
}
|
||||
|
||||
func (s *stmt) defaultQuery(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
|
||||
|
||||
sqltrace.Tracef("%s %v", s.query, args)
|
||||
|
||||
rid, values, attributes, err := s.session.Query(s.id, s.prmFieldSet, s.resultFieldSet, args)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
select {
|
||||
default:
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
|
||||
if rid == 0 { // non select query
|
||||
return noResult, nil
|
||||
}
|
||||
return newQueryResult(s.session, rid, s.resultFieldSet, values, attributes)
|
||||
}
|
|
@ -0,0 +1,32 @@
|
|||
// +build go1.10
|
||||
|
||||
/*
|
||||
Copyright 2014 SAP SE
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package driver
|
||||
|
||||
import (
|
||||
"database/sql/driver"
|
||||
)
|
||||
|
||||
// driver
|
||||
|
||||
// check if driver implements all required interfaces
|
||||
var _ driver.DriverContext = (*hdbDrv)(nil)
|
||||
|
||||
func (d *hdbDrv) OpenConnector(dsn string) (driver.Connector, error) {
|
||||
return NewDSNConnector(dsn)
|
||||
}
|
|
@ -0,0 +1,507 @@
|
|||
// +build !future
|
||||
|
||||
/*
|
||||
Copyright 2018 SAP SE
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package driver
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql/driver"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"io"
|
||||
"regexp"
|
||||
"sync"
|
||||
|
||||
"github.com/SAP/go-hdb/driver/sqltrace"
|
||||
|
||||
p "github.com/SAP/go-hdb/internal/protocol"
|
||||
)
|
||||
|
||||
var reBulk = regexp.MustCompile("(?i)^(\\s)*(bulk +)(.*)")
|
||||
|
||||
func checkBulkInsert(sql string) (string, bool) {
|
||||
if reBulk.MatchString(sql) {
|
||||
return reBulk.ReplaceAllString(sql, "${3}"), true
|
||||
}
|
||||
return sql, false
|
||||
}
|
||||
|
||||
var reCall = regexp.MustCompile("(?i)^(\\s)*(call +)(.*)")
|
||||
|
||||
func checkCallProcedure(sql string) bool {
|
||||
return reCall.MatchString(sql)
|
||||
}
|
||||
|
||||
// database connection
|
||||
|
||||
func (c *conn) PrepareContext(ctx context.Context, query string) (stmt driver.Stmt, err error) {
|
||||
if c.session.IsBad() {
|
||||
return nil, driver.ErrBadConn
|
||||
}
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
prepareQuery, bulkInsert := checkBulkInsert(query)
|
||||
var (
|
||||
qt p.QueryType
|
||||
id uint64
|
||||
prmFieldSet *p.ParameterFieldSet
|
||||
resultFieldSet *p.ResultFieldSet
|
||||
)
|
||||
qt, id, prmFieldSet, resultFieldSet, err = c.session.Prepare(prepareQuery)
|
||||
if err != nil {
|
||||
goto done
|
||||
}
|
||||
select {
|
||||
default:
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
if bulkInsert {
|
||||
stmt, err = newBulkInsertStmt(c.session, prepareQuery, id, prmFieldSet)
|
||||
} else {
|
||||
stmt, err = newStmt(qt, c.session, prepareQuery, id, prmFieldSet, resultFieldSet)
|
||||
}
|
||||
done:
|
||||
close(done)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
case <-done:
|
||||
return stmt, err
|
||||
}
|
||||
}
|
||||
|
||||
// QueryContext implements the database/sql/driver/QueryerContext interface.
|
||||
func (c *conn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (rows driver.Rows, err error) {
|
||||
if c.session.IsBad() {
|
||||
return nil, driver.ErrBadConn
|
||||
}
|
||||
|
||||
if len(args) != 0 {
|
||||
return nil, driver.ErrSkip //fast path not possible (prepare needed)
|
||||
}
|
||||
|
||||
// direct execution of call procedure
|
||||
// - returns no parameter metadata (sps 82) but only field values
|
||||
// --> let's take the 'prepare way' for stored procedures
|
||||
if checkCallProcedure(query) {
|
||||
return nil, driver.ErrSkip
|
||||
}
|
||||
|
||||
sqltrace.Traceln(query)
|
||||
|
||||
id, idx, ok := decodeTableQuery(query)
|
||||
if ok {
|
||||
r := procedureCallResultStore.get(id)
|
||||
if r == nil {
|
||||
return nil, fmt.Errorf("invalid procedure table query %s", query)
|
||||
}
|
||||
return r.tableRows(int(idx))
|
||||
}
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
var (
|
||||
id uint64
|
||||
resultFieldSet *p.ResultFieldSet
|
||||
fieldValues *p.FieldValues
|
||||
attributes p.PartAttributes
|
||||
)
|
||||
id, resultFieldSet, fieldValues, attributes, err = c.session.QueryDirect(query)
|
||||
if err != nil {
|
||||
goto done
|
||||
}
|
||||
select {
|
||||
default:
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
if id == 0 { // non select query
|
||||
rows = noResult
|
||||
} else {
|
||||
rows, err = newQueryResult(c.session, id, resultFieldSet, fieldValues, attributes)
|
||||
}
|
||||
done:
|
||||
close(done)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
case <-done:
|
||||
return rows, err
|
||||
}
|
||||
}
|
||||
|
||||
//statement
|
||||
|
||||
func (s *stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (rows driver.Rows, err error) {
|
||||
|
||||
if s.session.IsBad() {
|
||||
return nil, driver.ErrBadConn
|
||||
}
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
switch s.qt {
|
||||
default:
|
||||
rows, err = s.defaultQuery(ctx, args)
|
||||
case p.QtProcedureCall:
|
||||
rows, err = s.procedureCall(ctx, args)
|
||||
}
|
||||
close(done)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
case <-done:
|
||||
return rows, err
|
||||
}
|
||||
}
|
||||
|
||||
func (s *stmt) defaultQuery(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
|
||||
|
||||
sqltrace.Tracef("%s %v", s.query, args)
|
||||
|
||||
rid, values, attributes, err := s.session.Query(s.id, s.prmFieldSet, s.resultFieldSet, args)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
select {
|
||||
default:
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
|
||||
if rid == 0 { // non select query
|
||||
return noResult, nil
|
||||
}
|
||||
return newQueryResult(s.session, rid, s.resultFieldSet, values, attributes)
|
||||
}
|
||||
|
||||
func (s *stmt) procedureCall(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
|
||||
|
||||
sqltrace.Tracef("%s %v", s.query, args)
|
||||
|
||||
fieldValues, tableResults, err := s.session.Call(s.id, s.prmFieldSet, args)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
select {
|
||||
default:
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
|
||||
return newProcedureCallResult(s.session, s.prmFieldSet, fieldValues, tableResults)
|
||||
}
|
||||
|
||||
// bulk insert statement
|
||||
|
||||
// check if bulkInsertStmt implements all required interfaces
|
||||
var (
|
||||
_ driver.Stmt = (*bulkInsertStmt)(nil)
|
||||
_ driver.StmtExecContext = (*bulkInsertStmt)(nil)
|
||||
_ driver.StmtQueryContext = (*bulkInsertStmt)(nil)
|
||||
_ driver.NamedValueChecker = (*bulkInsertStmt)(nil)
|
||||
)
|
||||
|
||||
type bulkInsertStmt struct {
|
||||
session *p.Session
|
||||
query string
|
||||
id uint64
|
||||
prmFieldSet *p.ParameterFieldSet
|
||||
numArg int
|
||||
args []driver.NamedValue
|
||||
}
|
||||
|
||||
func newBulkInsertStmt(session *p.Session, query string, id uint64, prmFieldSet *p.ParameterFieldSet) (*bulkInsertStmt, error) {
|
||||
return &bulkInsertStmt{session: session, query: query, id: id, prmFieldSet: prmFieldSet, args: make([]driver.NamedValue, 0)}, nil
|
||||
}
|
||||
|
||||
func (s *bulkInsertStmt) Close() error {
|
||||
return s.session.DropStatementID(s.id)
|
||||
}
|
||||
|
||||
func (s *bulkInsertStmt) NumInput() int {
|
||||
return -1
|
||||
}
|
||||
|
||||
func (s *bulkInsertStmt) Exec(args []driver.Value) (driver.Result, error) {
|
||||
panic("deprecated")
|
||||
}
|
||||
|
||||
func (s *bulkInsertStmt) ExecContext(ctx context.Context, args []driver.NamedValue) (r driver.Result, err error) {
|
||||
|
||||
if s.session.IsBad() {
|
||||
return nil, driver.ErrBadConn
|
||||
}
|
||||
|
||||
sqltrace.Tracef("%s %v", s.query, args)
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
if args == nil || len(args) == 0 {
|
||||
r, err = s.execFlush()
|
||||
} else {
|
||||
r, err = s.execBuffer(args)
|
||||
}
|
||||
close(done)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
case <-done:
|
||||
return r, err
|
||||
}
|
||||
}
|
||||
|
||||
func (s *bulkInsertStmt) execFlush() (driver.Result, error) {
|
||||
|
||||
if s.numArg == 0 {
|
||||
return driver.ResultNoRows, nil
|
||||
}
|
||||
|
||||
sqltrace.Traceln("execFlush")
|
||||
|
||||
result, err := s.session.Exec(s.id, s.prmFieldSet, s.args)
|
||||
s.args = s.args[:0]
|
||||
s.numArg = 0
|
||||
return result, err
|
||||
}
|
||||
|
||||
func (s *bulkInsertStmt) execBuffer(args []driver.NamedValue) (driver.Result, error) {
|
||||
|
||||
numField := s.prmFieldSet.NumInputField()
|
||||
if len(args) != numField {
|
||||
return nil, fmt.Errorf("invalid number of arguments %d - %d expected", len(args), numField)
|
||||
}
|
||||
|
||||
var result driver.Result = driver.ResultNoRows
|
||||
var err error
|
||||
|
||||
if s.numArg == maxSmallint { // TODO: check why bigArgument count does not work
|
||||
result, err = s.execFlush()
|
||||
}
|
||||
|
||||
s.args = append(s.args, args...)
|
||||
s.numArg++
|
||||
|
||||
return result, err
|
||||
}
|
||||
|
||||
func (s *bulkInsertStmt) Query(args []driver.Value) (driver.Rows, error) {
|
||||
panic("deprecated")
|
||||
}
|
||||
|
||||
func (s *bulkInsertStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
|
||||
return nil, fmt.Errorf("query not allowed in context of bulk insert statement %s", s.query)
|
||||
}
|
||||
|
||||
// Deprecated: see NamedValueChecker.
|
||||
//func (s *bulkInsertStmt) ColumnConverter(idx int) driver.ValueConverter {
|
||||
//}
|
||||
|
||||
// CheckNamedValue implements NamedValueChecker interface.
|
||||
func (s *bulkInsertStmt) CheckNamedValue(nv *driver.NamedValue) error {
|
||||
return checkNamedValue(s.prmFieldSet, nv)
|
||||
}
|
||||
|
||||
//call result store
|
||||
type callResultStore struct {
|
||||
mu sync.RWMutex
|
||||
store map[uint64]*procedureCallResult
|
||||
cnt uint64
|
||||
free []uint64
|
||||
}
|
||||
|
||||
func (s *callResultStore) get(k uint64) *procedureCallResult {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
if r, ok := s.store[k]; ok {
|
||||
return r
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *callResultStore) add(v *procedureCallResult) uint64 {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
var k uint64
|
||||
|
||||
if s.free == nil || len(s.free) == 0 {
|
||||
s.cnt++
|
||||
k = s.cnt
|
||||
} else {
|
||||
size := len(s.free)
|
||||
k = s.free[size-1]
|
||||
s.free = s.free[:size-1]
|
||||
}
|
||||
|
||||
if s.store == nil {
|
||||
s.store = make(map[uint64]*procedureCallResult)
|
||||
}
|
||||
|
||||
s.store[k] = v
|
||||
|
||||
return k
|
||||
}
|
||||
|
||||
func (s *callResultStore) del(k uint64) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
delete(s.store, k)
|
||||
|
||||
if s.free == nil {
|
||||
s.free = []uint64{k}
|
||||
} else {
|
||||
s.free = append(s.free, k)
|
||||
}
|
||||
}
|
||||
|
||||
var procedureCallResultStore = new(callResultStore)
|
||||
|
||||
//procedure call result
|
||||
|
||||
// check if procedureCallResult implements all required interfaces
|
||||
var _ driver.Rows = (*procedureCallResult)(nil)
|
||||
|
||||
type procedureCallResult struct {
|
||||
id uint64
|
||||
session *p.Session
|
||||
prmFieldSet *p.ParameterFieldSet
|
||||
fieldValues *p.FieldValues
|
||||
_tableRows []driver.Rows
|
||||
columns []string
|
||||
eof error
|
||||
}
|
||||
|
||||
func newProcedureCallResult(session *p.Session, prmFieldSet *p.ParameterFieldSet, fieldValues *p.FieldValues, tableResults []*p.TableResult) (driver.Rows, error) {
|
||||
|
||||
fieldIdx := prmFieldSet.NumOutputField()
|
||||
columns := make([]string, fieldIdx+len(tableResults))
|
||||
|
||||
for i := 0; i < fieldIdx; i++ {
|
||||
columns[i] = prmFieldSet.OutputField(i).Name()
|
||||
}
|
||||
|
||||
tableRows := make([]driver.Rows, len(tableResults))
|
||||
for i, tableResult := range tableResults {
|
||||
var err error
|
||||
|
||||
if tableRows[i], err = newQueryResult(session, tableResult.ID(), tableResult.FieldSet(), tableResult.FieldValues(), tableResult.Attrs()); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
columns[fieldIdx] = fmt.Sprintf("table %d", i)
|
||||
|
||||
fieldIdx++
|
||||
|
||||
}
|
||||
|
||||
result := &procedureCallResult{
|
||||
session: session,
|
||||
prmFieldSet: prmFieldSet,
|
||||
fieldValues: fieldValues,
|
||||
_tableRows: tableRows,
|
||||
columns: columns,
|
||||
}
|
||||
id := procedureCallResultStore.add(result)
|
||||
result.id = id
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (r *procedureCallResult) Columns() []string {
|
||||
return r.columns
|
||||
}
|
||||
|
||||
func (r *procedureCallResult) Close() error {
|
||||
procedureCallResultStore.del(r.id)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *procedureCallResult) Next(dest []driver.Value) error {
|
||||
if r.session.IsBad() {
|
||||
return driver.ErrBadConn
|
||||
}
|
||||
|
||||
if r.eof != nil {
|
||||
return r.eof
|
||||
}
|
||||
|
||||
if r.fieldValues.NumRow() == 0 && len(r._tableRows) == 0 {
|
||||
r.eof = io.EOF
|
||||
return r.eof
|
||||
}
|
||||
|
||||
if r.fieldValues.NumRow() != 0 {
|
||||
r.fieldValues.Row(0, dest)
|
||||
}
|
||||
|
||||
i := r.prmFieldSet.NumOutputField()
|
||||
for j := range r._tableRows {
|
||||
dest[i] = encodeTableQuery(r.id, uint64(j))
|
||||
i++
|
||||
}
|
||||
|
||||
r.eof = io.EOF
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *procedureCallResult) tableRows(idx int) (driver.Rows, error) {
|
||||
if idx >= len(r._tableRows) {
|
||||
return nil, fmt.Errorf("table row index %d exceeds maximun %d", idx, len(r._tableRows)-1)
|
||||
}
|
||||
return r._tableRows[idx], nil
|
||||
}
|
||||
|
||||
// helper
|
||||
const tableQueryPrefix = "@tq"
|
||||
|
||||
func encodeTableQuery(id, idx uint64) string {
|
||||
start := len(tableQueryPrefix)
|
||||
b := make([]byte, start+8+8)
|
||||
copy(b, tableQueryPrefix)
|
||||
binary.LittleEndian.PutUint64(b[start:start+8], id)
|
||||
binary.LittleEndian.PutUint64(b[start+8:start+8+8], idx)
|
||||
return string(b)
|
||||
}
|
||||
|
||||
func decodeTableQuery(query string) (uint64, uint64, bool) {
|
||||
size := len(query)
|
||||
start := len(tableQueryPrefix)
|
||||
if size != start+8+8 {
|
||||
return 0, 0, false
|
||||
}
|
||||
if query[:start] != tableQueryPrefix {
|
||||
return 0, 0, false
|
||||
}
|
||||
id := binary.LittleEndian.Uint64([]byte(query[start : start+8]))
|
||||
idx := binary.LittleEndian.Uint64([]byte(query[start+8 : start+8+8]))
|
||||
return id, idx, true
|
||||
}
|
|
@ -0,0 +1,64 @@
|
|||
/*
|
||||
Copyright 2014 SAP SE
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package driver
|
||||
|
||||
// DSN parameters. For parameter client locale see http://help.sap.com/hana/SAP_HANA_SQL_Command_Network_Protocol_Reference_en.pdf.
|
||||
const (
|
||||
DSNLocale = "locale" // Client locale as described in the protocol reference.
|
||||
DSNTimeout = "timeout" // Driver side connection timeout in seconds.
|
||||
DSNFetchSize = "fetchSize" // Maximum number of fetched records from database by database/sql/driver/Rows.Next().
|
||||
)
|
||||
|
||||
/*
|
||||
DSN TLS parameters.
|
||||
For more information please see https://golang.org/pkg/crypto/tls/#Config.
|
||||
For more flexibility in TLS configuration please see driver.Connector.
|
||||
*/
|
||||
const (
|
||||
DSNTLSRootCAFile = "TLSRootCAFile" // Path,- filename to root certificate(s).
|
||||
DSNTLSServerName = "TLSServerName" // ServerName to verify the hostname.
|
||||
DSNTLSInsecureSkipVerify = "TLSInsecureSkipVerify" // Controls whether a client verifies the server's certificate chain and host name.
|
||||
)
|
||||
|
||||
// DSN default values.
|
||||
const (
|
||||
DefaultTimeout = 300 // Default value connection timeout (300 seconds = 5 minutes).
|
||||
DefaultFetchSize = 128 // Default value fetchSize.
|
||||
)
|
||||
|
||||
// DSN minimal values.
|
||||
const (
|
||||
minTimeout = 0 // Minimal timeout value.
|
||||
minFetchSize = 1 // Minimal fetchSize value.
|
||||
)
|
||||
|
||||
/*
|
||||
DSN is here for the purposes of documentation only. A DSN string is an URL string with the following format
|
||||
|
||||
"hdb://<username>:<password>@<host address>:<port number>"
|
||||
|
||||
and optional query parameters (see DSN query parameters and DSN query default values).
|
||||
|
||||
Example:
|
||||
"hdb://myuser:mypassword@localhost:30015?timeout=60"
|
||||
|
||||
Examples TLS connection:
|
||||
"hdb://myuser:mypassword@localhost:39013?TLSRootCAFile=trust.pem"
|
||||
"hdb://myuser:mypassword@localhost:39013?TLSRootCAFile=trust.pem&TLSServerName=hostname"
|
||||
"hdb://myuser:mypassword@localhost:39013?TLSInsecureSkipVerify"
|
||||
*/
|
||||
type DSN string
|
|
@ -0,0 +1,39 @@
|
|||
/*
|
||||
Copyright 2014 SAP SE
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package driver
|
||||
|
||||
// HDB error levels.
|
||||
const (
|
||||
HdbWarning = 0
|
||||
HdbError = 1
|
||||
HdbFatalError = 2
|
||||
)
|
||||
|
||||
// Error represents errors send by the database server.
|
||||
type Error interface {
|
||||
Error() string // Implements the golang error interface.
|
||||
NumError() int // NumError returns the number of errors.
|
||||
SetIdx(idx int) // Sets the error index in case number of errors are greater 1 in the range of 0 <= index < NumError().
|
||||
StmtNo() int // Returns the statement number of the error in multi statement contexts (e.g. bulk insert).
|
||||
Code() int // Code return the database error code.
|
||||
Position() int // Position returns the start position of erroneous sql statements sent to the database server.
|
||||
Level() int // Level return one of the database server predefined error levels.
|
||||
Text() string // Text return the error description sent from database server.
|
||||
IsWarning() bool // IsWarning returns true if the HDB error level equals 0.
|
||||
IsError() bool // IsError returns true if the HDB error level equals 1.
|
||||
IsFatal() bool // IsFatal returns true if the HDB error level equals 2.
|
||||
}
|
|
@ -0,0 +1,49 @@
|
|||
/*
|
||||
Copyright 2014 SAP SE
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package driver
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"fmt"
|
||||
"io"
|
||||
"regexp"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
var reSimple = regexp.MustCompile("^[_A-Z][_#$A-Z0-9]*$")
|
||||
|
||||
// Identifier in hdb SQL statements like schema or table name.
|
||||
type Identifier string
|
||||
|
||||
// RandomIdentifier returns a random Identifier prefixed by the prefix parameter.
|
||||
// This function is used to generate database objects with random names for test and example code.
|
||||
func RandomIdentifier(prefix string) Identifier {
|
||||
b := make([]byte, 16)
|
||||
if _, err := io.ReadFull(rand.Reader, b); err != nil {
|
||||
panic(err.Error()) // rand should never fail
|
||||
}
|
||||
return Identifier(fmt.Sprintf("%s%x", prefix, b))
|
||||
}
|
||||
|
||||
// String implements Stringer interface.
|
||||
func (i Identifier) String() string {
|
||||
s := string(i)
|
||||
if reSimple.MatchString(s) {
|
||||
return s
|
||||
}
|
||||
return strconv.Quote(s)
|
||||
}
|
|
@ -0,0 +1,94 @@
|
|||
/*
|
||||
Copyright 2014 SAP SE
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package driver
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
)
|
||||
|
||||
// A Lob is the driver representation of a database large object field.
|
||||
// A Lob object uses an io.Reader object as source for writing content to a database lob field.
|
||||
// A Lob object uses an io.Writer object as destination for reading content from a database lob field.
|
||||
// A Lob can be created by contructor method NewLob with io.Reader and io.Writer as parameters or
|
||||
// created by new, setting io.Reader and io.Writer by SetReader and SetWriter methods.
|
||||
type Lob struct {
|
||||
rd io.Reader
|
||||
wr io.Writer
|
||||
}
|
||||
|
||||
// NewLob creates a new Lob instance with the io.Reader and io.Writer given as parameters.
|
||||
func NewLob(rd io.Reader, wr io.Writer) *Lob {
|
||||
return &Lob{rd: rd, wr: wr}
|
||||
}
|
||||
|
||||
// SetReader sets the io.Reader source for a lob field to be written to database
|
||||
// and return *Lob, to enable simple call chaining.
|
||||
func (l *Lob) SetReader(rd io.Reader) *Lob {
|
||||
l.rd = rd
|
||||
return l
|
||||
}
|
||||
|
||||
// SetWriter sets the io.Writer destination for a lob field to be read from database
|
||||
// and return *Lob, to enable simple call chaining.
|
||||
func (l *Lob) SetWriter(wr io.Writer) *Lob {
|
||||
l.wr = wr
|
||||
return l
|
||||
}
|
||||
|
||||
type writerSetter interface {
|
||||
SetWriter(w io.Writer) error
|
||||
}
|
||||
|
||||
// Scan implements the database/sql/Scanner interface.
|
||||
func (l *Lob) Scan(src interface{}) error {
|
||||
|
||||
if l.wr == nil {
|
||||
return fmt.Errorf("lob error: initial reader %[1]T %[1]v", l)
|
||||
}
|
||||
|
||||
ws, ok := src.(writerSetter)
|
||||
if !ok {
|
||||
return fmt.Errorf("lob: invalid scan type %T", src)
|
||||
}
|
||||
|
||||
if err := ws.SetWriter(l.wr); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// NullLob represents an Lob that may be null.
|
||||
// NullLob implements the Scanner interface so
|
||||
// it can be used as a scan destination, similar to NullString.
|
||||
type NullLob struct {
|
||||
Lob *Lob
|
||||
Valid bool // Valid is true if Lob is not NULL
|
||||
}
|
||||
|
||||
// Scan implements the database/sql/Scanner interface.
|
||||
func (l *NullLob) Scan(src interface{}) error {
|
||||
if src == nil {
|
||||
l.Valid = false
|
||||
return nil
|
||||
}
|
||||
if err := l.Lob.Scan(src); err != nil {
|
||||
return err
|
||||
}
|
||||
l.Valid = true
|
||||
return nil
|
||||
}
|
|
@ -0,0 +1,18 @@
|
|||
/*
|
||||
Copyright 2014 SAP SE
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
// Package sqltrace implements driver sql trace functions.
|
||||
package sqltrace
|
|
@ -0,0 +1,78 @@
|
|||
/*
|
||||
Copyright 2014 SAP SE
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package sqltrace
|
||||
|
||||
import (
|
||||
"flag"
|
||||
"log"
|
||||
"os"
|
||||
"sync"
|
||||
)
|
||||
|
||||
type sqlTrace struct {
|
||||
mu sync.RWMutex //protects field on
|
||||
on bool
|
||||
*log.Logger
|
||||
}
|
||||
|
||||
func newSQLTrace() *sqlTrace {
|
||||
return &sqlTrace{
|
||||
Logger: log.New(os.Stdout, "hdb ", log.Ldate|log.Ltime|log.Lshortfile),
|
||||
}
|
||||
}
|
||||
|
||||
var tracer = newSQLTrace()
|
||||
|
||||
func init() {
|
||||
flag.BoolVar(&tracer.on, "hdb.sqlTrace", false, "enabling hdb sql trace")
|
||||
}
|
||||
|
||||
// On returns if tracing methods output is active.
|
||||
func On() bool {
|
||||
tracer.mu.RLock()
|
||||
on := tracer.on
|
||||
tracer.mu.RUnlock()
|
||||
return on
|
||||
}
|
||||
|
||||
// SetOn sets tracing methods output active or inactive.
|
||||
func SetOn(on bool) {
|
||||
tracer.mu.Lock()
|
||||
tracer.on = on
|
||||
tracer.mu.Unlock()
|
||||
}
|
||||
|
||||
// Trace calls trace logger Print method to print to the trace logger.
|
||||
func Trace(v ...interface{}) {
|
||||
if On() {
|
||||
tracer.Print(v...)
|
||||
}
|
||||
}
|
||||
|
||||
// Tracef calls trace logger Printf method to print to the trace logger.
|
||||
func Tracef(format string, v ...interface{}) {
|
||||
if On() {
|
||||
tracer.Printf(format, v...)
|
||||
}
|
||||
}
|
||||
|
||||
// Traceln calls trace logger Println method to print to the trace logger.
|
||||
func Traceln(v ...interface{}) {
|
||||
if On() {
|
||||
tracer.Println(v...)
|
||||
}
|
||||
}
|
|
@ -0,0 +1,44 @@
|
|||
/*
|
||||
Copyright 2014 SAP SE
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package driver
|
||||
|
||||
import (
|
||||
"database/sql/driver"
|
||||
"time"
|
||||
)
|
||||
|
||||
// NullTime represents an time.Time that may be null.
|
||||
// NullTime implements the Scanner interface so
|
||||
// it can be used as a scan destination, similar to NullString.
|
||||
type NullTime struct {
|
||||
Time time.Time
|
||||
Valid bool // Valid is true if Time is not NULL
|
||||
}
|
||||
|
||||
// Scan implements the Scanner interface.
|
||||
func (n *NullTime) Scan(value interface{}) error {
|
||||
n.Time, n.Valid = value.(time.Time)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Value implements the driver Valuer interface.
|
||||
func (n NullTime) Value() (driver.Value, error) {
|
||||
if !n.Valid {
|
||||
return nil, nil
|
||||
}
|
||||
return n.Time, nil
|
||||
}
|
|
@ -0,0 +1,414 @@
|
|||
/*
|
||||
Copyright 2014 SAP SE
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
// Package bufio implements buffered I/O for database read and writes on basis of the standard Go bufio package.
|
||||
package bufio
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"encoding/binary"
|
||||
"io"
|
||||
"math"
|
||||
|
||||
"github.com/SAP/go-hdb/internal/unicode"
|
||||
"golang.org/x/text/transform"
|
||||
)
|
||||
|
||||
// Reader is a bufio.Reader extended by methods needed for hdb protocol.
|
||||
type Reader struct {
|
||||
rd *bufio.Reader
|
||||
err error
|
||||
b [8]byte // scratch buffer (8 Bytes)
|
||||
tr transform.Transformer
|
||||
}
|
||||
|
||||
// NewReader creates a new Reader instance.
|
||||
func NewReader(r io.Reader) *Reader {
|
||||
return &Reader{
|
||||
rd: bufio.NewReader(r),
|
||||
tr: unicode.Cesu8ToUtf8Transformer,
|
||||
}
|
||||
}
|
||||
|
||||
// NewReaderSize creates a new Reader instance with given size for bufio.Reader.
|
||||
func NewReaderSize(r io.Reader, size int) *Reader {
|
||||
return &Reader{
|
||||
rd: bufio.NewReaderSize(r, size),
|
||||
tr: unicode.Cesu8ToUtf8Transformer,
|
||||
}
|
||||
}
|
||||
|
||||
// GetError returns reader error
|
||||
func (r *Reader) GetError() error {
|
||||
err := r.err
|
||||
r.err = nil
|
||||
return err
|
||||
}
|
||||
|
||||
// Skip skips cnt bytes from reading.
|
||||
func (r *Reader) Skip(cnt int) {
|
||||
if r.err != nil {
|
||||
return
|
||||
}
|
||||
_, r.err = r.rd.Discard(cnt)
|
||||
}
|
||||
|
||||
// ReadB reads and returns a byte.
|
||||
func (r *Reader) ReadB() byte { // ReadB as sig differs from ReadByte (vet issues)
|
||||
if r.err != nil {
|
||||
return 0
|
||||
}
|
||||
var b byte
|
||||
b, r.err = r.rd.ReadByte()
|
||||
return b
|
||||
}
|
||||
|
||||
// ReadFull implements io.ReadFull on Reader.
|
||||
func (r *Reader) ReadFull(p []byte) {
|
||||
if r.err != nil {
|
||||
return
|
||||
}
|
||||
_, r.err = io.ReadFull(r.rd, p)
|
||||
}
|
||||
|
||||
// ReadBool reads and returns a boolean.
|
||||
func (r *Reader) ReadBool() bool {
|
||||
if r.err != nil {
|
||||
return false
|
||||
}
|
||||
return !(r.ReadB() == 0)
|
||||
}
|
||||
|
||||
// ReadInt8 reads and returns an int8.
|
||||
func (r *Reader) ReadInt8() int8 {
|
||||
return int8(r.ReadB())
|
||||
}
|
||||
|
||||
// ReadInt16 reads and returns an int16.
|
||||
func (r *Reader) ReadInt16() int16 {
|
||||
if r.err != nil {
|
||||
return 0
|
||||
}
|
||||
if _, r.err = io.ReadFull(r.rd, r.b[:2]); r.err != nil {
|
||||
return 0
|
||||
}
|
||||
return int16(binary.LittleEndian.Uint16(r.b[:2]))
|
||||
}
|
||||
|
||||
// ReadUint16 reads and returns an uint16.
|
||||
func (r *Reader) ReadUint16() uint16 {
|
||||
if r.err != nil {
|
||||
return 0
|
||||
}
|
||||
if _, r.err = io.ReadFull(r.rd, r.b[:2]); r.err != nil {
|
||||
return 0
|
||||
}
|
||||
return binary.LittleEndian.Uint16(r.b[:2])
|
||||
}
|
||||
|
||||
// ReadInt32 reads and returns an int32.
|
||||
func (r *Reader) ReadInt32() int32 {
|
||||
if r.err != nil {
|
||||
return 0
|
||||
}
|
||||
if _, r.err = io.ReadFull(r.rd, r.b[:4]); r.err != nil {
|
||||
return 0
|
||||
}
|
||||
return int32(binary.LittleEndian.Uint32(r.b[:4]))
|
||||
}
|
||||
|
||||
// ReadUint32 reads and returns an uint32.
|
||||
func (r *Reader) ReadUint32() uint32 {
|
||||
if r.err != nil {
|
||||
return 0
|
||||
}
|
||||
if _, r.err = io.ReadFull(r.rd, r.b[:4]); r.err != nil {
|
||||
return 0
|
||||
}
|
||||
return binary.LittleEndian.Uint32(r.b[:4])
|
||||
}
|
||||
|
||||
// ReadInt64 reads and returns an int64.
|
||||
func (r *Reader) ReadInt64() int64 {
|
||||
if r.err != nil {
|
||||
return 0
|
||||
}
|
||||
if _, r.err = io.ReadFull(r.rd, r.b[:8]); r.err != nil {
|
||||
return 0
|
||||
}
|
||||
return int64(binary.LittleEndian.Uint64(r.b[:8]))
|
||||
}
|
||||
|
||||
// ReadUint64 reads and returns an uint64.
|
||||
func (r *Reader) ReadUint64() uint64 {
|
||||
if r.err != nil {
|
||||
return 0
|
||||
}
|
||||
if _, r.err = io.ReadFull(r.rd, r.b[:8]); r.err != nil {
|
||||
return 0
|
||||
}
|
||||
return binary.LittleEndian.Uint64(r.b[:8])
|
||||
}
|
||||
|
||||
// ReadFloat32 reads and returns a float32.
|
||||
func (r *Reader) ReadFloat32() float32 {
|
||||
if r.err != nil {
|
||||
return 0
|
||||
}
|
||||
if _, r.err = io.ReadFull(r.rd, r.b[:4]); r.err != nil {
|
||||
return 0
|
||||
}
|
||||
bits := binary.LittleEndian.Uint32(r.b[:4])
|
||||
return math.Float32frombits(bits)
|
||||
}
|
||||
|
||||
// ReadFloat64 reads and returns a float64.
|
||||
func (r *Reader) ReadFloat64() float64 {
|
||||
if r.err != nil {
|
||||
return 0
|
||||
}
|
||||
if _, r.err = io.ReadFull(r.rd, r.b[:8]); r.err != nil {
|
||||
return 0
|
||||
}
|
||||
bits := binary.LittleEndian.Uint64(r.b[:8])
|
||||
return math.Float64frombits(bits)
|
||||
}
|
||||
|
||||
// ReadCesu8 reads a size CESU-8 encoded byte sequence and returns an UTF-8 byte slice.
|
||||
func (r *Reader) ReadCesu8(size int) []byte {
|
||||
if r.err != nil {
|
||||
return nil
|
||||
}
|
||||
p := make([]byte, size)
|
||||
if _, r.err = io.ReadFull(r.rd, p); r.err != nil {
|
||||
return nil
|
||||
}
|
||||
r.tr.Reset()
|
||||
var n int
|
||||
if n, _, r.err = r.tr.Transform(p, p, true); r.err != nil { // inplace transformation
|
||||
return nil
|
||||
}
|
||||
return p[:n]
|
||||
}
|
||||
|
||||
const writerBufferSize = 4096
|
||||
|
||||
// Writer is a bufio.Writer extended by methods needed for hdb protocol.
|
||||
type Writer struct {
|
||||
wr *bufio.Writer
|
||||
err error
|
||||
b []byte // scratch buffer (min 8 Bytes)
|
||||
tr transform.Transformer
|
||||
}
|
||||
|
||||
// NewWriter creates a new Writer instance.
|
||||
func NewWriter(w io.Writer) *Writer {
|
||||
return &Writer{
|
||||
wr: bufio.NewWriter(w),
|
||||
b: make([]byte, writerBufferSize),
|
||||
tr: unicode.Utf8ToCesu8Transformer,
|
||||
}
|
||||
}
|
||||
|
||||
// NewWriterSize creates a new Writer instance with given size for bufio.Writer.
|
||||
func NewWriterSize(w io.Writer, size int) *Writer {
|
||||
return &Writer{
|
||||
wr: bufio.NewWriterSize(w, size),
|
||||
b: make([]byte, writerBufferSize),
|
||||
tr: unicode.Utf8ToCesu8Transformer,
|
||||
}
|
||||
}
|
||||
|
||||
// Flush writes any buffered data to the underlying io.Writer.
|
||||
func (w *Writer) Flush() error {
|
||||
if w.err != nil {
|
||||
return w.err
|
||||
}
|
||||
return w.wr.Flush()
|
||||
}
|
||||
|
||||
// WriteZeroes writes cnt zero byte values.
|
||||
func (w *Writer) WriteZeroes(cnt int) {
|
||||
if w.err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// zero out scratch area
|
||||
l := cnt
|
||||
if l > len(w.b) {
|
||||
l = len(w.b)
|
||||
}
|
||||
for i := 0; i < l; i++ {
|
||||
w.b[i] = 0
|
||||
}
|
||||
|
||||
for i := 0; i < cnt; {
|
||||
j := cnt - i
|
||||
if j > len(w.b) {
|
||||
j = len(w.b)
|
||||
}
|
||||
n, _ := w.wr.Write(w.b[:j])
|
||||
i += n
|
||||
}
|
||||
}
|
||||
|
||||
// Write writes the contents of p.
|
||||
func (w *Writer) Write(p []byte) {
|
||||
if w.err != nil {
|
||||
return
|
||||
}
|
||||
w.wr.Write(p)
|
||||
}
|
||||
|
||||
// WriteB writes a byte.
|
||||
func (w *Writer) WriteB(b byte) { // WriteB as sig differs from WriteByte (vet issues)
|
||||
if w.err != nil {
|
||||
return
|
||||
}
|
||||
w.wr.WriteByte(b)
|
||||
}
|
||||
|
||||
// WriteBool writes a boolean.
|
||||
func (w *Writer) WriteBool(v bool) {
|
||||
if w.err != nil {
|
||||
return
|
||||
}
|
||||
if v {
|
||||
w.wr.WriteByte(1)
|
||||
} else {
|
||||
w.wr.WriteByte(0)
|
||||
}
|
||||
}
|
||||
|
||||
// WriteInt8 writes an int8.
|
||||
func (w *Writer) WriteInt8(i int8) {
|
||||
if w.err != nil {
|
||||
return
|
||||
}
|
||||
w.wr.WriteByte(byte(i))
|
||||
}
|
||||
|
||||
// WriteInt16 writes an int16.
|
||||
func (w *Writer) WriteInt16(i int16) {
|
||||
if w.err != nil {
|
||||
return
|
||||
}
|
||||
binary.LittleEndian.PutUint16(w.b[:2], uint16(i))
|
||||
w.wr.Write(w.b[:2])
|
||||
}
|
||||
|
||||
// WriteUint16 writes an uint16.
|
||||
func (w *Writer) WriteUint16(i uint16) {
|
||||
if w.err != nil {
|
||||
return
|
||||
}
|
||||
binary.LittleEndian.PutUint16(w.b[:2], i)
|
||||
w.wr.Write(w.b[:2])
|
||||
}
|
||||
|
||||
// WriteInt32 writes an int32.
|
||||
func (w *Writer) WriteInt32(i int32) {
|
||||
if w.err != nil {
|
||||
return
|
||||
}
|
||||
binary.LittleEndian.PutUint32(w.b[:4], uint32(i))
|
||||
w.wr.Write(w.b[:4])
|
||||
}
|
||||
|
||||
// WriteUint32 writes an uint32.
|
||||
func (w *Writer) WriteUint32(i uint32) {
|
||||
if w.err != nil {
|
||||
return
|
||||
}
|
||||
binary.LittleEndian.PutUint32(w.b[:4], i)
|
||||
w.wr.Write(w.b[:4])
|
||||
}
|
||||
|
||||
// WriteInt64 writes an int64.
|
||||
func (w *Writer) WriteInt64(i int64) {
|
||||
if w.err != nil {
|
||||
return
|
||||
}
|
||||
binary.LittleEndian.PutUint64(w.b[:8], uint64(i))
|
||||
w.wr.Write(w.b[:8])
|
||||
}
|
||||
|
||||
// WriteUint64 writes an uint64.
|
||||
func (w *Writer) WriteUint64(i uint64) {
|
||||
if w.err != nil {
|
||||
return
|
||||
}
|
||||
binary.LittleEndian.PutUint64(w.b[:8], i)
|
||||
w.wr.Write(w.b[:8])
|
||||
}
|
||||
|
||||
// WriteFloat32 writes a float32.
|
||||
func (w *Writer) WriteFloat32(f float32) {
|
||||
if w.err != nil {
|
||||
return
|
||||
}
|
||||
bits := math.Float32bits(f)
|
||||
binary.LittleEndian.PutUint32(w.b[:4], bits)
|
||||
w.wr.Write(w.b[:4])
|
||||
}
|
||||
|
||||
// WriteFloat64 writes a float64.
|
||||
func (w *Writer) WriteFloat64(f float64) {
|
||||
if w.err != nil {
|
||||
return
|
||||
}
|
||||
bits := math.Float64bits(f)
|
||||
binary.LittleEndian.PutUint64(w.b[:8], bits)
|
||||
w.wr.Write(w.b[:8])
|
||||
}
|
||||
|
||||
// WriteString writes a string.
|
||||
func (w *Writer) WriteString(s string) {
|
||||
if w.err != nil {
|
||||
return
|
||||
}
|
||||
w.wr.WriteString(s)
|
||||
}
|
||||
|
||||
// WriteCesu8 writes an UTF-8 byte slice as CESU-8 and returns the CESU-8 bytes written.
|
||||
func (w *Writer) WriteCesu8(p []byte) int {
|
||||
if w.err != nil {
|
||||
return 0
|
||||
}
|
||||
w.tr.Reset()
|
||||
cnt := 0
|
||||
i := 0
|
||||
for i < len(p) {
|
||||
m, n, err := w.tr.Transform(w.b, p[i:], true)
|
||||
if err != nil && err != transform.ErrShortDst {
|
||||
w.err = err
|
||||
return cnt
|
||||
}
|
||||
if m == 0 {
|
||||
w.err = transform.ErrShortDst
|
||||
return cnt
|
||||
}
|
||||
o, _ := w.wr.Write(w.b[:m])
|
||||
cnt += o
|
||||
i += n
|
||||
}
|
||||
return cnt
|
||||
}
|
||||
|
||||
// WriteStringCesu8 is like WriteCesu8 with an UTF-8 string as parameter.
|
||||
func (w *Writer) WriteStringCesu8(s string) int {
|
||||
return w.WriteCesu8([]byte(s))
|
||||
}
|
|
@ -0,0 +1,55 @@
|
|||
/*
|
||||
Copyright 2014 SAP SE
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package protocol
|
||||
|
||||
import (
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/SAP/go-hdb/internal/bufio"
|
||||
)
|
||||
|
||||
type clientID []byte
|
||||
|
||||
func newClientID() clientID {
|
||||
if h, err := os.Hostname(); err == nil {
|
||||
return clientID(strings.Join([]string{strconv.Itoa(os.Getpid()), h}, "@"))
|
||||
}
|
||||
return clientID(strconv.Itoa(os.Getpid()))
|
||||
}
|
||||
|
||||
func (id clientID) kind() partKind {
|
||||
return partKind(pkClientID)
|
||||
}
|
||||
|
||||
func (id clientID) size() (int, error) {
|
||||
return len(id), nil
|
||||
}
|
||||
|
||||
func (id clientID) numArg() int {
|
||||
return 1
|
||||
}
|
||||
|
||||
func (id clientID) write(wr *bufio.Writer) error {
|
||||
wr.Write(id)
|
||||
|
||||
if trace {
|
||||
outLogger.Printf("client id: %s", id)
|
||||
}
|
||||
return nil
|
||||
}
|
|
@ -0,0 +1,47 @@
|
|||
/*
|
||||
Copyright 2014 SAP SE
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package protocol
|
||||
|
||||
import (
|
||||
"github.com/SAP/go-hdb/internal/bufio"
|
||||
"github.com/SAP/go-hdb/internal/unicode/cesu8"
|
||||
)
|
||||
|
||||
// cesu8 command
|
||||
type command []byte
|
||||
|
||||
func (c command) kind() partKind {
|
||||
return pkCommand
|
||||
}
|
||||
|
||||
func (c command) size() (int, error) {
|
||||
return cesu8.Size(c), nil
|
||||
}
|
||||
|
||||
func (c command) numArg() int {
|
||||
return 1
|
||||
}
|
||||
|
||||
func (c command) write(wr *bufio.Writer) error {
|
||||
wr.WriteCesu8(c)
|
||||
|
||||
if trace {
|
||||
outLogger.Printf("command: %s", c)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
|
@ -0,0 +1,57 @@
|
|||
/*
|
||||
Copyright 2014 SAP SE
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package protocol
|
||||
|
||||
//go:generate stringer -type=connectOption
|
||||
|
||||
type connectOption int8
|
||||
|
||||
const (
|
||||
coConnectionID connectOption = 1
|
||||
coCompleteArrayExecution connectOption = 2
|
||||
coClientLocale connectOption = 3
|
||||
coSupportsLargeBulkOperations connectOption = 4
|
||||
// duplicate in docu: coDataFormatVersion2 connectOption = 5
|
||||
// 6-9 reserved: do not use
|
||||
coLargeNumberOfParameterSupport connectOption = 10
|
||||
coSystemID connectOption = 11
|
||||
// 12 reserved: do not use
|
||||
coAbapVarcharMode connectOption = 13
|
||||
coSelectForUpdateSupported connectOption = 14
|
||||
coClientDistributionMode connectOption = 15
|
||||
coEngineDataFormatVersion connectOption = 16
|
||||
coDistributionProtocolVersion connectOption = 17
|
||||
coSplitBatchCommands connectOption = 18
|
||||
coUseTransactionFlagsOnly connectOption = 19
|
||||
//coRowAndColumnOptimizedFormat connectOption = 20 reserved: do not use
|
||||
coIgnoreUnknownParts connectOption = 21
|
||||
coTableOutputParameter connectOption = 22
|
||||
coDataFormatVersion2 connectOption = 23
|
||||
coItabParameter connectOption = 24
|
||||
coDescribeTableOutputParameter connectOption = 25
|
||||
coColumnarResultset connectOption = 26
|
||||
coScrollablResultSet connectOption = 27
|
||||
coClientInfoNullValueSupported connectOption = 28
|
||||
coAssociatedConnectionID connectOption = 29
|
||||
coNoTransactionalPrepare connectOption = 30
|
||||
coFDAEnabled connectOption = 31
|
||||
coOSUser connectOption = 32
|
||||
coRowslotImageResult connectOption = 33
|
||||
coEndianess connectOption = 34
|
||||
// 35, 36 reserved: do not use
|
||||
coImplicitLobStreaming connectOption = 37
|
||||
)
|
41
vendor/github.com/SAP/go-hdb/internal/protocol/connectoption_string.go
generated
vendored
Normal file
41
vendor/github.com/SAP/go-hdb/internal/protocol/connectoption_string.go
generated
vendored
Normal file
|
@ -0,0 +1,41 @@
|
|||
// Code generated by "stringer -type=connectOption"; DO NOT EDIT.
|
||||
|
||||
package protocol
|
||||
|
||||
import "strconv"
|
||||
|
||||
const (
|
||||
_connectOption_name_0 = "coConnectionIDcoCompleteArrayExecutioncoClientLocalecoSupportsLargeBulkOperations"
|
||||
_connectOption_name_1 = "coLargeNumberOfParameterSupportcoSystemID"
|
||||
_connectOption_name_2 = "coAbapVarcharModecoSelectForUpdateSupportedcoClientDistributionModecoEngineDataFormatVersioncoDistributionProtocolVersioncoSplitBatchCommandscoUseTransactionFlagsOnly"
|
||||
_connectOption_name_3 = "coIgnoreUnknownPartscoTableOutputParametercoDataFormatVersion2coItabParametercoDescribeTableOutputParametercoColumnarResultsetcoScrollablResultSetcoClientInfoNullValueSupportedcoAssociatedConnectionIDcoNoTransactionalPreparecoFDAEnabledcoOSUsercoRowslotImageResultcoEndianess"
|
||||
_connectOption_name_4 = "coImplicitLobStreaming"
|
||||
)
|
||||
|
||||
var (
|
||||
_connectOption_index_0 = [...]uint8{0, 14, 38, 52, 81}
|
||||
_connectOption_index_1 = [...]uint8{0, 31, 41}
|
||||
_connectOption_index_2 = [...]uint8{0, 17, 43, 67, 92, 121, 141, 166}
|
||||
_connectOption_index_3 = [...]uint16{0, 20, 42, 62, 77, 107, 126, 146, 176, 200, 224, 236, 244, 264, 275}
|
||||
)
|
||||
|
||||
func (i connectOption) String() string {
|
||||
switch {
|
||||
case 1 <= i && i <= 4:
|
||||
i -= 1
|
||||
return _connectOption_name_0[_connectOption_index_0[i]:_connectOption_index_0[i+1]]
|
||||
case 10 <= i && i <= 11:
|
||||
i -= 10
|
||||
return _connectOption_name_1[_connectOption_index_1[i]:_connectOption_index_1[i+1]]
|
||||
case 13 <= i && i <= 19:
|
||||
i -= 13
|
||||
return _connectOption_name_2[_connectOption_index_2[i]:_connectOption_index_2[i+1]]
|
||||
case 21 <= i && i <= 34:
|
||||
i -= 21
|
||||
return _connectOption_name_3[_connectOption_index_3[i]:_connectOption_index_3[i+1]]
|
||||
case i == 37:
|
||||
return _connectOption_name_4
|
||||
default:
|
||||
return "connectOption(" + strconv.FormatInt(int64(i), 10) + ")"
|
||||
}
|
||||
}
|
|
@ -0,0 +1,109 @@
|
|||
/*
|
||||
Copyright 2014 SAP SE
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package protocol
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/SAP/go-hdb/internal/bufio"
|
||||
)
|
||||
|
||||
// data format version
|
||||
const (
|
||||
dfvBaseline intType = 1
|
||||
dfvDoNotUse intType = 3
|
||||
dfvSPS06 intType = 4 //see docu
|
||||
dfvBINTEXT intType = 6
|
||||
)
|
||||
|
||||
// client distribution mode
|
||||
const (
|
||||
cdmOff intType = 0
|
||||
cdmConnection = 1
|
||||
cdmStatement = 2
|
||||
cdmConnectionStatement = 3
|
||||
)
|
||||
|
||||
// distribution protocol version
|
||||
const (
|
||||
dpvBaseline = 0
|
||||
dpvClientHandlesStatementSequence = 1
|
||||
)
|
||||
|
||||
type connectOptions struct {
|
||||
po plainOptions
|
||||
_numArg int
|
||||
}
|
||||
|
||||
func newConnectOptions() *connectOptions {
|
||||
return &connectOptions{
|
||||
po: plainOptions{},
|
||||
}
|
||||
}
|
||||
|
||||
func (o *connectOptions) String() string {
|
||||
m := make(map[connectOption]interface{})
|
||||
for k, v := range o.po {
|
||||
m[connectOption(k)] = v
|
||||
}
|
||||
return fmt.Sprintf("%s", m)
|
||||
}
|
||||
|
||||
func (o *connectOptions) kind() partKind {
|
||||
return pkConnectOptions
|
||||
}
|
||||
|
||||
func (o *connectOptions) size() (int, error) {
|
||||
return o.po.size(), nil
|
||||
}
|
||||
|
||||
func (o *connectOptions) numArg() int {
|
||||
return len(o.po)
|
||||
}
|
||||
|
||||
func (o *connectOptions) setNumArg(numArg int) {
|
||||
o._numArg = numArg
|
||||
}
|
||||
|
||||
func (o *connectOptions) set(k connectOption, v interface{}) {
|
||||
o.po[int8(k)] = v
|
||||
}
|
||||
|
||||
func (o *connectOptions) get(k connectOption) (interface{}, bool) {
|
||||
v, ok := o.po[int8(k)]
|
||||
return v, ok
|
||||
}
|
||||
|
||||
func (o *connectOptions) read(rd *bufio.Reader) error {
|
||||
o.po.read(rd, o._numArg)
|
||||
|
||||
if trace {
|
||||
outLogger.Printf("connect options: %v", o)
|
||||
}
|
||||
|
||||
return rd.GetError()
|
||||
}
|
||||
|
||||
func (o *connectOptions) write(wr *bufio.Writer) error {
|
||||
o.po.write(wr)
|
||||
|
||||
if trace {
|
||||
outLogger.Printf("connect options: %v", o)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
|
@ -0,0 +1,38 @@
|
|||
/*
|
||||
Copyright 2014 SAP SE
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package protocol
|
||||
|
||||
//go:generate stringer -type=DataType
|
||||
|
||||
// DataType is the type definition for data types supported by this package.
|
||||
type DataType byte
|
||||
|
||||
// Data type constants.
|
||||
const (
|
||||
DtUnknown DataType = iota // unknown data type
|
||||
DtTinyint
|
||||
DtSmallint
|
||||
DtInteger
|
||||
DtBigint
|
||||
DtReal
|
||||
DtDouble
|
||||
DtDecimal
|
||||
DtTime
|
||||
DtString
|
||||
DtBytes
|
||||
DtLob
|
||||
)
|
|
@ -0,0 +1,16 @@
|
|||
// Code generated by "stringer -type=DataType"; DO NOT EDIT.
|
||||
|
||||
package protocol
|
||||
|
||||
import "strconv"
|
||||
|
||||
const _DataType_name = "DtUnknownDtTinyintDtSmallintDtIntegerDtBigintDtRealDtDoubleDtDecimalDtTimeDtStringDtBytesDtLob"
|
||||
|
||||
var _DataType_index = [...]uint8{0, 9, 18, 28, 37, 45, 51, 59, 68, 74, 82, 89, 94}
|
||||
|
||||
func (i DataType) String() string {
|
||||
if i >= DataType(len(_DataType_index)-1) {
|
||||
return "DataType(" + strconv.FormatInt(int64(i), 10) + ")"
|
||||
}
|
||||
return _DataType_name[_DataType_index[i]:_DataType_index[i+1]]
|
||||
}
|
|
@ -0,0 +1,20 @@
|
|||
/*
|
||||
Copyright 2014 SAP SE
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
// Package protocol implements the hdb command network protocol.
|
||||
//
|
||||
// http://help.sap.com/hana/SAP_HANA_SQL_Command_Network_Protocol_Reference_en.pdf
|
||||
package protocol
|
|
@ -0,0 +1,26 @@
|
|||
/*
|
||||
Copyright 2014 SAP SE
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package protocol
|
||||
|
||||
//go:generate stringer -type=endianess
|
||||
|
||||
type endianess int8
|
||||
|
||||
const (
|
||||
bigEndian endianess = 0
|
||||
littleEndian endianess = 1
|
||||
)
|
16
vendor/github.com/SAP/go-hdb/internal/protocol/endianess_string.go
generated
vendored
Normal file
16
vendor/github.com/SAP/go-hdb/internal/protocol/endianess_string.go
generated
vendored
Normal file
|
@ -0,0 +1,16 @@
|
|||
// Code generated by "stringer -type=endianess"; DO NOT EDIT.
|
||||
|
||||
package protocol
|
||||
|
||||
import "strconv"
|
||||
|
||||
const _endianess_name = "bigEndianlittleEndian"
|
||||
|
||||
var _endianess_index = [...]uint8{0, 9, 21}
|
||||
|
||||
func (i endianess) String() string {
|
||||
if i < 0 || i >= endianess(len(_endianess_index)-1) {
|
||||
return "endianess(" + strconv.FormatInt(int64(i), 10) + ")"
|
||||
}
|
||||
return _endianess_name[_endianess_index[i]:_endianess_index[i+1]]
|
||||
}
|
|
@ -0,0 +1,204 @@
|
|||
/*
|
||||
Copyright 2014 SAP SE
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package protocol
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/SAP/go-hdb/internal/bufio"
|
||||
)
|
||||
|
||||
const (
|
||||
sqlStateSize = 5
|
||||
//bytes of fix length fields mod 8
|
||||
// - errorCode = 4, errorPosition = 4, errortextLength = 4, errorLevel = 1, sqlState = 5 => 18 bytes
|
||||
// - 18 mod 8 = 2
|
||||
fixLength = 2
|
||||
)
|
||||
|
||||
type sqlState [sqlStateSize]byte
|
||||
|
||||
type hdbError struct {
|
||||
errorCode int32
|
||||
errorPosition int32
|
||||
errorTextLength int32
|
||||
errorLevel errorLevel
|
||||
sqlState sqlState
|
||||
stmtNo int
|
||||
errorText []byte
|
||||
}
|
||||
|
||||
// String implements the Stringer interface.
|
||||
func (e *hdbError) String() string {
|
||||
return fmt.Sprintf("errorCode %d, errorPosition %d, errorTextLength % d errorLevel %s, sqlState %s stmtNo %d errorText %s",
|
||||
e.errorCode,
|
||||
e.errorPosition,
|
||||
e.errorTextLength,
|
||||
e.errorLevel,
|
||||
e.sqlState,
|
||||
e.stmtNo,
|
||||
e.errorText,
|
||||
)
|
||||
}
|
||||
|
||||
// Error implements the Error interface.
|
||||
func (e *hdbError) Error() string {
|
||||
if e.stmtNo != -1 {
|
||||
return fmt.Sprintf("SQL %s %d - %s (statement no: %d)", e.errorLevel, e.errorCode, e.errorText, e.stmtNo)
|
||||
}
|
||||
return fmt.Sprintf("SQL %s %d - %s", e.errorLevel, e.errorCode, e.errorText)
|
||||
}
|
||||
|
||||
type hdbErrors struct {
|
||||
errors []*hdbError
|
||||
numArg int
|
||||
idx int
|
||||
}
|
||||
|
||||
// String implements the Stringer interface.
|
||||
func (e *hdbErrors) String() string {
|
||||
return e.errors[e.idx].String()
|
||||
}
|
||||
|
||||
// Error implements the golang error interface.
|
||||
func (e *hdbErrors) Error() string {
|
||||
return e.errors[e.idx].Error()
|
||||
}
|
||||
|
||||
// NumError implements the driver.Error interface.
|
||||
func (e *hdbErrors) NumError() int {
|
||||
return e.numArg
|
||||
}
|
||||
|
||||
// SetIdx implements the driver.Error interface.
|
||||
func (e *hdbErrors) SetIdx(idx int) {
|
||||
switch {
|
||||
case idx < 0:
|
||||
e.idx = 0
|
||||
case idx >= e.numArg:
|
||||
e.idx = e.numArg - 1
|
||||
default:
|
||||
e.idx = idx
|
||||
}
|
||||
}
|
||||
|
||||
// StmtNo implements the driver.Error interface.
|
||||
func (e *hdbErrors) StmtNo() int {
|
||||
return e.errors[e.idx].stmtNo
|
||||
}
|
||||
|
||||
// Code implements the driver.Error interface.
|
||||
func (e *hdbErrors) Code() int {
|
||||
return int(e.errors[e.idx].errorCode)
|
||||
}
|
||||
|
||||
// Position implements the driver.Error interface.
|
||||
func (e *hdbErrors) Position() int {
|
||||
return int(e.errors[e.idx].errorPosition)
|
||||
}
|
||||
|
||||
// Level implements the driver.Error interface.
|
||||
func (e *hdbErrors) Level() int {
|
||||
return int(e.errors[e.idx].errorLevel)
|
||||
}
|
||||
|
||||
// Text implements the driver.Error interface.
|
||||
func (e *hdbErrors) Text() string {
|
||||
return string(e.errors[e.idx].errorText)
|
||||
}
|
||||
|
||||
// IsWarning implements the driver.Error interface.
|
||||
func (e *hdbErrors) IsWarning() bool {
|
||||
return e.errors[e.idx].errorLevel == errorLevelWarning
|
||||
}
|
||||
|
||||
// IsError implements the driver.Error interface.
|
||||
func (e *hdbErrors) IsError() bool {
|
||||
return e.errors[e.idx].errorLevel == errorLevelError
|
||||
}
|
||||
|
||||
// IsFatal implements the driver.Error interface.
|
||||
func (e *hdbErrors) IsFatal() bool {
|
||||
return e.errors[e.idx].errorLevel == errorLevelFatalError
|
||||
}
|
||||
|
||||
func (e *hdbErrors) setStmtNo(idx, no int) {
|
||||
if idx >= 0 && idx < e.numArg {
|
||||
e.errors[idx].stmtNo = no
|
||||
}
|
||||
}
|
||||
|
||||
func (e *hdbErrors) isWarnings() bool {
|
||||
for _, _error := range e.errors {
|
||||
if _error.errorLevel != errorLevelWarning {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (e *hdbErrors) kind() partKind {
|
||||
return pkError
|
||||
}
|
||||
|
||||
func (e *hdbErrors) setNumArg(numArg int) {
|
||||
e.numArg = numArg
|
||||
}
|
||||
|
||||
func (e *hdbErrors) read(rd *bufio.Reader) error {
|
||||
e.idx = 0 // init error index
|
||||
|
||||
if e.errors == nil || e.numArg > cap(e.errors) {
|
||||
e.errors = make([]*hdbError, e.numArg)
|
||||
} else {
|
||||
e.errors = e.errors[:e.numArg]
|
||||
}
|
||||
|
||||
for i := 0; i < e.numArg; i++ {
|
||||
_error := e.errors[i]
|
||||
if _error == nil {
|
||||
_error = new(hdbError)
|
||||
e.errors[i] = _error
|
||||
}
|
||||
|
||||
_error.stmtNo = -1
|
||||
_error.errorCode = rd.ReadInt32()
|
||||
_error.errorPosition = rd.ReadInt32()
|
||||
_error.errorTextLength = rd.ReadInt32()
|
||||
_error.errorLevel = errorLevel(rd.ReadInt8())
|
||||
rd.ReadFull(_error.sqlState[:])
|
||||
|
||||
// read error text as ASCII data as some errors return invalid CESU-8 characters
|
||||
// e.g: SQL HdbError 7 - feature not supported: invalid character encoding: <invaid CESU-8 characters>
|
||||
// if e.errorText, err = rd.ReadCesu8(int(e.errorTextLength)); err != nil {
|
||||
// return err
|
||||
// }
|
||||
_error.errorText = make([]byte, int(_error.errorTextLength))
|
||||
rd.ReadFull(_error.errorText)
|
||||
|
||||
if trace {
|
||||
outLogger.Printf("error %d: %s", i, _error)
|
||||
}
|
||||
|
||||
pad := padBytes(int(fixLength + _error.errorTextLength))
|
||||
if pad != 0 {
|
||||
rd.Skip(pad)
|
||||
}
|
||||
}
|
||||
|
||||
return rd.GetError()
|
||||
}
|
|
@ -0,0 +1,40 @@
|
|||
/*
|
||||
Copyright 2014 SAP SE
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package protocol
|
||||
|
||||
// ErrorLevel send from database server.
|
||||
type errorLevel int8
|
||||
|
||||
func (e errorLevel) String() string {
|
||||
switch e {
|
||||
case 0:
|
||||
return "Warning"
|
||||
case 1:
|
||||
return "Error"
|
||||
case 2:
|
||||
return "Fatal Error"
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
// HDB error level constants.
|
||||
const (
|
||||
errorLevelWarning errorLevel = 0
|
||||
errorLevelError errorLevel = 1
|
||||
errorLevelFatalError errorLevel = 2
|
||||
)
|
|
@ -0,0 +1,46 @@
|
|||
/*
|
||||
Copyright 2014 SAP SE
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package protocol
|
||||
|
||||
import (
|
||||
"github.com/SAP/go-hdb/internal/bufio"
|
||||
)
|
||||
|
||||
//fetch size
|
||||
type fetchsize int32
|
||||
|
||||
func (s fetchsize) kind() partKind {
|
||||
return pkFetchSize
|
||||
}
|
||||
|
||||
func (s fetchsize) size() (int, error) {
|
||||
return 4, nil
|
||||
}
|
||||
|
||||
func (s fetchsize) numArg() int {
|
||||
return 1
|
||||
}
|
||||
|
||||
func (s fetchsize) write(wr *bufio.Writer) error {
|
||||
wr.WriteInt32(int32(s))
|
||||
|
||||
if trace {
|
||||
outLogger.Printf("fetchsize: %d", s)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
|
@ -0,0 +1,774 @@
|
|||
/*
|
||||
Copyright 2014 SAP SE
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package protocol
|
||||
|
||||
import (
|
||||
"database/sql/driver"
|
||||
"fmt"
|
||||
"math"
|
||||
"sort"
|
||||
"time"
|
||||
|
||||
"github.com/SAP/go-hdb/internal/bufio"
|
||||
"github.com/SAP/go-hdb/internal/unicode/cesu8"
|
||||
)
|
||||
|
||||
var test uint32
|
||||
|
||||
const (
|
||||
realNullValue uint32 = ^uint32(0)
|
||||
doubleNullValue uint64 = ^uint64(0)
|
||||
)
|
||||
|
||||
const noFieldName uint32 = 0xFFFFFFFF
|
||||
|
||||
type uint32Slice []uint32
|
||||
|
||||
func (p uint32Slice) Len() int { return len(p) }
|
||||
func (p uint32Slice) Less(i, j int) bool { return p[i] < p[j] }
|
||||
func (p uint32Slice) Swap(i, j int) { p[i], p[j] = p[j], p[i] }
|
||||
func (p uint32Slice) sort() { sort.Sort(p) }
|
||||
|
||||
type fieldNames map[uint32]string
|
||||
|
||||
func newFieldNames() fieldNames {
|
||||
return make(map[uint32]string)
|
||||
}
|
||||
|
||||
func (f fieldNames) addOffset(offset uint32) {
|
||||
if offset != noFieldName {
|
||||
f[offset] = ""
|
||||
}
|
||||
}
|
||||
|
||||
func (f fieldNames) name(offset uint32) string {
|
||||
if name, ok := f[offset]; ok {
|
||||
return name
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (f fieldNames) setName(offset uint32, name string) {
|
||||
f[offset] = name
|
||||
}
|
||||
|
||||
func (f fieldNames) sortOffsets() []uint32 {
|
||||
offsets := make([]uint32, 0, len(f))
|
||||
for k := range f {
|
||||
offsets = append(offsets, k)
|
||||
}
|
||||
uint32Slice(offsets).sort()
|
||||
return offsets
|
||||
}
|
||||
|
||||
// FieldValues contains rows read from database.
|
||||
type FieldValues struct {
|
||||
rows int
|
||||
cols int
|
||||
values []driver.Value
|
||||
}
|
||||
|
||||
func newFieldValues() *FieldValues {
|
||||
return &FieldValues{}
|
||||
}
|
||||
|
||||
func (f *FieldValues) String() string {
|
||||
return fmt.Sprintf("rows %d columns %d", f.rows, f.cols)
|
||||
}
|
||||
|
||||
func (f *FieldValues) resize(rows, cols int) {
|
||||
f.rows, f.cols = rows, cols
|
||||
f.values = make([]driver.Value, rows*cols)
|
||||
}
|
||||
|
||||
// NumRow returns the number of rows available in FieldValues.
|
||||
func (f *FieldValues) NumRow() int {
|
||||
return f.rows
|
||||
}
|
||||
|
||||
// Row fills the dest value slice with row data at index idx.
|
||||
func (f *FieldValues) Row(idx int, dest []driver.Value) {
|
||||
copy(dest, f.values[idx*f.cols:(idx+1)*f.cols])
|
||||
}
|
||||
|
||||
const (
|
||||
tinyintFieldSize = 1
|
||||
smallintFieldSize = 2
|
||||
intFieldSize = 4
|
||||
bigintFieldSize = 8
|
||||
realFieldSize = 4
|
||||
doubleFieldSize = 8
|
||||
dateFieldSize = 4
|
||||
timeFieldSize = 4
|
||||
timestampFieldSize = dateFieldSize + timeFieldSize
|
||||
longdateFieldSize = 8
|
||||
seconddateFieldSize = 8
|
||||
daydateFieldSize = 4
|
||||
secondtimeFieldSize = 4
|
||||
decimalFieldSize = 16
|
||||
lobInputDescriptorSize = 9
|
||||
)
|
||||
|
||||
func fieldSize(tc TypeCode, arg driver.NamedValue) (int, error) {
|
||||
v := arg.Value
|
||||
|
||||
if v == nil { //HDB bug: secondtime null value --> see writeField
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
switch tc {
|
||||
case tcTinyint:
|
||||
return tinyintFieldSize, nil
|
||||
case tcSmallint:
|
||||
return smallintFieldSize, nil
|
||||
case tcInteger:
|
||||
return intFieldSize, nil
|
||||
case tcBigint:
|
||||
return bigintFieldSize, nil
|
||||
case tcReal:
|
||||
return realFieldSize, nil
|
||||
case tcDouble:
|
||||
return doubleFieldSize, nil
|
||||
case tcDate:
|
||||
return dateFieldSize, nil
|
||||
case tcTime:
|
||||
return timeFieldSize, nil
|
||||
case tcTimestamp:
|
||||
return timestampFieldSize, nil
|
||||
case tcLongdate:
|
||||
return longdateFieldSize, nil
|
||||
case tcSeconddate:
|
||||
return seconddateFieldSize, nil
|
||||
case tcDaydate:
|
||||
return daydateFieldSize, nil
|
||||
case tcSecondtime:
|
||||
return secondtimeFieldSize, nil
|
||||
case tcDecimal:
|
||||
return decimalFieldSize, nil
|
||||
case tcChar, tcVarchar, tcString:
|
||||
switch v := v.(type) {
|
||||
case []byte:
|
||||
return bytesSize(len(v))
|
||||
case string:
|
||||
return bytesSize(len(v))
|
||||
default:
|
||||
outLogger.Fatalf("data type %s mismatch %T", tc, v)
|
||||
}
|
||||
case tcNchar, tcNvarchar, tcNstring:
|
||||
switch v := v.(type) {
|
||||
case []byte:
|
||||
return bytesSize(cesu8.Size(v))
|
||||
case string:
|
||||
return bytesSize(cesu8.StringSize(v))
|
||||
default:
|
||||
outLogger.Fatalf("data type %s mismatch %T", tc, v)
|
||||
}
|
||||
case tcBinary, tcVarbinary:
|
||||
v, ok := v.([]byte)
|
||||
if !ok {
|
||||
outLogger.Fatalf("data type %s mismatch %T", tc, v)
|
||||
}
|
||||
return bytesSize(len(v))
|
||||
case tcBlob, tcClob, tcNclob:
|
||||
return lobInputDescriptorSize, nil
|
||||
}
|
||||
outLogger.Fatalf("data type %s not implemented", tc)
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func readField(session *Session, rd *bufio.Reader, tc TypeCode) (interface{}, error) {
|
||||
|
||||
switch tc {
|
||||
|
||||
case tcTinyint, tcSmallint, tcInteger, tcBigint:
|
||||
|
||||
if !rd.ReadBool() { //null value
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
switch tc {
|
||||
case tcTinyint:
|
||||
return int64(rd.ReadB()), nil
|
||||
case tcSmallint:
|
||||
return int64(rd.ReadInt16()), nil
|
||||
case tcInteger:
|
||||
return int64(rd.ReadInt32()), nil
|
||||
case tcBigint:
|
||||
return rd.ReadInt64(), nil
|
||||
}
|
||||
|
||||
case tcReal:
|
||||
v := rd.ReadUint32()
|
||||
if v == realNullValue {
|
||||
return nil, nil
|
||||
}
|
||||
return float64(math.Float32frombits(v)), nil
|
||||
|
||||
case tcDouble:
|
||||
v := rd.ReadUint64()
|
||||
if v == doubleNullValue {
|
||||
return nil, nil
|
||||
}
|
||||
return math.Float64frombits(v), nil
|
||||
|
||||
case tcDate:
|
||||
year, month, day, null := readDate(rd)
|
||||
if null {
|
||||
return nil, nil
|
||||
}
|
||||
return time.Date(year, month, day, 0, 0, 0, 0, time.UTC), nil
|
||||
|
||||
// time read gives only seconds (cut), no milliseconds
|
||||
case tcTime:
|
||||
hour, minute, nanosecs, null := readTime(rd)
|
||||
if null {
|
||||
return nil, nil
|
||||
}
|
||||
return time.Date(1, 1, 1, hour, minute, 0, nanosecs, time.UTC), nil
|
||||
|
||||
case tcTimestamp:
|
||||
year, month, day, dateNull := readDate(rd)
|
||||
hour, minute, nanosecs, timeNull := readTime(rd)
|
||||
if dateNull || timeNull {
|
||||
return nil, nil
|
||||
}
|
||||
return time.Date(year, month, day, hour, minute, 0, nanosecs, time.UTC), nil
|
||||
|
||||
case tcLongdate:
|
||||
time, null := readLongdate(rd)
|
||||
if null {
|
||||
return nil, nil
|
||||
}
|
||||
return time, nil
|
||||
|
||||
case tcSeconddate:
|
||||
time, null := readSeconddate(rd)
|
||||
if null {
|
||||
return nil, nil
|
||||
}
|
||||
return time, nil
|
||||
|
||||
case tcDaydate:
|
||||
time, null := readDaydate(rd)
|
||||
if null {
|
||||
return nil, nil
|
||||
}
|
||||
return time, nil
|
||||
|
||||
case tcSecondtime:
|
||||
time, null := readSecondtime(rd)
|
||||
if null {
|
||||
return nil, nil
|
||||
}
|
||||
return time, nil
|
||||
|
||||
case tcDecimal:
|
||||
b, null := readDecimal(rd)
|
||||
if null {
|
||||
return nil, nil
|
||||
}
|
||||
return b, nil
|
||||
|
||||
case tcChar, tcVarchar:
|
||||
value, null := readBytes(rd)
|
||||
if null {
|
||||
return nil, nil
|
||||
}
|
||||
return value, nil
|
||||
|
||||
case tcNchar, tcNvarchar:
|
||||
value, null := readUtf8(rd)
|
||||
if null {
|
||||
return nil, nil
|
||||
}
|
||||
return value, nil
|
||||
|
||||
case tcBinary, tcVarbinary:
|
||||
value, null := readBytes(rd)
|
||||
if null {
|
||||
return nil, nil
|
||||
}
|
||||
return value, nil
|
||||
|
||||
case tcBlob, tcClob, tcNclob:
|
||||
null, writer, err := readLob(session, rd, tc)
|
||||
if null {
|
||||
return nil, nil
|
||||
}
|
||||
return writer, err
|
||||
}
|
||||
|
||||
outLogger.Fatalf("read field: type code %s not implemented", tc)
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func writeField(wr *bufio.Writer, tc TypeCode, arg driver.NamedValue) error {
|
||||
v := arg.Value
|
||||
//HDB bug: secondtime null value cannot be set by setting high byte
|
||||
// trying so, gives
|
||||
// SQL HdbError 1033 - error while parsing protocol: no such data type: type_code=192, index=2
|
||||
|
||||
// null value
|
||||
//if v == nil && tc != tcSecondtime
|
||||
if v == nil {
|
||||
wr.WriteB(byte(tc) | 0x80) //set high bit
|
||||
return nil
|
||||
}
|
||||
|
||||
// type code
|
||||
wr.WriteB(byte(tc))
|
||||
|
||||
switch tc {
|
||||
|
||||
default:
|
||||
outLogger.Fatalf("write field: type code %s not implemented", tc)
|
||||
|
||||
case tcTinyint, tcSmallint, tcInteger, tcBigint:
|
||||
var i64 int64
|
||||
|
||||
switch v := v.(type) {
|
||||
default:
|
||||
return fmt.Errorf("invalid argument type %T", v)
|
||||
|
||||
case bool:
|
||||
if v {
|
||||
i64 = 1
|
||||
} else {
|
||||
i64 = 0
|
||||
}
|
||||
case int64:
|
||||
i64 = v
|
||||
}
|
||||
|
||||
switch tc {
|
||||
case tcTinyint:
|
||||
wr.WriteB(byte(i64))
|
||||
case tcSmallint:
|
||||
wr.WriteInt16(int16(i64))
|
||||
case tcInteger:
|
||||
wr.WriteInt32(int32(i64))
|
||||
case tcBigint:
|
||||
wr.WriteInt64(i64)
|
||||
}
|
||||
|
||||
case tcReal:
|
||||
|
||||
f64, ok := v.(float64)
|
||||
if !ok {
|
||||
return fmt.Errorf("invalid argument type %T", v)
|
||||
}
|
||||
wr.WriteFloat32(float32(f64))
|
||||
|
||||
case tcDouble:
|
||||
|
||||
f64, ok := v.(float64)
|
||||
if !ok {
|
||||
return fmt.Errorf("invalid argument type %T", v)
|
||||
}
|
||||
wr.WriteFloat64(f64)
|
||||
|
||||
case tcDate:
|
||||
t, ok := v.(time.Time)
|
||||
if !ok {
|
||||
return fmt.Errorf("invalid argument type %T", v)
|
||||
}
|
||||
writeDate(wr, t)
|
||||
|
||||
case tcTime:
|
||||
t, ok := v.(time.Time)
|
||||
if !ok {
|
||||
return fmt.Errorf("invalid argument type %T", v)
|
||||
}
|
||||
writeTime(wr, t)
|
||||
|
||||
case tcTimestamp:
|
||||
t, ok := v.(time.Time)
|
||||
if !ok {
|
||||
return fmt.Errorf("invalid argument type %T", v)
|
||||
}
|
||||
writeDate(wr, t)
|
||||
writeTime(wr, t)
|
||||
|
||||
case tcLongdate:
|
||||
t, ok := v.(time.Time)
|
||||
if !ok {
|
||||
return fmt.Errorf("invalid argument type %T", v)
|
||||
}
|
||||
writeLongdate(wr, t)
|
||||
|
||||
case tcSeconddate:
|
||||
t, ok := v.(time.Time)
|
||||
if !ok {
|
||||
return fmt.Errorf("invalid argument type %T", v)
|
||||
}
|
||||
writeSeconddate(wr, t)
|
||||
|
||||
case tcDaydate:
|
||||
t, ok := v.(time.Time)
|
||||
if !ok {
|
||||
return fmt.Errorf("invalid argument type %T", v)
|
||||
}
|
||||
writeDaydate(wr, t)
|
||||
|
||||
case tcSecondtime:
|
||||
// HDB bug: write null value explicite
|
||||
if v == nil {
|
||||
wr.WriteInt32(86401)
|
||||
return nil
|
||||
}
|
||||
t, ok := v.(time.Time)
|
||||
if !ok {
|
||||
return fmt.Errorf("invalid argument type %T", v)
|
||||
}
|
||||
writeSecondtime(wr, t)
|
||||
|
||||
case tcDecimal:
|
||||
b, ok := v.([]byte)
|
||||
if !ok {
|
||||
return fmt.Errorf("invalid argument type %T", v)
|
||||
}
|
||||
if len(b) != 16 {
|
||||
return fmt.Errorf("invalid argument length %d of type %T - expected %d", len(b), v, 16)
|
||||
}
|
||||
wr.Write(b)
|
||||
|
||||
case tcChar, tcVarchar, tcString:
|
||||
switch v := v.(type) {
|
||||
case []byte:
|
||||
writeBytes(wr, v)
|
||||
case string:
|
||||
writeString(wr, v)
|
||||
default:
|
||||
return fmt.Errorf("invalid argument type %T", v)
|
||||
}
|
||||
|
||||
case tcNchar, tcNvarchar, tcNstring:
|
||||
switch v := v.(type) {
|
||||
case []byte:
|
||||
writeUtf8Bytes(wr, v)
|
||||
case string:
|
||||
writeUtf8String(wr, v)
|
||||
default:
|
||||
return fmt.Errorf("invalid argument type %T", v)
|
||||
}
|
||||
|
||||
case tcBinary, tcVarbinary:
|
||||
v, ok := v.([]byte)
|
||||
if !ok {
|
||||
return fmt.Errorf("invalid argument type %T", v)
|
||||
}
|
||||
writeBytes(wr, v)
|
||||
|
||||
case tcBlob, tcClob, tcNclob:
|
||||
writeLob(wr)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// null values: most sig bit unset
|
||||
// year: unset second most sig bit (subtract 2^15)
|
||||
// --> read year as unsigned
|
||||
// month is 0-based
|
||||
// day is 1 byte
|
||||
func readDate(rd *bufio.Reader) (int, time.Month, int, bool) {
|
||||
year := rd.ReadUint16()
|
||||
null := ((year & 0x8000) == 0) //null value
|
||||
year &= 0x3fff
|
||||
month := rd.ReadInt8()
|
||||
month++
|
||||
day := rd.ReadInt8()
|
||||
return int(year), time.Month(month), int(day), null
|
||||
}
|
||||
|
||||
// year: set most sig bit
|
||||
// month 0 based
|
||||
func writeDate(wr *bufio.Writer, t time.Time) {
|
||||
//store in utc
|
||||
utc := t.In(time.UTC)
|
||||
|
||||
year, month, day := utc.Date()
|
||||
|
||||
wr.WriteUint16(uint16(year) | 0x8000)
|
||||
wr.WriteInt8(int8(month) - 1)
|
||||
wr.WriteInt8(int8(day))
|
||||
}
|
||||
|
||||
func readTime(rd *bufio.Reader) (int, int, int, bool) {
|
||||
hour := rd.ReadB()
|
||||
null := (hour & 0x80) == 0 //null value
|
||||
hour &= 0x7f
|
||||
minute := rd.ReadInt8()
|
||||
millisecs := rd.ReadUint16()
|
||||
nanosecs := int(millisecs) * 1000000
|
||||
return int(hour), int(minute), nanosecs, null
|
||||
}
|
||||
|
||||
func writeTime(wr *bufio.Writer, t time.Time) {
|
||||
//store in utc
|
||||
utc := t.UTC()
|
||||
|
||||
wr.WriteB(byte(utc.Hour()) | 0x80)
|
||||
wr.WriteInt8(int8(utc.Minute()))
|
||||
millisecs := utc.Second()*1000 + utc.Round(time.Millisecond).Nanosecond()/1000000
|
||||
wr.WriteUint16(uint16(millisecs))
|
||||
}
|
||||
|
||||
var zeroTime = time.Date(1, time.January, 1, 0, 0, 0, 0, time.UTC)
|
||||
|
||||
func readLongdate(rd *bufio.Reader) (time.Time, bool) {
|
||||
longdate := rd.ReadInt64()
|
||||
if longdate == 3155380704000000001 { // null value
|
||||
return zeroTime, true
|
||||
}
|
||||
return convertLongdateToTime(longdate), false
|
||||
}
|
||||
|
||||
func writeLongdate(wr *bufio.Writer, t time.Time) {
|
||||
wr.WriteInt64(convertTimeToLongdate(t))
|
||||
}
|
||||
|
||||
func readSeconddate(rd *bufio.Reader) (time.Time, bool) {
|
||||
seconddate := rd.ReadInt64()
|
||||
if seconddate == 315538070401 { // null value
|
||||
return zeroTime, true
|
||||
}
|
||||
return convertSeconddateToTime(seconddate), false
|
||||
}
|
||||
|
||||
func writeSeconddate(wr *bufio.Writer, t time.Time) {
|
||||
wr.WriteInt64(convertTimeToSeconddate(t))
|
||||
}
|
||||
|
||||
func readDaydate(rd *bufio.Reader) (time.Time, bool) {
|
||||
daydate := rd.ReadInt32()
|
||||
if daydate == 3652062 { // null value
|
||||
return zeroTime, true
|
||||
}
|
||||
return convertDaydateToTime(int64(daydate)), false
|
||||
}
|
||||
|
||||
func writeDaydate(wr *bufio.Writer, t time.Time) {
|
||||
wr.WriteInt32(int32(convertTimeToDayDate(t)))
|
||||
}
|
||||
|
||||
func readSecondtime(rd *bufio.Reader) (time.Time, bool) {
|
||||
secondtime := rd.ReadInt32()
|
||||
if secondtime == 86401 { // null value
|
||||
return zeroTime, true
|
||||
}
|
||||
return convertSecondtimeToTime(int(secondtime)), false
|
||||
}
|
||||
|
||||
func writeSecondtime(wr *bufio.Writer, t time.Time) {
|
||||
wr.WriteInt32(int32(convertTimeToSecondtime(t)))
|
||||
}
|
||||
|
||||
// nanosecond: HDB - 7 digits precision (not 9 digits)
|
||||
func convertTimeToLongdate(t time.Time) int64 {
|
||||
t = t.UTC()
|
||||
return (((((((int64(convertTimeToDayDate(t))-1)*24)+int64(t.Hour()))*60)+int64(t.Minute()))*60)+int64(t.Second()))*10000000 + int64(t.Nanosecond()/100) + 1
|
||||
}
|
||||
|
||||
func convertLongdateToTime(longdate int64) time.Time {
|
||||
const dayfactor = 10000000 * 24 * 60 * 60
|
||||
longdate--
|
||||
d := (longdate % dayfactor) * 100
|
||||
t := convertDaydateToTime((longdate / dayfactor) + 1)
|
||||
return t.Add(time.Duration(d))
|
||||
}
|
||||
|
||||
func convertTimeToSeconddate(t time.Time) int64 {
|
||||
t = t.UTC()
|
||||
return (((((int64(convertTimeToDayDate(t))-1)*24)+int64(t.Hour()))*60)+int64(t.Minute()))*60 + int64(t.Second()) + 1
|
||||
}
|
||||
|
||||
func convertSeconddateToTime(seconddate int64) time.Time {
|
||||
const dayfactor = 24 * 60 * 60
|
||||
seconddate--
|
||||
d := (seconddate % dayfactor) * 1000000000
|
||||
t := convertDaydateToTime((seconddate / dayfactor) + 1)
|
||||
return t.Add(time.Duration(d))
|
||||
}
|
||||
|
||||
const julianHdb = 1721423 // 1 January 0001 00:00:00 (1721424) - 1
|
||||
|
||||
func convertTimeToDayDate(t time.Time) int64 {
|
||||
return int64(timeToJulianDay(t) - julianHdb)
|
||||
}
|
||||
|
||||
func convertDaydateToTime(daydate int64) time.Time {
|
||||
return julianDayToTime(int(daydate) + julianHdb)
|
||||
}
|
||||
|
||||
func convertTimeToSecondtime(t time.Time) int {
|
||||
t = t.UTC()
|
||||
return (t.Hour()*60+t.Minute())*60 + t.Second() + 1
|
||||
}
|
||||
|
||||
func convertSecondtimeToTime(secondtime int) time.Time {
|
||||
return time.Date(1, 1, 1, 0, 0, 0, 0, time.UTC).Add(time.Duration(int64(secondtime-1) * 1000000000))
|
||||
}
|
||||
|
||||
func readDecimal(rd *bufio.Reader) ([]byte, bool) {
|
||||
b := make([]byte, 16)
|
||||
rd.ReadFull(b)
|
||||
if (b[15] & 0x70) == 0x70 { //null value (bit 4,5,6 set)
|
||||
return nil, true
|
||||
}
|
||||
return b, false
|
||||
}
|
||||
|
||||
// string / binary length indicators
|
||||
const (
|
||||
bytesLenIndNullValue byte = 255
|
||||
bytesLenIndSmall byte = 245
|
||||
bytesLenIndMedium byte = 246
|
||||
bytesLenIndBig byte = 247
|
||||
)
|
||||
|
||||
func bytesSize(size int) (int, error) { //size + length indicator
|
||||
switch {
|
||||
default:
|
||||
return 0, fmt.Errorf("max string length %d exceeded %d", math.MaxInt32, size)
|
||||
case size <= int(bytesLenIndSmall):
|
||||
return size + 1, nil
|
||||
case size <= math.MaxInt16:
|
||||
return size + 3, nil
|
||||
case size <= math.MaxInt32:
|
||||
return size + 5, nil
|
||||
}
|
||||
}
|
||||
|
||||
func readBytesSize(rd *bufio.Reader) (int, bool) {
|
||||
|
||||
ind := rd.ReadB() //length indicator
|
||||
|
||||
switch {
|
||||
|
||||
default:
|
||||
return 0, false
|
||||
|
||||
case ind == bytesLenIndNullValue:
|
||||
return 0, true
|
||||
|
||||
case ind <= bytesLenIndSmall:
|
||||
return int(ind), false
|
||||
|
||||
case ind == bytesLenIndMedium:
|
||||
return int(rd.ReadInt16()), false
|
||||
|
||||
case ind == bytesLenIndBig:
|
||||
return int(rd.ReadInt32()), false
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
func writeBytesSize(wr *bufio.Writer, size int) error {
|
||||
switch {
|
||||
|
||||
default:
|
||||
return fmt.Errorf("max argument length %d of string exceeded", size)
|
||||
|
||||
case size <= int(bytesLenIndSmall):
|
||||
wr.WriteB(byte(size))
|
||||
case size <= math.MaxInt16:
|
||||
wr.WriteB(bytesLenIndMedium)
|
||||
wr.WriteInt16(int16(size))
|
||||
case size <= math.MaxInt32:
|
||||
wr.WriteB(bytesLenIndBig)
|
||||
wr.WriteInt32(int32(size))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func readBytes(rd *bufio.Reader) ([]byte, bool) {
|
||||
size, null := readBytesSize(rd)
|
||||
if null {
|
||||
return nil, true
|
||||
}
|
||||
b := make([]byte, size)
|
||||
rd.ReadFull(b)
|
||||
return b, false
|
||||
}
|
||||
|
||||
func readUtf8(rd *bufio.Reader) ([]byte, bool) {
|
||||
size, null := readBytesSize(rd)
|
||||
if null {
|
||||
return nil, true
|
||||
}
|
||||
b := rd.ReadCesu8(size)
|
||||
return b, false
|
||||
}
|
||||
|
||||
// strings with one byte length
|
||||
func readShortUtf8(rd *bufio.Reader) ([]byte, int) {
|
||||
size := rd.ReadB()
|
||||
b := rd.ReadCesu8(int(size))
|
||||
return b, int(size)
|
||||
}
|
||||
|
||||
func writeBytes(wr *bufio.Writer, b []byte) {
|
||||
writeBytesSize(wr, len(b))
|
||||
wr.Write(b)
|
||||
}
|
||||
|
||||
func writeString(wr *bufio.Writer, s string) {
|
||||
writeBytesSize(wr, len(s))
|
||||
wr.WriteString(s)
|
||||
}
|
||||
|
||||
func writeUtf8Bytes(wr *bufio.Writer, b []byte) {
|
||||
size := cesu8.Size(b)
|
||||
writeBytesSize(wr, size)
|
||||
wr.WriteCesu8(b)
|
||||
}
|
||||
|
||||
func writeUtf8String(wr *bufio.Writer, s string) {
|
||||
size := cesu8.StringSize(s)
|
||||
writeBytesSize(wr, size)
|
||||
wr.WriteStringCesu8(s)
|
||||
}
|
||||
|
||||
func readLob(s *Session, rd *bufio.Reader, tc TypeCode) (bool, lobChunkWriter, error) {
|
||||
rd.ReadInt8() // type code (is int here)
|
||||
opt := rd.ReadInt8()
|
||||
null := (lobOptions(opt) & loNullindicator) != 0
|
||||
if null {
|
||||
return true, nil, nil
|
||||
}
|
||||
eof := (lobOptions(opt) & loLastdata) != 0
|
||||
rd.Skip(2)
|
||||
|
||||
charLen := rd.ReadInt64()
|
||||
byteLen := rd.ReadInt64()
|
||||
id := rd.ReadUint64()
|
||||
chunkLen := rd.ReadInt32()
|
||||
|
||||
lobChunkWriter := newLobChunkWriter(tc.isCharBased(), s, locatorID(id), charLen, byteLen)
|
||||
if err := lobChunkWriter.write(rd, int(chunkLen), eof); err != nil {
|
||||
return null, lobChunkWriter, err
|
||||
}
|
||||
return null, lobChunkWriter, nil
|
||||
}
|
||||
|
||||
// TODO: first write: add content? - actually no data transferred
|
||||
func writeLob(wr *bufio.Writer) {
|
||||
wr.WriteB(0)
|
||||
wr.WriteInt32(0)
|
||||
wr.WriteInt32(0)
|
||||
}
|
|
@ -0,0 +1,60 @@
|
|||
/*
|
||||
Copyright 2014 SAP SE
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package protocol
|
||||
|
||||
//go:generate stringer -type=functionCode
|
||||
|
||||
type functionCode int16
|
||||
|
||||
const (
|
||||
fcNil functionCode = 0
|
||||
fcDDL functionCode = 1
|
||||
fcInsert functionCode = 2
|
||||
fcUpdate functionCode = 3
|
||||
fcDelete functionCode = 4
|
||||
fcSelect functionCode = 5
|
||||
fcSelectForUpdate functionCode = 6
|
||||
fcExplain functionCode = 7
|
||||
fcDBProcedureCall functionCode = 8
|
||||
fcDBProcedureCallWithResult functionCode = 9
|
||||
fcFetch functionCode = 10
|
||||
fcCommit functionCode = 11
|
||||
fcRollback functionCode = 12
|
||||
fcSavepoint functionCode = 13
|
||||
fcConnect functionCode = 14
|
||||
fcWriteLob functionCode = 15
|
||||
fcReadLob functionCode = 16
|
||||
fcPing functionCode = 17 //reserved: do not use
|
||||
fcDisconnect functionCode = 18
|
||||
fcCloseCursor functionCode = 19
|
||||
fcFindLob functionCode = 20
|
||||
fcAbapStream functionCode = 21
|
||||
fcXAStart functionCode = 22
|
||||
fcXAJoin functionCode = 23
|
||||
)
|
||||
|
||||
func (k functionCode) queryType() QueryType {
|
||||
|
||||
switch k {
|
||||
default:
|
||||
return QtNone
|
||||
case fcSelect, fcSelectForUpdate:
|
||||
return QtSelect
|
||||
case fcDBProcedureCall:
|
||||
return QtProcedureCall
|
||||
}
|
||||
}
|
16
vendor/github.com/SAP/go-hdb/internal/protocol/functioncode_string.go
generated
vendored
Normal file
16
vendor/github.com/SAP/go-hdb/internal/protocol/functioncode_string.go
generated
vendored
Normal file
|
@ -0,0 +1,16 @@
|
|||
// Code generated by "stringer -type=functionCode"; DO NOT EDIT.
|
||||
|
||||
package protocol
|
||||
|
||||
import "strconv"
|
||||
|
||||
const _functionCode_name = "fcNilfcDDLfcInsertfcUpdatefcDeletefcSelectfcSelectForUpdatefcExplainfcDBProcedureCallfcDBProcedureCallWithResultfcFetchfcCommitfcRollbackfcSavepointfcConnectfcWriteLobfcReadLobfcPingfcDisconnectfcCloseCursorfcFindLobfcAbapStreamfcXAStartfcXAJoin"
|
||||
|
||||
var _functionCode_index = [...]uint8{0, 5, 10, 18, 26, 34, 42, 59, 68, 85, 112, 119, 127, 137, 148, 157, 167, 176, 182, 194, 207, 216, 228, 237, 245}
|
||||
|
||||
func (i functionCode) String() string {
|
||||
if i < 0 || i >= functionCode(len(_functionCode_index)-1) {
|
||||
return "functionCode(" + strconv.FormatInt(int64(i), 10) + ")"
|
||||
}
|
||||
return _functionCode_name[_functionCode_index[i]:_functionCode_index[i+1]]
|
||||
}
|
|
@ -0,0 +1,198 @@
|
|||
/*
|
||||
Copyright 2014 SAP SE
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package protocol
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/SAP/go-hdb/internal/bufio"
|
||||
)
|
||||
|
||||
const (
|
||||
okEndianess int8 = 1
|
||||
)
|
||||
|
||||
const (
|
||||
initRequestFillerSize = 4
|
||||
)
|
||||
|
||||
var initRequestFiller uint32 = 0xffffffff
|
||||
|
||||
type productVersion struct {
|
||||
major int8
|
||||
minor int16
|
||||
}
|
||||
|
||||
func (v *productVersion) String() string {
|
||||
return fmt.Sprintf("%d.%d", v.major, v.minor)
|
||||
}
|
||||
|
||||
type protocolVersion struct {
|
||||
major int8
|
||||
minor int16
|
||||
}
|
||||
|
||||
func (v *protocolVersion) String() string {
|
||||
return fmt.Sprintf("%d.%d", v.major, v.minor)
|
||||
}
|
||||
|
||||
type version struct {
|
||||
major int8
|
||||
minor int16
|
||||
}
|
||||
|
||||
func (v *version) String() string {
|
||||
return fmt.Sprintf("%d.%d", v.major, v.minor)
|
||||
}
|
||||
|
||||
type initRequest struct {
|
||||
product *version
|
||||
protocol *version
|
||||
numOptions int8
|
||||
endianess endianess
|
||||
}
|
||||
|
||||
func newInitRequest() *initRequest {
|
||||
return &initRequest{
|
||||
product: new(version),
|
||||
protocol: new(version),
|
||||
}
|
||||
}
|
||||
|
||||
func (r *initRequest) String() string {
|
||||
switch r.numOptions {
|
||||
default:
|
||||
return fmt.Sprintf("init request: product version %s protocol version %s", r.product, r.protocol)
|
||||
case 1:
|
||||
return fmt.Sprintf("init request: product version %s protocol version %s endianess %s", r.product, r.protocol, r.endianess)
|
||||
}
|
||||
}
|
||||
|
||||
func (r *initRequest) read(rd *bufio.Reader) error {
|
||||
rd.Skip(initRequestFillerSize) //filler
|
||||
r.product.major = rd.ReadInt8()
|
||||
r.product.minor = rd.ReadInt16()
|
||||
r.protocol.major = rd.ReadInt8()
|
||||
r.protocol.minor = rd.ReadInt16()
|
||||
rd.Skip(1) //reserved filler
|
||||
r.numOptions = rd.ReadInt8()
|
||||
|
||||
switch r.numOptions {
|
||||
default:
|
||||
outLogger.Fatalf("invalid number of options %d", r.numOptions)
|
||||
|
||||
case 0:
|
||||
rd.Skip(2)
|
||||
|
||||
case 1:
|
||||
cnt := rd.ReadInt8()
|
||||
if cnt != 1 {
|
||||
outLogger.Fatalf("endianess %d - 1 expected", cnt)
|
||||
}
|
||||
r.endianess = endianess(rd.ReadInt8())
|
||||
}
|
||||
|
||||
if trace {
|
||||
outLogger.Printf("read %s", r)
|
||||
}
|
||||
|
||||
return rd.GetError()
|
||||
}
|
||||
|
||||
func (r *initRequest) write(wr *bufio.Writer) error {
|
||||
wr.WriteUint32(initRequestFiller)
|
||||
wr.WriteInt8(r.product.major)
|
||||
wr.WriteInt16(r.product.minor)
|
||||
wr.WriteInt8(r.protocol.major)
|
||||
wr.WriteInt16(r.protocol.minor)
|
||||
|
||||
switch r.numOptions {
|
||||
default:
|
||||
outLogger.Fatalf("invalid number of options %d", r.numOptions)
|
||||
|
||||
case 0:
|
||||
wr.WriteZeroes(4)
|
||||
|
||||
case 1:
|
||||
// reserved
|
||||
wr.WriteZeroes(1)
|
||||
wr.WriteInt8(r.numOptions)
|
||||
wr.WriteInt8(int8(okEndianess))
|
||||
wr.WriteInt8(int8(r.endianess))
|
||||
|
||||
}
|
||||
|
||||
// flush
|
||||
if err := wr.Flush(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if trace {
|
||||
outLogger.Printf("write %s", r)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
type initReply struct {
|
||||
product *version
|
||||
protocol *version
|
||||
}
|
||||
|
||||
func newInitReply() *initReply {
|
||||
return &initReply{
|
||||
product: new(version),
|
||||
protocol: new(version),
|
||||
}
|
||||
}
|
||||
|
||||
func (r *initReply) String() string {
|
||||
return fmt.Sprintf("init reply: product version %s protocol version %s", r.product, r.protocol)
|
||||
}
|
||||
|
||||
func (r *initReply) read(rd *bufio.Reader) error {
|
||||
r.product.major = rd.ReadInt8()
|
||||
r.product.minor = rd.ReadInt16()
|
||||
r.protocol.major = rd.ReadInt8()
|
||||
r.protocol.minor = rd.ReadInt16()
|
||||
rd.Skip(2) //commitInitReplySize
|
||||
|
||||
if trace {
|
||||
outLogger.Printf("read %s", r)
|
||||
}
|
||||
|
||||
return rd.GetError()
|
||||
}
|
||||
|
||||
func (r *initReply) write(wr *bufio.Writer) error {
|
||||
wr.WriteInt8(r.product.major)
|
||||
wr.WriteInt16(r.product.minor)
|
||||
wr.WriteInt8(r.product.major)
|
||||
wr.WriteInt16(r.protocol.minor)
|
||||
wr.WriteZeroes(2) // commitInitReplySize
|
||||
|
||||
// flush
|
||||
if err := wr.Flush(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if trace {
|
||||
outLogger.Printf("write %s", r)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
|
@ -0,0 +1,23 @@
|
|||
// +build amd64 386 arm arm64 ppc64le mipsle mips64le
|
||||
|
||||
/*
|
||||
Copyright 2014 SAP SE
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package protocol
|
||||
|
||||
//amd64, 386 architectures: little endian
|
||||
//arm, arm64: go supports little endian only
|
||||
var archEndian = littleEndian
|
|
@ -0,0 +1,507 @@
|
|||
/*
|
||||
Copyright 2014 SAP SE
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package protocol
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"math"
|
||||
"unicode/utf8"
|
||||
|
||||
"golang.org/x/text/transform"
|
||||
|
||||
"github.com/SAP/go-hdb/internal/bufio"
|
||||
"github.com/SAP/go-hdb/internal/unicode"
|
||||
"github.com/SAP/go-hdb/internal/unicode/cesu8"
|
||||
)
|
||||
|
||||
const (
|
||||
locatorIDSize = 8
|
||||
writeLobRequestHeaderSize = 21
|
||||
readLobRequestSize = 24
|
||||
)
|
||||
|
||||
// variable (unit testing)
|
||||
//var lobChunkSize = 1 << 14 //TODO: check size
|
||||
var lobChunkSize int32 = 4096 //TODO: check size
|
||||
|
||||
//lob options
|
||||
type lobOptions int8
|
||||
|
||||
const (
|
||||
loNullindicator lobOptions = 0x01
|
||||
loDataincluded lobOptions = 0x02
|
||||
loLastdata lobOptions = 0x04
|
||||
)
|
||||
|
||||
var lobOptionsText = map[lobOptions]string{
|
||||
loNullindicator: "null indicator",
|
||||
loDataincluded: "data included",
|
||||
loLastdata: "last data",
|
||||
}
|
||||
|
||||
func (k lobOptions) String() string {
|
||||
t := make([]string, 0, len(lobOptionsText))
|
||||
|
||||
for option, text := range lobOptionsText {
|
||||
if (k & option) != 0 {
|
||||
t = append(t, text)
|
||||
}
|
||||
}
|
||||
return fmt.Sprintf("%v", t)
|
||||
}
|
||||
|
||||
type locatorID uint64 // byte[locatorIdSize]
|
||||
|
||||
// write lob reply
|
||||
type writeLobReply struct {
|
||||
ids []locatorID
|
||||
numArg int
|
||||
}
|
||||
|
||||
func (r *writeLobReply) String() string {
|
||||
return fmt.Sprintf("write lob reply: %v", r.ids)
|
||||
}
|
||||
|
||||
func (r *writeLobReply) kind() partKind {
|
||||
return pkWriteLobReply
|
||||
}
|
||||
|
||||
func (r *writeLobReply) setNumArg(numArg int) {
|
||||
r.numArg = numArg
|
||||
}
|
||||
|
||||
func (r *writeLobReply) read(rd *bufio.Reader) error {
|
||||
|
||||
//resize ids
|
||||
if r.ids == nil || cap(r.ids) < r.numArg {
|
||||
r.ids = make([]locatorID, r.numArg)
|
||||
} else {
|
||||
r.ids = r.ids[:r.numArg]
|
||||
}
|
||||
|
||||
for i := 0; i < r.numArg; i++ {
|
||||
r.ids[i] = locatorID(rd.ReadUint64())
|
||||
}
|
||||
|
||||
return rd.GetError()
|
||||
}
|
||||
|
||||
//write lob request
|
||||
type writeLobRequest struct {
|
||||
lobPrmFields []*ParameterField
|
||||
}
|
||||
|
||||
func (r *writeLobRequest) kind() partKind {
|
||||
return pkWriteLobRequest
|
||||
}
|
||||
|
||||
func (r *writeLobRequest) size() (int, error) {
|
||||
|
||||
// TODO: check size limit
|
||||
|
||||
size := 0
|
||||
for _, prmField := range r.lobPrmFields {
|
||||
cr := prmField.chunkReader
|
||||
if cr.done() {
|
||||
continue
|
||||
}
|
||||
|
||||
if err := cr.fill(); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
size += writeLobRequestHeaderSize
|
||||
size += cr.size()
|
||||
}
|
||||
return size, nil
|
||||
}
|
||||
|
||||
func (r *writeLobRequest) numArg() int {
|
||||
n := 0
|
||||
for _, prmField := range r.lobPrmFields {
|
||||
cr := prmField.chunkReader
|
||||
if !cr.done() {
|
||||
n++
|
||||
}
|
||||
}
|
||||
return n
|
||||
}
|
||||
|
||||
func (r *writeLobRequest) write(wr *bufio.Writer) error {
|
||||
for _, prmField := range r.lobPrmFields {
|
||||
cr := prmField.chunkReader
|
||||
if !cr.done() {
|
||||
|
||||
wr.WriteUint64(uint64(prmField.lobLocatorID))
|
||||
|
||||
opt := int8(0x02) // data included
|
||||
if cr.eof() {
|
||||
opt |= 0x04 // last data
|
||||
}
|
||||
|
||||
wr.WriteInt8(opt)
|
||||
wr.WriteInt64(-1) //offset (-1 := append)
|
||||
wr.WriteInt32(int32(cr.size())) // size
|
||||
wr.Write(cr.bytes())
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
//read lob request
|
||||
type readLobRequest struct {
|
||||
w lobChunkWriter
|
||||
}
|
||||
|
||||
func (r *readLobRequest) kind() partKind {
|
||||
return pkReadLobRequest
|
||||
}
|
||||
|
||||
func (r *readLobRequest) size() (int, error) {
|
||||
return readLobRequestSize, nil
|
||||
}
|
||||
|
||||
func (r *readLobRequest) numArg() int {
|
||||
return 1
|
||||
}
|
||||
|
||||
func (r *readLobRequest) write(wr *bufio.Writer) error {
|
||||
wr.WriteUint64(uint64(r.w.id()))
|
||||
|
||||
readOfs, readLen := r.w.readOfsLen()
|
||||
|
||||
wr.WriteInt64(readOfs + 1) //1-based
|
||||
wr.WriteInt32(readLen)
|
||||
wr.WriteZeroes(4)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// read lob reply
|
||||
// - seems like readLobreply gives only an result for one lob - even if more then one is requested
|
||||
// --> read single lobs
|
||||
type readLobReply struct {
|
||||
w lobChunkWriter
|
||||
}
|
||||
|
||||
func (r *readLobReply) kind() partKind {
|
||||
return pkReadLobReply
|
||||
}
|
||||
|
||||
func (r *readLobReply) setNumArg(numArg int) {
|
||||
if numArg != 1 {
|
||||
panic("numArg == 1 expected")
|
||||
}
|
||||
}
|
||||
|
||||
func (r *readLobReply) read(rd *bufio.Reader) error {
|
||||
id := rd.ReadUint64()
|
||||
|
||||
if r.w.id() != locatorID(id) {
|
||||
return fmt.Errorf("internal error: invalid lob locator %d - expected %d", id, r.w.id())
|
||||
}
|
||||
|
||||
opt := rd.ReadInt8()
|
||||
chunkLen := rd.ReadInt32()
|
||||
rd.Skip(3)
|
||||
eof := (lobOptions(opt) & loLastdata) != 0
|
||||
|
||||
if err := r.w.write(rd, int(chunkLen), eof); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return rd.GetError()
|
||||
}
|
||||
|
||||
// lobChunkReader reads lob field io.Reader in chunks for writing to db.
|
||||
type lobChunkReader interface {
|
||||
fill() error
|
||||
size() int
|
||||
bytes() []byte
|
||||
eof() bool
|
||||
done() bool
|
||||
}
|
||||
|
||||
func newLobChunkReader(isCharBased bool, r io.Reader) lobChunkReader {
|
||||
if isCharBased {
|
||||
return &charLobChunkReader{r: r}
|
||||
}
|
||||
return &binaryLobChunkReader{r: r}
|
||||
}
|
||||
|
||||
// binaryLobChunkReader (byte based chunks).
|
||||
type binaryLobChunkReader struct {
|
||||
r io.Reader
|
||||
_size int
|
||||
_eof bool
|
||||
_done bool
|
||||
b []byte
|
||||
}
|
||||
|
||||
func (l *binaryLobChunkReader) eof() bool { return l._eof }
|
||||
func (l *binaryLobChunkReader) done() bool { return l._done }
|
||||
func (l *binaryLobChunkReader) size() int { return l._size }
|
||||
|
||||
func (l *binaryLobChunkReader) bytes() []byte {
|
||||
l._done = l._eof
|
||||
return l.b[:l._size]
|
||||
}
|
||||
|
||||
func (l *binaryLobChunkReader) fill() error {
|
||||
if l._eof {
|
||||
return io.EOF
|
||||
}
|
||||
|
||||
var err error
|
||||
|
||||
l.b = resizeBuffer(l.b, int(lobChunkSize))
|
||||
l._size, err = l.r.Read(l.b)
|
||||
if err != nil && err != io.EOF {
|
||||
return err
|
||||
}
|
||||
l._eof = err == io.EOF
|
||||
return nil
|
||||
}
|
||||
|
||||
// charLobChunkReader (cesu8 character based chunks).
|
||||
type charLobChunkReader struct {
|
||||
r io.Reader
|
||||
_size int
|
||||
_eof bool
|
||||
_done bool
|
||||
b []byte
|
||||
c []byte
|
||||
ofs int
|
||||
}
|
||||
|
||||
func (l *charLobChunkReader) eof() bool { return l._eof }
|
||||
func (l *charLobChunkReader) done() bool { return l._done }
|
||||
func (l *charLobChunkReader) size() int { return l._size }
|
||||
|
||||
func (l *charLobChunkReader) bytes() []byte {
|
||||
l._done = l._eof
|
||||
return l.b[:l._size]
|
||||
}
|
||||
|
||||
func (l *charLobChunkReader) fill() error {
|
||||
if l._eof {
|
||||
return io.EOF
|
||||
}
|
||||
|
||||
l.c = resizeBuffer(l.c, int(lobChunkSize)+l.ofs)
|
||||
n, err := l.r.Read(l.c[l.ofs:])
|
||||
size := n + l.ofs
|
||||
|
||||
if err != nil && err != io.EOF {
|
||||
return err
|
||||
}
|
||||
l._eof = err == io.EOF
|
||||
if l._eof && size == 0 {
|
||||
l._size = 0
|
||||
return nil
|
||||
}
|
||||
|
||||
l.b = resizeBuffer(l.b, cesu8.Size(l.c[:size])) // last rune might be incomplete, so size is one greater than needed
|
||||
nDst, nSrc, err := unicode.Utf8ToCesu8Transformer.Transform(l.b, l.c[:size], l._eof)
|
||||
if err != nil && err != transform.ErrShortSrc {
|
||||
return err
|
||||
}
|
||||
|
||||
if l._eof && err == transform.ErrShortSrc {
|
||||
return unicode.ErrInvalidUtf8
|
||||
}
|
||||
|
||||
l._size = nDst
|
||||
l.ofs = size - nSrc
|
||||
|
||||
if l.ofs > 0 {
|
||||
copy(l.c, l.c[nSrc:size]) // copy rest to buffer beginn
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// lobChunkWriter reads db lob chunks and writes them into lob field io.Writer.
|
||||
type lobChunkWriter interface {
|
||||
SetWriter(w io.Writer) error // gets called by driver.Lob.Scan
|
||||
|
||||
id() locatorID
|
||||
write(rd *bufio.Reader, size int, eof bool) error
|
||||
readOfsLen() (int64, int32)
|
||||
eof() bool
|
||||
}
|
||||
|
||||
func newLobChunkWriter(isCharBased bool, s *Session, id locatorID, charLen, byteLen int64) lobChunkWriter {
|
||||
if isCharBased {
|
||||
return &charLobChunkWriter{s: s, _id: id, charLen: charLen, byteLen: byteLen}
|
||||
}
|
||||
return &binaryLobChunkWriter{s: s, _id: id, charLen: charLen, byteLen: byteLen}
|
||||
}
|
||||
|
||||
// binaryLobChunkWriter (byte based lobs).
|
||||
type binaryLobChunkWriter struct {
|
||||
s *Session
|
||||
|
||||
_id locatorID
|
||||
charLen int64
|
||||
byteLen int64
|
||||
|
||||
readOfs int64
|
||||
_eof bool
|
||||
|
||||
ofs int
|
||||
|
||||
wr io.Writer
|
||||
|
||||
b []byte
|
||||
}
|
||||
|
||||
func (l *binaryLobChunkWriter) id() locatorID { return l._id }
|
||||
func (l *binaryLobChunkWriter) eof() bool { return l._eof }
|
||||
|
||||
func (l *binaryLobChunkWriter) SetWriter(wr io.Writer) error {
|
||||
l.wr = wr
|
||||
if err := l.flush(); err != nil {
|
||||
return err
|
||||
}
|
||||
return l.s.readLobStream(l)
|
||||
}
|
||||
|
||||
func (l *binaryLobChunkWriter) write(rd *bufio.Reader, size int, eof bool) error {
|
||||
l._eof = eof // store eof
|
||||
|
||||
if size == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
l.b = resizeBuffer(l.b, size+l.ofs)
|
||||
rd.ReadFull(l.b[l.ofs:])
|
||||
if l.wr != nil {
|
||||
return l.flush()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (l *binaryLobChunkWriter) readOfsLen() (int64, int32) {
|
||||
readLen := l.charLen - l.readOfs
|
||||
if readLen > int64(math.MaxInt32) || readLen > int64(lobChunkSize) {
|
||||
return l.readOfs, lobChunkSize
|
||||
}
|
||||
return l.readOfs, int32(readLen)
|
||||
}
|
||||
|
||||
func (l *binaryLobChunkWriter) flush() error {
|
||||
if _, err := l.wr.Write(l.b); err != nil {
|
||||
return err
|
||||
}
|
||||
l.readOfs += int64(len(l.b))
|
||||
return nil
|
||||
}
|
||||
|
||||
type charLobChunkWriter struct {
|
||||
s *Session
|
||||
|
||||
_id locatorID
|
||||
charLen int64
|
||||
byteLen int64
|
||||
|
||||
readOfs int64
|
||||
_eof bool
|
||||
|
||||
ofs int
|
||||
|
||||
wr io.Writer
|
||||
|
||||
b []byte
|
||||
}
|
||||
|
||||
func (l *charLobChunkWriter) id() locatorID { return l._id }
|
||||
func (l *charLobChunkWriter) eof() bool { return l._eof }
|
||||
|
||||
func (l *charLobChunkWriter) SetWriter(wr io.Writer) error {
|
||||
l.wr = wr
|
||||
if err := l.flush(); err != nil {
|
||||
return err
|
||||
}
|
||||
return l.s.readLobStream(l)
|
||||
}
|
||||
|
||||
func (l *charLobChunkWriter) write(rd *bufio.Reader, size int, eof bool) error {
|
||||
l._eof = eof // store eof
|
||||
|
||||
if size == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
l.b = resizeBuffer(l.b, size+l.ofs)
|
||||
rd.ReadFull(l.b[l.ofs:])
|
||||
if l.wr != nil {
|
||||
return l.flush()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (l *charLobChunkWriter) readOfsLen() (int64, int32) {
|
||||
readLen := l.charLen - l.readOfs
|
||||
if readLen > int64(math.MaxInt32) || readLen > int64(lobChunkSize) {
|
||||
return l.readOfs, lobChunkSize
|
||||
}
|
||||
return l.readOfs, int32(readLen)
|
||||
}
|
||||
|
||||
func (l *charLobChunkWriter) flush() error {
|
||||
nDst, nSrc, err := unicode.Cesu8ToUtf8Transformer.Transform(l.b, l.b, true) // inline cesu8 to utf8 transformation
|
||||
if err != nil && err != transform.ErrShortSrc {
|
||||
return err
|
||||
}
|
||||
if _, err := l.wr.Write(l.b[:nDst]); err != nil {
|
||||
return err
|
||||
}
|
||||
l.ofs = len(l.b) - nSrc
|
||||
if l.ofs != 0 && l.ofs != cesu8.CESUMax/2 { // assert remaining bytes
|
||||
return unicode.ErrInvalidCesu8
|
||||
}
|
||||
l.readOfs += int64(l.runeCount(l.b[:nDst]))
|
||||
if l.ofs != 0 {
|
||||
l.readOfs++ // add half encoding
|
||||
copy(l.b, l.b[nSrc:len(l.b)]) // move half encoding to buffer begin
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Caution: hdb counts 4 byte utf-8 encodings (cesu-8 6 bytes) as 2 (3 byte) chars
|
||||
func (l *charLobChunkWriter) runeCount(b []byte) int {
|
||||
numChars := 0
|
||||
for len(b) > 0 {
|
||||
_, size := utf8.DecodeRune(b)
|
||||
b = b[size:]
|
||||
numChars++
|
||||
if size == utf8.UTFMax {
|
||||
numChars++
|
||||
}
|
||||
}
|
||||
return numChars
|
||||
}
|
||||
|
||||
// helper
|
||||
func resizeBuffer(b1 []byte, size int) []byte {
|
||||
if b1 == nil || cap(b1) < size {
|
||||
b2 := make([]byte, size)
|
||||
copy(b2, b1) // !!!
|
||||
return b2
|
||||
}
|
||||
return b1[:size]
|
||||
}
|
|
@ -0,0 +1,75 @@
|
|||
/*
|
||||
Copyright 2014 SAP SE
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package protocol
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/SAP/go-hdb/internal/bufio"
|
||||
)
|
||||
|
||||
const (
|
||||
messageHeaderSize = 32
|
||||
)
|
||||
|
||||
//message header
|
||||
type messageHeader struct {
|
||||
sessionID int64
|
||||
packetCount int32
|
||||
varPartLength uint32
|
||||
varPartSize uint32
|
||||
noOfSegm int16
|
||||
}
|
||||
|
||||
func (h *messageHeader) String() string {
|
||||
return fmt.Sprintf("session id %d packetCount %d varPartLength %d, varPartSize %d noOfSegm %d",
|
||||
h.sessionID,
|
||||
h.packetCount,
|
||||
h.varPartLength,
|
||||
h.varPartSize,
|
||||
h.noOfSegm)
|
||||
}
|
||||
|
||||
func (h *messageHeader) write(wr *bufio.Writer) error {
|
||||
wr.WriteInt64(h.sessionID)
|
||||
wr.WriteInt32(h.packetCount)
|
||||
wr.WriteUint32(h.varPartLength)
|
||||
wr.WriteUint32(h.varPartSize)
|
||||
wr.WriteInt16(h.noOfSegm)
|
||||
wr.WriteZeroes(10) //messageHeaderSize
|
||||
|
||||
if trace {
|
||||
outLogger.Printf("write message header: %s", h)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *messageHeader) read(rd *bufio.Reader) error {
|
||||
h.sessionID = rd.ReadInt64()
|
||||
h.packetCount = rd.ReadInt32()
|
||||
h.varPartLength = rd.ReadUint32()
|
||||
h.varPartSize = rd.ReadUint32()
|
||||
h.noOfSegm = rd.ReadInt16()
|
||||
rd.Skip(10) //messageHeaderSize
|
||||
|
||||
if trace {
|
||||
outLogger.Printf("read message header: %s", h)
|
||||
}
|
||||
|
||||
return rd.GetError()
|
||||
}
|
|
@ -0,0 +1,49 @@
|
|||
/*
|
||||
Copyright 2014 SAP SE
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package protocol
|
||||
|
||||
//go:generate stringer -type=messageType
|
||||
|
||||
type messageType int8
|
||||
|
||||
const (
|
||||
mtNil messageType = 0
|
||||
mtExecuteDirect messageType = 2
|
||||
mtPrepare messageType = 3
|
||||
mtAbapStream messageType = 4
|
||||
mtXAStart messageType = 5
|
||||
mtXAJoin messageType = 6
|
||||
mtExecute messageType = 13
|
||||
mtWriteLob messageType = 16
|
||||
mtReadLob messageType = 17
|
||||
mtFindLob messageType = 18
|
||||
mtAuthenticate messageType = 65
|
||||
mtConnect messageType = 66
|
||||
mtCommit messageType = 67
|
||||
mtRollback messageType = 68
|
||||
mtCloseResultset messageType = 69
|
||||
mtDropStatementID messageType = 70
|
||||
mtFetchNext messageType = 71
|
||||
mtFetchAbsolute messageType = 72
|
||||
mtFetchRelative messageType = 73
|
||||
mtFetchFirst messageType = 74
|
||||
mtFetchLast messageType = 75
|
||||
mtDisconnect messageType = 77
|
||||
mtExecuteITab messageType = 78
|
||||
mtFetchNextITab messageType = 79
|
||||
mtInsertNextITab messageType = 80
|
||||
)
|
44
vendor/github.com/SAP/go-hdb/internal/protocol/messagetype_string.go
generated
vendored
Normal file
44
vendor/github.com/SAP/go-hdb/internal/protocol/messagetype_string.go
generated
vendored
Normal file
|
@ -0,0 +1,44 @@
|
|||
// Code generated by "stringer -type=messageType"; DO NOT EDIT.
|
||||
|
||||
package protocol
|
||||
|
||||
import "strconv"
|
||||
|
||||
const (
|
||||
_messageType_name_0 = "mtNil"
|
||||
_messageType_name_1 = "mtExecuteDirectmtPreparemtAbapStreammtXAStartmtXAJoin"
|
||||
_messageType_name_2 = "mtExecute"
|
||||
_messageType_name_3 = "mtWriteLobmtReadLobmtFindLob"
|
||||
_messageType_name_4 = "mtAuthenticatemtConnectmtCommitmtRollbackmtCloseResultsetmtDropStatementIDmtFetchNextmtFetchAbsolutemtFetchRelativemtFetchFirstmtFetchLast"
|
||||
_messageType_name_5 = "mtDisconnectmtExecuteITabmtFetchNextITabmtInsertNextITab"
|
||||
)
|
||||
|
||||
var (
|
||||
_messageType_index_1 = [...]uint8{0, 15, 24, 36, 45, 53}
|
||||
_messageType_index_3 = [...]uint8{0, 10, 19, 28}
|
||||
_messageType_index_4 = [...]uint8{0, 14, 23, 31, 41, 57, 74, 85, 100, 115, 127, 138}
|
||||
_messageType_index_5 = [...]uint8{0, 12, 25, 40, 56}
|
||||
)
|
||||
|
||||
func (i messageType) String() string {
|
||||
switch {
|
||||
case i == 0:
|
||||
return _messageType_name_0
|
||||
case 2 <= i && i <= 6:
|
||||
i -= 2
|
||||
return _messageType_name_1[_messageType_index_1[i]:_messageType_index_1[i+1]]
|
||||
case i == 13:
|
||||
return _messageType_name_2
|
||||
case 16 <= i && i <= 18:
|
||||
i -= 16
|
||||
return _messageType_name_3[_messageType_index_3[i]:_messageType_index_3[i+1]]
|
||||
case 65 <= i && i <= 75:
|
||||
i -= 65
|
||||
return _messageType_name_4[_messageType_index_4[i]:_messageType_index_4[i+1]]
|
||||
case 77 <= i && i <= 80:
|
||||
i -= 77
|
||||
return _messageType_name_5[_messageType_index_5[i]:_messageType_index_5[i+1]]
|
||||
default:
|
||||
return "messageType(" + strconv.FormatInt(int64(i), 10) + ")"
|
||||
}
|
||||
}
|
|
@ -0,0 +1,188 @@
|
|||
/*
|
||||
Copyright 2014 SAP SE
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package protocol
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/SAP/go-hdb/internal/bufio"
|
||||
)
|
||||
|
||||
type booleanType bool
|
||||
|
||||
func (t booleanType) String() string {
|
||||
return fmt.Sprintf("%t", t)
|
||||
}
|
||||
|
||||
type intType int32
|
||||
|
||||
func (t intType) String() string {
|
||||
return fmt.Sprintf("%d", t)
|
||||
}
|
||||
|
||||
type bigintType int64
|
||||
|
||||
func (t bigintType) String() string {
|
||||
return fmt.Sprintf("%d", t)
|
||||
}
|
||||
|
||||
type doubleType float64
|
||||
|
||||
func (t doubleType) String() string {
|
||||
return fmt.Sprintf("%g", t)
|
||||
}
|
||||
|
||||
type stringType []byte
|
||||
|
||||
type binaryStringType []byte
|
||||
|
||||
func (t binaryStringType) String() string {
|
||||
return fmt.Sprintf("%v", []byte(t))
|
||||
}
|
||||
|
||||
//multi line options (number of lines in part header argumentCount)
|
||||
type multiLineOptions []plainOptions
|
||||
|
||||
func (o multiLineOptions) size() int {
|
||||
size := 0
|
||||
for _, m := range o {
|
||||
size += m.size()
|
||||
}
|
||||
return size
|
||||
}
|
||||
|
||||
//pointer: append multiLineOptions itself
|
||||
func (o *multiLineOptions) read(rd *bufio.Reader, lineCnt int) {
|
||||
for i := 0; i < lineCnt; i++ {
|
||||
m := plainOptions{}
|
||||
cnt := rd.ReadInt16()
|
||||
m.read(rd, int(cnt))
|
||||
*o = append(*o, m)
|
||||
}
|
||||
}
|
||||
|
||||
func (o multiLineOptions) write(wr *bufio.Writer) {
|
||||
for _, m := range o {
|
||||
wr.WriteInt16(int16(len(m)))
|
||||
m.write(wr)
|
||||
}
|
||||
}
|
||||
|
||||
type plainOptions map[int8]interface{}
|
||||
|
||||
func (o plainOptions) size() int {
|
||||
size := 2 * len(o) //option + type
|
||||
for _, v := range o {
|
||||
switch v := v.(type) {
|
||||
default:
|
||||
outLogger.Fatalf("type %T not implemented", v)
|
||||
case booleanType:
|
||||
size++
|
||||
case intType:
|
||||
size += 4
|
||||
case bigintType:
|
||||
size += 8
|
||||
case doubleType:
|
||||
size += 8
|
||||
case stringType:
|
||||
size += (2 + len(v)) //length int16 + string length
|
||||
case binaryStringType:
|
||||
size += (2 + len(v)) //length int16 + string length
|
||||
}
|
||||
}
|
||||
return size
|
||||
}
|
||||
|
||||
func (o plainOptions) read(rd *bufio.Reader, cnt int) {
|
||||
|
||||
for i := 0; i < cnt; i++ {
|
||||
|
||||
k := rd.ReadInt8()
|
||||
tc := rd.ReadB()
|
||||
|
||||
switch TypeCode(tc) {
|
||||
|
||||
default:
|
||||
outLogger.Fatalf("type code %s not implemented", TypeCode(tc))
|
||||
|
||||
case tcBoolean:
|
||||
o[k] = booleanType(rd.ReadBool())
|
||||
|
||||
case tcInteger:
|
||||
o[k] = intType(rd.ReadInt32())
|
||||
|
||||
case tcBigint:
|
||||
o[k] = bigintType(rd.ReadInt64())
|
||||
|
||||
case tcDouble:
|
||||
o[k] = doubleType(rd.ReadFloat64())
|
||||
|
||||
case tcString:
|
||||
size := rd.ReadInt16()
|
||||
v := make([]byte, size)
|
||||
rd.ReadFull(v)
|
||||
o[k] = stringType(v)
|
||||
|
||||
case tcBstring:
|
||||
size := rd.ReadInt16()
|
||||
v := make([]byte, size)
|
||||
rd.ReadFull(v)
|
||||
o[k] = binaryStringType(v)
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (o plainOptions) write(wr *bufio.Writer) {
|
||||
|
||||
for k, v := range o {
|
||||
|
||||
wr.WriteInt8(k)
|
||||
|
||||
switch v := v.(type) {
|
||||
|
||||
default:
|
||||
outLogger.Fatalf("type %T not implemented", v)
|
||||
|
||||
case booleanType:
|
||||
wr.WriteInt8(int8(tcBoolean))
|
||||
wr.WriteBool(bool(v))
|
||||
|
||||
case intType:
|
||||
wr.WriteInt8(int8(tcInteger))
|
||||
wr.WriteInt32(int32(v))
|
||||
|
||||
case bigintType:
|
||||
wr.WriteInt8(int8(tcBigint))
|
||||
wr.WriteInt64(int64(v))
|
||||
|
||||
case doubleType:
|
||||
wr.WriteInt8(int8(tcDouble))
|
||||
wr.WriteFloat64(float64(v))
|
||||
|
||||
case stringType:
|
||||
wr.WriteInt8(int8(tcString))
|
||||
wr.WriteInt16(int16(len(v)))
|
||||
wr.Write(v)
|
||||
|
||||
case binaryStringType:
|
||||
wr.WriteInt8(int8(tcBstring))
|
||||
wr.WriteInt16(int16(len(v)))
|
||||
wr.Write(v)
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,389 @@
|
|||
/*
|
||||
Copyright 2014 SAP SE
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package protocol
|
||||
|
||||
import (
|
||||
"database/sql/driver"
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
"github.com/SAP/go-hdb/internal/bufio"
|
||||
)
|
||||
|
||||
type parameterOptions int8
|
||||
|
||||
const (
|
||||
poMandatory parameterOptions = 0x01
|
||||
poOptional parameterOptions = 0x02
|
||||
poDefault parameterOptions = 0x04
|
||||
)
|
||||
|
||||
var parameterOptionsText = map[parameterOptions]string{
|
||||
poMandatory: "mandatory",
|
||||
poOptional: "optional",
|
||||
poDefault: "default",
|
||||
}
|
||||
|
||||
func (k parameterOptions) String() string {
|
||||
t := make([]string, 0, len(parameterOptionsText))
|
||||
|
||||
for option, text := range parameterOptionsText {
|
||||
if (k & option) != 0 {
|
||||
t = append(t, text)
|
||||
}
|
||||
}
|
||||
return fmt.Sprintf("%v", t)
|
||||
}
|
||||
|
||||
type parameterMode int8
|
||||
|
||||
const (
|
||||
pmIn parameterMode = 0x01
|
||||
pmInout parameterMode = 0x02
|
||||
pmOut parameterMode = 0x04
|
||||
)
|
||||
|
||||
var parameterModeText = map[parameterMode]string{
|
||||
pmIn: "in",
|
||||
pmInout: "inout",
|
||||
pmOut: "out",
|
||||
}
|
||||
|
||||
func (k parameterMode) String() string {
|
||||
t := make([]string, 0, len(parameterModeText))
|
||||
|
||||
for mode, text := range parameterModeText {
|
||||
if (k & mode) != 0 {
|
||||
t = append(t, text)
|
||||
}
|
||||
}
|
||||
return fmt.Sprintf("%v", t)
|
||||
}
|
||||
|
||||
// ParameterFieldSet contains database field metadata for parameters.
|
||||
type ParameterFieldSet struct {
|
||||
fields []*ParameterField
|
||||
_inputFields []*ParameterField
|
||||
_outputFields []*ParameterField
|
||||
names fieldNames
|
||||
}
|
||||
|
||||
func newParameterFieldSet(size int) *ParameterFieldSet {
|
||||
return &ParameterFieldSet{
|
||||
fields: make([]*ParameterField, size),
|
||||
_inputFields: make([]*ParameterField, 0, size),
|
||||
_outputFields: make([]*ParameterField, 0, size),
|
||||
names: newFieldNames(),
|
||||
}
|
||||
}
|
||||
|
||||
// String implements the Stringer interface.
|
||||
func (f *ParameterFieldSet) String() string {
|
||||
a := make([]string, len(f.fields))
|
||||
for i, f := range f.fields {
|
||||
a[i] = f.String()
|
||||
}
|
||||
return fmt.Sprintf("%v", a)
|
||||
}
|
||||
|
||||
func (f *ParameterFieldSet) read(rd *bufio.Reader) {
|
||||
for i := 0; i < len(f.fields); i++ {
|
||||
field := newParameterField(f.names)
|
||||
field.read(rd)
|
||||
f.fields[i] = field
|
||||
if field.In() {
|
||||
f._inputFields = append(f._inputFields, field)
|
||||
}
|
||||
if field.Out() {
|
||||
f._outputFields = append(f._outputFields, field)
|
||||
}
|
||||
}
|
||||
|
||||
pos := uint32(0)
|
||||
for _, offset := range f.names.sortOffsets() {
|
||||
if diff := int(offset - pos); diff > 0 {
|
||||
rd.Skip(diff)
|
||||
}
|
||||
b, size := readShortUtf8(rd)
|
||||
f.names.setName(offset, string(b))
|
||||
pos += uint32(1 + size)
|
||||
}
|
||||
}
|
||||
|
||||
func (f *ParameterFieldSet) inputFields() []*ParameterField {
|
||||
return f._inputFields
|
||||
}
|
||||
|
||||
func (f *ParameterFieldSet) outputFields() []*ParameterField {
|
||||
return f._outputFields
|
||||
}
|
||||
|
||||
// NumInputField returns the number of input fields in a database statement.
|
||||
func (f *ParameterFieldSet) NumInputField() int {
|
||||
return len(f._inputFields)
|
||||
}
|
||||
|
||||
// NumOutputField returns the number of output fields of a query or stored procedure.
|
||||
func (f *ParameterFieldSet) NumOutputField() int {
|
||||
return len(f._outputFields)
|
||||
}
|
||||
|
||||
// Field returns the field at index idx.
|
||||
func (f *ParameterFieldSet) Field(idx int) *ParameterField {
|
||||
return f.fields[idx]
|
||||
}
|
||||
|
||||
// OutputField returns the output field at index idx.
|
||||
func (f *ParameterFieldSet) OutputField(idx int) *ParameterField {
|
||||
return f._outputFields[idx]
|
||||
}
|
||||
|
||||
// ParameterField contains database field attributes for parameters.
|
||||
type ParameterField struct {
|
||||
fieldNames fieldNames
|
||||
parameterOptions parameterOptions
|
||||
tc TypeCode
|
||||
mode parameterMode
|
||||
fraction int16
|
||||
length int16
|
||||
offset uint32
|
||||
chunkReader lobChunkReader
|
||||
lobLocatorID locatorID
|
||||
}
|
||||
|
||||
func newParameterField(fieldNames fieldNames) *ParameterField {
|
||||
return &ParameterField{fieldNames: fieldNames}
|
||||
}
|
||||
|
||||
// String implements the Stringer interface.
|
||||
func (f *ParameterField) String() string {
|
||||
return fmt.Sprintf("parameterOptions %s typeCode %s mode %s fraction %d length %d name %s",
|
||||
f.parameterOptions,
|
||||
f.tc,
|
||||
f.mode,
|
||||
f.fraction,
|
||||
f.length,
|
||||
f.Name(),
|
||||
)
|
||||
}
|
||||
|
||||
// TypeCode returns the type code of the field.
|
||||
func (f *ParameterField) TypeCode() TypeCode {
|
||||
return f.tc
|
||||
}
|
||||
|
||||
// TypeLength returns the type length of the field.
|
||||
// see https://golang.org/pkg/database/sql/driver/#RowsColumnTypeLength
|
||||
func (f *ParameterField) TypeLength() (int64, bool) {
|
||||
if f.tc.isVariableLength() {
|
||||
return int64(f.length), true
|
||||
}
|
||||
return 0, false
|
||||
}
|
||||
|
||||
// TypePrecisionScale returns the type precision and scale (decimal types) of the field.
|
||||
// see https://golang.org/pkg/database/sql/driver/#RowsColumnTypePrecisionScale
|
||||
func (f *ParameterField) TypePrecisionScale() (int64, int64, bool) {
|
||||
if f.tc.isDecimalType() {
|
||||
return int64(f.length), int64(f.fraction), true
|
||||
}
|
||||
return 0, 0, false
|
||||
}
|
||||
|
||||
// Nullable returns true if the field may be null, false otherwise.
|
||||
// see https://golang.org/pkg/database/sql/driver/#RowsColumnTypeNullable
|
||||
func (f *ParameterField) Nullable() bool {
|
||||
return f.parameterOptions == poOptional
|
||||
}
|
||||
|
||||
// In returns true if the parameter field is an input field.
|
||||
func (f *ParameterField) In() bool {
|
||||
return f.mode == pmInout || f.mode == pmIn
|
||||
}
|
||||
|
||||
// Out returns true if the parameter field is an output field.
|
||||
func (f *ParameterField) Out() bool {
|
||||
return f.mode == pmInout || f.mode == pmOut
|
||||
}
|
||||
|
||||
// Name returns the parameter field name.
|
||||
func (f *ParameterField) Name() string {
|
||||
return f.fieldNames.name(f.offset)
|
||||
}
|
||||
|
||||
// SetLobReader sets the io.Reader if a Lob parameter field.
|
||||
func (f *ParameterField) SetLobReader(rd io.Reader) error {
|
||||
f.chunkReader = newLobChunkReader(f.TypeCode().isCharBased(), rd)
|
||||
return nil
|
||||
}
|
||||
|
||||
//
|
||||
|
||||
func (f *ParameterField) read(rd *bufio.Reader) {
|
||||
f.parameterOptions = parameterOptions(rd.ReadInt8())
|
||||
f.tc = TypeCode(rd.ReadInt8())
|
||||
f.mode = parameterMode(rd.ReadInt8())
|
||||
rd.Skip(1) //filler
|
||||
f.offset = rd.ReadUint32()
|
||||
f.fieldNames.addOffset(f.offset)
|
||||
f.length = rd.ReadInt16()
|
||||
f.fraction = rd.ReadInt16()
|
||||
rd.Skip(4) //filler
|
||||
}
|
||||
|
||||
// parameter metadata
|
||||
type parameterMetadata struct {
|
||||
prmFieldSet *ParameterFieldSet
|
||||
numArg int
|
||||
}
|
||||
|
||||
func (m *parameterMetadata) String() string {
|
||||
return fmt.Sprintf("parameter metadata: %s", m.prmFieldSet.fields)
|
||||
}
|
||||
|
||||
func (m *parameterMetadata) kind() partKind {
|
||||
return pkParameterMetadata
|
||||
}
|
||||
|
||||
func (m *parameterMetadata) setNumArg(numArg int) {
|
||||
m.numArg = numArg
|
||||
}
|
||||
|
||||
func (m *parameterMetadata) read(rd *bufio.Reader) error {
|
||||
|
||||
m.prmFieldSet.read(rd)
|
||||
|
||||
if trace {
|
||||
outLogger.Printf("read %s", m)
|
||||
}
|
||||
|
||||
return rd.GetError()
|
||||
}
|
||||
|
||||
// input parameters
|
||||
type inputParameters struct {
|
||||
inputFields []*ParameterField
|
||||
args []driver.NamedValue
|
||||
}
|
||||
|
||||
func newInputParameters(inputFields []*ParameterField, args []driver.NamedValue) *inputParameters {
|
||||
return &inputParameters{inputFields: inputFields, args: args}
|
||||
}
|
||||
|
||||
func (p *inputParameters) String() string {
|
||||
return fmt.Sprintf("input parameters: %v", p.args)
|
||||
}
|
||||
|
||||
func (p *inputParameters) kind() partKind {
|
||||
return pkParameters
|
||||
}
|
||||
|
||||
func (p *inputParameters) size() (int, error) {
|
||||
|
||||
size := len(p.args)
|
||||
cnt := len(p.inputFields)
|
||||
|
||||
for i, arg := range p.args {
|
||||
|
||||
if arg.Value == nil { // null value
|
||||
continue
|
||||
}
|
||||
|
||||
// mass insert
|
||||
field := p.inputFields[i%cnt]
|
||||
|
||||
fieldSize, err := fieldSize(field.TypeCode(), arg)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
size += fieldSize
|
||||
}
|
||||
|
||||
return size, nil
|
||||
}
|
||||
|
||||
func (p *inputParameters) numArg() int {
|
||||
cnt := len(p.inputFields)
|
||||
|
||||
if cnt == 0 { // avoid divide-by-zero (e.g. prepare without parameters)
|
||||
return 0
|
||||
}
|
||||
|
||||
return len(p.args) / cnt
|
||||
}
|
||||
|
||||
func (p *inputParameters) write(wr *bufio.Writer) error {
|
||||
|
||||
cnt := len(p.inputFields)
|
||||
|
||||
for i, arg := range p.args {
|
||||
|
||||
//mass insert
|
||||
field := p.inputFields[i%cnt]
|
||||
|
||||
if err := writeField(wr, field.TypeCode(), arg); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if trace {
|
||||
outLogger.Printf("input parameters: %s", p)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// output parameter
|
||||
type outputParameters struct {
|
||||
numArg int
|
||||
s *Session
|
||||
outputFields []*ParameterField
|
||||
fieldValues *FieldValues
|
||||
}
|
||||
|
||||
func (p *outputParameters) String() string {
|
||||
return fmt.Sprintf("output parameters: %v", p.fieldValues)
|
||||
}
|
||||
|
||||
func (p *outputParameters) kind() partKind {
|
||||
return pkOutputParameters
|
||||
}
|
||||
|
||||
func (p *outputParameters) setNumArg(numArg int) {
|
||||
p.numArg = numArg // should always be 1
|
||||
}
|
||||
|
||||
func (p *outputParameters) read(rd *bufio.Reader) error {
|
||||
|
||||
cols := len(p.outputFields)
|
||||
p.fieldValues.resize(p.numArg, cols)
|
||||
|
||||
for i := 0; i < p.numArg; i++ {
|
||||
for j, field := range p.outputFields {
|
||||
var err error
|
||||
if p.fieldValues.values[i*cols+j], err = readField(p.s, rd, field.TypeCode()); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if trace {
|
||||
outLogger.Printf("read %s", p)
|
||||
}
|
||||
return rd.GetError()
|
||||
}
|
|
@ -0,0 +1,144 @@
|
|||
/*
|
||||
Copyright 2014 SAP SE
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package protocol
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/SAP/go-hdb/internal/bufio"
|
||||
)
|
||||
|
||||
const (
|
||||
partHeaderSize = 16
|
||||
)
|
||||
|
||||
type requestPart interface {
|
||||
kind() partKind
|
||||
size() (int, error)
|
||||
numArg() int
|
||||
write(*bufio.Writer) error
|
||||
}
|
||||
|
||||
type replyPart interface {
|
||||
//kind() partKind
|
||||
setNumArg(int)
|
||||
read(*bufio.Reader) error
|
||||
}
|
||||
|
||||
// PartAttributes is an interface defining methods for reading query resultset parts.
|
||||
type PartAttributes interface {
|
||||
ResultsetClosed() bool
|
||||
LastPacket() bool
|
||||
NoRows() bool
|
||||
}
|
||||
|
||||
type partAttributes int8
|
||||
|
||||
const (
|
||||
paLastPacket partAttributes = 0x01
|
||||
paNextPacket partAttributes = 0x02
|
||||
paFirstPacket partAttributes = 0x04
|
||||
paRowNotFound partAttributes = 0x08
|
||||
paResultsetClosed partAttributes = 0x10
|
||||
)
|
||||
|
||||
var partAttributesText = map[partAttributes]string{
|
||||
paLastPacket: "lastPacket",
|
||||
paNextPacket: "nextPacket",
|
||||
paFirstPacket: "firstPacket",
|
||||
paRowNotFound: "rowNotFound",
|
||||
paResultsetClosed: "resultsetClosed",
|
||||
}
|
||||
|
||||
func (k partAttributes) String() string {
|
||||
t := make([]string, 0, len(partAttributesText))
|
||||
|
||||
for attr, text := range partAttributesText {
|
||||
if (k & attr) != 0 {
|
||||
t = append(t, text)
|
||||
}
|
||||
}
|
||||
return fmt.Sprintf("%v", t)
|
||||
}
|
||||
|
||||
func (k partAttributes) ResultsetClosed() bool {
|
||||
return (k & paResultsetClosed) == paResultsetClosed
|
||||
}
|
||||
|
||||
func (k partAttributes) LastPacket() bool {
|
||||
return (k & paLastPacket) == paLastPacket
|
||||
}
|
||||
|
||||
func (k partAttributes) NoRows() bool {
|
||||
attrs := paLastPacket | paRowNotFound
|
||||
return (k & attrs) == attrs
|
||||
}
|
||||
|
||||
// part header
|
||||
type partHeader struct {
|
||||
partKind partKind
|
||||
partAttributes partAttributes
|
||||
argumentCount int16
|
||||
bigArgumentCount int32
|
||||
bufferLength int32
|
||||
bufferSize int32
|
||||
}
|
||||
|
||||
func (h *partHeader) String() string {
|
||||
return fmt.Sprintf("part kind %s partAttributes %s argumentCount %d bigArgumentCount %d bufferLength %d bufferSize %d",
|
||||
h.partKind,
|
||||
h.partAttributes,
|
||||
h.argumentCount,
|
||||
h.bigArgumentCount,
|
||||
h.bufferLength,
|
||||
h.bufferSize,
|
||||
)
|
||||
}
|
||||
|
||||
func (h *partHeader) write(wr *bufio.Writer) error {
|
||||
wr.WriteInt8(int8(h.partKind))
|
||||
wr.WriteInt8(int8(h.partAttributes))
|
||||
wr.WriteInt16(h.argumentCount)
|
||||
wr.WriteInt32(h.bigArgumentCount)
|
||||
wr.WriteInt32(h.bufferLength)
|
||||
wr.WriteInt32(h.bufferSize)
|
||||
|
||||
//no filler
|
||||
|
||||
if trace {
|
||||
outLogger.Printf("write part header: %s", h)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *partHeader) read(rd *bufio.Reader) error {
|
||||
h.partKind = partKind(rd.ReadInt8())
|
||||
h.partAttributes = partAttributes(rd.ReadInt8())
|
||||
h.argumentCount = rd.ReadInt16()
|
||||
h.bigArgumentCount = rd.ReadInt32()
|
||||
h.bufferLength = rd.ReadInt32()
|
||||
h.bufferSize = rd.ReadInt32()
|
||||
|
||||
// no filler
|
||||
|
||||
if trace {
|
||||
outLogger.Printf("read part header: %s", h)
|
||||
}
|
||||
|
||||
return rd.GetError()
|
||||
}
|
|
@ -0,0 +1,79 @@
|
|||
/*
|
||||
Copyright 2014 SAP SE
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package protocol
|
||||
|
||||
//go:generate stringer -type=partKind
|
||||
|
||||
type partKind int8
|
||||
|
||||
const (
|
||||
pkNil partKind = 0
|
||||
pkCommand partKind = 3
|
||||
pkResultset partKind = 5
|
||||
pkError partKind = 6
|
||||
pkStatementID partKind = 10
|
||||
pkTransactionID partKind = 11
|
||||
pkRowsAffected partKind = 12
|
||||
pkResultsetID partKind = 13
|
||||
pkTopologyInformation partKind = 15
|
||||
pkTableLocation partKind = 16
|
||||
pkReadLobRequest partKind = 17
|
||||
pkReadLobReply partKind = 18
|
||||
pkAbapIStream partKind = 25
|
||||
pkAbapOStream partKind = 26
|
||||
pkCommandInfo partKind = 27
|
||||
pkWriteLobRequest partKind = 28
|
||||
pkClientContext partKind = 29
|
||||
pkWriteLobReply partKind = 30
|
||||
pkParameters partKind = 32
|
||||
pkAuthentication partKind = 33
|
||||
pkSessionContext partKind = 34
|
||||
pkClientID partKind = 35
|
||||
pkProfile partKind = 38
|
||||
pkStatementContext partKind = 39
|
||||
pkPartitionInformation partKind = 40
|
||||
pkOutputParameters partKind = 41
|
||||
pkConnectOptions partKind = 42
|
||||
pkCommitOptions partKind = 43
|
||||
pkFetchOptions partKind = 44
|
||||
pkFetchSize partKind = 45
|
||||
pkParameterMetadata partKind = 47
|
||||
pkResultMetadata partKind = 48
|
||||
pkFindLobRequest partKind = 49
|
||||
pkFindLobReply partKind = 50
|
||||
pkItabSHM partKind = 51
|
||||
pkItabChunkMetadata partKind = 53
|
||||
pkItabMetadata partKind = 55
|
||||
pkItabResultChunk partKind = 56
|
||||
pkClientInfo partKind = 57
|
||||
pkStreamData partKind = 58
|
||||
pkOStreamResult partKind = 59
|
||||
pkFDARequestMetadata partKind = 60
|
||||
pkFDAReplyMetadata partKind = 61
|
||||
pkBatchPrepare partKind = 62 //Reserved: do not use
|
||||
pkBatchExecute partKind = 63 //Reserved: do not use
|
||||
pkTransactionFlags partKind = 64
|
||||
pkRowSlotImageParamMetadata partKind = 65 //Reserved: do not use
|
||||
pkRowSlotImageResultset partKind = 66 //Reserved: do not use
|
||||
pkDBConnectInfo partKind = 67
|
||||
pkLobFlags partKind = 68
|
||||
pkResultsetOptions partKind = 69
|
||||
pkXATransactionInfo partKind = 70
|
||||
pkSessionVariable partKind = 71
|
||||
pkWorkLoadReplayContext partKind = 72
|
||||
pkSQLReplyOptions partKind = 73
|
||||
)
|
|
@ -0,0 +1,72 @@
|
|||
// Code generated by "stringer -type=partKind"; DO NOT EDIT.
|
||||
|
||||
package protocol
|
||||
|
||||
import "strconv"
|
||||
|
||||
const _partKind_name = "pkNilpkCommandpkResultsetpkErrorpkStatementIDpkTransactionIDpkRowsAffectedpkResultsetIDpkTopologyInformationpkTableLocationpkReadLobRequestpkReadLobReplypkAbapIStreampkAbapOStreampkCommandInfopkWriteLobRequestpkClientContextpkWriteLobReplypkParameterspkAuthenticationpkSessionContextpkClientIDpkProfilepkStatementContextpkPartitionInformationpkOutputParameterspkConnectOptionspkCommitOptionspkFetchOptionspkFetchSizepkParameterMetadatapkResultMetadatapkFindLobRequestpkFindLobReplypkItabSHMpkItabChunkMetadatapkItabMetadatapkItabResultChunkpkClientInfopkStreamDatapkOStreamResultpkFDARequestMetadatapkFDAReplyMetadatapkBatchPreparepkBatchExecutepkTransactionFlagspkRowSlotImageParamMetadatapkRowSlotImageResultsetpkDBConnectInfopkLobFlagspkResultsetOptionspkXATransactionInfopkSessionVariablepkWorkLoadReplayContextpkSQLReplyOptions"
|
||||
|
||||
var _partKind_map = map[partKind]string{
|
||||
0: _partKind_name[0:5],
|
||||
3: _partKind_name[5:14],
|
||||
5: _partKind_name[14:25],
|
||||
6: _partKind_name[25:32],
|
||||
10: _partKind_name[32:45],
|
||||
11: _partKind_name[45:60],
|
||||
12: _partKind_name[60:74],
|
||||
13: _partKind_name[74:87],
|
||||
15: _partKind_name[87:108],
|
||||
16: _partKind_name[108:123],
|
||||
17: _partKind_name[123:139],
|
||||
18: _partKind_name[139:153],
|
||||
25: _partKind_name[153:166],
|
||||
26: _partKind_name[166:179],
|
||||
27: _partKind_name[179:192],
|
||||
28: _partKind_name[192:209],
|
||||
29: _partKind_name[209:224],
|
||||
30: _partKind_name[224:239],
|
||||
32: _partKind_name[239:251],
|
||||
33: _partKind_name[251:267],
|
||||
34: _partKind_name[267:283],
|
||||
35: _partKind_name[283:293],
|
||||
38: _partKind_name[293:302],
|
||||
39: _partKind_name[302:320],
|
||||
40: _partKind_name[320:342],
|
||||
41: _partKind_name[342:360],
|
||||
42: _partKind_name[360:376],
|
||||
43: _partKind_name[376:391],
|
||||
44: _partKind_name[391:405],
|
||||
45: _partKind_name[405:416],
|
||||
47: _partKind_name[416:435],
|
||||
48: _partKind_name[435:451],
|
||||
49: _partKind_name[451:467],
|
||||
50: _partKind_name[467:481],
|
||||
51: _partKind_name[481:490],
|
||||
53: _partKind_name[490:509],
|
||||
55: _partKind_name[509:523],
|
||||
56: _partKind_name[523:540],
|
||||
57: _partKind_name[540:552],
|
||||
58: _partKind_name[552:564],
|
||||
59: _partKind_name[564:579],
|
||||
60: _partKind_name[579:599],
|
||||
61: _partKind_name[599:617],
|
||||
62: _partKind_name[617:631],
|
||||
63: _partKind_name[631:645],
|
||||
64: _partKind_name[645:663],
|
||||
65: _partKind_name[663:690],
|
||||
66: _partKind_name[690:713],
|
||||
67: _partKind_name[713:728],
|
||||
68: _partKind_name[728:738],
|
||||
69: _partKind_name[738:756],
|
||||
70: _partKind_name[756:775],
|
||||
71: _partKind_name[775:792],
|
||||
72: _partKind_name[792:815],
|
||||
73: _partKind_name[815:832],
|
||||
}
|
||||
|
||||
func (i partKind) String() string {
|
||||
if str, ok := _partKind_map[i]; ok {
|
||||
return str
|
||||
}
|
||||
return "partKind(" + strconv.FormatInt(int64(i), 10) + ")"
|
||||
}
|
|
@ -0,0 +1,29 @@
|
|||
/*
|
||||
Copyright 2014 SAP SE
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package protocol
|
||||
|
||||
//go:generate stringer -type=QueryType
|
||||
|
||||
// QueryType is the type definition for query types supported by this package.
|
||||
type QueryType byte
|
||||
|
||||
// Query type constants.
|
||||
const (
|
||||
QtNone QueryType = iota
|
||||
QtSelect
|
||||
QtProcedureCall
|
||||
)
|
16
vendor/github.com/SAP/go-hdb/internal/protocol/querytype_string.go
generated
vendored
Normal file
16
vendor/github.com/SAP/go-hdb/internal/protocol/querytype_string.go
generated
vendored
Normal file
|
@ -0,0 +1,16 @@
|
|||
// Code generated by "stringer -type=QueryType"; DO NOT EDIT.
|
||||
|
||||
package protocol
|
||||
|
||||
import "strconv"
|
||||
|
||||
const _QueryType_name = "QtNoneQtSelectQtProcedureCall"
|
||||
|
||||
var _QueryType_index = [...]uint8{0, 6, 14, 29}
|
||||
|
||||
func (i QueryType) String() string {
|
||||
if i >= QueryType(len(_QueryType_index)-1) {
|
||||
return "QueryType(" + strconv.FormatInt(int64(i), 10) + ")"
|
||||
}
|
||||
return _QueryType_name[_QueryType_index[i]:_QueryType_index[i+1]]
|
||||
}
|
|
@ -0,0 +1,294 @@
|
|||
/*
|
||||
Copyright 2014 SAP SE
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package protocol
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/SAP/go-hdb/internal/bufio"
|
||||
)
|
||||
|
||||
const (
|
||||
resultsetIDSize = 8
|
||||
)
|
||||
|
||||
type columnOptions int8
|
||||
|
||||
const (
|
||||
coMandatory columnOptions = 0x01
|
||||
coOptional columnOptions = 0x02
|
||||
)
|
||||
|
||||
var columnOptionsText = map[columnOptions]string{
|
||||
coMandatory: "mandatory",
|
||||
coOptional: "optional",
|
||||
}
|
||||
|
||||
func (k columnOptions) String() string {
|
||||
t := make([]string, 0, len(columnOptionsText))
|
||||
|
||||
for option, text := range columnOptionsText {
|
||||
if (k & option) != 0 {
|
||||
t = append(t, text)
|
||||
}
|
||||
}
|
||||
return fmt.Sprintf("%v", t)
|
||||
}
|
||||
|
||||
//resultset id
|
||||
type resultsetID struct {
|
||||
id *uint64
|
||||
}
|
||||
|
||||
func (id *resultsetID) kind() partKind {
|
||||
return pkResultsetID
|
||||
}
|
||||
|
||||
func (id *resultsetID) size() (int, error) {
|
||||
return resultsetIDSize, nil
|
||||
}
|
||||
|
||||
func (id *resultsetID) numArg() int {
|
||||
return 1
|
||||
}
|
||||
|
||||
func (id *resultsetID) setNumArg(int) {
|
||||
//ignore - always 1
|
||||
}
|
||||
|
||||
func (id *resultsetID) read(rd *bufio.Reader) error {
|
||||
_id := rd.ReadUint64()
|
||||
*id.id = _id
|
||||
|
||||
if trace {
|
||||
outLogger.Printf("resultset id: %d", *id.id)
|
||||
}
|
||||
|
||||
return rd.GetError()
|
||||
}
|
||||
|
||||
func (id *resultsetID) write(wr *bufio.Writer) error {
|
||||
wr.WriteUint64(*id.id)
|
||||
|
||||
if trace {
|
||||
outLogger.Printf("resultset id: %d", *id.id)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ResultFieldSet contains database field metadata for result fields.
|
||||
type ResultFieldSet struct {
|
||||
fields []*ResultField
|
||||
names fieldNames
|
||||
}
|
||||
|
||||
func newResultFieldSet(size int) *ResultFieldSet {
|
||||
return &ResultFieldSet{
|
||||
fields: make([]*ResultField, size),
|
||||
names: newFieldNames(),
|
||||
}
|
||||
}
|
||||
|
||||
// String implements the Stringer interface.
|
||||
func (f *ResultFieldSet) String() string {
|
||||
a := make([]string, len(f.fields))
|
||||
for i, f := range f.fields {
|
||||
a[i] = f.String()
|
||||
}
|
||||
return fmt.Sprintf("%v", a)
|
||||
}
|
||||
|
||||
func (f *ResultFieldSet) read(rd *bufio.Reader) {
|
||||
for i := 0; i < len(f.fields); i++ {
|
||||
field := newResultField(f.names)
|
||||
field.read(rd)
|
||||
f.fields[i] = field
|
||||
}
|
||||
|
||||
pos := uint32(0)
|
||||
for _, offset := range f.names.sortOffsets() {
|
||||
if diff := int(offset - pos); diff > 0 {
|
||||
rd.Skip(diff)
|
||||
}
|
||||
b, size := readShortUtf8(rd)
|
||||
f.names.setName(offset, string(b))
|
||||
pos += uint32(1 + size)
|
||||
}
|
||||
}
|
||||
|
||||
// NumField returns the number of fields of a query.
|
||||
func (f *ResultFieldSet) NumField() int {
|
||||
return len(f.fields)
|
||||
}
|
||||
|
||||
// Field returns the field at index idx.
|
||||
func (f *ResultFieldSet) Field(idx int) *ResultField {
|
||||
return f.fields[idx]
|
||||
}
|
||||
|
||||
const (
|
||||
tableName = iota
|
||||
schemaName
|
||||
columnName
|
||||
columnDisplayName
|
||||
maxNames
|
||||
)
|
||||
|
||||
// ResultField contains database field attributes for result fields.
|
||||
type ResultField struct {
|
||||
fieldNames fieldNames
|
||||
columnOptions columnOptions
|
||||
tc TypeCode
|
||||
fraction int16
|
||||
length int16
|
||||
offsets [maxNames]uint32
|
||||
}
|
||||
|
||||
func newResultField(fieldNames fieldNames) *ResultField {
|
||||
return &ResultField{fieldNames: fieldNames}
|
||||
}
|
||||
|
||||
// String implements the Stringer interface.
|
||||
func (f *ResultField) String() string {
|
||||
return fmt.Sprintf("columnsOptions %s typeCode %s fraction %d length %d tablename %s schemaname %s columnname %s columnDisplayname %s",
|
||||
f.columnOptions,
|
||||
f.tc,
|
||||
f.fraction,
|
||||
f.length,
|
||||
f.fieldNames.name(f.offsets[tableName]),
|
||||
f.fieldNames.name(f.offsets[schemaName]),
|
||||
f.fieldNames.name(f.offsets[columnName]),
|
||||
f.fieldNames.name(f.offsets[columnDisplayName]),
|
||||
)
|
||||
}
|
||||
|
||||
// TypeCode returns the type code of the field.
|
||||
func (f *ResultField) TypeCode() TypeCode {
|
||||
return f.tc
|
||||
}
|
||||
|
||||
// TypeLength returns the type length of the field.
|
||||
// see https://golang.org/pkg/database/sql/driver/#RowsColumnTypeLength
|
||||
func (f *ResultField) TypeLength() (int64, bool) {
|
||||
if f.tc.isVariableLength() {
|
||||
return int64(f.length), true
|
||||
}
|
||||
return 0, false
|
||||
}
|
||||
|
||||
// TypePrecisionScale returns the type precision and scale (decimal types) of the field.
|
||||
// see https://golang.org/pkg/database/sql/driver/#RowsColumnTypePrecisionScale
|
||||
func (f *ResultField) TypePrecisionScale() (int64, int64, bool) {
|
||||
if f.tc.isDecimalType() {
|
||||
return int64(f.length), int64(f.fraction), true
|
||||
}
|
||||
return 0, 0, false
|
||||
}
|
||||
|
||||
// Nullable returns true if the field may be null, false otherwise.
|
||||
// see https://golang.org/pkg/database/sql/driver/#RowsColumnTypeNullable
|
||||
func (f *ResultField) Nullable() bool {
|
||||
return f.columnOptions == coOptional
|
||||
}
|
||||
|
||||
// Name returns the result field name.
|
||||
func (f *ResultField) Name() string {
|
||||
return f.fieldNames.name(f.offsets[columnDisplayName])
|
||||
}
|
||||
|
||||
func (f *ResultField) read(rd *bufio.Reader) {
|
||||
f.columnOptions = columnOptions(rd.ReadInt8())
|
||||
f.tc = TypeCode(rd.ReadInt8())
|
||||
f.fraction = rd.ReadInt16()
|
||||
f.length = rd.ReadInt16()
|
||||
rd.Skip(2) //filler
|
||||
for i := 0; i < maxNames; i++ {
|
||||
offset := rd.ReadUint32()
|
||||
f.offsets[i] = offset
|
||||
f.fieldNames.addOffset(offset)
|
||||
}
|
||||
}
|
||||
|
||||
//resultset metadata
|
||||
type resultMetadata struct {
|
||||
resultFieldSet *ResultFieldSet
|
||||
numArg int
|
||||
}
|
||||
|
||||
func (r *resultMetadata) String() string {
|
||||
return fmt.Sprintf("result metadata: %s", r.resultFieldSet.fields)
|
||||
}
|
||||
|
||||
func (r *resultMetadata) kind() partKind {
|
||||
return pkResultMetadata
|
||||
}
|
||||
|
||||
func (r *resultMetadata) setNumArg(numArg int) {
|
||||
r.numArg = numArg
|
||||
}
|
||||
|
||||
func (r *resultMetadata) read(rd *bufio.Reader) error {
|
||||
|
||||
r.resultFieldSet.read(rd)
|
||||
|
||||
if trace {
|
||||
outLogger.Printf("read %s", r)
|
||||
}
|
||||
|
||||
return rd.GetError()
|
||||
}
|
||||
|
||||
//resultset
|
||||
type resultset struct {
|
||||
numArg int
|
||||
s *Session
|
||||
resultFieldSet *ResultFieldSet
|
||||
fieldValues *FieldValues
|
||||
}
|
||||
|
||||
func (r *resultset) String() string {
|
||||
return fmt.Sprintf("resultset: %s", r.fieldValues)
|
||||
}
|
||||
|
||||
func (r *resultset) kind() partKind {
|
||||
return pkResultset
|
||||
}
|
||||
|
||||
func (r *resultset) setNumArg(numArg int) {
|
||||
r.numArg = numArg
|
||||
}
|
||||
|
||||
func (r *resultset) read(rd *bufio.Reader) error {
|
||||
|
||||
cols := len(r.resultFieldSet.fields)
|
||||
r.fieldValues.resize(r.numArg, cols)
|
||||
|
||||
for i := 0; i < r.numArg; i++ {
|
||||
for j, field := range r.resultFieldSet.fields {
|
||||
var err error
|
||||
if r.fieldValues.values[i*cols+j], err = readField(r.s, rd, field.TypeCode()); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if trace {
|
||||
outLogger.Printf("read %s", r)
|
||||
}
|
||||
return rd.GetError()
|
||||
}
|
|
@ -0,0 +1,72 @@
|
|||
/*
|
||||
Copyright 2014 SAP SE
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package protocol
|
||||
|
||||
import (
|
||||
"github.com/SAP/go-hdb/internal/bufio"
|
||||
)
|
||||
|
||||
//rows affected
|
||||
const (
|
||||
raSuccessNoInfo = -2
|
||||
raExecutionFailed = -3
|
||||
)
|
||||
|
||||
//rows affected
|
||||
type rowsAffected struct {
|
||||
rows []int32
|
||||
_numArg int
|
||||
}
|
||||
|
||||
func (r *rowsAffected) kind() partKind {
|
||||
return pkRowsAffected
|
||||
}
|
||||
|
||||
func (r *rowsAffected) setNumArg(numArg int) {
|
||||
r._numArg = numArg
|
||||
}
|
||||
|
||||
func (r *rowsAffected) read(rd *bufio.Reader) error {
|
||||
if r.rows == nil || r._numArg > cap(r.rows) {
|
||||
r.rows = make([]int32, r._numArg)
|
||||
} else {
|
||||
r.rows = r.rows[:r._numArg]
|
||||
}
|
||||
|
||||
for i := 0; i < r._numArg; i++ {
|
||||
r.rows[i] = rd.ReadInt32()
|
||||
if trace {
|
||||
outLogger.Printf("rows affected %d: %d", i, r.rows[i])
|
||||
}
|
||||
}
|
||||
|
||||
return rd.GetError()
|
||||
}
|
||||
|
||||
func (r *rowsAffected) total() int64 {
|
||||
if r.rows == nil {
|
||||
return 0
|
||||
}
|
||||
|
||||
total := int64(0)
|
||||
for _, rows := range r.rows {
|
||||
if rows > 0 {
|
||||
total += int64(rows)
|
||||
}
|
||||
}
|
||||
return total
|
||||
}
|
|
@ -0,0 +1,265 @@
|
|||
/*
|
||||
Copyright 2014 SAP SE
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package protocol
|
||||
|
||||
//Salted Challenge Response Authentication Mechanism (SCRAM)
|
||||
|
||||
import (
|
||||
"crypto/hmac"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"fmt"
|
||||
|
||||
"github.com/SAP/go-hdb/internal/bufio"
|
||||
)
|
||||
|
||||
const (
|
||||
clientChallengeSize = 64
|
||||
serverChallengeDataSize = 68
|
||||
clientProofDataSize = 35
|
||||
clientProofSize = 32
|
||||
)
|
||||
|
||||
type scramsha256InitialRequest struct {
|
||||
username []byte
|
||||
clientChallenge []byte
|
||||
}
|
||||
|
||||
func (r *scramsha256InitialRequest) kind() partKind {
|
||||
return pkAuthentication
|
||||
}
|
||||
|
||||
func (r *scramsha256InitialRequest) size() (int, error) {
|
||||
return 2 + authFieldSize(r.username) + authFieldSize([]byte(mnSCRAMSHA256)) + authFieldSize(r.clientChallenge), nil
|
||||
}
|
||||
|
||||
func (r *scramsha256InitialRequest) numArg() int {
|
||||
return 1
|
||||
}
|
||||
|
||||
func (r *scramsha256InitialRequest) write(wr *bufio.Writer) error {
|
||||
wr.WriteInt16(3)
|
||||
writeAuthField(wr, r.username)
|
||||
writeAuthField(wr, []byte(mnSCRAMSHA256))
|
||||
writeAuthField(wr, r.clientChallenge)
|
||||
return nil
|
||||
}
|
||||
|
||||
type scramsha256InitialReply struct {
|
||||
salt []byte
|
||||
serverChallenge []byte
|
||||
}
|
||||
|
||||
func (r *scramsha256InitialReply) kind() partKind {
|
||||
return pkAuthentication
|
||||
}
|
||||
|
||||
func (r *scramsha256InitialReply) setNumArg(int) {
|
||||
//not needed
|
||||
}
|
||||
|
||||
func (r *scramsha256InitialReply) read(rd *bufio.Reader) error {
|
||||
cnt := rd.ReadInt16()
|
||||
if err := readMethodName(rd); err != nil {
|
||||
return err
|
||||
}
|
||||
size := rd.ReadB()
|
||||
if size != serverChallengeDataSize {
|
||||
return fmt.Errorf("invalid server challenge data size %d - %d expected", size, serverChallengeDataSize)
|
||||
}
|
||||
|
||||
//server challenge data
|
||||
|
||||
cnt = rd.ReadInt16()
|
||||
if cnt != 2 {
|
||||
return fmt.Errorf("invalid server challenge data field count %d - %d expected", cnt, 2)
|
||||
}
|
||||
|
||||
size = rd.ReadB()
|
||||
if trace {
|
||||
outLogger.Printf("salt size %d", size)
|
||||
}
|
||||
|
||||
r.salt = make([]byte, size)
|
||||
rd.ReadFull(r.salt)
|
||||
if trace {
|
||||
outLogger.Printf("salt %v", r.salt)
|
||||
}
|
||||
|
||||
size = rd.ReadB()
|
||||
r.serverChallenge = make([]byte, size)
|
||||
rd.ReadFull(r.serverChallenge)
|
||||
if trace {
|
||||
outLogger.Printf("server challenge %v", r.serverChallenge)
|
||||
}
|
||||
|
||||
return rd.GetError()
|
||||
}
|
||||
|
||||
type scramsha256FinalRequest struct {
|
||||
username []byte
|
||||
clientProof []byte
|
||||
}
|
||||
|
||||
func newScramsha256FinalRequest() *scramsha256FinalRequest {
|
||||
return &scramsha256FinalRequest{}
|
||||
}
|
||||
|
||||
func (r *scramsha256FinalRequest) kind() partKind {
|
||||
return pkAuthentication
|
||||
}
|
||||
|
||||
func (r *scramsha256FinalRequest) size() (int, error) {
|
||||
return 2 + authFieldSize(r.username) + authFieldSize([]byte(mnSCRAMSHA256)) + authFieldSize(r.clientProof), nil
|
||||
}
|
||||
|
||||
func (r *scramsha256FinalRequest) numArg() int {
|
||||
return 1
|
||||
}
|
||||
|
||||
func (r *scramsha256FinalRequest) write(wr *bufio.Writer) error {
|
||||
wr.WriteInt16(3)
|
||||
writeAuthField(wr, r.username)
|
||||
writeAuthField(wr, []byte(mnSCRAMSHA256))
|
||||
writeAuthField(wr, r.clientProof)
|
||||
return nil
|
||||
}
|
||||
|
||||
type scramsha256FinalReply struct {
|
||||
serverProof []byte
|
||||
}
|
||||
|
||||
func newScramsha256FinalReply() *scramsha256FinalReply {
|
||||
return &scramsha256FinalReply{}
|
||||
}
|
||||
|
||||
func (r *scramsha256FinalReply) kind() partKind {
|
||||
return pkAuthentication
|
||||
}
|
||||
|
||||
func (r *scramsha256FinalReply) setNumArg(int) {
|
||||
//not needed
|
||||
}
|
||||
|
||||
func (r *scramsha256FinalReply) read(rd *bufio.Reader) error {
|
||||
cnt := rd.ReadInt16()
|
||||
if cnt != 2 {
|
||||
return fmt.Errorf("invalid final reply field count %d - %d expected", cnt, 2)
|
||||
}
|
||||
if err := readMethodName(rd); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
//serverProof
|
||||
size := rd.ReadB()
|
||||
|
||||
serverProof := make([]byte, size)
|
||||
rd.ReadFull(serverProof)
|
||||
|
||||
return rd.GetError()
|
||||
}
|
||||
|
||||
//helper
|
||||
func authFieldSize(f []byte) int {
|
||||
size := len(f)
|
||||
if size >= 250 {
|
||||
// - different indicators compared to db field handling
|
||||
// - 1-5 bytes? but only 1 resp 3 bytes explained
|
||||
panic("not implemented error")
|
||||
}
|
||||
return size + 1 //length indicator size := 1
|
||||
}
|
||||
|
||||
func writeAuthField(wr *bufio.Writer, f []byte) {
|
||||
size := len(f)
|
||||
if size >= 250 {
|
||||
// - different indicators compared to db field handling
|
||||
// - 1-5 bytes? but only 1 resp 3 bytes explained
|
||||
panic("not implemented error")
|
||||
}
|
||||
|
||||
wr.WriteB(byte(size))
|
||||
wr.Write(f)
|
||||
}
|
||||
|
||||
func readMethodName(rd *bufio.Reader) error {
|
||||
size := rd.ReadB()
|
||||
methodName := make([]byte, size)
|
||||
rd.ReadFull(methodName)
|
||||
if string(methodName) != mnSCRAMSHA256 {
|
||||
return fmt.Errorf("invalid authentication method %s - %s expected", methodName, mnSCRAMSHA256)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func clientChallenge() []byte {
|
||||
r := make([]byte, clientChallengeSize)
|
||||
if _, err := rand.Read(r); err != nil {
|
||||
outLogger.Fatal("client challenge fatal error")
|
||||
}
|
||||
return r
|
||||
}
|
||||
|
||||
func clientProof(salt, serverChallenge, clientChallenge, password []byte) []byte {
|
||||
|
||||
clientProof := make([]byte, clientProofDataSize)
|
||||
|
||||
buf := make([]byte, 0, len(salt)+len(serverChallenge)+len(clientChallenge))
|
||||
buf = append(buf, salt...)
|
||||
buf = append(buf, serverChallenge...)
|
||||
buf = append(buf, clientChallenge...)
|
||||
|
||||
key := _sha256(_hmac(password, salt))
|
||||
sig := _hmac(_sha256(key), buf)
|
||||
|
||||
proof := xor(sig, key)
|
||||
//actual implementation: only one salt value?
|
||||
clientProof[0] = 0
|
||||
clientProof[1] = 1
|
||||
clientProof[2] = clientProofSize
|
||||
copy(clientProof[3:], proof)
|
||||
return clientProof
|
||||
}
|
||||
|
||||
func _sha256(p []byte) []byte {
|
||||
hash := sha256.New()
|
||||
hash.Write(p)
|
||||
s := hash.Sum(nil)
|
||||
if trace {
|
||||
outLogger.Printf("sha length %d value %v", len(s), s)
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
func _hmac(key, p []byte) []byte {
|
||||
hash := hmac.New(sha256.New, key)
|
||||
hash.Write(p)
|
||||
s := hash.Sum(nil)
|
||||
if trace {
|
||||
outLogger.Printf("hmac length %d value %v", len(s), s)
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
func xor(sig, key []byte) []byte {
|
||||
r := make([]byte, len(sig))
|
||||
|
||||
for i, v := range sig {
|
||||
r[i] = v ^ key[i]
|
||||
}
|
||||
return r
|
||||
}
|
|
@ -0,0 +1,174 @@
|
|||
/*
|
||||
Copyright 2014 SAP SE
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package protocol
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/SAP/go-hdb/internal/bufio"
|
||||
)
|
||||
|
||||
const (
|
||||
segmentHeaderSize = 24
|
||||
)
|
||||
|
||||
type commandOptions int8
|
||||
|
||||
const (
|
||||
coNil commandOptions = 0x00
|
||||
coSelfetchOff commandOptions = 0x01
|
||||
coScrollableCursorOn commandOptions = 0x02
|
||||
coNoResultsetCloseNeeded commandOptions = 0x04
|
||||
coHoldCursorOverCommtit commandOptions = 0x08
|
||||
coExecuteLocally commandOptions = 0x10
|
||||
)
|
||||
|
||||
var commandOptionsText = map[commandOptions]string{
|
||||
coSelfetchOff: "selfetchOff",
|
||||
coScrollableCursorOn: "scrollabeCursorOn",
|
||||
coNoResultsetCloseNeeded: "noResltsetCloseNeeded",
|
||||
coHoldCursorOverCommtit: "holdCursorOverCommit",
|
||||
coExecuteLocally: "executLocally",
|
||||
}
|
||||
|
||||
func (k commandOptions) String() string {
|
||||
t := make([]string, 0, len(commandOptionsText))
|
||||
|
||||
for option, text := range commandOptionsText {
|
||||
if (k & option) != 0 {
|
||||
t = append(t, text)
|
||||
}
|
||||
}
|
||||
return fmt.Sprintf("%v", t)
|
||||
}
|
||||
|
||||
//segment header
|
||||
type segmentHeader struct {
|
||||
segmentLength int32
|
||||
segmentOfs int32
|
||||
noOfParts int16
|
||||
segmentNo int16
|
||||
segmentKind segmentKind
|
||||
messageType messageType
|
||||
commit bool
|
||||
commandOptions commandOptions
|
||||
functionCode functionCode
|
||||
}
|
||||
|
||||
func (h *segmentHeader) String() string {
|
||||
switch h.segmentKind {
|
||||
|
||||
default: //error
|
||||
return fmt.Sprintf(
|
||||
"segment length %d segment ofs %d noOfParts %d, segmentNo %d segmentKind %s",
|
||||
h.segmentLength,
|
||||
h.segmentOfs,
|
||||
h.noOfParts,
|
||||
h.segmentNo,
|
||||
h.segmentKind,
|
||||
)
|
||||
case skRequest:
|
||||
return fmt.Sprintf(
|
||||
"segment length %d segment ofs %d noOfParts %d, segmentNo %d segmentKind %s messageType %s commit %t commandOptions %s",
|
||||
h.segmentLength,
|
||||
h.segmentOfs,
|
||||
h.noOfParts,
|
||||
h.segmentNo,
|
||||
h.segmentKind,
|
||||
h.messageType,
|
||||
h.commit,
|
||||
h.commandOptions,
|
||||
)
|
||||
case skReply:
|
||||
return fmt.Sprintf(
|
||||
"segment length %d segment ofs %d noOfParts %d, segmentNo %d segmentKind %s functionCode %s",
|
||||
h.segmentLength,
|
||||
h.segmentOfs,
|
||||
h.noOfParts,
|
||||
h.segmentNo,
|
||||
h.segmentKind,
|
||||
h.functionCode,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// request
|
||||
func (h *segmentHeader) write(wr *bufio.Writer) error {
|
||||
wr.WriteInt32(h.segmentLength)
|
||||
wr.WriteInt32(h.segmentOfs)
|
||||
wr.WriteInt16(h.noOfParts)
|
||||
wr.WriteInt16(h.segmentNo)
|
||||
wr.WriteInt8(int8(h.segmentKind))
|
||||
|
||||
switch h.segmentKind {
|
||||
|
||||
default: //error
|
||||
wr.WriteZeroes(11) //segmentHeaderLength
|
||||
|
||||
case skRequest:
|
||||
wr.WriteInt8(int8(h.messageType))
|
||||
wr.WriteBool(h.commit)
|
||||
wr.WriteInt8(int8(h.commandOptions))
|
||||
wr.WriteZeroes(8) //segmentHeaderSize
|
||||
|
||||
case skReply:
|
||||
|
||||
wr.WriteZeroes(1) //reserved
|
||||
wr.WriteInt16(int16(h.functionCode))
|
||||
wr.WriteZeroes(8) //segmentHeaderSize
|
||||
|
||||
}
|
||||
|
||||
if trace {
|
||||
outLogger.Printf("write segment header: %s", h)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// reply || error
|
||||
func (h *segmentHeader) read(rd *bufio.Reader) error {
|
||||
h.segmentLength = rd.ReadInt32()
|
||||
h.segmentOfs = rd.ReadInt32()
|
||||
h.noOfParts = rd.ReadInt16()
|
||||
h.segmentNo = rd.ReadInt16()
|
||||
h.segmentKind = segmentKind(rd.ReadInt8())
|
||||
|
||||
switch h.segmentKind {
|
||||
|
||||
default: //error
|
||||
rd.Skip(11) //segmentHeaderLength
|
||||
|
||||
case skRequest:
|
||||
h.messageType = messageType(rd.ReadInt8())
|
||||
h.commit = rd.ReadBool()
|
||||
h.commandOptions = commandOptions(rd.ReadInt8())
|
||||
rd.Skip(8) //segmentHeaderLength
|
||||
|
||||
case skReply:
|
||||
rd.Skip(1) //reserved
|
||||
h.functionCode = functionCode(rd.ReadInt16())
|
||||
rd.Skip(8) //segmentHeaderLength
|
||||
|
||||
}
|
||||
|
||||
if trace {
|
||||
outLogger.Printf("read segment header: %s", h)
|
||||
}
|
||||
|
||||
return rd.GetError()
|
||||
}
|
|
@ -0,0 +1,28 @@
|
|||
/*
|
||||
Copyright 2014 SAP SE
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package protocol
|
||||
|
||||
//go:generate stringer -type=segmentKind
|
||||
|
||||
type segmentKind int8
|
||||
|
||||
const (
|
||||
skInvalid segmentKind = 0
|
||||
skRequest segmentKind = 1
|
||||
skReply segmentKind = 2
|
||||
skError segmentKind = 5
|
||||
)
|
25
vendor/github.com/SAP/go-hdb/internal/protocol/segmentkind_string.go
generated
vendored
Normal file
25
vendor/github.com/SAP/go-hdb/internal/protocol/segmentkind_string.go
generated
vendored
Normal file
|
@ -0,0 +1,25 @@
|
|||
// Code generated by "stringer -type=segmentKind"; DO NOT EDIT.
|
||||
|
||||
package protocol
|
||||
|
||||
import "strconv"
|
||||
|
||||
const (
|
||||
_segmentKind_name_0 = "skInvalidskRequestskReply"
|
||||
_segmentKind_name_1 = "skError"
|
||||
)
|
||||
|
||||
var (
|
||||
_segmentKind_index_0 = [...]uint8{0, 9, 18, 25}
|
||||
)
|
||||
|
||||
func (i segmentKind) String() string {
|
||||
switch {
|
||||
case 0 <= i && i <= 2:
|
||||
return _segmentKind_name_0[_segmentKind_index_0[i]:_segmentKind_index_0[i+1]]
|
||||
case i == 5:
|
||||
return _segmentKind_name_1
|
||||
default:
|
||||
return "segmentKind(" + strconv.FormatInt(int64(i), 10) + ")"
|
||||
}
|
||||
}
|
|
@ -0,0 +1,975 @@
|
|||
/*
|
||||
Copyright 2014 SAP SE
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package protocol
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"database/sql/driver"
|
||||
"flag"
|
||||
"fmt"
|
||||
"log"
|
||||
"math"
|
||||
"net"
|
||||
"os"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/SAP/go-hdb/internal/bufio"
|
||||
"github.com/SAP/go-hdb/internal/unicode"
|
||||
"github.com/SAP/go-hdb/internal/unicode/cesu8"
|
||||
|
||||
"github.com/SAP/go-hdb/driver/sqltrace"
|
||||
)
|
||||
|
||||
const (
|
||||
mnSCRAMSHA256 = "SCRAMSHA256"
|
||||
mnGSS = "GSS"
|
||||
mnSAML = "SAML"
|
||||
)
|
||||
|
||||
var trace bool
|
||||
|
||||
func init() {
|
||||
flag.BoolVar(&trace, "hdb.protocol.trace", false, "enabling hdb protocol trace")
|
||||
}
|
||||
|
||||
var (
|
||||
outLogger = log.New(os.Stdout, "hdb.protocol ", log.Ldate|log.Ltime|log.Lshortfile)
|
||||
errLogger = log.New(os.Stderr, "hdb.protocol ", log.Ldate|log.Ltime|log.Lshortfile)
|
||||
)
|
||||
|
||||
//padding
|
||||
const (
|
||||
padding = 8
|
||||
)
|
||||
|
||||
func padBytes(size int) int {
|
||||
if r := size % padding; r != 0 {
|
||||
return padding - r
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// SessionConn wraps the database tcp connection. It sets timeouts and handles driver ErrBadConn behavior.
|
||||
type sessionConn struct {
|
||||
addr string
|
||||
timeout time.Duration
|
||||
conn net.Conn
|
||||
isBad bool // bad connection
|
||||
badError error // error cause for session bad state
|
||||
inTx bool // in transaction
|
||||
}
|
||||
|
||||
func newSessionConn(ctx context.Context, addr string, timeoutSec int, config *tls.Config) (*sessionConn, error) {
|
||||
timeout := time.Duration(timeoutSec) * time.Second
|
||||
dialer := net.Dialer{Timeout: timeout}
|
||||
conn, err := dialer.DialContext(ctx, "tcp", addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// is TLS connection requested?
|
||||
if config != nil {
|
||||
conn = tls.Client(conn, config)
|
||||
}
|
||||
|
||||
return &sessionConn{addr: addr, timeout: timeout, conn: conn}, nil
|
||||
}
|
||||
|
||||
func (c *sessionConn) close() error {
|
||||
return c.conn.Close()
|
||||
}
|
||||
|
||||
// Read implements the io.Reader interface.
|
||||
func (c *sessionConn) Read(b []byte) (int, error) {
|
||||
//set timeout
|
||||
if err := c.conn.SetReadDeadline(time.Now().Add(c.timeout)); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
n, err := c.conn.Read(b)
|
||||
if err != nil {
|
||||
errLogger.Printf("Connection read error local address %s remote address %s: %s", c.conn.LocalAddr(), c.conn.RemoteAddr(), err)
|
||||
c.isBad = true
|
||||
c.badError = err
|
||||
return n, driver.ErrBadConn
|
||||
}
|
||||
return n, nil
|
||||
}
|
||||
|
||||
// Write implements the io.Writer interface.
|
||||
func (c *sessionConn) Write(b []byte) (int, error) {
|
||||
//set timeout
|
||||
if err := c.conn.SetWriteDeadline(time.Now().Add(c.timeout)); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
n, err := c.conn.Write(b)
|
||||
if err != nil {
|
||||
errLogger.Printf("Connection write error local address %s remote address %s: %s", c.conn.LocalAddr(), c.conn.RemoteAddr(), err)
|
||||
c.isBad = true
|
||||
c.badError = err
|
||||
return n, driver.ErrBadConn
|
||||
}
|
||||
return n, nil
|
||||
}
|
||||
|
||||
type beforeRead func(p replyPart)
|
||||
|
||||
// session parameter
|
||||
type sessionPrm interface {
|
||||
Host() string
|
||||
Username() string
|
||||
Password() string
|
||||
Locale() string
|
||||
FetchSize() int
|
||||
Timeout() int
|
||||
TLSConfig() *tls.Config
|
||||
}
|
||||
|
||||
// Session represents a HDB session.
|
||||
type Session struct {
|
||||
prm sessionPrm
|
||||
|
||||
conn *sessionConn
|
||||
rd *bufio.Reader
|
||||
wr *bufio.Writer
|
||||
|
||||
// reuse header
|
||||
mh *messageHeader
|
||||
sh *segmentHeader
|
||||
ph *partHeader
|
||||
|
||||
//reuse request / reply parts
|
||||
scramsha256InitialRequest *scramsha256InitialRequest
|
||||
scramsha256InitialReply *scramsha256InitialReply
|
||||
scramsha256FinalRequest *scramsha256FinalRequest
|
||||
scramsha256FinalReply *scramsha256FinalReply
|
||||
topologyInformation *topologyInformation
|
||||
connectOptions *connectOptions
|
||||
rowsAffected *rowsAffected
|
||||
statementID *statementID
|
||||
resultMetadata *resultMetadata
|
||||
resultsetID *resultsetID
|
||||
resultset *resultset
|
||||
parameterMetadata *parameterMetadata
|
||||
outputParameters *outputParameters
|
||||
writeLobRequest *writeLobRequest
|
||||
readLobRequest *readLobRequest
|
||||
writeLobReply *writeLobReply
|
||||
readLobReply *readLobReply
|
||||
|
||||
//standard replies
|
||||
stmtCtx *statementContext
|
||||
txFlags *transactionFlags
|
||||
lastError *hdbErrors
|
||||
|
||||
//serialize write request - read reply
|
||||
//supports calling session methods in go routines (driver methods with context cancellation)
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
// NewSession creates a new database session.
|
||||
func NewSession(ctx context.Context, prm sessionPrm) (*Session, error) {
|
||||
|
||||
if trace {
|
||||
outLogger.Printf("%s", prm)
|
||||
}
|
||||
|
||||
conn, err := newSessionConn(ctx, prm.Host(), prm.Timeout(), prm.TLSConfig())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rd := bufio.NewReader(conn)
|
||||
wr := bufio.NewWriter(conn)
|
||||
|
||||
s := &Session{
|
||||
prm: prm,
|
||||
conn: conn,
|
||||
rd: rd,
|
||||
wr: wr,
|
||||
mh: new(messageHeader),
|
||||
sh: new(segmentHeader),
|
||||
ph: new(partHeader),
|
||||
scramsha256InitialRequest: new(scramsha256InitialRequest),
|
||||
scramsha256InitialReply: new(scramsha256InitialReply),
|
||||
scramsha256FinalRequest: new(scramsha256FinalRequest),
|
||||
scramsha256FinalReply: new(scramsha256FinalReply),
|
||||
topologyInformation: newTopologyInformation(),
|
||||
connectOptions: newConnectOptions(),
|
||||
rowsAffected: new(rowsAffected),
|
||||
statementID: new(statementID),
|
||||
resultMetadata: new(resultMetadata),
|
||||
resultsetID: new(resultsetID),
|
||||
resultset: new(resultset),
|
||||
parameterMetadata: new(parameterMetadata),
|
||||
outputParameters: new(outputParameters),
|
||||
writeLobRequest: new(writeLobRequest),
|
||||
readLobRequest: new(readLobRequest),
|
||||
writeLobReply: new(writeLobReply),
|
||||
readLobReply: new(readLobReply),
|
||||
stmtCtx: newStatementContext(),
|
||||
txFlags: newTransactionFlags(),
|
||||
lastError: new(hdbErrors),
|
||||
}
|
||||
|
||||
if err = s.init(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return s, nil
|
||||
}
|
||||
|
||||
// Close closes the session.
|
||||
func (s *Session) Close() error {
|
||||
return s.conn.close()
|
||||
}
|
||||
|
||||
func (s *Session) sessionID() int64 {
|
||||
return s.mh.sessionID
|
||||
}
|
||||
|
||||
// InTx indicates, if the session is in transaction mode.
|
||||
func (s *Session) InTx() bool {
|
||||
return s.conn.inTx
|
||||
}
|
||||
|
||||
// SetInTx sets session in transaction mode.
|
||||
func (s *Session) SetInTx(v bool) {
|
||||
s.conn.inTx = v
|
||||
}
|
||||
|
||||
// IsBad indicates, that the session is in bad state.
|
||||
func (s *Session) IsBad() bool {
|
||||
return s.conn.isBad
|
||||
}
|
||||
|
||||
// BadErr returns the error, that caused the bad session state.
|
||||
func (s *Session) BadErr() error {
|
||||
return s.conn.badError
|
||||
}
|
||||
|
||||
func (s *Session) init() error {
|
||||
|
||||
if err := s.initRequest(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// TODO: detect authentication method
|
||||
// - actually only basic authetication supported
|
||||
|
||||
authentication := mnSCRAMSHA256
|
||||
|
||||
switch authentication {
|
||||
default:
|
||||
return fmt.Errorf("invalid authentication %s", authentication)
|
||||
|
||||
case mnSCRAMSHA256:
|
||||
if err := s.authenticateScramsha256(); err != nil {
|
||||
return err
|
||||
}
|
||||
case mnGSS:
|
||||
panic("not implemented error")
|
||||
case mnSAML:
|
||||
panic("not implemented error")
|
||||
}
|
||||
|
||||
id := s.sessionID()
|
||||
if id <= 0 {
|
||||
return fmt.Errorf("invalid session id %d", id)
|
||||
}
|
||||
|
||||
if trace {
|
||||
outLogger.Printf("sessionId %d", id)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Session) authenticateScramsha256() error {
|
||||
tr := unicode.Utf8ToCesu8Transformer
|
||||
tr.Reset()
|
||||
|
||||
username := make([]byte, cesu8.StringSize(s.prm.Username()))
|
||||
if _, _, err := tr.Transform(username, []byte(s.prm.Username()), true); err != nil {
|
||||
return err // should never happen
|
||||
}
|
||||
|
||||
password := make([]byte, cesu8.StringSize(s.prm.Password()))
|
||||
if _, _, err := tr.Transform(password, []byte(s.prm.Password()), true); err != nil {
|
||||
return err //should never happen
|
||||
}
|
||||
|
||||
clientChallenge := clientChallenge()
|
||||
|
||||
//initial request
|
||||
s.scramsha256InitialRequest.username = username
|
||||
s.scramsha256InitialRequest.clientChallenge = clientChallenge
|
||||
|
||||
if err := s.writeRequest(mtAuthenticate, false, s.scramsha256InitialRequest); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := s.readReply(nil); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
//final request
|
||||
s.scramsha256FinalRequest.username = username
|
||||
s.scramsha256FinalRequest.clientProof = clientProof(s.scramsha256InitialReply.salt, s.scramsha256InitialReply.serverChallenge, clientChallenge, password)
|
||||
|
||||
s.scramsha256InitialReply = nil // !!! next time readReply uses FinalReply
|
||||
|
||||
id := newClientID()
|
||||
|
||||
co := newConnectOptions()
|
||||
co.set(coDistributionProtocolVersion, booleanType(false))
|
||||
co.set(coSelectForUpdateSupported, booleanType(false))
|
||||
co.set(coSplitBatchCommands, booleanType(true))
|
||||
// cannot use due to HDB protocol error with secondtime datatype
|
||||
//co.set(coDataFormatVersion2, dfvSPS06)
|
||||
co.set(coDataFormatVersion2, dfvBaseline)
|
||||
co.set(coCompleteArrayExecution, booleanType(true))
|
||||
if s.prm.Locale() != "" {
|
||||
co.set(coClientLocale, stringType(s.prm.Locale()))
|
||||
}
|
||||
co.set(coClientDistributionMode, cdmOff)
|
||||
// setting this option has no effect
|
||||
//co.set(coImplicitLobStreaming, booleanType(true))
|
||||
|
||||
if err := s.writeRequest(mtConnect, false, s.scramsha256FinalRequest, id, co); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := s.readReply(nil); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// QueryDirect executes a query without query parameters.
|
||||
func (s *Session) QueryDirect(query string) (uint64, *ResultFieldSet, *FieldValues, PartAttributes, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if err := s.writeRequest(mtExecuteDirect, false, command(query)); err != nil {
|
||||
return 0, nil, nil, nil, err
|
||||
}
|
||||
|
||||
var id uint64
|
||||
var resultFieldSet *ResultFieldSet
|
||||
fieldValues := newFieldValues()
|
||||
|
||||
f := func(p replyPart) {
|
||||
|
||||
switch p := p.(type) {
|
||||
|
||||
case *resultsetID:
|
||||
p.id = &id
|
||||
case *resultMetadata:
|
||||
resultFieldSet = newResultFieldSet(p.numArg)
|
||||
p.resultFieldSet = resultFieldSet
|
||||
case *resultset:
|
||||
p.s = s
|
||||
p.resultFieldSet = resultFieldSet
|
||||
p.fieldValues = fieldValues
|
||||
}
|
||||
}
|
||||
|
||||
if err := s.readReply(f); err != nil {
|
||||
return 0, nil, nil, nil, err
|
||||
}
|
||||
|
||||
return id, resultFieldSet, fieldValues, s.ph.partAttributes, nil
|
||||
}
|
||||
|
||||
// ExecDirect executes a sql statement without statement parameters.
|
||||
func (s *Session) ExecDirect(query string) (driver.Result, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if err := s.writeRequest(mtExecuteDirect, !s.conn.inTx, command(query)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := s.readReply(nil); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if s.sh.functionCode == fcDDL {
|
||||
return driver.ResultNoRows, nil
|
||||
}
|
||||
return driver.RowsAffected(s.rowsAffected.total()), nil
|
||||
}
|
||||
|
||||
// Prepare prepares a sql statement.
|
||||
func (s *Session) Prepare(query string) (QueryType, uint64, *ParameterFieldSet, *ResultFieldSet, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if err := s.writeRequest(mtPrepare, false, command(query)); err != nil {
|
||||
return QtNone, 0, nil, nil, err
|
||||
}
|
||||
|
||||
var id uint64
|
||||
var prmFieldSet *ParameterFieldSet
|
||||
var resultFieldSet *ResultFieldSet
|
||||
|
||||
f := func(p replyPart) {
|
||||
|
||||
switch p := p.(type) {
|
||||
|
||||
case *statementID:
|
||||
p.id = &id
|
||||
case *parameterMetadata:
|
||||
prmFieldSet = newParameterFieldSet(p.numArg)
|
||||
p.prmFieldSet = prmFieldSet
|
||||
case *resultMetadata:
|
||||
resultFieldSet = newResultFieldSet(p.numArg)
|
||||
p.resultFieldSet = resultFieldSet
|
||||
}
|
||||
}
|
||||
|
||||
if err := s.readReply(f); err != nil {
|
||||
return QtNone, 0, nil, nil, err
|
||||
}
|
||||
|
||||
return s.sh.functionCode.queryType(), id, prmFieldSet, resultFieldSet, nil
|
||||
}
|
||||
|
||||
// Exec executes a sql statement.
|
||||
func (s *Session) Exec(id uint64, prmFieldSet *ParameterFieldSet, args []driver.NamedValue) (driver.Result, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
s.statementID.id = &id
|
||||
if err := s.writeRequest(mtExecute, !s.conn.inTx, s.statementID, newInputParameters(prmFieldSet.inputFields(), args)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := s.readReply(nil); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var result driver.Result
|
||||
if s.sh.functionCode == fcDDL {
|
||||
result = driver.ResultNoRows
|
||||
} else {
|
||||
result = driver.RowsAffected(s.rowsAffected.total())
|
||||
}
|
||||
|
||||
if err := s.writeLobStream(prmFieldSet, nil, args); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// DropStatementID releases the hdb statement handle.
|
||||
func (s *Session) DropStatementID(id uint64) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
s.statementID.id = &id
|
||||
if err := s.writeRequest(mtDropStatementID, false, s.statementID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := s.readReply(nil); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Call executes a stored procedure.
|
||||
func (s *Session) Call(id uint64, prmFieldSet *ParameterFieldSet, args []driver.NamedValue) (*FieldValues, []*TableResult, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
s.statementID.id = &id
|
||||
if err := s.writeRequest(mtExecute, false, s.statementID, newInputParameters(prmFieldSet.inputFields(), args)); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
prmFieldValues := newFieldValues()
|
||||
var tableResults []*TableResult
|
||||
var tableResult *TableResult
|
||||
|
||||
f := func(p replyPart) {
|
||||
|
||||
switch p := p.(type) {
|
||||
|
||||
case *outputParameters:
|
||||
p.s = s
|
||||
p.outputFields = prmFieldSet.outputFields()
|
||||
p.fieldValues = prmFieldValues
|
||||
|
||||
// table output parameters: meta, id, result (only first param?)
|
||||
case *resultMetadata:
|
||||
tableResult = newTableResult(s, p.numArg)
|
||||
tableResults = append(tableResults, tableResult)
|
||||
p.resultFieldSet = tableResult.resultFieldSet
|
||||
case *resultsetID:
|
||||
p.id = &(tableResult.id)
|
||||
case *resultset:
|
||||
p.s = s
|
||||
tableResult.attrs = s.ph.partAttributes
|
||||
p.resultFieldSet = tableResult.resultFieldSet
|
||||
p.fieldValues = tableResult.fieldValues
|
||||
}
|
||||
}
|
||||
|
||||
if err := s.readReply(f); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
if err := s.writeLobStream(prmFieldSet, prmFieldValues, args); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
return prmFieldValues, tableResults, nil
|
||||
}
|
||||
|
||||
// Query executes a query.
|
||||
func (s *Session) Query(stmtID uint64, prmFieldSet *ParameterFieldSet, resultFieldSet *ResultFieldSet, args []driver.NamedValue) (uint64, *FieldValues, PartAttributes, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
s.statementID.id = &stmtID
|
||||
if err := s.writeRequest(mtExecute, false, s.statementID, newInputParameters(prmFieldSet.inputFields(), args)); err != nil {
|
||||
return 0, nil, nil, err
|
||||
}
|
||||
|
||||
var rsetID uint64
|
||||
fieldValues := newFieldValues()
|
||||
|
||||
f := func(p replyPart) {
|
||||
|
||||
switch p := p.(type) {
|
||||
|
||||
case *resultsetID:
|
||||
p.id = &rsetID
|
||||
case *resultset:
|
||||
p.s = s
|
||||
p.resultFieldSet = resultFieldSet
|
||||
p.fieldValues = fieldValues
|
||||
}
|
||||
}
|
||||
|
||||
if err := s.readReply(f); err != nil {
|
||||
return 0, nil, nil, err
|
||||
}
|
||||
|
||||
return rsetID, fieldValues, s.ph.partAttributes, nil
|
||||
}
|
||||
|
||||
// FetchNext fetches next chunk in query result set.
|
||||
func (s *Session) FetchNext(id uint64, resultFieldSet *ResultFieldSet, fieldValues *FieldValues) (PartAttributes, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
s.resultsetID.id = &id
|
||||
if err := s.writeRequest(mtFetchNext, false, s.resultsetID, fetchsize(s.prm.FetchSize())); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
f := func(p replyPart) {
|
||||
|
||||
switch p := p.(type) {
|
||||
|
||||
case *resultset:
|
||||
p.s = s
|
||||
p.resultFieldSet = resultFieldSet
|
||||
p.fieldValues = fieldValues
|
||||
}
|
||||
}
|
||||
|
||||
if err := s.readReply(f); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return s.ph.partAttributes, nil
|
||||
}
|
||||
|
||||
// CloseResultsetID releases the hdb resultset handle.
|
||||
func (s *Session) CloseResultsetID(id uint64) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
s.resultsetID.id = &id
|
||||
if err := s.writeRequest(mtCloseResultset, false, s.resultsetID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := s.readReply(nil); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Commit executes a database commit.
|
||||
func (s *Session) Commit() error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if err := s.writeRequest(mtCommit, false); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := s.readReply(nil); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if trace {
|
||||
outLogger.Printf("transaction flags: %s", s.txFlags)
|
||||
}
|
||||
|
||||
s.conn.inTx = false
|
||||
return nil
|
||||
}
|
||||
|
||||
// Rollback executes a database rollback.
|
||||
func (s *Session) Rollback() error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if err := s.writeRequest(mtRollback, false); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := s.readReply(nil); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if trace {
|
||||
outLogger.Printf("transaction flags: %s", s.txFlags)
|
||||
}
|
||||
|
||||
s.conn.inTx = false
|
||||
return nil
|
||||
}
|
||||
|
||||
//
|
||||
|
||||
func (s *Session) readLobStream(w lobChunkWriter) error {
|
||||
|
||||
s.readLobRequest.w = w
|
||||
s.readLobReply.w = w
|
||||
|
||||
for !w.eof() {
|
||||
|
||||
if err := s.writeRequest(mtWriteLob, false, s.readLobRequest); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := s.readReply(nil); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Session) writeLobStream(prmFieldSet *ParameterFieldSet, prmFieldValues *FieldValues, args []driver.NamedValue) error {
|
||||
|
||||
if s.writeLobReply.numArg == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
lobPrmFields := make([]*ParameterField, s.writeLobReply.numArg)
|
||||
|
||||
j := 0
|
||||
for _, f := range prmFieldSet.fields {
|
||||
if f.TypeCode().isLob() && f.In() && f.chunkReader != nil {
|
||||
f.lobLocatorID = s.writeLobReply.ids[j]
|
||||
lobPrmFields[j] = f
|
||||
j++
|
||||
}
|
||||
}
|
||||
if j != s.writeLobReply.numArg {
|
||||
return fmt.Errorf("protocol error: invalid number of lob parameter ids %d - expected %d", j, s.writeLobReply.numArg)
|
||||
}
|
||||
|
||||
s.writeLobRequest.lobPrmFields = lobPrmFields
|
||||
|
||||
f := func(p replyPart) {
|
||||
if p, ok := p.(*outputParameters); ok {
|
||||
p.s = s
|
||||
p.outputFields = prmFieldSet.outputFields()
|
||||
p.fieldValues = prmFieldValues
|
||||
}
|
||||
}
|
||||
|
||||
for s.writeLobReply.numArg != 0 {
|
||||
if err := s.writeRequest(mtReadLob, false, s.writeLobRequest); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := s.readReply(f); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
//
|
||||
|
||||
func (s *Session) initRequest() error {
|
||||
|
||||
// init
|
||||
s.mh.sessionID = -1
|
||||
|
||||
// handshake
|
||||
req := newInitRequest()
|
||||
// TODO: constants
|
||||
req.product.major = 4
|
||||
req.product.minor = 20
|
||||
req.protocol.major = 4
|
||||
req.protocol.minor = 1
|
||||
req.numOptions = 1
|
||||
req.endianess = archEndian
|
||||
if err := req.write(s.wr); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
rep := newInitReply()
|
||||
if err := rep.read(s.rd); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Session) writeRequest(messageType messageType, commit bool, requests ...requestPart) error {
|
||||
|
||||
partSize := make([]int, len(requests))
|
||||
|
||||
size := int64(segmentHeaderSize + len(requests)*partHeaderSize) //int64 to hold MaxUInt32 in 32bit OS
|
||||
|
||||
for i, part := range requests {
|
||||
s, err := part.size()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
size += int64(s + padBytes(s))
|
||||
partSize[i] = s // buffer size (expensive calculation)
|
||||
}
|
||||
|
||||
if size > math.MaxUint32 {
|
||||
return fmt.Errorf("message size %d exceeds maximum message header value %d", size, int64(math.MaxUint32)) //int64: without cast overflow error in 32bit OS
|
||||
}
|
||||
|
||||
bufferSize := size
|
||||
|
||||
s.mh.varPartLength = uint32(size)
|
||||
s.mh.varPartSize = uint32(bufferSize)
|
||||
s.mh.noOfSegm = 1
|
||||
|
||||
if err := s.mh.write(s.wr); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if size > math.MaxInt32 {
|
||||
return fmt.Errorf("message size %d exceeds maximum part header value %d", size, math.MaxInt32)
|
||||
}
|
||||
|
||||
s.sh.messageType = messageType
|
||||
s.sh.commit = commit
|
||||
s.sh.segmentKind = skRequest
|
||||
s.sh.segmentLength = int32(size)
|
||||
s.sh.segmentOfs = 0
|
||||
s.sh.noOfParts = int16(len(requests))
|
||||
s.sh.segmentNo = 1
|
||||
|
||||
if err := s.sh.write(s.wr); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
bufferSize -= segmentHeaderSize
|
||||
|
||||
for i, part := range requests {
|
||||
|
||||
size := partSize[i]
|
||||
pad := padBytes(size)
|
||||
|
||||
s.ph.partKind = part.kind()
|
||||
numArg := part.numArg()
|
||||
switch {
|
||||
default:
|
||||
return fmt.Errorf("maximum number of arguments %d exceeded", numArg)
|
||||
case numArg <= math.MaxInt16:
|
||||
s.ph.argumentCount = int16(numArg)
|
||||
s.ph.bigArgumentCount = 0
|
||||
|
||||
// TODO: seems not to work: see bulk insert test
|
||||
case numArg <= math.MaxInt32:
|
||||
s.ph.argumentCount = 0
|
||||
s.ph.bigArgumentCount = int32(numArg)
|
||||
}
|
||||
|
||||
s.ph.bufferLength = int32(size)
|
||||
s.ph.bufferSize = int32(bufferSize)
|
||||
|
||||
if err := s.ph.write(s.wr); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := part.write(s.wr); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
s.wr.WriteZeroes(pad)
|
||||
|
||||
bufferSize -= int64(partHeaderSize + size + pad)
|
||||
|
||||
}
|
||||
|
||||
return s.wr.Flush()
|
||||
|
||||
}
|
||||
|
||||
func (s *Session) readReply(beforeRead beforeRead) error {
|
||||
|
||||
replyRowsAffected := false
|
||||
replyError := false
|
||||
|
||||
if err := s.mh.read(s.rd); err != nil {
|
||||
return err
|
||||
}
|
||||
if s.mh.noOfSegm != 1 {
|
||||
return fmt.Errorf("simple message: no of segments %d - expected 1", s.mh.noOfSegm)
|
||||
}
|
||||
if err := s.sh.read(s.rd); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// TODO: protocol error (sps 82)?: message header varPartLength < segment header segmentLength (*1)
|
||||
diff := int(s.mh.varPartLength) - int(s.sh.segmentLength)
|
||||
if trace && diff != 0 {
|
||||
outLogger.Printf("+++++diff %d", diff)
|
||||
}
|
||||
|
||||
noOfParts := int(s.sh.noOfParts)
|
||||
lastPart := noOfParts - 1
|
||||
|
||||
for i := 0; i < noOfParts; i++ {
|
||||
|
||||
if err := s.ph.read(s.rd); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
numArg := int(s.ph.argumentCount)
|
||||
|
||||
var part replyPart
|
||||
|
||||
switch s.ph.partKind {
|
||||
|
||||
case pkAuthentication:
|
||||
if s.scramsha256InitialReply != nil { // first call: initial reply
|
||||
part = s.scramsha256InitialReply
|
||||
} else { // second call: final reply
|
||||
part = s.scramsha256FinalReply
|
||||
}
|
||||
case pkTopologyInformation:
|
||||
part = s.topologyInformation
|
||||
case pkConnectOptions:
|
||||
part = s.connectOptions
|
||||
case pkStatementID:
|
||||
part = s.statementID
|
||||
case pkResultMetadata:
|
||||
part = s.resultMetadata
|
||||
case pkResultsetID:
|
||||
part = s.resultsetID
|
||||
case pkResultset:
|
||||
part = s.resultset
|
||||
case pkParameterMetadata:
|
||||
part = s.parameterMetadata
|
||||
case pkOutputParameters:
|
||||
part = s.outputParameters
|
||||
case pkError:
|
||||
replyError = true
|
||||
part = s.lastError
|
||||
case pkStatementContext:
|
||||
part = s.stmtCtx
|
||||
case pkTransactionFlags:
|
||||
part = s.txFlags
|
||||
case pkRowsAffected:
|
||||
replyRowsAffected = true
|
||||
part = s.rowsAffected
|
||||
case pkReadLobReply:
|
||||
part = s.readLobReply
|
||||
case pkWriteLobReply:
|
||||
part = s.writeLobReply
|
||||
default:
|
||||
return fmt.Errorf("read not expected part kind %s", s.ph.partKind)
|
||||
}
|
||||
|
||||
part.setNumArg(numArg)
|
||||
|
||||
if beforeRead != nil {
|
||||
beforeRead(part)
|
||||
}
|
||||
|
||||
if err := part.read(s.rd); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if i != lastPart { // not last part
|
||||
// Error padding (protocol error?)
|
||||
// driver test TestHDBWarning
|
||||
// --> 18 bytes fix error bytes + 103 bytes error text => 121 bytes (7 bytes padding needed)
|
||||
// but s.ph.bufferLength = 122 (standard padding would only consume 6 bytes instead of 7)
|
||||
// driver test TestBulkInsertDuplicates
|
||||
// --> returns 3 errors (number of total bytes matches s.ph.bufferLength)
|
||||
// ==> hdbErrors take care about padding
|
||||
if s.ph.partKind != pkError {
|
||||
s.rd.Skip(padBytes(int(s.ph.bufferLength)))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// last part
|
||||
// TODO: workaround (see *)
|
||||
if diff == 0 {
|
||||
s.rd.Skip(padBytes(int(s.ph.bufferLength)))
|
||||
}
|
||||
|
||||
if err := s.rd.GetError(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if replyError {
|
||||
if replyRowsAffected { //link statement to error
|
||||
j := 0
|
||||
for i, rows := range s.rowsAffected.rows {
|
||||
if rows == raExecutionFailed {
|
||||
s.lastError.setStmtNo(j, i)
|
||||
j++
|
||||
}
|
||||
}
|
||||
}
|
||||
if s.lastError.isWarnings() {
|
||||
for _, _error := range s.lastError.errors {
|
||||
sqltrace.Traceln(_error)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
return s.lastError
|
||||
}
|
||||
return nil
|
||||
}
|
|
@ -0,0 +1,203 @@
|
|||
/*
|
||||
Copyright 2014 SAP SE
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package protocol
|
||||
|
||||
import (
|
||||
"log"
|
||||
"net"
|
||||
|
||||
"github.com/SAP/go-hdb/internal/bufio"
|
||||
)
|
||||
|
||||
type dir bool
|
||||
|
||||
const (
|
||||
maxBinarySize = 128
|
||||
)
|
||||
|
||||
type fragment interface {
|
||||
read(rd *bufio.Reader) error
|
||||
write(wr *bufio.Writer) error
|
||||
}
|
||||
|
||||
func (d dir) String() string {
|
||||
if d {
|
||||
return "->"
|
||||
}
|
||||
return "<-"
|
||||
}
|
||||
|
||||
// A Sniffer is a simple proxy for logging hdb protocol requests and responses.
|
||||
type Sniffer struct {
|
||||
conn net.Conn
|
||||
dbAddr string
|
||||
dbConn net.Conn
|
||||
|
||||
//client
|
||||
clRd *bufio.Reader
|
||||
clWr *bufio.Writer
|
||||
//database
|
||||
dbRd *bufio.Reader
|
||||
dbWr *bufio.Writer
|
||||
|
||||
mh *messageHeader
|
||||
sh *segmentHeader
|
||||
ph *partHeader
|
||||
|
||||
buf []byte
|
||||
}
|
||||
|
||||
// NewSniffer creates a new sniffer instance. The conn parameter is the net.Conn connection, where the Sniffer
|
||||
// is listening for hdb protocol calls. The dbAddr is the hdb host port address in "host:port" format.
|
||||
func NewSniffer(conn net.Conn, dbAddr string) (*Sniffer, error) {
|
||||
s := &Sniffer{
|
||||
conn: conn,
|
||||
dbAddr: dbAddr,
|
||||
clRd: bufio.NewReader(conn),
|
||||
clWr: bufio.NewWriter(conn),
|
||||
mh: &messageHeader{},
|
||||
sh: &segmentHeader{},
|
||||
ph: &partHeader{},
|
||||
buf: make([]byte, 0),
|
||||
}
|
||||
|
||||
dbConn, err := net.Dial("tcp", s.dbAddr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
s.dbRd = bufio.NewReader(dbConn)
|
||||
s.dbWr = bufio.NewWriter(dbConn)
|
||||
s.dbConn = dbConn
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func (s *Sniffer) getBuffer(size int) []byte {
|
||||
if cap(s.buf) < size {
|
||||
s.buf = make([]byte, size)
|
||||
}
|
||||
return s.buf[:size]
|
||||
}
|
||||
|
||||
// Go starts the protocol request and response logging.
|
||||
func (s *Sniffer) Go() {
|
||||
defer s.dbConn.Close()
|
||||
defer s.conn.Close()
|
||||
|
||||
req := newInitRequest()
|
||||
if err := s.streamFragment(dir(true), s.clRd, s.dbWr, req); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
rep := newInitReply()
|
||||
if err := s.streamFragment(dir(false), s.dbRd, s.clWr, rep); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
for {
|
||||
//up stream
|
||||
if err := s.stream(dir(true), s.clRd, s.dbWr); err != nil {
|
||||
return
|
||||
}
|
||||
//down stream
|
||||
if err := s.stream(dir(false), s.dbRd, s.clWr); err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Sniffer) stream(d dir, from *bufio.Reader, to *bufio.Writer) error {
|
||||
|
||||
if err := s.streamFragment(d, from, to, s.mh); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
size := int(s.mh.varPartLength)
|
||||
|
||||
for i := 0; i < int(s.mh.noOfSegm); i++ {
|
||||
|
||||
if err := s.streamFragment(d, from, to, s.sh); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
size -= int(s.sh.segmentLength)
|
||||
|
||||
for j := 0; j < int(s.sh.noOfParts); j++ {
|
||||
|
||||
if err := s.streamFragment(d, from, to, s.ph); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// protocol error workaraound
|
||||
padding := (size == 0) || (j != (int(s.sh.noOfParts) - 1))
|
||||
|
||||
if err := s.streamPart(d, from, to, s.ph, padding); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
return to.Flush()
|
||||
}
|
||||
|
||||
func (s *Sniffer) streamPart(d dir, from *bufio.Reader, to *bufio.Writer, ph *partHeader, padding bool) error {
|
||||
|
||||
switch ph.partKind {
|
||||
|
||||
default:
|
||||
return s.streamBinary(d, from, to, int(ph.bufferLength), padding)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Sniffer) streamBinary(d dir, from *bufio.Reader, to *bufio.Writer, size int, padding bool) error {
|
||||
var b []byte
|
||||
|
||||
//protocol error workaraound
|
||||
if padding {
|
||||
pad := padBytes(size)
|
||||
b = s.getBuffer(size + pad)
|
||||
} else {
|
||||
b = s.getBuffer(size)
|
||||
}
|
||||
|
||||
from.ReadFull(b)
|
||||
err := from.GetError()
|
||||
if err != nil {
|
||||
log.Print(err)
|
||||
return err
|
||||
}
|
||||
|
||||
if size > maxBinarySize {
|
||||
log.Printf("%s %v", d, b[:maxBinarySize])
|
||||
} else {
|
||||
log.Printf("%s %v", d, b[:size])
|
||||
}
|
||||
to.Write(b)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Sniffer) streamFragment(d dir, from *bufio.Reader, to *bufio.Writer, f fragment) error {
|
||||
if err := f.read(from); err != nil {
|
||||
log.Print(err)
|
||||
return err
|
||||
}
|
||||
log.Printf("%s %s", d, f)
|
||||
if err := f.write(to); err != nil {
|
||||
log.Print(err)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
60
vendor/github.com/SAP/go-hdb/internal/protocol/statementcontext.go
generated
vendored
Normal file
60
vendor/github.com/SAP/go-hdb/internal/protocol/statementcontext.go
generated
vendored
Normal file
|
@ -0,0 +1,60 @@
|
|||
/*
|
||||
Copyright 2014 SAP SE
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package protocol
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/SAP/go-hdb/internal/bufio"
|
||||
)
|
||||
|
||||
type statementContext struct {
|
||||
options plainOptions
|
||||
_numArg int
|
||||
}
|
||||
|
||||
func newStatementContext() *statementContext {
|
||||
return &statementContext{
|
||||
options: plainOptions{},
|
||||
}
|
||||
}
|
||||
|
||||
func (c *statementContext) String() string {
|
||||
typedSc := make(map[statementContextType]interface{})
|
||||
for k, v := range c.options {
|
||||
typedSc[statementContextType(k)] = v
|
||||
}
|
||||
return fmt.Sprintf("%s", typedSc)
|
||||
}
|
||||
|
||||
func (c *statementContext) kind() partKind {
|
||||
return pkStatementContext
|
||||
}
|
||||
|
||||
func (c *statementContext) setNumArg(numArg int) {
|
||||
c._numArg = numArg
|
||||
}
|
||||
|
||||
func (c *statementContext) read(rd *bufio.Reader) error {
|
||||
c.options.read(rd, c._numArg)
|
||||
|
||||
if trace {
|
||||
outLogger.Printf("statement context: %v", c)
|
||||
}
|
||||
|
||||
return rd.GetError()
|
||||
}
|
26
vendor/github.com/SAP/go-hdb/internal/protocol/statementcontexttype.go
generated
vendored
Normal file
26
vendor/github.com/SAP/go-hdb/internal/protocol/statementcontexttype.go
generated
vendored
Normal file
|
@ -0,0 +1,26 @@
|
|||
/*
|
||||
Copyright 2014 SAP SE
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package protocol
|
||||
|
||||
//go:generate stringer -type=statementContextType
|
||||
|
||||
type statementContextType int8
|
||||
|
||||
const (
|
||||
scStatementSequenceInfo statementContextType = 1
|
||||
scServerExecutionTime statementContextType = 2
|
||||
)
|
17
vendor/github.com/SAP/go-hdb/internal/protocol/statementcontexttype_string.go
generated
vendored
Normal file
17
vendor/github.com/SAP/go-hdb/internal/protocol/statementcontexttype_string.go
generated
vendored
Normal file
|
@ -0,0 +1,17 @@
|
|||
// Code generated by "stringer -type=statementContextType"; DO NOT EDIT.
|
||||
|
||||
package protocol
|
||||
|
||||
import "strconv"
|
||||
|
||||
const _statementContextType_name = "scStatementSequenceInfoscServerExecutionTime"
|
||||
|
||||
var _statementContextType_index = [...]uint8{0, 23, 44}
|
||||
|
||||
func (i statementContextType) String() string {
|
||||
i -= 1
|
||||
if i < 0 || i >= statementContextType(len(_statementContextType_index)-1) {
|
||||
return "statementContextType(" + strconv.FormatInt(int64(i+1), 10) + ")"
|
||||
}
|
||||
return _statementContextType_name[_statementContextType_index[i]:_statementContextType_index[i+1]]
|
||||
}
|
|
@ -0,0 +1,66 @@
|
|||
/*
|
||||
Copyright 2014 SAP SE
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package protocol
|
||||
|
||||
import (
|
||||
"github.com/SAP/go-hdb/internal/bufio"
|
||||
)
|
||||
|
||||
const (
|
||||
statementIDSize = 8
|
||||
)
|
||||
|
||||
type statementID struct {
|
||||
id *uint64
|
||||
}
|
||||
|
||||
func (id statementID) kind() partKind {
|
||||
return pkStatementID
|
||||
}
|
||||
|
||||
func (id statementID) size() (int, error) {
|
||||
return statementIDSize, nil
|
||||
}
|
||||
|
||||
func (id statementID) numArg() int {
|
||||
return 1
|
||||
}
|
||||
|
||||
func (id statementID) setNumArg(int) {
|
||||
//ignore - always 1
|
||||
}
|
||||
|
||||
func (id *statementID) read(rd *bufio.Reader) error {
|
||||
_id := rd.ReadUint64()
|
||||
*id.id = _id
|
||||
|
||||
if trace {
|
||||
outLogger.Printf("statement id: %d", *id.id)
|
||||
}
|
||||
|
||||
return rd.GetError()
|
||||
}
|
||||
|
||||
func (id statementID) write(wr *bufio.Writer) error {
|
||||
wr.WriteUint64(*id.id)
|
||||
|
||||
if trace {
|
||||
outLogger.Printf("statement id: %d", *id.id)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
|
@ -0,0 +1,52 @@
|
|||
/*
|
||||
Copyright 2014 SAP SE
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package protocol
|
||||
|
||||
// TableResult is the package internal representation of a table like output parameter of a stored procedure.
|
||||
type TableResult struct {
|
||||
id uint64
|
||||
resultFieldSet *ResultFieldSet
|
||||
fieldValues *FieldValues
|
||||
attrs partAttributes
|
||||
}
|
||||
|
||||
func newTableResult(s *Session, size int) *TableResult {
|
||||
return &TableResult{
|
||||
resultFieldSet: newResultFieldSet(size),
|
||||
fieldValues: newFieldValues(),
|
||||
}
|
||||
}
|
||||
|
||||
// ID returns the resultset id.
|
||||
func (r *TableResult) ID() uint64 {
|
||||
return r.id
|
||||
}
|
||||
|
||||
// FieldSet returns the field metadata of the table.
|
||||
func (r *TableResult) FieldSet() *ResultFieldSet {
|
||||
return r.resultFieldSet
|
||||
}
|
||||
|
||||
// FieldValues returns the field values (fetched resultset part) of the table.
|
||||
func (r *TableResult) FieldValues() *FieldValues {
|
||||
return r.fieldValues
|
||||
}
|
||||
|
||||
// Attrs returns the PartAttributes interface of the fetched resultset part.
|
||||
func (r *TableResult) Attrs() PartAttributes {
|
||||
return r.attrs
|
||||
}
|
|
@ -0,0 +1,64 @@
|
|||
/*
|
||||
Copyright 2017 SAP SE
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package protocol
|
||||
|
||||
import (
|
||||
"time"
|
||||
)
|
||||
|
||||
const gregorianDay = 2299161 // Start date of Gregorian Calendar as Julian Day Number
|
||||
var gregorianDate = julianDayToTime(gregorianDay) // Start date of Gregorian Calendar (1582-10-15)
|
||||
|
||||
// timeToJulianDay returns the Julian Date Number of time's date components.
|
||||
// The algorithm is taken from https://en.wikipedia.org/wiki/Julian_day.
|
||||
func timeToJulianDay(t time.Time) int {
|
||||
|
||||
t = t.UTC()
|
||||
|
||||
month := int(t.Month())
|
||||
|
||||
a := (14 - month) / 12
|
||||
y := t.Year() + 4800 - a
|
||||
m := month + (12 * a) - 3
|
||||
|
||||
if t.Before(gregorianDate) { // Julian Calendar
|
||||
return t.Day() + (153*m+2)/5 + 365*y + y/4 - 32083
|
||||
}
|
||||
// Gregorian Calendar
|
||||
return t.Day() + (153*m+2)/5 + 365*y + y/4 - y/100 + y/400 - 32045
|
||||
}
|
||||
|
||||
// JulianDayToTime returns the correcponding UTC date for a Julian Day Number.
|
||||
// The algorithm is taken from https://en.wikipedia.org/wiki/Julian_day.
|
||||
func julianDayToTime(jd int) time.Time {
|
||||
var f int
|
||||
|
||||
if jd < gregorianDay {
|
||||
f = jd + 1401
|
||||
} else {
|
||||
f = jd + 1401 + (((4*jd+274277)/146097)*3)/4 - 38
|
||||
}
|
||||
|
||||
e := 4*f + 3
|
||||
g := (e % 1461) / 4
|
||||
h := 5*g + 2
|
||||
day := (h%153)/5 + 1
|
||||
month := (h/153+2)%12 + 1
|
||||
year := (e / 1461) - 4716 + (12+2-month)/12
|
||||
|
||||
return time.Date(year, time.Month(month), day, 0, 0, 0, 0, time.UTC)
|
||||
}
|
|
@ -0,0 +1,85 @@
|
|||
/*
|
||||
Copyright 2014 SAP SE
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package protocol
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/SAP/go-hdb/internal/bufio"
|
||||
)
|
||||
|
||||
type topologyInformation struct {
|
||||
mlo multiLineOptions
|
||||
_numArg int
|
||||
}
|
||||
|
||||
func newTopologyInformation() *topologyInformation {
|
||||
return &topologyInformation{
|
||||
mlo: multiLineOptions{},
|
||||
}
|
||||
}
|
||||
|
||||
func (o *topologyInformation) String() string {
|
||||
mlo := make([]map[topologyOption]interface{}, len(o.mlo))
|
||||
for i, po := range o.mlo {
|
||||
typedPo := make(map[topologyOption]interface{})
|
||||
for k, v := range po {
|
||||
typedPo[topologyOption(k)] = v
|
||||
}
|
||||
mlo[i] = typedPo
|
||||
}
|
||||
return fmt.Sprintf("%s", mlo)
|
||||
}
|
||||
|
||||
func (o *topologyInformation) kind() partKind {
|
||||
return pkTopologyInformation
|
||||
}
|
||||
|
||||
func (o *topologyInformation) size() int {
|
||||
return o.mlo.size()
|
||||
}
|
||||
|
||||
func (o *topologyInformation) numArg() int {
|
||||
return len(o.mlo)
|
||||
}
|
||||
|
||||
func (o *topologyInformation) setNumArg(numArg int) {
|
||||
o._numArg = numArg
|
||||
}
|
||||
|
||||
func (o *topologyInformation) read(rd *bufio.Reader) error {
|
||||
o.mlo.read(rd, o._numArg)
|
||||
|
||||
if trace {
|
||||
outLogger.Printf("topology options: %v", o)
|
||||
}
|
||||
|
||||
return rd.GetError()
|
||||
}
|
||||
|
||||
func (o *topologyInformation) write(wr *bufio.Writer) error {
|
||||
for _, m := range o.mlo {
|
||||
wr.WriteInt16(int16(len(m)))
|
||||
o.mlo.write(wr)
|
||||
}
|
||||
|
||||
if trace {
|
||||
outLogger.Printf("topology options: %v", o)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
|
@ -0,0 +1,36 @@
|
|||
/*
|
||||
Copyright 2014 SAP SE
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package protocol
|
||||
|
||||
//go:generate stringer -type=topologyOption
|
||||
|
||||
type topologyOption int8
|
||||
|
||||
const (
|
||||
toHostName topologyOption = 1
|
||||
toHostPortnumber topologyOption = 2
|
||||
toTenantName topologyOption = 3
|
||||
toLoadfactor topologyOption = 4
|
||||
toVolumeID topologyOption = 5
|
||||
toIsMaster topologyOption = 6
|
||||
toIsCurrentSession topologyOption = 7
|
||||
toServiceType topologyOption = 8
|
||||
toNetworkDomain topologyOption = 9
|
||||
toIsStandby topologyOption = 10
|
||||
toAllIPAddresses topologyOption = 11
|
||||
toAllHostNames topologyOption = 12
|
||||
)
|
17
vendor/github.com/SAP/go-hdb/internal/protocol/topologyoption_string.go
generated
vendored
Normal file
17
vendor/github.com/SAP/go-hdb/internal/protocol/topologyoption_string.go
generated
vendored
Normal file
|
@ -0,0 +1,17 @@
|
|||
// Code generated by "stringer -type=topologyOption"; DO NOT EDIT.
|
||||
|
||||
package protocol
|
||||
|
||||
import "strconv"
|
||||
|
||||
const _topologyOption_name = "toHostNametoHostPortnumbertoTenantNametoLoadfactortoVolumeIDtoIsMastertoIsCurrentSessiontoServiceTypetoNetworkDomaintoIsStandbytoAllIPAddressestoAllHostNames"
|
||||
|
||||
var _topologyOption_index = [...]uint8{0, 10, 26, 38, 50, 60, 70, 88, 101, 116, 127, 143, 157}
|
||||
|
||||
func (i topologyOption) String() string {
|
||||
i -= 1
|
||||
if i < 0 || i >= topologyOption(len(_topologyOption_index)-1) {
|
||||
return "topologyOption(" + strconv.FormatInt(int64(i+1), 10) + ")"
|
||||
}
|
||||
return _topologyOption_name[_topologyOption_index[i]:_topologyOption_index[i+1]]
|
||||
}
|
60
vendor/github.com/SAP/go-hdb/internal/protocol/transactionflags.go
generated
vendored
Normal file
60
vendor/github.com/SAP/go-hdb/internal/protocol/transactionflags.go
generated
vendored
Normal file
|
@ -0,0 +1,60 @@
|
|||
/*
|
||||
Copyright 2014 SAP SE
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package protocol
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/SAP/go-hdb/internal/bufio"
|
||||
)
|
||||
|
||||
type transactionFlags struct {
|
||||
options plainOptions
|
||||
_numArg int
|
||||
}
|
||||
|
||||
func newTransactionFlags() *transactionFlags {
|
||||
return &transactionFlags{
|
||||
options: plainOptions{},
|
||||
}
|
||||
}
|
||||
|
||||
func (f *transactionFlags) String() string {
|
||||
typedSc := make(map[transactionFlagType]interface{})
|
||||
for k, v := range f.options {
|
||||
typedSc[transactionFlagType(k)] = v
|
||||
}
|
||||
return fmt.Sprintf("%s", typedSc)
|
||||
}
|
||||
|
||||
func (f *transactionFlags) kind() partKind {
|
||||
return pkTransactionFlags
|
||||
}
|
||||
|
||||
func (f *transactionFlags) setNumArg(numArg int) {
|
||||
f._numArg = numArg
|
||||
}
|
||||
|
||||
func (f *transactionFlags) read(rd *bufio.Reader) error {
|
||||
f.options.read(rd, f._numArg)
|
||||
|
||||
if trace {
|
||||
outLogger.Printf("transaction flags: %v", f)
|
||||
}
|
||||
|
||||
return rd.GetError()
|
||||
}
|
32
vendor/github.com/SAP/go-hdb/internal/protocol/transactionflagtype.go
generated
vendored
Normal file
32
vendor/github.com/SAP/go-hdb/internal/protocol/transactionflagtype.go
generated
vendored
Normal file
|
@ -0,0 +1,32 @@
|
|||
/*
|
||||
Copyright 2014 SAP SE
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package protocol
|
||||
|
||||
//go:generate stringer -type=transactionFlagType
|
||||
|
||||
//transaction flags
|
||||
type transactionFlagType int8
|
||||
|
||||
const (
|
||||
tfRolledback transactionFlagType = 0
|
||||
tfCommited transactionFlagType = 1
|
||||
tfNewIsolationLevel transactionFlagType = 2
|
||||
tfDDLCommitmodeChanged transactionFlagType = 3
|
||||
tfWriteTransactionStarted transactionFlagType = 4
|
||||
tfNowriteTransactionStarted transactionFlagType = 5
|
||||
tfSessionClosingTransactionError transactionFlagType = 6
|
||||
)
|
16
vendor/github.com/SAP/go-hdb/internal/protocol/transactionflagtype_string.go
generated
vendored
Normal file
16
vendor/github.com/SAP/go-hdb/internal/protocol/transactionflagtype_string.go
generated
vendored
Normal file
|
@ -0,0 +1,16 @@
|
|||
// Code generated by "stringer -type=transactionFlagType"; DO NOT EDIT.
|
||||
|
||||
package protocol
|
||||
|
||||
import "strconv"
|
||||
|
||||
const _transactionFlagType_name = "tfRolledbacktfCommitedtfNewIsolationLeveltfDDLCommitmodeChangedtfWriteTransactionStartedtfNowriteTransactionStartedtfSessionClosingTransactionError"
|
||||
|
||||
var _transactionFlagType_index = [...]uint8{0, 12, 22, 41, 63, 88, 115, 147}
|
||||
|
||||
func (i transactionFlagType) String() string {
|
||||
if i < 0 || i >= transactionFlagType(len(_transactionFlagType_index)-1) {
|
||||
return "transactionFlagType(" + strconv.FormatInt(int64(i), 10) + ")"
|
||||
}
|
||||
return _transactionFlagType_name[_transactionFlagType_index[i]:_transactionFlagType_index[i+1]]
|
||||
}
|
|
@ -0,0 +1,159 @@
|
|||
/*
|
||||
Copyright 2014 SAP SE
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package protocol
|
||||
|
||||
import (
|
||||
"strings"
|
||||
)
|
||||
|
||||
//go:generate stringer -type=TypeCode
|
||||
|
||||
// TypeCode identify the type of a field transferred to or from the database.
|
||||
type TypeCode byte
|
||||
|
||||
// null value indicator is high bit
|
||||
|
||||
const (
|
||||
tcNull TypeCode = 0
|
||||
tcTinyint TypeCode = 1
|
||||
tcSmallint TypeCode = 2
|
||||
tcInteger TypeCode = 3
|
||||
tcBigint TypeCode = 4
|
||||
tcDecimal TypeCode = 5
|
||||
tcReal TypeCode = 6
|
||||
tcDouble TypeCode = 7
|
||||
tcChar TypeCode = 8
|
||||
tcVarchar TypeCode = 9
|
||||
tcNchar TypeCode = 10
|
||||
tcNvarchar TypeCode = 11
|
||||
tcBinary TypeCode = 12
|
||||
tcVarbinary TypeCode = 13
|
||||
// deprecated with 3 (doku) - but table 'date' field uses it
|
||||
tcDate TypeCode = 14
|
||||
// deprecated with 3 (doku) - but table 'time' field uses it
|
||||
tcTime TypeCode = 15
|
||||
// deprecated with 3 (doku) - but table 'timestamp' field uses it
|
||||
tcTimestamp TypeCode = 16
|
||||
//tcTimetz TypeCode = 17 // reserved: do not use
|
||||
//tcTimeltz TypeCode = 18 // reserved: do not use
|
||||
//tcTimestamptz TypeCode = 19 // reserved: do not use
|
||||
//tcTimestampltz TypeCode = 20 // reserved: do not use
|
||||
//tcInvervalym TypeCode = 21 // reserved: do not use
|
||||
//tcInvervalds TypeCode = 22 // reserved: do not use
|
||||
//tcRowid TypeCode = 23 // reserved: do not use
|
||||
//tcUrowid TypeCode = 24 // reserved: do not use
|
||||
tcClob TypeCode = 25
|
||||
tcNclob TypeCode = 26
|
||||
tcBlob TypeCode = 27
|
||||
tcBoolean TypeCode = 28
|
||||
tcString TypeCode = 29
|
||||
tcNstring TypeCode = 30
|
||||
tcBlocator TypeCode = 31
|
||||
tcNlocator TypeCode = 32
|
||||
tcBstring TypeCode = 33
|
||||
//tcDecimaldigitarray TypeCode = 34 // reserved: do not use
|
||||
tcVarchar2 TypeCode = 35
|
||||
tcVarchar3 TypeCode = 36
|
||||
tcNvarchar3 TypeCode = 37
|
||||
tcVarbinary3 TypeCode = 38
|
||||
//tcVargroup TypeCode = 39 // reserved: do not use
|
||||
//tcTinyintnotnull TypeCode = 40 // reserved: do not use
|
||||
//tcSmallintnotnull TypeCode = 41 // reserved: do not use
|
||||
//tcIntnotnull TypeCode = 42 // reserved: do not use
|
||||
//tcBigintnotnull TypeCode = 43 // reserved: do not use
|
||||
//tcArgument TypeCode = 44 // reserved: do not use
|
||||
//tcTable TypeCode = 45 // reserved: do not use
|
||||
//tcCursor TypeCode = 46 // reserved: do not use
|
||||
tcSmalldecimal TypeCode = 47
|
||||
//tcAbapitab TypeCode = 48 // not supported by GO hdb driver
|
||||
//tcAbapstruct TypeCode = 49 // not supported by GO hdb driver
|
||||
tcArray TypeCode = 50
|
||||
tcText TypeCode = 51
|
||||
tcShorttext TypeCode = 52
|
||||
//tcFixedString TypeCode = 53 // reserved: do not use
|
||||
//tcFixedpointdecimal TypeCode = 54 // reserved: do not use
|
||||
tcAlphanum TypeCode = 55
|
||||
//tcTlocator TypeCode = 56 // reserved: do not use
|
||||
tcLongdate TypeCode = 61
|
||||
tcSeconddate TypeCode = 62
|
||||
tcDaydate TypeCode = 63
|
||||
tcSecondtime TypeCode = 64
|
||||
//tcCte TypeCode = 65 // reserved: do not use
|
||||
//tcCstimesda TypeCode = 66 // reserved: do not use
|
||||
//tcBlobdisk TypeCode = 71 // reserved: do not use
|
||||
//tcClobdisk TypeCode = 72 // reserved: do not use
|
||||
//tcNclobdisk TypeCode = 73 // reserved: do not use
|
||||
//tcGeometry TypeCode = 74 // reserved: do not use
|
||||
//tcPoint TypeCode = 75 // reserved: do not use
|
||||
//tcFixed16 TypeCode = 76 // reserved: do not use
|
||||
//tcBlobhybrid TypeCode = 77 // reserved: do not use
|
||||
//tcClobhybrid TypeCode = 78 // reserved: do not use
|
||||
//tcNclobhybrid TypeCode = 79 // reserved: do not use
|
||||
//tcPointz TypeCode = 80 // reserved: do not use
|
||||
)
|
||||
|
||||
func (k TypeCode) isLob() bool {
|
||||
return k == tcClob || k == tcNclob || k == tcBlob
|
||||
}
|
||||
|
||||
func (k TypeCode) isCharBased() bool {
|
||||
return k == tcNvarchar || k == tcNstring || k == tcNclob
|
||||
}
|
||||
|
||||
func (k TypeCode) isVariableLength() bool {
|
||||
return k == tcChar || k == tcNchar || k == tcVarchar || k == tcNvarchar || k == tcBinary || k == tcVarbinary || k == tcShorttext || k == tcAlphanum
|
||||
}
|
||||
|
||||
func (k TypeCode) isDecimalType() bool {
|
||||
return k == tcSmalldecimal || k == tcDecimal
|
||||
}
|
||||
|
||||
// DataType converts a type code into one of the supported data types by the driver.
|
||||
func (k TypeCode) DataType() DataType {
|
||||
switch k {
|
||||
default:
|
||||
return DtUnknown
|
||||
case tcTinyint:
|
||||
return DtTinyint
|
||||
case tcSmallint:
|
||||
return DtSmallint
|
||||
case tcInteger:
|
||||
return DtInteger
|
||||
case tcBigint:
|
||||
return DtBigint
|
||||
case tcReal:
|
||||
return DtReal
|
||||
case tcDouble:
|
||||
return DtDouble
|
||||
case tcDate, tcTime, tcTimestamp, tcLongdate, tcSeconddate, tcDaydate, tcSecondtime:
|
||||
return DtTime
|
||||
case tcDecimal:
|
||||
return DtDecimal
|
||||
case tcChar, tcVarchar, tcString, tcNchar, tcNvarchar, tcNstring:
|
||||
return DtString
|
||||
case tcBinary, tcVarbinary:
|
||||
return DtBytes
|
||||
case tcBlob, tcClob, tcNclob:
|
||||
return DtLob
|
||||
}
|
||||
}
|
||||
|
||||
// TypeName returns the database type name.
|
||||
// see https://golang.org/pkg/database/sql/driver/#RowsColumnTypeDatabaseTypeName
|
||||
func (k TypeCode) TypeName() string {
|
||||
return strings.ToUpper(k.String()[2:])
|
||||
}
|
|
@ -0,0 +1,48 @@
|
|||
// Code generated by "stringer -type=TypeCode"; DO NOT EDIT.
|
||||
|
||||
package protocol
|
||||
|
||||
import "strconv"
|
||||
|
||||
const (
|
||||
_TypeCode_name_0 = "tcNulltcTinyinttcSmallinttcIntegertcBiginttcDecimaltcRealtcDoubletcChartcVarchartcNchartcNvarchartcBinarytcVarbinarytcDatetcTimetcTimestamp"
|
||||
_TypeCode_name_1 = "tcClobtcNclobtcBlobtcBooleantcStringtcNstringtcBlocatortcNlocatortcBstring"
|
||||
_TypeCode_name_2 = "tcVarchar2tcVarchar3tcNvarchar3tcVarbinary3"
|
||||
_TypeCode_name_3 = "tcSmalldecimal"
|
||||
_TypeCode_name_4 = "tcArraytcTexttcShorttext"
|
||||
_TypeCode_name_5 = "tcAlphanum"
|
||||
_TypeCode_name_6 = "tcLongdatetcSeconddatetcDaydatetcSecondtime"
|
||||
)
|
||||
|
||||
var (
|
||||
_TypeCode_index_0 = [...]uint8{0, 6, 15, 25, 34, 42, 51, 57, 65, 71, 80, 87, 97, 105, 116, 122, 128, 139}
|
||||
_TypeCode_index_1 = [...]uint8{0, 6, 13, 19, 28, 36, 45, 55, 65, 74}
|
||||
_TypeCode_index_2 = [...]uint8{0, 10, 20, 31, 43}
|
||||
_TypeCode_index_4 = [...]uint8{0, 7, 13, 24}
|
||||
_TypeCode_index_6 = [...]uint8{0, 10, 22, 31, 43}
|
||||
)
|
||||
|
||||
func (i TypeCode) String() string {
|
||||
switch {
|
||||
case 0 <= i && i <= 16:
|
||||
return _TypeCode_name_0[_TypeCode_index_0[i]:_TypeCode_index_0[i+1]]
|
||||
case 25 <= i && i <= 33:
|
||||
i -= 25
|
||||
return _TypeCode_name_1[_TypeCode_index_1[i]:_TypeCode_index_1[i+1]]
|
||||
case 35 <= i && i <= 38:
|
||||
i -= 35
|
||||
return _TypeCode_name_2[_TypeCode_index_2[i]:_TypeCode_index_2[i+1]]
|
||||
case i == 47:
|
||||
return _TypeCode_name_3
|
||||
case 50 <= i && i <= 52:
|
||||
i -= 50
|
||||
return _TypeCode_name_4[_TypeCode_index_4[i]:_TypeCode_index_4[i+1]]
|
||||
case i == 55:
|
||||
return _TypeCode_name_5
|
||||
case 61 <= i && i <= 64:
|
||||
i -= 61
|
||||
return _TypeCode_name_6[_TypeCode_index_6[i]:_TypeCode_index_6[i+1]]
|
||||
default:
|
||||
return "TypeCode(" + strconv.FormatInt(int64(i), 10) + ")"
|
||||
}
|
||||
}
|
|
@ -0,0 +1,240 @@
|
|||
/*
|
||||
Copyright 2014 SAP SE
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
// Package cesu8 implements functions and constants to support text encoded in CESU-8.
|
||||
// It implements functions comparable to the unicode/utf8 package for UTF-8 de- and encoding.
|
||||
package cesu8
|
||||
|
||||
import (
|
||||
"unicode/utf16"
|
||||
"unicode/utf8"
|
||||
)
|
||||
|
||||
const (
|
||||
// CESUMax is the maximum amount of bytes used by an CESU-8 codepoint encoding.
|
||||
CESUMax = 6
|
||||
)
|
||||
|
||||
// Size returns the amount of bytes needed to encode an UTF-8 byte slice to CESU-8.
|
||||
func Size(p []byte) int {
|
||||
n := 0
|
||||
for i := 0; i < len(p); {
|
||||
r, size, _ := decodeRune(p[i:])
|
||||
i += size
|
||||
n += RuneLen(r)
|
||||
}
|
||||
return n
|
||||
}
|
||||
|
||||
// StringSize is like Size with a string as parameter.
|
||||
func StringSize(s string) int {
|
||||
n := 0
|
||||
for _, r := range s {
|
||||
n += RuneLen(r)
|
||||
}
|
||||
return n
|
||||
}
|
||||
|
||||
// EncodeRune writes into p (which must be large enough) the CESU-8 encoding of the rune. It returns the number of bytes written.
|
||||
func EncodeRune(p []byte, r rune) int {
|
||||
if r <= rune3Max {
|
||||
return encodeRune(p, r)
|
||||
}
|
||||
high, low := utf16.EncodeRune(r)
|
||||
n := encodeRune(p, high)
|
||||
n += encodeRune(p[n:], low)
|
||||
return n
|
||||
}
|
||||
|
||||
// FullRune reports whether the bytes in p begin with a full CESU-8 encoding of a rune.
|
||||
func FullRune(p []byte) bool {
|
||||
high, n, short := decodeRune(p)
|
||||
if short {
|
||||
return false
|
||||
}
|
||||
if !utf16.IsSurrogate(high) {
|
||||
return true
|
||||
}
|
||||
_, _, short = decodeRune(p[n:])
|
||||
return !short
|
||||
}
|
||||
|
||||
// DecodeRune unpacks the first CESU-8 encoding in p and returns the rune and its width in bytes.
|
||||
func DecodeRune(p []byte) (rune, int) {
|
||||
high, n1, _ := decodeRune(p)
|
||||
if !utf16.IsSurrogate(high) {
|
||||
return high, n1
|
||||
}
|
||||
low, n2, _ := decodeRune(p[n1:])
|
||||
if low == utf8.RuneError {
|
||||
return low, n1 + n2
|
||||
}
|
||||
return utf16.DecodeRune(high, low), n1 + n2
|
||||
}
|
||||
|
||||
// RuneLen returns the number of bytes required to encode the rune.
|
||||
func RuneLen(r rune) int {
|
||||
switch {
|
||||
case r < 0:
|
||||
return -1
|
||||
case r <= rune1Max:
|
||||
return 1
|
||||
case r <= rune2Max:
|
||||
return 2
|
||||
case r <= rune3Max:
|
||||
return 3
|
||||
case r <= utf8.MaxRune:
|
||||
return CESUMax
|
||||
}
|
||||
return -1
|
||||
}
|
||||
|
||||
// +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
|
||||
// Copied from unicode utf8
|
||||
// - allow utf8 encoding of utf16 surrogate values
|
||||
// - see (*) for code changes
|
||||
|
||||
// Code points in the surrogate range are not valid for UTF-8.
|
||||
const (
|
||||
surrogateMin = 0xD800
|
||||
surrogateMax = 0xDFFF
|
||||
)
|
||||
|
||||
const (
|
||||
t1 = 0x00 // 0000 0000
|
||||
tx = 0x80 // 1000 0000
|
||||
t2 = 0xC0 // 1100 0000
|
||||
t3 = 0xE0 // 1110 0000
|
||||
t4 = 0xF0 // 1111 0000
|
||||
t5 = 0xF8 // 1111 1000
|
||||
|
||||
maskx = 0x3F // 0011 1111
|
||||
mask2 = 0x1F // 0001 1111
|
||||
mask3 = 0x0F // 0000 1111
|
||||
mask4 = 0x07 // 0000 0111
|
||||
|
||||
rune1Max = 1<<7 - 1
|
||||
rune2Max = 1<<11 - 1
|
||||
rune3Max = 1<<16 - 1
|
||||
)
|
||||
|
||||
func encodeRune(p []byte, r rune) int {
|
||||
// Negative values are erroneous. Making it unsigned addresses the problem.
|
||||
switch i := uint32(r); {
|
||||
case i <= rune1Max:
|
||||
p[0] = byte(r)
|
||||
return 1
|
||||
case i <= rune2Max:
|
||||
p[0] = t2 | byte(r>>6)
|
||||
p[1] = tx | byte(r)&maskx
|
||||
return 2
|
||||
//case i > MaxRune, surrogateMin <= i && i <= surrogateMax: // replaced (*)
|
||||
case i > utf8.MaxRune: // (*)
|
||||
r = utf8.RuneError
|
||||
fallthrough
|
||||
case i <= rune3Max:
|
||||
p[0] = t3 | byte(r>>12)
|
||||
p[1] = tx | byte(r>>6)&maskx
|
||||
p[2] = tx | byte(r)&maskx
|
||||
return 3
|
||||
default:
|
||||
p[0] = t4 | byte(r>>18)
|
||||
p[1] = tx | byte(r>>12)&maskx
|
||||
p[2] = tx | byte(r>>6)&maskx
|
||||
p[3] = tx | byte(r)&maskx
|
||||
return 4
|
||||
}
|
||||
}
|
||||
|
||||
func decodeRune(p []byte) (r rune, size int, short bool) {
|
||||
n := len(p)
|
||||
if n < 1 {
|
||||
return utf8.RuneError, 0, true
|
||||
}
|
||||
c0 := p[0]
|
||||
|
||||
// 1-byte, 7-bit sequence?
|
||||
if c0 < tx {
|
||||
return rune(c0), 1, false
|
||||
}
|
||||
|
||||
// unexpected continuation byte?
|
||||
if c0 < t2 {
|
||||
return utf8.RuneError, 1, false
|
||||
}
|
||||
|
||||
// need first continuation byte
|
||||
if n < 2 {
|
||||
return utf8.RuneError, 1, true
|
||||
}
|
||||
c1 := p[1]
|
||||
if c1 < tx || t2 <= c1 {
|
||||
return utf8.RuneError, 1, false
|
||||
}
|
||||
|
||||
// 2-byte, 11-bit sequence?
|
||||
if c0 < t3 {
|
||||
r = rune(c0&mask2)<<6 | rune(c1&maskx)
|
||||
if r <= rune1Max {
|
||||
return utf8.RuneError, 1, false
|
||||
}
|
||||
return r, 2, false
|
||||
}
|
||||
|
||||
// need second continuation byte
|
||||
if n < 3 {
|
||||
return utf8.RuneError, 1, true
|
||||
}
|
||||
c2 := p[2]
|
||||
if c2 < tx || t2 <= c2 {
|
||||
return utf8.RuneError, 1, false
|
||||
}
|
||||
|
||||
// 3-byte, 16-bit sequence?
|
||||
if c0 < t4 {
|
||||
r = rune(c0&mask3)<<12 | rune(c1&maskx)<<6 | rune(c2&maskx)
|
||||
if r <= rune2Max {
|
||||
return utf8.RuneError, 1, false
|
||||
}
|
||||
// do not throw error on surrogates // (*)
|
||||
//if surrogateMin <= r && r <= surrogateMax {
|
||||
// return RuneError, 1, false
|
||||
//}
|
||||
return r, 3, false
|
||||
}
|
||||
|
||||
// need third continuation byte
|
||||
if n < 4 {
|
||||
return utf8.RuneError, 1, true
|
||||
}
|
||||
c3 := p[3]
|
||||
if c3 < tx || t2 <= c3 {
|
||||
return utf8.RuneError, 1, false
|
||||
}
|
||||
|
||||
// 4-byte, 21-bit sequence?
|
||||
if c0 < t5 {
|
||||
r = rune(c0&mask4)<<18 | rune(c1&maskx)<<12 | rune(c2&maskx)<<6 | rune(c3&maskx)
|
||||
if r <= rune3Max || utf8.MaxRune < r {
|
||||
return utf8.RuneError, 1, false
|
||||
}
|
||||
return r, 4, false
|
||||
}
|
||||
|
||||
// error
|
||||
return utf8.RuneError, 1, false
|
||||
}
|
|
@ -0,0 +1,111 @@
|
|||
/*
|
||||
Copyright 2014 SAP SE
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
// Package unicode implements UTF-8 to CESU-8 and vice versa transformations.
|
||||
package unicode
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"unicode/utf8"
|
||||
|
||||
"github.com/SAP/go-hdb/internal/unicode/cesu8"
|
||||
"golang.org/x/text/transform"
|
||||
)
|
||||
|
||||
var (
|
||||
// Utf8ToCesu8Transformer implements the golang.org/x/text/transform/Transformer interface for UTF-8 to CESU-8 transformation.
|
||||
Utf8ToCesu8Transformer = new(utf8ToCesu8Transformer)
|
||||
// Cesu8ToUtf8Transformer implements the golang.org/x/text/transform/Transformer interface for CESU-8 to UTF-8 transformation.
|
||||
Cesu8ToUtf8Transformer = new(cesu8ToUtf8Transformer)
|
||||
// ErrInvalidUtf8 means that a transformer detected invalid UTF-8 data.
|
||||
ErrInvalidUtf8 = errors.New("Invalid UTF-8")
|
||||
// ErrInvalidCesu8 means that a transformer detected invalid CESU-8 data.
|
||||
ErrInvalidCesu8 = errors.New("Invalid CESU-8")
|
||||
)
|
||||
|
||||
type utf8ToCesu8Transformer struct{ transform.NopResetter }
|
||||
|
||||
func (t *utf8ToCesu8Transformer) Transform(dst, src []byte, atEOF bool) (nDst, nSrc int, err error) {
|
||||
i, j := 0, 0
|
||||
for i < len(src) {
|
||||
if src[i] < utf8.RuneSelf {
|
||||
if j < len(dst) {
|
||||
dst[j] = src[i]
|
||||
i++
|
||||
j++
|
||||
} else {
|
||||
return j, i, transform.ErrShortDst
|
||||
}
|
||||
} else {
|
||||
if !utf8.FullRune(src[i:]) {
|
||||
return j, i, transform.ErrShortSrc
|
||||
}
|
||||
r, n := utf8.DecodeRune(src[i:])
|
||||
if r == utf8.RuneError {
|
||||
return j, i, ErrInvalidUtf8
|
||||
}
|
||||
m := cesu8.RuneLen(r)
|
||||
if m == -1 {
|
||||
panic("internal UTF-8 to CESU-8 transformation error")
|
||||
}
|
||||
if j+m <= len(dst) {
|
||||
cesu8.EncodeRune(dst[j:], r)
|
||||
i += n
|
||||
j += m
|
||||
} else {
|
||||
return j, i, transform.ErrShortDst
|
||||
}
|
||||
}
|
||||
}
|
||||
return j, i, nil
|
||||
}
|
||||
|
||||
type cesu8ToUtf8Transformer struct{ transform.NopResetter }
|
||||
|
||||
func (t *cesu8ToUtf8Transformer) Transform(dst, src []byte, atEOF bool) (nDst, nSrc int, err error) {
|
||||
i, j := 0, 0
|
||||
for i < len(src) {
|
||||
if src[i] < utf8.RuneSelf {
|
||||
if j < len(dst) {
|
||||
dst[j] = src[i]
|
||||
i++
|
||||
j++
|
||||
} else {
|
||||
return j, i, transform.ErrShortDst
|
||||
}
|
||||
} else {
|
||||
if !cesu8.FullRune(src[i:]) {
|
||||
return j, i, transform.ErrShortSrc
|
||||
}
|
||||
r, n := cesu8.DecodeRune(src[i:])
|
||||
if r == utf8.RuneError {
|
||||
return j, i, ErrInvalidCesu8
|
||||
}
|
||||
m := utf8.RuneLen(r)
|
||||
if m == -1 {
|
||||
panic("internal CESU-8 to UTF-8 transformation error")
|
||||
}
|
||||
if j+m <= len(dst) {
|
||||
utf8.EncodeRune(dst[j:], r)
|
||||
i += n
|
||||
j += m
|
||||
} else {
|
||||
return j, i, transform.ErrShortDst
|
||||
}
|
||||
}
|
||||
}
|
||||
return j, i, nil
|
||||
}
|
|
@ -0,0 +1,22 @@
|
|||
The MIT License (MIT)
|
||||
|
||||
Copyright (c) 2015 Sermo Digital LLC
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
|
|
@ -0,0 +1,40 @@
|
|||
JOSE
|
||||
============
|
||||
[![Build Status](https://travis-ci.org/SermoDigital/jose.svg?branch=master)](https://travis-ci.org/SermoDigital/jose)
|
||||
[![GoDoc](https://godoc.org/github.com/SermoDigital/jose?status.svg)](https://godoc.org/github.com/SermoDigital/jose)
|
||||
|
||||
JOSE is a comprehensive set of JWT, JWS, and JWE libraries.
|
||||
|
||||
## Why
|
||||
|
||||
The only other JWS/JWE/JWT implementations are specific to JWT, and none
|
||||
were particularly pleasant to work with.
|
||||
|
||||
These libraries should provide an easy, straightforward way to securely
|
||||
create, parse, and validate JWS, JWE, and JWTs.
|
||||
|
||||
## Notes:
|
||||
JWE is currently unimplemented.
|
||||
|
||||
## Version 0.9:
|
||||
|
||||
## Documentation
|
||||
|
||||
The docs can be found at [godoc.org] [docs], as usual.
|
||||
|
||||
A gopkg.in mirror can be found at https://gopkg.in/jose.v1, thanks to
|
||||
@zia-newversion. (For context, see issue #30.)
|
||||
|
||||
### [JWS RFC][jws]
|
||||
### [JWE RFC][jwe]
|
||||
### [JWT RFC][jwt]
|
||||
|
||||
## License
|
||||
|
||||
[MIT] [license].
|
||||
|
||||
[docs]: https://godoc.org/github.com/SermoDigital/jose
|
||||
[license]: https://github.com/SermoDigital/jose/blob/master/LICENSE.md
|
||||
[jws]: https://tools.ietf.org/html/rfc7515
|
||||
[jwe]: https://tools.ietf.org/html/rfc7516
|
||||
[jwt]: https://tools.ietf.org/html/rfc7519
|
|
@ -0,0 +1,8 @@
|
|||
#!/usr/bin/env bash
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
go build ./...
|
||||
go test ./...
|
||||
golint ./...
|
||||
go vet ./...
|
|
@ -0,0 +1,44 @@
|
|||
package jose
|
||||
|
||||
import "encoding/base64"
|
||||
|
||||
// Encoder is satisfied if the type can marshal itself into a valid
|
||||
// structure for a JWS.
|
||||
type Encoder interface {
|
||||
// Base64 implies T -> JSON -> RawURLEncodingBase64
|
||||
Base64() ([]byte, error)
|
||||
}
|
||||
|
||||
// Base64Decode decodes a base64-encoded byte slice.
|
||||
func Base64Decode(b []byte) ([]byte, error) {
|
||||
buf := make([]byte, base64.RawURLEncoding.DecodedLen(len(b)))
|
||||
n, err := base64.RawURLEncoding.Decode(buf, b)
|
||||
return buf[:n], err
|
||||
}
|
||||
|
||||
// Base64Encode encodes a byte slice.
|
||||
func Base64Encode(b []byte) []byte {
|
||||
buf := make([]byte, base64.RawURLEncoding.EncodedLen(len(b)))
|
||||
base64.RawURLEncoding.Encode(buf, b)
|
||||
return buf
|
||||
}
|
||||
|
||||
// EncodeEscape base64-encodes a byte slice but escapes it for JSON.
|
||||
// It'll return the format: `"base64"`
|
||||
func EncodeEscape(b []byte) []byte {
|
||||
buf := make([]byte, base64.RawURLEncoding.EncodedLen(len(b))+2)
|
||||
buf[0] = '"'
|
||||
base64.RawURLEncoding.Encode(buf[1:], b)
|
||||
buf[len(buf)-1] = '"'
|
||||
return buf
|
||||
}
|
||||
|
||||
// DecodeEscaped decodes a base64-encoded byte slice straight from a JSON
|
||||
// structure. It assumes it's in the format: `"base64"`, but can handle
|
||||
// cases where it's not.
|
||||
func DecodeEscaped(b []byte) ([]byte, error) {
|
||||
if len(b) > 1 && b[0] == '"' && b[len(b)-1] == '"' {
|
||||
b = b[1 : len(b)-1]
|
||||
}
|
||||
return Base64Decode(b)
|
||||
}
|
|
@ -0,0 +1,4 @@
|
|||
// Package crypto implements "SigningMethods" and "EncryptionMethods";
|
||||
// that is, ways to sign and encrypt JWS and JWEs, respectively, as well
|
||||
// as JWTs.
|
||||
package crypto
|
|
@ -0,0 +1,117 @@
|
|||
package crypto
|
||||
|
||||
import (
|
||||
"crypto"
|
||||
"crypto/ecdsa"
|
||||
"crypto/rand"
|
||||
"encoding/asn1"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"math/big"
|
||||
)
|
||||
|
||||
// ErrECDSAVerification is missing from crypto/ecdsa compared to crypto/rsa
|
||||
var ErrECDSAVerification = errors.New("crypto/ecdsa: verification error")
|
||||
|
||||
// SigningMethodECDSA implements the ECDSA family of signing methods signing
|
||||
// methods
|
||||
type SigningMethodECDSA struct {
|
||||
Name string
|
||||
Hash crypto.Hash
|
||||
_ struct{}
|
||||
}
|
||||
|
||||
// ECPoint is a marshalling structure for the EC points R and S.
|
||||
type ECPoint struct {
|
||||
R *big.Int
|
||||
S *big.Int
|
||||
}
|
||||
|
||||
// Specific instances of EC SigningMethods.
|
||||
var (
|
||||
// SigningMethodES256 implements ES256.
|
||||
SigningMethodES256 = &SigningMethodECDSA{
|
||||
Name: "ES256",
|
||||
Hash: crypto.SHA256,
|
||||
}
|
||||
|
||||
// SigningMethodES384 implements ES384.
|
||||
SigningMethodES384 = &SigningMethodECDSA{
|
||||
Name: "ES384",
|
||||
Hash: crypto.SHA384,
|
||||
}
|
||||
|
||||
// SigningMethodES512 implements ES512.
|
||||
SigningMethodES512 = &SigningMethodECDSA{
|
||||
Name: "ES512",
|
||||
Hash: crypto.SHA512,
|
||||
}
|
||||
)
|
||||
|
||||
// Alg returns the name of the SigningMethodECDSA instance.
|
||||
func (m *SigningMethodECDSA) Alg() string { return m.Name }
|
||||
|
||||
// Verify implements the Verify method from SigningMethod.
|
||||
// For this verify method, key must be an *ecdsa.PublicKey.
|
||||
func (m *SigningMethodECDSA) Verify(raw []byte, signature Signature, key interface{}) error {
|
||||
|
||||
ecdsaKey, ok := key.(*ecdsa.PublicKey)
|
||||
if !ok {
|
||||
return ErrInvalidKey
|
||||
}
|
||||
|
||||
// Unmarshal asn1 ECPoint
|
||||
var ecpoint ECPoint
|
||||
if _, err := asn1.Unmarshal(signature, &ecpoint); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Verify the signature
|
||||
if !ecdsa.Verify(ecdsaKey, m.sum(raw), ecpoint.R, ecpoint.S) {
|
||||
return ErrECDSAVerification
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Sign implements the Sign method from SigningMethod.
|
||||
// For this signing method, key must be an *ecdsa.PrivateKey.
|
||||
func (m *SigningMethodECDSA) Sign(data []byte, key interface{}) (Signature, error) {
|
||||
|
||||
ecdsaKey, ok := key.(*ecdsa.PrivateKey)
|
||||
if !ok {
|
||||
return nil, ErrInvalidKey
|
||||
}
|
||||
|
||||
r, s, err := ecdsa.Sign(rand.Reader, ecdsaKey, m.sum(data))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
signature, err := asn1.Marshal(ECPoint{R: r, S: s})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return Signature(signature), nil
|
||||
}
|
||||
|
||||
func (m *SigningMethodECDSA) sum(b []byte) []byte {
|
||||
h := m.Hash.New()
|
||||
h.Write(b)
|
||||
return h.Sum(nil)
|
||||
}
|
||||
|
||||
// Hasher implements the Hasher method from SigningMethod.
|
||||
func (m *SigningMethodECDSA) Hasher() crypto.Hash {
|
||||
return m.Hash
|
||||
}
|
||||
|
||||
// MarshalJSON is in case somebody decides to place SigningMethodECDSA
|
||||
// inside the Header, presumably because they (wrongly) decided it was a good
|
||||
// idea to use the SigningMethod itself instead of the SigningMethod's Alg
|
||||
// method. In order to keep things sane, marshalling this will simply
|
||||
// return the JSON-compatible representation of m.Alg().
|
||||
func (m *SigningMethodECDSA) MarshalJSON() ([]byte, error) {
|
||||
return []byte(`"` + m.Alg() + `"`), nil
|
||||
}
|
||||
|
||||
var _ json.Marshaler = (*SigningMethodECDSA)(nil)
|
|
@ -0,0 +1,48 @@
|
|||
package crypto
|
||||
|
||||
import (
|
||||
"crypto/ecdsa"
|
||||
"crypto/x509"
|
||||
"encoding/pem"
|
||||
"errors"
|
||||
)
|
||||
|
||||
// ECDSA parsing errors.
|
||||
var (
|
||||
ErrNotECPublicKey = errors.New("Key is not a valid ECDSA public key")
|
||||
ErrNotECPrivateKey = errors.New("Key is not a valid ECDSA private key")
|
||||
)
|
||||
|
||||
// ParseECPrivateKeyFromPEM will parse a PEM encoded EC Private
|
||||
// Key Structure.
|
||||
func ParseECPrivateKeyFromPEM(key []byte) (*ecdsa.PrivateKey, error) {
|
||||
block, _ := pem.Decode(key)
|
||||
if block == nil {
|
||||
return nil, ErrKeyMustBePEMEncoded
|
||||
}
|
||||
return x509.ParseECPrivateKey(block.Bytes)
|
||||
}
|
||||
|
||||
// ParseECPublicKeyFromPEM will parse a PEM encoded PKCS1 or PKCS8 public key
|
||||
func ParseECPublicKeyFromPEM(key []byte) (*ecdsa.PublicKey, error) {
|
||||
|
||||
block, _ := pem.Decode(key)
|
||||
if block == nil {
|
||||
return nil, ErrKeyMustBePEMEncoded
|
||||
}
|
||||
|
||||
parsedKey, err := x509.ParsePKIXPublicKey(block.Bytes)
|
||||
if err != nil {
|
||||
cert, err := x509.ParseCertificate(block.Bytes)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
parsedKey = cert.PublicKey
|
||||
}
|
||||
|
||||
pkey, ok := parsedKey.(*ecdsa.PublicKey)
|
||||
if !ok {
|
||||
return nil, ErrNotECPublicKey
|
||||
}
|
||||
return pkey, nil
|
||||
}
|
|
@ -0,0 +1,9 @@
|
|||
package crypto
|
||||
|
||||
import "errors"
|
||||
|
||||
var (
|
||||
// ErrInvalidKey means the key argument passed to SigningMethod.Verify
|
||||
// was not the correct type.
|
||||
ErrInvalidKey = errors.New("key is invalid")
|
||||
)
|
|
@ -0,0 +1,81 @@
|
|||
package crypto
|
||||
|
||||
import (
|
||||
"crypto"
|
||||
"crypto/hmac"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
)
|
||||
|
||||
// SigningMethodHMAC implements the HMAC-SHA family of SigningMethods.
|
||||
type SigningMethodHMAC struct {
|
||||
Name string
|
||||
Hash crypto.Hash
|
||||
_ struct{}
|
||||
}
|
||||
|
||||
// Specific instances of HMAC-SHA SigningMethods.
|
||||
var (
|
||||
// SigningMethodHS256 implements HS256.
|
||||
SigningMethodHS256 = &SigningMethodHMAC{
|
||||
Name: "HS256",
|
||||
Hash: crypto.SHA256,
|
||||
}
|
||||
|
||||
// SigningMethodHS384 implements HS384.
|
||||
SigningMethodHS384 = &SigningMethodHMAC{
|
||||
Name: "HS384",
|
||||
Hash: crypto.SHA384,
|
||||
}
|
||||
|
||||
// SigningMethodHS512 implements HS512.
|
||||
SigningMethodHS512 = &SigningMethodHMAC{
|
||||
Name: "HS512",
|
||||
Hash: crypto.SHA512,
|
||||
}
|
||||
|
||||
// ErrSignatureInvalid is returned when the provided signature is found
|
||||
// to be invalid.
|
||||
ErrSignatureInvalid = errors.New("signature is invalid")
|
||||
)
|
||||
|
||||
// Alg implements the SigningMethod interface.
|
||||
func (m *SigningMethodHMAC) Alg() string { return m.Name }
|
||||
|
||||
// Verify implements the Verify method from SigningMethod.
|
||||
// For this signing method, must be a []byte.
|
||||
func (m *SigningMethodHMAC) Verify(raw []byte, signature Signature, key interface{}) error {
|
||||
keyBytes, ok := key.([]byte)
|
||||
if !ok {
|
||||
return ErrInvalidKey
|
||||
}
|
||||
hasher := hmac.New(m.Hash.New, keyBytes)
|
||||
hasher.Write(raw)
|
||||
if hmac.Equal(signature, hasher.Sum(nil)) {
|
||||
return nil
|
||||
}
|
||||
return ErrSignatureInvalid
|
||||
}
|
||||
|
||||
// Sign implements the Sign method from SigningMethod for this signing method.
|
||||
// Key must be a []byte.
|
||||
func (m *SigningMethodHMAC) Sign(data []byte, key interface{}) (Signature, error) {
|
||||
keyBytes, ok := key.([]byte)
|
||||
if !ok {
|
||||
return nil, ErrInvalidKey
|
||||
}
|
||||
hasher := hmac.New(m.Hash.New, keyBytes)
|
||||
hasher.Write(data)
|
||||
return Signature(hasher.Sum(nil)), nil
|
||||
}
|
||||
|
||||
// Hasher implements the SigningMethod interface.
|
||||
func (m *SigningMethodHMAC) Hasher() crypto.Hash { return m.Hash }
|
||||
|
||||
// MarshalJSON implements json.Marshaler.
|
||||
// See SigningMethodECDSA.MarshalJSON() for information.
|
||||
func (m *SigningMethodHMAC) MarshalJSON() ([]byte, error) {
|
||||
return []byte(`"` + m.Alg() + `"`), nil
|
||||
}
|
||||
|
||||
var _ json.Marshaler = (*SigningMethodHMAC)(nil)
|
|
@ -0,0 +1,72 @@
|
|||
package crypto
|
||||
|
||||
import (
|
||||
"crypto"
|
||||
"encoding/json"
|
||||
"hash"
|
||||
"io"
|
||||
)
|
||||
|
||||
func init() {
|
||||
crypto.RegisterHash(crypto.Hash(0), h)
|
||||
}
|
||||
|
||||
// h is passed to crypto.RegisterHash.
|
||||
func h() hash.Hash {
|
||||
return &f{Writer: nil}
|
||||
}
|
||||
|
||||
type f struct{ io.Writer }
|
||||
|
||||
// Sum helps implement the hash.Hash interface.
|
||||
func (_ *f) Sum(b []byte) []byte { return nil }
|
||||
|
||||
// Reset helps implement the hash.Hash interface.
|
||||
func (_ *f) Reset() {}
|
||||
|
||||
// Size helps implement the hash.Hash interface.
|
||||
func (_ *f) Size() int { return -1 }
|
||||
|
||||
// BlockSize helps implement the hash.Hash interface.
|
||||
func (_ *f) BlockSize() int { return -1 }
|
||||
|
||||
// Unsecured is the default "none" algorithm.
|
||||
var Unsecured = &SigningMethodNone{
|
||||
Name: "none",
|
||||
Hash: crypto.Hash(0),
|
||||
}
|
||||
|
||||
// SigningMethodNone is the default "none" algorithm.
|
||||
type SigningMethodNone struct {
|
||||
Name string
|
||||
Hash crypto.Hash
|
||||
_ struct{}
|
||||
}
|
||||
|
||||
// Verify helps implement the SigningMethod interface.
|
||||
func (_ *SigningMethodNone) Verify(_ []byte, _ Signature, _ interface{}) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Sign helps implement the SigningMethod interface.
|
||||
func (_ *SigningMethodNone) Sign(_ []byte, _ interface{}) (Signature, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Alg helps implement the SigningMethod interface.
|
||||
func (m *SigningMethodNone) Alg() string {
|
||||
return m.Name
|
||||
}
|
||||
|
||||
// Hasher helps implement the SigningMethod interface.
|
||||
func (m *SigningMethodNone) Hasher() crypto.Hash {
|
||||
return m.Hash
|
||||
}
|
||||
|
||||
// MarshalJSON implements json.Marshaler.
|
||||
// See SigningMethodECDSA.MarshalJSON() for information.
|
||||
func (m *SigningMethodNone) MarshalJSON() ([]byte, error) {
|
||||
return []byte(`"` + m.Alg() + `"`), nil
|
||||
}
|
||||
|
||||
var _ json.Marshaler = (*SigningMethodNone)(nil)
|
|
@ -0,0 +1,80 @@
|
|||
package crypto
|
||||
|
||||
import (
|
||||
"crypto"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"encoding/json"
|
||||
)
|
||||
|
||||
// SigningMethodRSA implements the RSA family of SigningMethods.
|
||||
type SigningMethodRSA struct {
|
||||
Name string
|
||||
Hash crypto.Hash
|
||||
_ struct{}
|
||||
}
|
||||
|
||||
// Specific instances of RSA SigningMethods.
|
||||
var (
|
||||
// SigningMethodRS256 implements RS256.
|
||||
SigningMethodRS256 = &SigningMethodRSA{
|
||||
Name: "RS256",
|
||||
Hash: crypto.SHA256,
|
||||
}
|
||||
|
||||
// SigningMethodRS384 implements RS384.
|
||||
SigningMethodRS384 = &SigningMethodRSA{
|
||||
Name: "RS384",
|
||||
Hash: crypto.SHA384,
|
||||
}
|
||||
|
||||
// SigningMethodRS512 implements RS512.
|
||||
SigningMethodRS512 = &SigningMethodRSA{
|
||||
Name: "RS512",
|
||||
Hash: crypto.SHA512,
|
||||
}
|
||||
)
|
||||
|
||||
// Alg implements the SigningMethod interface.
|
||||
func (m *SigningMethodRSA) Alg() string { return m.Name }
|
||||
|
||||
// Verify implements the Verify method from SigningMethod.
|
||||
// For this signing method, must be an *rsa.PublicKey.
|
||||
func (m *SigningMethodRSA) Verify(raw []byte, sig Signature, key interface{}) error {
|
||||
rsaKey, ok := key.(*rsa.PublicKey)
|
||||
if !ok {
|
||||
return ErrInvalidKey
|
||||
}
|
||||
return rsa.VerifyPKCS1v15(rsaKey, m.Hash, m.sum(raw), sig)
|
||||
}
|
||||
|
||||
// Sign implements the Sign method from SigningMethod.
|
||||
// For this signing method, must be an *rsa.PrivateKey structure.
|
||||
func (m *SigningMethodRSA) Sign(data []byte, key interface{}) (Signature, error) {
|
||||
rsaKey, ok := key.(*rsa.PrivateKey)
|
||||
if !ok {
|
||||
return nil, ErrInvalidKey
|
||||
}
|
||||
sigBytes, err := rsa.SignPKCS1v15(rand.Reader, rsaKey, m.Hash, m.sum(data))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return Signature(sigBytes), nil
|
||||
}
|
||||
|
||||
func (m *SigningMethodRSA) sum(b []byte) []byte {
|
||||
h := m.Hash.New()
|
||||
h.Write(b)
|
||||
return h.Sum(nil)
|
||||
}
|
||||
|
||||
// Hasher implements the SigningMethod interface.
|
||||
func (m *SigningMethodRSA) Hasher() crypto.Hash { return m.Hash }
|
||||
|
||||
// MarshalJSON implements json.Marshaler.
|
||||
// See SigningMethodECDSA.MarshalJSON() for information.
|
||||
func (m *SigningMethodRSA) MarshalJSON() ([]byte, error) {
|
||||
return []byte(`"` + m.Alg() + `"`), nil
|
||||
}
|
||||
|
||||
var _ json.Marshaler = (*SigningMethodRSA)(nil)
|
|
@ -0,0 +1,96 @@
|
|||
// +build go1.4
|
||||
|
||||
package crypto
|
||||
|
||||
import (
|
||||
"crypto"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"encoding/json"
|
||||
)
|
||||
|
||||
// SigningMethodRSAPSS implements the RSAPSS family of SigningMethods.
|
||||
type SigningMethodRSAPSS struct {
|
||||
*SigningMethodRSA
|
||||
Options *rsa.PSSOptions
|
||||
}
|
||||
|
||||
// Specific instances for RS/PS SigningMethods.
|
||||
var (
|
||||
// SigningMethodPS256 implements PS256.
|
||||
SigningMethodPS256 = &SigningMethodRSAPSS{
|
||||
&SigningMethodRSA{
|
||||
Name: "PS256",
|
||||
Hash: crypto.SHA256,
|
||||
},
|
||||
&rsa.PSSOptions{
|
||||
SaltLength: rsa.PSSSaltLengthAuto,
|
||||
Hash: crypto.SHA256,
|
||||
},
|
||||
}
|
||||
|
||||
// SigningMethodPS384 implements PS384.
|
||||
SigningMethodPS384 = &SigningMethodRSAPSS{
|
||||
&SigningMethodRSA{
|
||||
Name: "PS384",
|
||||
Hash: crypto.SHA384,
|
||||
},
|
||||
&rsa.PSSOptions{
|
||||
SaltLength: rsa.PSSSaltLengthAuto,
|
||||
Hash: crypto.SHA384,
|
||||
},
|
||||
}
|
||||
|
||||
// SigningMethodPS512 implements PS512.
|
||||
SigningMethodPS512 = &SigningMethodRSAPSS{
|
||||
&SigningMethodRSA{
|
||||
Name: "PS512",
|
||||
Hash: crypto.SHA512,
|
||||
},
|
||||
&rsa.PSSOptions{
|
||||
SaltLength: rsa.PSSSaltLengthAuto,
|
||||
Hash: crypto.SHA512,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
// Verify implements the Verify method from SigningMethod.
|
||||
// For this verify method, key must be an *rsa.PublicKey.
|
||||
func (m *SigningMethodRSAPSS) Verify(raw []byte, signature Signature, key interface{}) error {
|
||||
rsaKey, ok := key.(*rsa.PublicKey)
|
||||
if !ok {
|
||||
return ErrInvalidKey
|
||||
}
|
||||
return rsa.VerifyPSS(rsaKey, m.Hash, m.sum(raw), signature, m.Options)
|
||||
}
|
||||
|
||||
// Sign implements the Sign method from SigningMethod.
|
||||
// For this signing method, key must be an *rsa.PrivateKey.
|
||||
func (m *SigningMethodRSAPSS) Sign(raw []byte, key interface{}) (Signature, error) {
|
||||
rsaKey, ok := key.(*rsa.PrivateKey)
|
||||
if !ok {
|
||||
return nil, ErrInvalidKey
|
||||
}
|
||||
sigBytes, err := rsa.SignPSS(rand.Reader, rsaKey, m.Hash, m.sum(raw), m.Options)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return Signature(sigBytes), nil
|
||||
}
|
||||
|
||||
func (m *SigningMethodRSAPSS) sum(b []byte) []byte {
|
||||
h := m.Hash.New()
|
||||
h.Write(b)
|
||||
return h.Sum(nil)
|
||||
}
|
||||
|
||||
// Hasher implements the Hasher method from SigningMethod.
|
||||
func (m *SigningMethodRSAPSS) Hasher() crypto.Hash { return m.Hash }
|
||||
|
||||
// MarshalJSON implements json.Marshaler.
|
||||
// See SigningMethodECDSA.MarshalJSON() for information.
|
||||
func (m *SigningMethodRSAPSS) MarshalJSON() ([]byte, error) {
|
||||
return []byte(`"` + m.Alg() + `"`), nil
|
||||
}
|
||||
|
||||
var _ json.Marshaler = (*SigningMethodRSAPSS)(nil)
|
|
@ -0,0 +1,70 @@
|
|||
package crypto
|
||||
|
||||
import (
|
||||
"crypto/rsa"
|
||||
"crypto/x509"
|
||||
"encoding/pem"
|
||||
"errors"
|
||||
)
|
||||
|
||||
// Errors specific to rsa_utils.
|
||||
var (
|
||||
ErrKeyMustBePEMEncoded = errors.New("invalid key: Key must be PEM encoded PKCS1 or PKCS8 private key")
|
||||
ErrNotRSAPrivateKey = errors.New("key is not a valid RSA private key")
|
||||
ErrNotRSAPublicKey = errors.New("key is not a valid RSA public key")
|
||||
)
|
||||
|
||||
// ParseRSAPrivateKeyFromPEM parses a PEM encoded PKCS1 or PKCS8 private key.
|
||||
func ParseRSAPrivateKeyFromPEM(key []byte) (*rsa.PrivateKey, error) {
|
||||
var err error
|
||||
|
||||
// Parse PEM block
|
||||
var block *pem.Block
|
||||
if block, _ = pem.Decode(key); block == nil {
|
||||
return nil, ErrKeyMustBePEMEncoded
|
||||
}
|
||||
|
||||
var parsedKey interface{}
|
||||
if parsedKey, err = x509.ParsePKCS1PrivateKey(block.Bytes); err != nil {
|
||||
if parsedKey, err = x509.ParsePKCS8PrivateKey(block.Bytes); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
var pkey *rsa.PrivateKey
|
||||
var ok bool
|
||||
if pkey, ok = parsedKey.(*rsa.PrivateKey); !ok {
|
||||
return nil, ErrNotRSAPrivateKey
|
||||
}
|
||||
|
||||
return pkey, nil
|
||||
}
|
||||
|
||||
// ParseRSAPublicKeyFromPEM parses PEM encoded PKCS1 or PKCS8 public key.
|
||||
func ParseRSAPublicKeyFromPEM(key []byte) (*rsa.PublicKey, error) {
|
||||
var err error
|
||||
|
||||
// Parse PEM block
|
||||
var block *pem.Block
|
||||
if block, _ = pem.Decode(key); block == nil {
|
||||
return nil, ErrKeyMustBePEMEncoded
|
||||
}
|
||||
|
||||
// Parse the key
|
||||
var parsedKey interface{}
|
||||
if parsedKey, err = x509.ParsePKIXPublicKey(block.Bytes); err != nil {
|
||||
if cert, err := x509.ParseCertificate(block.Bytes); err == nil {
|
||||
parsedKey = cert.PublicKey
|
||||
} else {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
var pkey *rsa.PublicKey
|
||||
var ok bool
|
||||
if pkey, ok = parsedKey.(*rsa.PublicKey); !ok {
|
||||
return nil, ErrNotRSAPublicKey
|
||||
}
|
||||
|
||||
return pkey, nil
|
||||
}
|
|
@ -0,0 +1,36 @@
|
|||
package crypto
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
|
||||
"github.com/SermoDigital/jose"
|
||||
)
|
||||
|
||||
// Signature is a JWS signature.
|
||||
type Signature []byte
|
||||
|
||||
// MarshalJSON implements json.Marshaler for a signature.
|
||||
func (s Signature) MarshalJSON() ([]byte, error) {
|
||||
return jose.EncodeEscape(s), nil
|
||||
}
|
||||
|
||||
// Base64 helps implements jose.Encoder for Signature.
|
||||
func (s Signature) Base64() ([]byte, error) {
|
||||
return jose.Base64Encode(s), nil
|
||||
}
|
||||
|
||||
// UnmarshalJSON implements json.Unmarshaler for signature.
|
||||
func (s *Signature) UnmarshalJSON(b []byte) error {
|
||||
dec, err := jose.DecodeEscaped(b)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
*s = Signature(dec)
|
||||
return nil
|
||||
}
|
||||
|
||||
var (
|
||||
_ json.Marshaler = (Signature)(nil)
|
||||
_ json.Unmarshaler = (*Signature)(nil)
|
||||
_ jose.Encoder = (Signature)(nil)
|
||||
)
|
|
@ -0,0 +1,24 @@
|
|||
package crypto
|
||||
|
||||
import "crypto"
|
||||
|
||||
// SigningMethod is an interface that provides a way to sign JWS tokens.
|
||||
type SigningMethod interface {
|
||||
// Alg describes the signing algorithm, and is used to uniquely
|
||||
// describe the specific crypto.SigningMethod.
|
||||
Alg() string
|
||||
|
||||
// Verify accepts the raw content, the signature, and the key used
|
||||
// to sign the raw content, and returns any errors found while validating
|
||||
// the signature and content.
|
||||
Verify(raw []byte, sig Signature, key interface{}) error
|
||||
|
||||
// Sign returns a Signature for the raw bytes, as well as any errors
|
||||
// that occurred during the signing.
|
||||
Sign(raw []byte, key interface{}) (Signature, error)
|
||||
|
||||
// Used to cause quick panics when a crypto.SigningMethod whose form of hashing
|
||||
// isn't linked in the binary when you register a crypto.SigningMethod.
|
||||
// To spoof this, see "crypto.SigningMethodNone".
|
||||
Hasher() crypto.Hash
|
||||
}
|
|
@ -0,0 +1,3 @@
|
|||
// Package jose implements some helper functions and types for the children
|
||||
// packages, jws, jwt, and jwe.
|
||||
package jose
|
|
@ -0,0 +1,124 @@
|
|||
package jose
|
||||
|
||||
import "encoding/json"
|
||||
|
||||
// Header implements a JOSE Header with the addition of some helper
|
||||
// methods, similar to net/url.Values.
|
||||
type Header map[string]interface{}
|
||||
|
||||
// Get retrieves the value corresponding with key from the Header.
|
||||
func (h Header) Get(key string) interface{} {
|
||||
if h == nil {
|
||||
return nil
|
||||
}
|
||||
return h[key]
|
||||
}
|
||||
|
||||
// Set sets Claims[key] = val. It'll overwrite without warning.
|
||||
func (h Header) Set(key string, val interface{}) {
|
||||
h[key] = val
|
||||
}
|
||||
|
||||
// Del removes the value that corresponds with key from the Header.
|
||||
func (h Header) Del(key string) {
|
||||
delete(h, key)
|
||||
}
|
||||
|
||||
// Has returns true if a value for the given key exists inside the Header.
|
||||
func (h Header) Has(key string) bool {
|
||||
_, ok := h[key]
|
||||
return ok
|
||||
}
|
||||
|
||||
// MarshalJSON implements json.Marshaler for Header.
|
||||
func (h Header) MarshalJSON() ([]byte, error) {
|
||||
if len(h) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
b, err := json.Marshal(map[string]interface{}(h))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return EncodeEscape(b), nil
|
||||
}
|
||||
|
||||
// Base64 implements the Encoder interface.
|
||||
func (h Header) Base64() ([]byte, error) {
|
||||
return h.MarshalJSON()
|
||||
}
|
||||
|
||||
// UnmarshalJSON implements json.Unmarshaler for Header.
|
||||
func (h *Header) UnmarshalJSON(b []byte) error {
|
||||
if b == nil {
|
||||
return nil
|
||||
}
|
||||
b, err := DecodeEscaped(b)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return json.Unmarshal(b, (*map[string]interface{})(h))
|
||||
}
|
||||
|
||||
// Protected Headers are base64-encoded after they're marshaled into
|
||||
// JSON.
|
||||
type Protected Header
|
||||
|
||||
// Get retrieves the value corresponding with key from the Protected Header.
|
||||
func (p Protected) Get(key string) interface{} {
|
||||
if p == nil {
|
||||
return nil
|
||||
}
|
||||
return p[key]
|
||||
}
|
||||
|
||||
// Set sets Protected[key] = val. It'll overwrite without warning.
|
||||
func (p Protected) Set(key string, val interface{}) {
|
||||
p[key] = val
|
||||
}
|
||||
|
||||
// Del removes the value that corresponds with key from the Protected Header.
|
||||
func (p Protected) Del(key string) {
|
||||
delete(p, key)
|
||||
}
|
||||
|
||||
// Has returns true if a value for the given key exists inside the Protected
|
||||
// Header.
|
||||
func (p Protected) Has(key string) bool {
|
||||
_, ok := p[key]
|
||||
return ok
|
||||
}
|
||||
|
||||
// MarshalJSON implements json.Marshaler for Protected.
|
||||
func (p Protected) MarshalJSON() ([]byte, error) {
|
||||
b, err := json.Marshal(map[string]interface{}(p))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return EncodeEscape(b), nil
|
||||
}
|
||||
|
||||
// Base64 implements the Encoder interface.
|
||||
func (p Protected) Base64() ([]byte, error) {
|
||||
b, err := json.Marshal(map[string]interface{}(p))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return Base64Encode(b), nil
|
||||
}
|
||||
|
||||
// UnmarshalJSON implements json.Unmarshaler for Protected.
|
||||
func (p *Protected) UnmarshalJSON(b []byte) error {
|
||||
var h Header
|
||||
if err := h.UnmarshalJSON(b); err != nil {
|
||||
return err
|
||||
}
|
||||
*p = Protected(h)
|
||||
return nil
|
||||
}
|
||||
|
||||
var (
|
||||
_ json.Marshaler = (Protected)(nil)
|
||||
_ json.Unmarshaler = (*Protected)(nil)
|
||||
_ json.Marshaler = (Header)(nil)
|
||||
_ json.Unmarshaler = (*Header)(nil)
|
||||
)
|
|
@ -0,0 +1,190 @@
|
|||
package jws
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"time"
|
||||
|
||||
"github.com/SermoDigital/jose"
|
||||
"github.com/SermoDigital/jose/jwt"
|
||||
)
|
||||
|
||||
// Claims represents a set of JOSE Claims.
|
||||
type Claims jwt.Claims
|
||||
|
||||
// Get retrieves the value corresponding with key from the Claims.
|
||||
func (c Claims) Get(key string) interface{} {
|
||||
return jwt.Claims(c).Get(key)
|
||||
}
|
||||
|
||||
// Set sets Claims[key] = val. It'll overwrite without warning.
|
||||
func (c Claims) Set(key string, val interface{}) {
|
||||
jwt.Claims(c).Set(key, val)
|
||||
}
|
||||
|
||||
// Del removes the value that corresponds with key from the Claims.
|
||||
func (c Claims) Del(key string) {
|
||||
jwt.Claims(c).Del(key)
|
||||
}
|
||||
|
||||
// Has returns true if a value for the given key exists inside the Claims.
|
||||
func (c Claims) Has(key string) bool {
|
||||
return jwt.Claims(c).Has(key)
|
||||
}
|
||||
|
||||
// MarshalJSON implements json.Marshaler for Claims.
|
||||
func (c Claims) MarshalJSON() ([]byte, error) {
|
||||
return jwt.Claims(c).MarshalJSON()
|
||||
}
|
||||
|
||||
// Base64 implements the Encoder interface.
|
||||
func (c Claims) Base64() ([]byte, error) {
|
||||
return jwt.Claims(c).Base64()
|
||||
}
|
||||
|
||||
// UnmarshalJSON implements json.Unmarshaler for Claims.
|
||||
func (c *Claims) UnmarshalJSON(b []byte) error {
|
||||
if b == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
b, err := jose.DecodeEscaped(b)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Since json.Unmarshal calls UnmarshalJSON,
|
||||
// calling json.Unmarshal on *p would be infinitely recursive
|
||||
// A temp variable is needed because &map[string]interface{}(*p) is
|
||||
// invalid Go.
|
||||
|
||||
tmp := map[string]interface{}(*c)
|
||||
if err = json.Unmarshal(b, &tmp); err != nil {
|
||||
return err
|
||||
}
|
||||
*c = Claims(tmp)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Issuer retrieves claim "iss" per its type in
|
||||
// https://tools.ietf.org/html/rfc7519#section-4.1.1
|
||||
func (c Claims) Issuer() (string, bool) {
|
||||
return jwt.Claims(c).Issuer()
|
||||
}
|
||||
|
||||
// Subject retrieves claim "sub" per its type in
|
||||
// https://tools.ietf.org/html/rfc7519#section-4.1.2
|
||||
func (c Claims) Subject() (string, bool) {
|
||||
return jwt.Claims(c).Subject()
|
||||
}
|
||||
|
||||
// Audience retrieves claim "aud" per its type in
|
||||
// https://tools.ietf.org/html/rfc7519#section-4.1.3
|
||||
func (c Claims) Audience() ([]string, bool) {
|
||||
return jwt.Claims(c).Audience()
|
||||
}
|
||||
|
||||
// Expiration retrieves claim "exp" per its type in
|
||||
// https://tools.ietf.org/html/rfc7519#section-4.1.4
|
||||
func (c Claims) Expiration() (time.Time, bool) {
|
||||
return jwt.Claims(c).Expiration()
|
||||
}
|
||||
|
||||
// NotBefore retrieves claim "nbf" per its type in
|
||||
// https://tools.ietf.org/html/rfc7519#section-4.1.5
|
||||
func (c Claims) NotBefore() (time.Time, bool) {
|
||||
return jwt.Claims(c).NotBefore()
|
||||
}
|
||||
|
||||
// IssuedAt retrieves claim "iat" per its type in
|
||||
// https://tools.ietf.org/html/rfc7519#section-4.1.6
|
||||
func (c Claims) IssuedAt() (time.Time, bool) {
|
||||
return jwt.Claims(c).IssuedAt()
|
||||
}
|
||||
|
||||
// JWTID retrieves claim "jti" per its type in
|
||||
// https://tools.ietf.org/html/rfc7519#section-4.1.7
|
||||
func (c Claims) JWTID() (string, bool) {
|
||||
return jwt.Claims(c).JWTID()
|
||||
}
|
||||
|
||||
// RemoveIssuer deletes claim "iss" from c.
|
||||
func (c Claims) RemoveIssuer() {
|
||||
jwt.Claims(c).RemoveIssuer()
|
||||
}
|
||||
|
||||
// RemoveSubject deletes claim "sub" from c.
|
||||
func (c Claims) RemoveSubject() {
|
||||
jwt.Claims(c).RemoveIssuer()
|
||||
}
|
||||
|
||||
// RemoveAudience deletes claim "aud" from c.
|
||||
func (c Claims) RemoveAudience() {
|
||||
jwt.Claims(c).Audience()
|
||||
}
|
||||
|
||||
// RemoveExpiration deletes claim "exp" from c.
|
||||
func (c Claims) RemoveExpiration() {
|
||||
jwt.Claims(c).RemoveExpiration()
|
||||
}
|
||||
|
||||
// RemoveNotBefore deletes claim "nbf" from c.
|
||||
func (c Claims) RemoveNotBefore() {
|
||||
jwt.Claims(c).NotBefore()
|
||||
}
|
||||
|
||||
// RemoveIssuedAt deletes claim "iat" from c.
|
||||
func (c Claims) RemoveIssuedAt() {
|
||||
jwt.Claims(c).IssuedAt()
|
||||
}
|
||||
|
||||
// RemoveJWTID deletes claim "jti" from c.
|
||||
func (c Claims) RemoveJWTID() {
|
||||
jwt.Claims(c).RemoveJWTID()
|
||||
}
|
||||
|
||||
// SetIssuer sets claim "iss" per its type in
|
||||
// https://tools.ietf.org/html/rfc7519#section-4.1.1
|
||||
func (c Claims) SetIssuer(issuer string) {
|
||||
jwt.Claims(c).SetIssuer(issuer)
|
||||
}
|
||||
|
||||
// SetSubject sets claim "iss" per its type in
|
||||
// https://tools.ietf.org/html/rfc7519#section-4.1.2
|
||||
func (c Claims) SetSubject(subject string) {
|
||||
jwt.Claims(c).SetSubject(subject)
|
||||
}
|
||||
|
||||
// SetAudience sets claim "aud" per its type in
|
||||
// https://tools.ietf.org/html/rfc7519#section-4.1.3
|
||||
func (c Claims) SetAudience(audience ...string) {
|
||||
jwt.Claims(c).SetAudience(audience...)
|
||||
}
|
||||
|
||||
// SetExpiration sets claim "exp" per its type in
|
||||
// https://tools.ietf.org/html/rfc7519#section-4.1.4
|
||||
func (c Claims) SetExpiration(expiration time.Time) {
|
||||
jwt.Claims(c).SetExpiration(expiration)
|
||||
}
|
||||
|
||||
// SetNotBefore sets claim "nbf" per its type in
|
||||
// https://tools.ietf.org/html/rfc7519#section-4.1.5
|
||||
func (c Claims) SetNotBefore(notBefore time.Time) {
|
||||
jwt.Claims(c).SetNotBefore(notBefore)
|
||||
}
|
||||
|
||||
// SetIssuedAt sets claim "iat" per its type in
|
||||
// https://tools.ietf.org/html/rfc7519#section-4.1.6
|
||||
func (c Claims) SetIssuedAt(issuedAt time.Time) {
|
||||
jwt.Claims(c).SetIssuedAt(issuedAt)
|
||||
}
|
||||
|
||||
// SetJWTID sets claim "jti" per its type in
|
||||
// https://tools.ietf.org/html/rfc7519#section-4.1.7
|
||||
func (c Claims) SetJWTID(uniqueID string) {
|
||||
jwt.Claims(c).SetJWTID(uniqueID)
|
||||
}
|
||||
|
||||
var (
|
||||
_ json.Marshaler = (Claims)(nil)
|
||||
_ json.Unmarshaler = (*Claims)(nil)
|
||||
)
|
|
@ -0,0 +1,2 @@
|
|||
// Package jws implements JWSs per RFC 7515
|
||||
package jws
|
|
@ -0,0 +1,62 @@
|
|||
package jws
|
||||
|
||||
import "errors"
|
||||
|
||||
var (
|
||||
|
||||
// ErrNotEnoughMethods is returned if New was called _or_ the Flat/Compact
|
||||
// methods were called with 0 SigningMethods.
|
||||
ErrNotEnoughMethods = errors.New("not enough methods provided")
|
||||
|
||||
// ErrCouldNotUnmarshal is returned when Parse's json.Unmarshaler
|
||||
// parameter returns an error.
|
||||
ErrCouldNotUnmarshal = errors.New("custom unmarshal failed")
|
||||
|
||||
// ErrNotCompact signals that the provided potential JWS is not
|
||||
// in its compact representation.
|
||||
ErrNotCompact = errors.New("not a compact JWS")
|
||||
|
||||
// ErrDuplicateHeaderParameter signals that there are duplicate parameters
|
||||
// in the provided Headers.
|
||||
ErrDuplicateHeaderParameter = errors.New("duplicate parameters in the JOSE Header")
|
||||
|
||||
// ErrTwoEmptyHeaders is returned if both Headers are empty.
|
||||
ErrTwoEmptyHeaders = errors.New("both headers cannot be empty")
|
||||
|
||||
// ErrNotEnoughKeys is returned when not enough keys are provided for
|
||||
// the given SigningMethods.
|
||||
ErrNotEnoughKeys = errors.New("not enough keys (for given methods)")
|
||||
|
||||
// ErrDidNotValidate means the given JWT did not properly validate
|
||||
ErrDidNotValidate = errors.New("did not validate")
|
||||
|
||||
// ErrNoAlgorithm means no algorithm ("alg") was found in the Protected
|
||||
// Header.
|
||||
ErrNoAlgorithm = errors.New("no algorithm found")
|
||||
|
||||
// ErrAlgorithmDoesntExist means the algorithm asked for cannot be
|
||||
// found inside the signingMethod cache.
|
||||
ErrAlgorithmDoesntExist = errors.New("algorithm doesn't exist")
|
||||
|
||||
// ErrMismatchedAlgorithms means the algorithm inside the JWT was
|
||||
// different than the algorithm the caller wanted to use.
|
||||
ErrMismatchedAlgorithms = errors.New("mismatched algorithms")
|
||||
|
||||
// ErrCannotValidate means the JWS cannot be validated for various
|
||||
// reasons. For example, if there aren't any signatures/payloads/headers
|
||||
// to actually validate.
|
||||
ErrCannotValidate = errors.New("cannot validate")
|
||||
|
||||
// ErrIsNotJWT means the given JWS is not a JWT.
|
||||
ErrIsNotJWT = errors.New("JWS is not a JWT")
|
||||
|
||||
// ErrHoldsJWE means the given JWS holds a JWE inside its payload.
|
||||
ErrHoldsJWE = errors.New("JWS holds JWE")
|
||||
|
||||
// ErrNotEnoughValidSignatures means the JWS did not meet the required
|
||||
// number of signatures.
|
||||
ErrNotEnoughValidSignatures = errors.New("not enough valid signatures in the JWS")
|
||||
|
||||
// ErrNoTokenInRequest means there's no token present inside the *http.Request.
|
||||
ErrNoTokenInRequest = errors.New("no token present in request")
|
||||
)
|
|
@ -0,0 +1,490 @@
|
|||
package jws
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/SermoDigital/jose"
|
||||
"github.com/SermoDigital/jose/crypto"
|
||||
)
|
||||
|
||||
// JWS implements a JWS per RFC 7515.
|
||||
type JWS interface {
|
||||
// Payload Returns the payload.
|
||||
Payload() interface{}
|
||||
|
||||
// SetPayload sets the payload with the given value.
|
||||
SetPayload(p interface{})
|
||||
|
||||
// Protected returns the JWS' Protected Header.
|
||||
Protected() jose.Protected
|
||||
|
||||
// ProtectedAt returns the JWS' Protected Header.
|
||||
// i represents the index of the Protected Header.
|
||||
ProtectedAt(i int) jose.Protected
|
||||
|
||||
// Header returns the JWS' unprotected Header.
|
||||
Header() jose.Header
|
||||
|
||||
// HeaderAt returns the JWS' unprotected Header.
|
||||
// i represents the index of the unprotected Header.
|
||||
HeaderAt(i int) jose.Header
|
||||
|
||||
// Verify validates the current JWS' signature as-is. Refer to
|
||||
// ValidateMulti for more information.
|
||||
Verify(key interface{}, method crypto.SigningMethod) error
|
||||
|
||||
// ValidateMulti validates the current JWS' signature as-is. Since it's
|
||||
// meant to be called after parsing a stream of bytes into a JWS, it
|
||||
// shouldn't do any internal parsing like the Sign, Flat, Compact, or
|
||||
// General methods do.
|
||||
VerifyMulti(keys []interface{}, methods []crypto.SigningMethod, o *SigningOpts) error
|
||||
|
||||
// VerifyCallback validates the current JWS' signature as-is. It
|
||||
// accepts a callback function that can be used to access header
|
||||
// parameters to lookup needed information. For example, looking
|
||||
// up the "kid" parameter.
|
||||
// The return slice must be a slice of keys used in the verification
|
||||
// of the JWS.
|
||||
VerifyCallback(fn VerifyCallback, methods []crypto.SigningMethod, o *SigningOpts) error
|
||||
|
||||
// General serializes the JWS into its "general" form per
|
||||
// https://tools.ietf.org/html/rfc7515#section-7.2.1
|
||||
General(keys ...interface{}) ([]byte, error)
|
||||
|
||||
// Flat serializes the JWS to its "flattened" form per
|
||||
// https://tools.ietf.org/html/rfc7515#section-7.2.2
|
||||
Flat(key interface{}) ([]byte, error)
|
||||
|
||||
// Compact serializes the JWS into its "compact" form per
|
||||
// https://tools.ietf.org/html/rfc7515#section-7.1
|
||||
Compact(key interface{}) ([]byte, error)
|
||||
|
||||
// IsJWT returns true if the JWS is a JWT.
|
||||
IsJWT() bool
|
||||
}
|
||||
|
||||
// jws represents a specific jws.
|
||||
type jws struct {
|
||||
payload *payload
|
||||
plcache rawBase64
|
||||
clean bool
|
||||
|
||||
sb []sigHead
|
||||
|
||||
isJWT bool
|
||||
}
|
||||
|
||||
// Payload returns the jws' payload.
|
||||
func (j *jws) Payload() interface{} {
|
||||
return j.payload.v
|
||||
}
|
||||
|
||||
// SetPayload sets the jws' raw, unexported payload.
|
||||
func (j *jws) SetPayload(val interface{}) {
|
||||
j.payload.v = val
|
||||
}
|
||||
|
||||
// Protected returns the JWS' Protected Header.
|
||||
func (j *jws) Protected() jose.Protected {
|
||||
return j.sb[0].protected
|
||||
}
|
||||
|
||||
// Protected returns the JWS' Protected Header.
|
||||
// i represents the index of the Protected Header.
|
||||
// Left empty, it defaults to 0.
|
||||
func (j *jws) ProtectedAt(i int) jose.Protected {
|
||||
return j.sb[i].protected
|
||||
}
|
||||
|
||||
// Header returns the JWS' unprotected Header.
|
||||
func (j *jws) Header() jose.Header {
|
||||
return j.sb[0].unprotected
|
||||
}
|
||||
|
||||
// HeaderAt returns the JWS' unprotected Header.
|
||||
// |i| is the index of the unprotected Header.
|
||||
func (j *jws) HeaderAt(i int) jose.Header {
|
||||
return j.sb[i].unprotected
|
||||
}
|
||||
|
||||
// sigHead represents the 'signatures' member of the jws' "general"
|
||||
// serialization form per
|
||||
// https://tools.ietf.org/html/rfc7515#section-7.2.1
|
||||
//
|
||||
// It's embedded inside the "flat" structure in order to properly
|
||||
// create the "flat" jws.
|
||||
type sigHead struct {
|
||||
Protected rawBase64 `json:"protected,omitempty"`
|
||||
Unprotected rawBase64 `json:"header,omitempty"`
|
||||
Signature crypto.Signature `json:"signature"`
|
||||
|
||||
protected jose.Protected
|
||||
unprotected jose.Header
|
||||
clean bool
|
||||
|
||||
method crypto.SigningMethod
|
||||
}
|
||||
|
||||
func (s *sigHead) unmarshal() error {
|
||||
if err := s.protected.UnmarshalJSON(s.Protected); err != nil {
|
||||
return err
|
||||
}
|
||||
return s.unprotected.UnmarshalJSON(s.Unprotected)
|
||||
}
|
||||
|
||||
// New creates a JWS with the provided crypto.SigningMethods.
|
||||
func New(content interface{}, methods ...crypto.SigningMethod) JWS {
|
||||
sb := make([]sigHead, len(methods))
|
||||
for i := range methods {
|
||||
sb[i] = sigHead{
|
||||
protected: jose.Protected{
|
||||
"alg": methods[i].Alg(),
|
||||
},
|
||||
unprotected: jose.Header{},
|
||||
method: methods[i],
|
||||
}
|
||||
}
|
||||
return &jws{
|
||||
payload: &payload{v: content},
|
||||
sb: sb,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *sigHead) assignMethod(p jose.Protected) error {
|
||||
alg, ok := p.Get("alg").(string)
|
||||
if !ok {
|
||||
return ErrNoAlgorithm
|
||||
}
|
||||
|
||||
sm := GetSigningMethod(alg)
|
||||
if sm == nil {
|
||||
return ErrNoAlgorithm
|
||||
}
|
||||
s.method = sm
|
||||
return nil
|
||||
}
|
||||
|
||||
type generic struct {
|
||||
Payload rawBase64 `json:"payload"`
|
||||
sigHead
|
||||
Signatures []sigHead `json:"signatures,omitempty"`
|
||||
}
|
||||
|
||||
// Parse parses any of the three serialized jws forms into a physical
|
||||
// jws per https://tools.ietf.org/html/rfc7515#section-5.2
|
||||
//
|
||||
// It accepts a json.Unmarshaler in order to properly parse
|
||||
// the payload. In order to keep the caller from having to do extra
|
||||
// parsing of the payload, a json.Unmarshaler can be passed
|
||||
// which will be then to unmarshal the payload however the caller
|
||||
// wishes. Do note that if json.Unmarshal returns an error the
|
||||
// original payload will be used as if no json.Unmarshaler was
|
||||
// passed.
|
||||
//
|
||||
// Internally, Parse applies some heuristics and then calls either
|
||||
// ParseGeneral, ParseFlat, or ParseCompact.
|
||||
// It should only be called if, for whatever reason, you do not
|
||||
// know which form the serialized JWT is in.
|
||||
//
|
||||
// It cannot parse a JWT.
|
||||
func Parse(encoded []byte, u ...json.Unmarshaler) (JWS, error) {
|
||||
// Try and unmarshal into a generic struct that'll
|
||||
// hopefully hold either of the two JSON serialization
|
||||
// formats.
|
||||
var g generic
|
||||
|
||||
// Not valid JSON. Let's try compact.
|
||||
if err := json.Unmarshal(encoded, &g); err != nil {
|
||||
return ParseCompact(encoded, u...)
|
||||
}
|
||||
|
||||
if g.Signatures == nil {
|
||||
return g.parseFlat(u...)
|
||||
}
|
||||
return g.parseGeneral(u...)
|
||||
}
|
||||
|
||||
// ParseGeneral parses a jws serialized into its "general" form per
|
||||
// https://tools.ietf.org/html/rfc7515#section-7.2.1
|
||||
// into a physical jws per
|
||||
// https://tools.ietf.org/html/rfc7515#section-5.2
|
||||
//
|
||||
// For information on the json.Unmarshaler parameter, see Parse.
|
||||
func ParseGeneral(encoded []byte, u ...json.Unmarshaler) (JWS, error) {
|
||||
var g generic
|
||||
if err := json.Unmarshal(encoded, &g); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return g.parseGeneral(u...)
|
||||
}
|
||||
|
||||
func (g *generic) parseGeneral(u ...json.Unmarshaler) (JWS, error) {
|
||||
|
||||
var p payload
|
||||
if len(u) > 0 {
|
||||
p.u = u[0]
|
||||
}
|
||||
|
||||
if err := p.UnmarshalJSON(g.Payload); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for i := range g.Signatures {
|
||||
if err := g.Signatures[i].unmarshal(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := checkHeaders(jose.Header(g.Signatures[i].protected), g.Signatures[i].unprotected); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := g.Signatures[i].assignMethod(g.Signatures[i].protected); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
g.clean = len(g.Signatures) != 0
|
||||
|
||||
return &jws{
|
||||
payload: &p,
|
||||
plcache: g.Payload,
|
||||
clean: true,
|
||||
sb: g.Signatures,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ParseFlat parses a jws serialized into its "flat" form per
|
||||
// https://tools.ietf.org/html/rfc7515#section-7.2.2
|
||||
// into a physical jws per
|
||||
// https://tools.ietf.org/html/rfc7515#section-5.2
|
||||
//
|
||||
// For information on the json.Unmarshaler parameter, see Parse.
|
||||
func ParseFlat(encoded []byte, u ...json.Unmarshaler) (JWS, error) {
|
||||
var g generic
|
||||
if err := json.Unmarshal(encoded, &g); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return g.parseFlat(u...)
|
||||
}
|
||||
|
||||
func (g *generic) parseFlat(u ...json.Unmarshaler) (JWS, error) {
|
||||
|
||||
var p payload
|
||||
if len(u) > 0 {
|
||||
p.u = u[0]
|
||||
}
|
||||
|
||||
if err := p.UnmarshalJSON(g.Payload); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := g.sigHead.unmarshal(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
g.sigHead.clean = true
|
||||
|
||||
if err := checkHeaders(jose.Header(g.sigHead.protected), g.sigHead.unprotected); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := g.sigHead.assignMethod(g.sigHead.protected); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &jws{
|
||||
payload: &p,
|
||||
plcache: g.Payload,
|
||||
clean: true,
|
||||
sb: []sigHead{g.sigHead},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ParseCompact parses a jws serialized into its "compact" form per
|
||||
// https://tools.ietf.org/html/rfc7515#section-7.1
|
||||
// into a physical jws per
|
||||
// https://tools.ietf.org/html/rfc7515#section-5.2
|
||||
//
|
||||
// For information on the json.Unmarshaler parameter, see Parse.
|
||||
func ParseCompact(encoded []byte, u ...json.Unmarshaler) (JWS, error) {
|
||||
return parseCompact(encoded, false, u...)
|
||||
}
|
||||
|
||||
func parseCompact(encoded []byte, jwt bool, u ...json.Unmarshaler) (*jws, error) {
|
||||
|
||||
// This section loosely follows
|
||||
// https://tools.ietf.org/html/rfc7519#section-7.2
|
||||
// because it's used to parse _both_ jws and JWTs.
|
||||
|
||||
parts := bytes.Split(encoded, []byte{'.'})
|
||||
if len(parts) != 3 {
|
||||
return nil, ErrNotCompact
|
||||
}
|
||||
|
||||
var p jose.Protected
|
||||
if err := p.UnmarshalJSON(parts[0]); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
s := sigHead{
|
||||
Protected: parts[0],
|
||||
protected: p,
|
||||
Signature: parts[2],
|
||||
clean: true,
|
||||
}
|
||||
|
||||
if err := s.assignMethod(p); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var pl payload
|
||||
if len(u) > 0 {
|
||||
pl.u = u[0]
|
||||
}
|
||||
|
||||
j := jws{
|
||||
payload: &pl,
|
||||
plcache: parts[1],
|
||||
sb: []sigHead{s},
|
||||
isJWT: jwt,
|
||||
}
|
||||
|
||||
if err := j.payload.UnmarshalJSON(parts[1]); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
j.clean = true
|
||||
|
||||
if err := j.sb[0].Signature.UnmarshalJSON(parts[2]); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// https://tools.ietf.org/html/rfc7519#section-7.2.8
|
||||
cty, ok := p.Get("cty").(string)
|
||||
if ok && cty == "JWT" {
|
||||
return &j, ErrHoldsJWE
|
||||
}
|
||||
return &j, nil
|
||||
}
|
||||
|
||||
var (
|
||||
// JWSFormKey is the form "key" which should be used inside
|
||||
// ParseFromRequest if the request is a multipart.Form.
|
||||
JWSFormKey = "access_token"
|
||||
|
||||
// MaxMemory is maximum amount of memory which should be used
|
||||
// inside ParseFromRequest while parsing the multipart.Form
|
||||
// if the request is a multipart.Form.
|
||||
MaxMemory int64 = 10e6
|
||||
)
|
||||
|
||||
// Format specifies which "format" the JWS is in -- Flat, General,
|
||||
// or compact. Additionally, constants for JWT/Unknown are added.
|
||||
type Format uint8
|
||||
|
||||
const (
|
||||
// Unknown format.
|
||||
Unknown Format = iota
|
||||
|
||||
// Flat format.
|
||||
Flat
|
||||
|
||||
// General format.
|
||||
General
|
||||
|
||||
// Compact format.
|
||||
Compact
|
||||
)
|
||||
|
||||
var parseJumpTable = [...]func([]byte, ...json.Unmarshaler) (JWS, error){
|
||||
Unknown: Parse,
|
||||
Flat: ParseFlat,
|
||||
General: ParseGeneral,
|
||||
Compact: ParseCompact,
|
||||
1<<8 - 1: Parse, // Max uint8.
|
||||
}
|
||||
|
||||
func init() {
|
||||
for i := range parseJumpTable {
|
||||
if parseJumpTable[i] == nil {
|
||||
parseJumpTable[i] = Parse
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func fromHeader(req *http.Request) ([]byte, bool) {
|
||||
if ah := req.Header.Get("Authorization"); len(ah) > 7 && strings.EqualFold(ah[0:7], "BEARER ") {
|
||||
return []byte(ah[7:]), true
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
func fromForm(req *http.Request) ([]byte, bool) {
|
||||
if err := req.ParseMultipartForm(MaxMemory); err != nil {
|
||||
return nil, false
|
||||
}
|
||||
if tokStr := req.Form.Get(JWSFormKey); tokStr != "" {
|
||||
return []byte(tokStr), true
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// ParseFromHeader tries to find the JWS in an http.Request header.
|
||||
func ParseFromHeader(req *http.Request, format Format, u ...json.Unmarshaler) (JWS, error) {
|
||||
if b, ok := fromHeader(req); ok {
|
||||
return parseJumpTable[format](b, u...)
|
||||
}
|
||||
return nil, ErrNoTokenInRequest
|
||||
}
|
||||
|
||||
// ParseFromForm tries to find the JWS in an http.Request form request.
|
||||
func ParseFromForm(req *http.Request, format Format, u ...json.Unmarshaler) (JWS, error) {
|
||||
if b, ok := fromForm(req); ok {
|
||||
return parseJumpTable[format](b, u...)
|
||||
}
|
||||
return nil, ErrNoTokenInRequest
|
||||
}
|
||||
|
||||
// ParseFromRequest tries to find the JWS in an http.Request.
|
||||
// This method will call ParseMultipartForm if there's no token in the header.
|
||||
func ParseFromRequest(req *http.Request, format Format, u ...json.Unmarshaler) (JWS, error) {
|
||||
token, err := ParseFromHeader(req, format, u...)
|
||||
if err == nil {
|
||||
return token, nil
|
||||
}
|
||||
|
||||
token, err = ParseFromForm(req, format, u...)
|
||||
if err == nil {
|
||||
return token, nil
|
||||
}
|
||||
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// IgnoreDupes should be set to true if the internal duplicate header key check
|
||||
// should ignore duplicate Header keys instead of reporting an error when
|
||||
// duplicate Header keys are found.
|
||||
//
|
||||
// Note:
|
||||
// Duplicate Header keys are defined in
|
||||
// https://tools.ietf.org/html/rfc7515#section-5.2
|
||||
// meaning keys that both the protected and unprotected
|
||||
// Headers possess.
|
||||
var IgnoreDupes bool
|
||||
|
||||
// checkHeaders returns an error per the constraints described in
|
||||
// IgnoreDupes' comment.
|
||||
func checkHeaders(a, b jose.Header) error {
|
||||
if len(a)+len(b) == 0 {
|
||||
return ErrTwoEmptyHeaders
|
||||
}
|
||||
for key := range a {
|
||||
if b.Has(key) && !IgnoreDupes {
|
||||
return ErrDuplicateHeaderParameter
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
var _ JWS = (*jws)(nil)
|
|
@ -0,0 +1,132 @@
|
|||
package jws
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
)
|
||||
|
||||
// Flat serializes the JWS to its "flattened" form per
|
||||
// https://tools.ietf.org/html/rfc7515#section-7.2.2
|
||||
func (j *jws) Flat(key interface{}) ([]byte, error) {
|
||||
if len(j.sb) < 1 {
|
||||
return nil, ErrNotEnoughMethods
|
||||
}
|
||||
if err := j.sign(key); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return json.Marshal(struct {
|
||||
Payload rawBase64 `json:"payload"`
|
||||
sigHead
|
||||
}{
|
||||
Payload: j.plcache,
|
||||
sigHead: j.sb[0],
|
||||
})
|
||||
}
|
||||
|
||||
// General serializes the JWS into its "general" form per
|
||||
// https://tools.ietf.org/html/rfc7515#section-7.2.1
|
||||
//
|
||||
// If only one key is passed it's used for all the provided
|
||||
// crypto.SigningMethods. Otherwise, len(keys) must equal the number
|
||||
// of crypto.SigningMethods added.
|
||||
func (j *jws) General(keys ...interface{}) ([]byte, error) {
|
||||
if err := j.sign(keys...); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return json.Marshal(struct {
|
||||
Payload rawBase64 `json:"payload"`
|
||||
Signatures []sigHead `json:"signatures"`
|
||||
}{
|
||||
Payload: j.plcache,
|
||||
Signatures: j.sb,
|
||||
})
|
||||
}
|
||||
|
||||
// Compact serializes the JWS into its "compact" form per
|
||||
// https://tools.ietf.org/html/rfc7515#section-7.1
|
||||
func (j *jws) Compact(key interface{}) ([]byte, error) {
|
||||
if len(j.sb) < 1 {
|
||||
return nil, ErrNotEnoughMethods
|
||||
}
|
||||
|
||||
if err := j.sign(key); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
sig, err := j.sb[0].Signature.Base64()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return format(
|
||||
j.sb[0].Protected,
|
||||
j.plcache,
|
||||
sig,
|
||||
), nil
|
||||
}
|
||||
|
||||
// sign signs each index of j's sb member.
|
||||
func (j *jws) sign(keys ...interface{}) error {
|
||||
if err := j.cache(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(keys) < 1 ||
|
||||
len(keys) > 1 && len(keys) != len(j.sb) {
|
||||
return ErrNotEnoughKeys
|
||||
}
|
||||
|
||||
if len(keys) == 1 {
|
||||
k := keys[0]
|
||||
keys = make([]interface{}, len(j.sb))
|
||||
for i := range keys {
|
||||
keys[i] = k
|
||||
}
|
||||
}
|
||||
|
||||
for i := range j.sb {
|
||||
if err := j.sb[i].cache(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
raw := format(j.sb[i].Protected, j.plcache)
|
||||
sig, err := j.sb[i].method.Sign(raw, keys[i])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
j.sb[i].Signature = sig
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// cache marshals the payload, but only if it's changed since the last cache.
|
||||
func (j *jws) cache() (err error) {
|
||||
if !j.clean {
|
||||
j.plcache, err = j.payload.Base64()
|
||||
j.clean = err == nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// cache marshals the protected and unprotected headers, but only if
|
||||
// they've changed since their last cache.
|
||||
func (s *sigHead) cache() (err error) {
|
||||
if !s.clean {
|
||||
s.Protected, err = s.protected.Base64()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
s.Unprotected, err = s.unprotected.Base64()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
s.clean = true
|
||||
return nil
|
||||
}
|
||||
|
||||
// format formats a slice of bytes in the order given, joining
|
||||
// them with a period.
|
||||
func format(a ...[]byte) []byte {
|
||||
return bytes.Join(a, []byte{'.'})
|
||||
}
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue