Merge pull request #527 from prometheus/nogodep-build

Copy vendored deps manually instead of using Godeps.
pull/537/head
Julius Volz 2015-02-17 17:17:45 +01:00
commit b0c8b56603
192 changed files with 20695 additions and 1064 deletions

2
.gitignore vendored
View File

@ -26,7 +26,7 @@ _cgo_*
core
*-stamp
prometheus
/prometheus
benchmark.txt
.#*

56
Godeps/Godeps.json generated
View File

@ -1,27 +1,65 @@
{
"ImportPath": "github.com/prometheus/prometheus",
"GoVersion": "go1.4",
"GoVersion": "go1.4.1",
"Deps": [
{
"ImportPath": "code.google.com/p/goprotobuf/proto",
"Comment": "go.r60-152",
"Rev": "36be16571e14f67e114bb0af619e5de2c1591679"
},
{
"ImportPath": "github.com/golang/glog",
"Rev": "44145f04b68cf362d9c4df2182967c2275eaefed"
},
{
"ImportPath": "github.com/golang/protobuf/proto",
"Rev": "5677a0e3d5e89854c9974e1256839ee23f8233ca"
},
{
"ImportPath": "github.com/matttproud/golang_protobuf_extensions/ext",
"Rev": "7a864a042e844af638df17ebbabf8183dace556a"
"Rev": "ba7d65ac66e9da93a714ca18f6d1bc7a0c09100c"
},
{
"ImportPath": "github.com/miekg/dns",
"Rev": "6b75215519f9916839204d80413bb178b94ef769"
"Rev": "b65f52f3f0dd1afa25cbbf63f8e7eb15fb5c0641"
},
{
"ImportPath": "github.com/prometheus/client_golang/_vendor/goautoneg",
"Comment": "0.1.0-24-g4627d59",
"Rev": "4627d59e8a09c330c5ccfe7414baca28d8df847d"
},
{
"ImportPath": "github.com/prometheus/client_golang/_vendor/perks/quantile",
"Comment": "0.1.0-24-g4627d59",
"Rev": "4627d59e8a09c330c5ccfe7414baca28d8df847d"
},
{
"ImportPath": "github.com/prometheus/client_golang/extraction",
"Comment": "0.1.0-24-g4627d59",
"Rev": "4627d59e8a09c330c5ccfe7414baca28d8df847d"
},
{
"ImportPath": "github.com/prometheus/client_golang/model",
"Comment": "0.1.0-24-g4627d59",
"Rev": "4627d59e8a09c330c5ccfe7414baca28d8df847d"
},
{
"ImportPath": "github.com/prometheus/client_golang/prometheus",
"Comment": "0.1.0-24-g4627d59",
"Rev": "4627d59e8a09c330c5ccfe7414baca28d8df847d"
},
{
"ImportPath": "github.com/prometheus/client_golang/text",
"Comment": "0.1.0-24-g4627d59",
"Rev": "4627d59e8a09c330c5ccfe7414baca28d8df847d"
},
{
"ImportPath": "github.com/prometheus/client_model/go",
"Comment": "model-0.0.2-12-gfa8ad6f",
"Rev": "fa8ad6fec33561be4280a8f0514318c79d7f6cb6"
},
{
"ImportPath": "github.com/prometheus/procfs",
"Rev": "92faa308558161acab0ada1db048e9996ecec160"
},
{
"ImportPath": "github.com/syndtr/goleveldb/leveldb",
"Rev": "63c9e642efad852f49e20a6f90194cae112fd2ac"
"Rev": "e9e2c8f6d3b9c313fb4acaac5ab06285bcf30b04"
},
{
"ImportPath": "github.com/syndtr/gosnappy/snappy",

View File

@ -1,7 +1,7 @@
# Go support for Protocol Buffers - Google's data interchange format
#
# Copyright 2010 The Go Authors. All rights reserved.
# http://code.google.com/p/goprotobuf/
# https://github.com/golang/protobuf
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are
@ -37,4 +37,7 @@ test: install generate-test-pbs
generate-test-pbs:
make install && cd testdata && make
make install
make -C testdata
make -C proto3_proto
make

View File

@ -1,7 +1,7 @@
// Go support for Protocol Buffers - Google's data interchange format
//
// Copyright 2010 The Go Authors. All rights reserved.
// http://code.google.com/p/goprotobuf/
// https://github.com/golang/protobuf
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are
@ -45,7 +45,7 @@ import (
"time"
. "./testdata"
. "code.google.com/p/goprotobuf/proto"
. "github.com/golang/protobuf/proto"
)
var globalO *Buffer
@ -1833,6 +1833,86 @@ func fuzzUnmarshal(t *testing.T, data []byte) {
Unmarshal(data, pb)
}
func TestMapFieldMarshal(t *testing.T) {
m := &MessageWithMap{
NameMapping: map[int32]string{
1: "Rob",
4: "Ian",
8: "Dave",
},
}
b, err := Marshal(m)
if err != nil {
t.Fatalf("Marshal: %v", err)
}
// b should be the concatenation of these three byte sequences in some order.
parts := []string{
"\n\a\b\x01\x12\x03Rob",
"\n\a\b\x04\x12\x03Ian",
"\n\b\b\x08\x12\x04Dave",
}
ok := false
for i := range parts {
for j := range parts {
if j == i {
continue
}
for k := range parts {
if k == i || k == j {
continue
}
try := parts[i] + parts[j] + parts[k]
if bytes.Equal(b, []byte(try)) {
ok = true
break
}
}
}
}
if !ok {
t.Fatalf("Incorrect Marshal output.\n got %q\nwant %q (or a permutation of that)", b, parts[0]+parts[1]+parts[2])
}
t.Logf("FYI b: %q", b)
(new(Buffer)).DebugPrint("Dump of b", b)
}
func TestMapFieldRoundTrips(t *testing.T) {
m := &MessageWithMap{
NameMapping: map[int32]string{
1: "Rob",
4: "Ian",
8: "Dave",
},
MsgMapping: map[int64]*FloatingPoint{
0x7001: &FloatingPoint{F: Float64(2.0)},
},
ByteMapping: map[bool][]byte{
false: []byte("that's not right!"),
true: []byte("aye, 'tis true!"),
},
}
b, err := Marshal(m)
if err != nil {
t.Fatalf("Marshal: %v", err)
}
t.Logf("FYI b: %q", b)
m2 := new(MessageWithMap)
if err := Unmarshal(b, m2); err != nil {
t.Fatalf("Unmarshal: %v", err)
}
for _, pair := range [][2]interface{}{
{m.NameMapping, m2.NameMapping},
{m.MsgMapping, m2.MsgMapping},
{m.ByteMapping, m2.ByteMapping},
} {
if !reflect.DeepEqual(pair[0], pair[1]) {
t.Errorf("Map did not survive a round trip.\ninitial: %v\n final: %v", pair[0], pair[1])
}
}
}
// Benchmarks
func testMsg() *GoTest {

View File

@ -1,7 +1,7 @@
// Go support for Protocol Buffers - Google's data interchange format
//
// Copyright 2011 The Go Authors. All rights reserved.
// http://code.google.com/p/goprotobuf/
// https://github.com/golang/protobuf
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are
@ -29,7 +29,7 @@
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
// Protocol buffer deep copy.
// Protocol buffer deep copy and merge.
// TODO: MessageSet and RawMessage.
package proto
@ -113,6 +113,29 @@ func mergeAny(out, in reflect.Value) {
case reflect.Bool, reflect.Float32, reflect.Float64, reflect.Int32, reflect.Int64,
reflect.String, reflect.Uint32, reflect.Uint64:
out.Set(in)
case reflect.Map:
if in.Len() == 0 {
return
}
if out.IsNil() {
out.Set(reflect.MakeMap(in.Type()))
}
// For maps with value types of *T or []byte we need to deep copy each value.
elemKind := in.Type().Elem().Kind()
for _, key := range in.MapKeys() {
var val reflect.Value
switch elemKind {
case reflect.Ptr:
val = reflect.New(in.Type().Elem().Elem())
mergeAny(val, in.MapIndex(key))
case reflect.Slice:
val = in.MapIndex(key)
val = reflect.ValueOf(append([]byte{}, val.Bytes()...))
default:
val = in.MapIndex(key)
}
out.SetMapIndex(key, val)
}
case reflect.Ptr:
if in.IsNil() {
return
@ -125,6 +148,14 @@ func mergeAny(out, in reflect.Value) {
if in.IsNil() {
return
}
if in.Type().Elem().Kind() == reflect.Uint8 {
// []byte is a scalar bytes field, not a repeated field.
// Make a deep copy.
// Append to []byte{} instead of []byte(nil) so that we never end up
// with a nil result.
out.SetBytes(append([]byte{}, in.Bytes()...))
return
}
n := in.Len()
if out.IsNil() {
out.Set(reflect.MakeSlice(in.Type(), 0, n))
@ -133,9 +164,6 @@ func mergeAny(out, in reflect.Value) {
case reflect.Bool, reflect.Float32, reflect.Float64, reflect.Int32, reflect.Int64,
reflect.String, reflect.Uint32, reflect.Uint64:
out.Set(reflect.AppendSlice(out, in))
case reflect.Uint8:
// []byte is a scalar bytes field.
out.Set(in)
default:
for i := 0; i < n; i++ {
x := reflect.Indirect(reflect.New(in.Type().Elem()))

View File

@ -1,7 +1,7 @@
// Go support for Protocol Buffers - Google's data interchange format
//
// Copyright 2011 The Go Authors. All rights reserved.
// http://code.google.com/p/goprotobuf/
// https://github.com/golang/protobuf
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are
@ -34,7 +34,7 @@ package proto_test
import (
"testing"
"code.google.com/p/goprotobuf/proto"
"github.com/golang/protobuf/proto"
pb "./testdata"
)
@ -79,6 +79,22 @@ func TestClone(t *testing.T) {
if proto.Equal(m, cloneTestMessage) {
t.Error("Mutating clone changed the original")
}
// Byte fields and repeated fields should be copied.
if &m.Pet[0] == &cloneTestMessage.Pet[0] {
t.Error("Pet: repeated field not copied")
}
if &m.Others[0] == &cloneTestMessage.Others[0] {
t.Error("Others: repeated field not copied")
}
if &m.Others[0].Value[0] == &cloneTestMessage.Others[0].Value[0] {
t.Error("Others[0].Value: bytes field not copied")
}
if &m.RepBytes[0] == &cloneTestMessage.RepBytes[0] {
t.Error("RepBytes: repeated field not copied")
}
if &m.RepBytes[0][0] == &cloneTestMessage.RepBytes[0][0] {
t.Error("RepBytes[0]: bytes field not copied")
}
}
func TestCloneNil(t *testing.T) {
@ -173,6 +189,31 @@ var mergeTests = []struct {
dst: &pb.OtherMessage{Value: []byte("bar")},
want: &pb.OtherMessage{Value: []byte("foo")},
},
{
src: &pb.MessageWithMap{
NameMapping: map[int32]string{6: "Nigel"},
MsgMapping: map[int64]*pb.FloatingPoint{
0x4001: &pb.FloatingPoint{F: proto.Float64(2.0)},
},
ByteMapping: map[bool][]byte{true: []byte("wowsa")},
},
dst: &pb.MessageWithMap{
NameMapping: map[int32]string{
6: "Bruce", // should be overwritten
7: "Andrew",
},
},
want: &pb.MessageWithMap{
NameMapping: map[int32]string{
6: "Nigel",
7: "Andrew",
},
MsgMapping: map[int64]*pb.FloatingPoint{
0x4001: &pb.FloatingPoint{F: proto.Float64(2.0)},
},
ByteMapping: map[bool][]byte{true: []byte("wowsa")},
},
},
}
func TestMerge(t *testing.T) {

View File

@ -1,7 +1,7 @@
// Go support for Protocol Buffers - Google's data interchange format
//
// Copyright 2010 The Go Authors. All rights reserved.
// http://code.google.com/p/goprotobuf/
// https://github.com/golang/protobuf
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are
@ -178,7 +178,7 @@ func (p *Buffer) DecodeZigzag32() (x uint64, err error) {
func (p *Buffer) DecodeRawBytes(alloc bool) (buf []byte, err error) {
n, err := p.DecodeVarint()
if err != nil {
return
return nil, err
}
nb := int(n)
@ -362,7 +362,7 @@ func (o *Buffer) unmarshalType(st reflect.Type, prop *StructProperties, is_group
}
tag := int(u >> 3)
if tag <= 0 {
return fmt.Errorf("proto: %s: illegal tag %d", st, tag)
return fmt.Errorf("proto: %s: illegal tag %d (wire type %d)", st, tag, wire)
}
fieldnum, ok := prop.decoderTags.get(tag)
if !ok {
@ -465,6 +465,15 @@ func (o *Buffer) dec_bool(p *Properties, base structPointer) error {
return nil
}
func (o *Buffer) dec_proto3_bool(p *Properties, base structPointer) error {
u, err := p.valDec(o)
if err != nil {
return err
}
*structPointer_BoolVal(base, p.field) = u != 0
return nil
}
// Decode an int32.
func (o *Buffer) dec_int32(p *Properties, base structPointer) error {
u, err := p.valDec(o)
@ -475,6 +484,15 @@ func (o *Buffer) dec_int32(p *Properties, base structPointer) error {
return nil
}
func (o *Buffer) dec_proto3_int32(p *Properties, base structPointer) error {
u, err := p.valDec(o)
if err != nil {
return err
}
word32Val_Set(structPointer_Word32Val(base, p.field), uint32(u))
return nil
}
// Decode an int64.
func (o *Buffer) dec_int64(p *Properties, base structPointer) error {
u, err := p.valDec(o)
@ -485,15 +503,31 @@ func (o *Buffer) dec_int64(p *Properties, base structPointer) error {
return nil
}
func (o *Buffer) dec_proto3_int64(p *Properties, base structPointer) error {
u, err := p.valDec(o)
if err != nil {
return err
}
word64Val_Set(structPointer_Word64Val(base, p.field), o, u)
return nil
}
// Decode a string.
func (o *Buffer) dec_string(p *Properties, base structPointer) error {
s, err := o.DecodeStringBytes()
if err != nil {
return err
}
sp := new(string)
*sp = s
*structPointer_String(base, p.field) = sp
*structPointer_String(base, p.field) = &s
return nil
}
func (o *Buffer) dec_proto3_string(p *Properties, base structPointer) error {
s, err := o.DecodeStringBytes()
if err != nil {
return err
}
*structPointer_StringVal(base, p.field) = s
return nil
}
@ -632,6 +666,72 @@ func (o *Buffer) dec_slice_slice_byte(p *Properties, base structPointer) error {
return nil
}
// Decode a map field.
func (o *Buffer) dec_new_map(p *Properties, base structPointer) error {
raw, err := o.DecodeRawBytes(false)
if err != nil {
return err
}
oi := o.index // index at the end of this map entry
o.index -= len(raw) // move buffer back to start of map entry
mptr := structPointer_Map(base, p.field, p.mtype) // *map[K]V
if mptr.Elem().IsNil() {
mptr.Elem().Set(reflect.MakeMap(mptr.Type().Elem()))
}
v := mptr.Elem() // map[K]V
// Prepare addressable doubly-indirect placeholders for the key and value types.
// See enc_new_map for why.
keyptr := reflect.New(reflect.PtrTo(p.mtype.Key())).Elem() // addressable *K
keybase := toStructPointer(keyptr.Addr()) // **K
var valbase structPointer
var valptr reflect.Value
switch p.mtype.Elem().Kind() {
case reflect.Slice:
// []byte
var dummy []byte
valptr = reflect.ValueOf(&dummy) // *[]byte
valbase = toStructPointer(valptr) // *[]byte
case reflect.Ptr:
// message; valptr is **Msg; need to allocate the intermediate pointer
valptr = reflect.New(reflect.PtrTo(p.mtype.Elem())).Elem() // addressable *V
valptr.Set(reflect.New(valptr.Type().Elem()))
valbase = toStructPointer(valptr)
default:
// everything else
valptr = reflect.New(reflect.PtrTo(p.mtype.Elem())).Elem() // addressable *V
valbase = toStructPointer(valptr.Addr()) // **V
}
// Decode.
// This parses a restricted wire format, namely the encoding of a message
// with two fields. See enc_new_map for the format.
for o.index < oi {
// tagcode for key and value properties are always a single byte
// because they have tags 1 and 2.
tagcode := o.buf[o.index]
o.index++
switch tagcode {
case p.mkeyprop.tagcode[0]:
if err := p.mkeyprop.dec(o, p.mkeyprop, keybase); err != nil {
return err
}
case p.mvalprop.tagcode[0]:
if err := p.mvalprop.dec(o, p.mvalprop, valbase); err != nil {
return err
}
default:
// TODO: Should we silently skip this instead?
return fmt.Errorf("proto: bad map data tag %d", raw[0])
}
}
v.SetMapIndex(keyptr.Elem(), valptr.Elem())
return nil
}
// Decode a group.
func (o *Buffer) dec_struct_group(p *Properties, base structPointer) error {
bas := structPointer_GetStructPointer(base, p.field)

View File

@ -1,7 +1,7 @@
// Go support for Protocol Buffers - Google's data interchange format
//
// Copyright 2010 The Go Authors. All rights reserved.
// http://code.google.com/p/goprotobuf/
// https://github.com/golang/protobuf
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are
@ -247,7 +247,7 @@ func (p *Buffer) Marshal(pb Message) error {
return ErrNil
}
if err == nil {
err = p.enc_struct(t.Elem(), GetProperties(t.Elem()), base)
err = p.enc_struct(GetProperties(t.Elem()), base)
}
if collectStats {
@ -271,7 +271,7 @@ func Size(pb Message) (n int) {
return 0
}
if err == nil {
n = size_struct(t.Elem(), GetProperties(t.Elem()), base)
n = size_struct(GetProperties(t.Elem()), base)
}
if collectStats {
@ -298,6 +298,16 @@ func (o *Buffer) enc_bool(p *Properties, base structPointer) error {
return nil
}
func (o *Buffer) enc_proto3_bool(p *Properties, base structPointer) error {
v := *structPointer_BoolVal(base, p.field)
if !v {
return ErrNil
}
o.buf = append(o.buf, p.tagcode...)
p.valEnc(o, 1)
return nil
}
func size_bool(p *Properties, base structPointer) int {
v := *structPointer_Bool(base, p.field)
if v == nil {
@ -306,6 +316,14 @@ func size_bool(p *Properties, base structPointer) int {
return len(p.tagcode) + 1 // each bool takes exactly one byte
}
func size_proto3_bool(p *Properties, base structPointer) int {
v := *structPointer_BoolVal(base, p.field)
if !v {
return 0
}
return len(p.tagcode) + 1 // each bool takes exactly one byte
}
// Encode an int32.
func (o *Buffer) enc_int32(p *Properties, base structPointer) error {
v := structPointer_Word32(base, p.field)
@ -318,6 +336,17 @@ func (o *Buffer) enc_int32(p *Properties, base structPointer) error {
return nil
}
func (o *Buffer) enc_proto3_int32(p *Properties, base structPointer) error {
v := structPointer_Word32Val(base, p.field)
x := int32(word32Val_Get(v)) // permit sign extension to use full 64-bit range
if x == 0 {
return ErrNil
}
o.buf = append(o.buf, p.tagcode...)
p.valEnc(o, uint64(x))
return nil
}
func size_int32(p *Properties, base structPointer) (n int) {
v := structPointer_Word32(base, p.field)
if word32_IsNil(v) {
@ -329,6 +358,17 @@ func size_int32(p *Properties, base structPointer) (n int) {
return
}
func size_proto3_int32(p *Properties, base structPointer) (n int) {
v := structPointer_Word32Val(base, p.field)
x := int32(word32Val_Get(v)) // permit sign extension to use full 64-bit range
if x == 0 {
return 0
}
n += len(p.tagcode)
n += p.valSize(uint64(x))
return
}
// Encode a uint32.
// Exactly the same as int32, except for no sign extension.
func (o *Buffer) enc_uint32(p *Properties, base structPointer) error {
@ -342,6 +382,17 @@ func (o *Buffer) enc_uint32(p *Properties, base structPointer) error {
return nil
}
func (o *Buffer) enc_proto3_uint32(p *Properties, base structPointer) error {
v := structPointer_Word32Val(base, p.field)
x := word32Val_Get(v)
if x == 0 {
return ErrNil
}
o.buf = append(o.buf, p.tagcode...)
p.valEnc(o, uint64(x))
return nil
}
func size_uint32(p *Properties, base structPointer) (n int) {
v := structPointer_Word32(base, p.field)
if word32_IsNil(v) {
@ -353,6 +404,17 @@ func size_uint32(p *Properties, base structPointer) (n int) {
return
}
func size_proto3_uint32(p *Properties, base structPointer) (n int) {
v := structPointer_Word32Val(base, p.field)
x := word32Val_Get(v)
if x == 0 {
return 0
}
n += len(p.tagcode)
n += p.valSize(uint64(x))
return
}
// Encode an int64.
func (o *Buffer) enc_int64(p *Properties, base structPointer) error {
v := structPointer_Word64(base, p.field)
@ -365,6 +427,17 @@ func (o *Buffer) enc_int64(p *Properties, base structPointer) error {
return nil
}
func (o *Buffer) enc_proto3_int64(p *Properties, base structPointer) error {
v := structPointer_Word64Val(base, p.field)
x := word64Val_Get(v)
if x == 0 {
return ErrNil
}
o.buf = append(o.buf, p.tagcode...)
p.valEnc(o, x)
return nil
}
func size_int64(p *Properties, base structPointer) (n int) {
v := structPointer_Word64(base, p.field)
if word64_IsNil(v) {
@ -376,6 +449,17 @@ func size_int64(p *Properties, base structPointer) (n int) {
return
}
func size_proto3_int64(p *Properties, base structPointer) (n int) {
v := structPointer_Word64Val(base, p.field)
x := word64Val_Get(v)
if x == 0 {
return 0
}
n += len(p.tagcode)
n += p.valSize(x)
return
}
// Encode a string.
func (o *Buffer) enc_string(p *Properties, base structPointer) error {
v := *structPointer_String(base, p.field)
@ -388,6 +472,16 @@ func (o *Buffer) enc_string(p *Properties, base structPointer) error {
return nil
}
func (o *Buffer) enc_proto3_string(p *Properties, base structPointer) error {
v := *structPointer_StringVal(base, p.field)
if v == "" {
return ErrNil
}
o.buf = append(o.buf, p.tagcode...)
o.EncodeStringBytes(v)
return nil
}
func size_string(p *Properties, base structPointer) (n int) {
v := *structPointer_String(base, p.field)
if v == nil {
@ -399,6 +493,16 @@ func size_string(p *Properties, base structPointer) (n int) {
return
}
func size_proto3_string(p *Properties, base structPointer) (n int) {
v := *structPointer_StringVal(base, p.field)
if v == "" {
return 0
}
n += len(p.tagcode)
n += sizeStringBytes(v)
return
}
// All protocol buffer fields are nillable, but be careful.
func isNil(v reflect.Value) bool {
switch v.Kind() {
@ -429,7 +533,7 @@ func (o *Buffer) enc_struct_message(p *Properties, base structPointer) error {
}
o.buf = append(o.buf, p.tagcode...)
return o.enc_len_struct(p.stype, p.sprop, structp, &state)
return o.enc_len_struct(p.sprop, structp, &state)
}
func size_struct_message(p *Properties, base structPointer) int {
@ -448,7 +552,7 @@ func size_struct_message(p *Properties, base structPointer) int {
}
n0 := len(p.tagcode)
n1 := size_struct(p.stype, p.sprop, structp)
n1 := size_struct(p.sprop, structp)
n2 := sizeVarint(uint64(n1)) // size of encoded length
return n0 + n1 + n2
}
@ -462,7 +566,7 @@ func (o *Buffer) enc_struct_group(p *Properties, base structPointer) error {
}
o.EncodeVarint(uint64((p.Tag << 3) | WireStartGroup))
err := o.enc_struct(p.stype, p.sprop, b)
err := o.enc_struct(p.sprop, b)
if err != nil && !state.shouldContinue(err, nil) {
return err
}
@ -477,7 +581,7 @@ func size_struct_group(p *Properties, base structPointer) (n int) {
}
n += sizeVarint(uint64((p.Tag << 3) | WireStartGroup))
n += size_struct(p.stype, p.sprop, b)
n += size_struct(p.sprop, b)
n += sizeVarint(uint64((p.Tag << 3) | WireEndGroup))
return
}
@ -551,6 +655,16 @@ func (o *Buffer) enc_slice_byte(p *Properties, base structPointer) error {
return nil
}
func (o *Buffer) enc_proto3_slice_byte(p *Properties, base structPointer) error {
s := *structPointer_Bytes(base, p.field)
if len(s) == 0 {
return ErrNil
}
o.buf = append(o.buf, p.tagcode...)
o.EncodeRawBytes(s)
return nil
}
func size_slice_byte(p *Properties, base structPointer) (n int) {
s := *structPointer_Bytes(base, p.field)
if s == nil {
@ -561,6 +675,16 @@ func size_slice_byte(p *Properties, base structPointer) (n int) {
return
}
func size_proto3_slice_byte(p *Properties, base structPointer) (n int) {
s := *structPointer_Bytes(base, p.field)
if len(s) == 0 {
return 0
}
n += len(p.tagcode)
n += sizeRawBytes(s)
return
}
// Encode a slice of int32s ([]int32).
func (o *Buffer) enc_slice_int32(p *Properties, base structPointer) error {
s := structPointer_Word32Slice(base, p.field)
@ -831,7 +955,7 @@ func (o *Buffer) enc_slice_struct_message(p *Properties, base structPointer) err
}
o.buf = append(o.buf, p.tagcode...)
err := o.enc_len_struct(p.stype, p.sprop, structp, &state)
err := o.enc_len_struct(p.sprop, structp, &state)
if err != nil && !state.shouldContinue(err, nil) {
if err == ErrNil {
return ErrRepeatedHasNil
@ -861,7 +985,7 @@ func size_slice_struct_message(p *Properties, base structPointer) (n int) {
continue
}
n0 := size_struct(p.stype, p.sprop, structp)
n0 := size_struct(p.sprop, structp)
n1 := sizeVarint(uint64(n0)) // size of encoded length
n += n0 + n1
}
@ -882,7 +1006,7 @@ func (o *Buffer) enc_slice_struct_group(p *Properties, base structPointer) error
o.EncodeVarint(uint64((p.Tag << 3) | WireStartGroup))
err := o.enc_struct(p.stype, p.sprop, b)
err := o.enc_struct(p.sprop, b)
if err != nil && !state.shouldContinue(err, nil) {
if err == ErrNil {
@ -908,7 +1032,7 @@ func size_slice_struct_group(p *Properties, base structPointer) (n int) {
return // return size up to this point
}
n += size_struct(p.stype, p.sprop, b)
n += size_struct(p.sprop, b)
}
return
}
@ -945,12 +1069,112 @@ func size_map(p *Properties, base structPointer) int {
return sizeExtensionMap(v)
}
// Encode a map field.
func (o *Buffer) enc_new_map(p *Properties, base structPointer) error {
var state errorState // XXX: or do we need to plumb this through?
/*
A map defined as
map<key_type, value_type> map_field = N;
is encoded in the same way as
message MapFieldEntry {
key_type key = 1;
value_type value = 2;
}
repeated MapFieldEntry map_field = N;
*/
v := structPointer_Map(base, p.field, p.mtype).Elem() // map[K]V
if v.Len() == 0 {
return nil
}
keycopy, valcopy, keybase, valbase := mapEncodeScratch(p.mtype)
enc := func() error {
if err := p.mkeyprop.enc(o, p.mkeyprop, keybase); err != nil {
return err
}
if err := p.mvalprop.enc(o, p.mvalprop, valbase); err != nil {
return err
}
return nil
}
keys := v.MapKeys()
sort.Sort(mapKeys(keys))
for _, key := range keys {
val := v.MapIndex(key)
keycopy.Set(key)
valcopy.Set(val)
o.buf = append(o.buf, p.tagcode...)
if err := o.enc_len_thing(enc, &state); err != nil {
return err
}
}
return nil
}
func size_new_map(p *Properties, base structPointer) int {
v := structPointer_Map(base, p.field, p.mtype).Elem() // map[K]V
keycopy, valcopy, keybase, valbase := mapEncodeScratch(p.mtype)
n := 0
for _, key := range v.MapKeys() {
val := v.MapIndex(key)
keycopy.Set(key)
valcopy.Set(val)
// Tag codes are two bytes per map entry.
n += 2
n += p.mkeyprop.size(p.mkeyprop, keybase)
n += p.mvalprop.size(p.mvalprop, valbase)
}
return n
}
// mapEncodeScratch returns a new reflect.Value matching the map's value type,
// and a structPointer suitable for passing to an encoder or sizer.
func mapEncodeScratch(mapType reflect.Type) (keycopy, valcopy reflect.Value, keybase, valbase structPointer) {
// Prepare addressable doubly-indirect placeholders for the key and value types.
// This is needed because the element-type encoders expect **T, but the map iteration produces T.
keycopy = reflect.New(mapType.Key()).Elem() // addressable K
keyptr := reflect.New(reflect.PtrTo(keycopy.Type())).Elem() // addressable *K
keyptr.Set(keycopy.Addr()) //
keybase = toStructPointer(keyptr.Addr()) // **K
// Value types are more varied and require special handling.
switch mapType.Elem().Kind() {
case reflect.Slice:
// []byte
var dummy []byte
valcopy = reflect.ValueOf(&dummy).Elem() // addressable []byte
valbase = toStructPointer(valcopy.Addr())
case reflect.Ptr:
// message; the generated field type is map[K]*Msg (so V is *Msg),
// so we only need one level of indirection.
valcopy = reflect.New(mapType.Elem()).Elem() // addressable V
valbase = toStructPointer(valcopy.Addr())
default:
// everything else
valcopy = reflect.New(mapType.Elem()).Elem() // addressable V
valptr := reflect.New(reflect.PtrTo(valcopy.Type())).Elem() // addressable *V
valptr.Set(valcopy.Addr()) //
valbase = toStructPointer(valptr.Addr()) // **V
}
return
}
// Encode a struct.
func (o *Buffer) enc_struct(t reflect.Type, prop *StructProperties, base structPointer) error {
func (o *Buffer) enc_struct(prop *StructProperties, base structPointer) error {
var state errorState
// Encode fields in tag order so that decoders may use optimizations
// that depend on the ordering.
// http://code.google.com/apis/protocolbuffers/docs/encoding.html#order
// https://developers.google.com/protocol-buffers/docs/encoding#order
for _, i := range prop.order {
p := prop.Prop[i]
if p.enc != nil {
@ -978,7 +1202,7 @@ func (o *Buffer) enc_struct(t reflect.Type, prop *StructProperties, base structP
return state.err
}
func size_struct(t reflect.Type, prop *StructProperties, base structPointer) (n int) {
func size_struct(prop *StructProperties, base structPointer) (n int) {
for _, i := range prop.order {
p := prop.Prop[i]
if p.size != nil {
@ -998,11 +1222,16 @@ func size_struct(t reflect.Type, prop *StructProperties, base structPointer) (n
var zeroes [20]byte // longer than any conceivable sizeVarint
// Encode a struct, preceded by its encoded length (as a varint).
func (o *Buffer) enc_len_struct(t reflect.Type, prop *StructProperties, base structPointer, state *errorState) error {
func (o *Buffer) enc_len_struct(prop *StructProperties, base structPointer, state *errorState) error {
return o.enc_len_thing(func() error { return o.enc_struct(prop, base) }, state)
}
// Encode something, preceded by its encoded length (as a varint).
func (o *Buffer) enc_len_thing(enc func() error, state *errorState) error {
iLen := len(o.buf)
o.buf = append(o.buf, 0, 0, 0, 0) // reserve four bytes for length
iMsg := len(o.buf)
err := o.enc_struct(t, prop, base)
err := enc()
if err != nil && !state.shouldContinue(err, nil) {
return err
}

View File

@ -1,7 +1,7 @@
// Go support for Protocol Buffers - Google's data interchange format
//
// Copyright 2011 The Go Authors. All rights reserved.
// http://code.google.com/p/goprotobuf/
// https://github.com/golang/protobuf
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are
@ -57,7 +57,7 @@ Equality is defined in this way:
although represented by []byte, is not a repeated field)
- Two unset fields are equal.
- Two unknown field sets are equal if their current
encoded state is equal. (TODO)
encoded state is equal.
- Two extension sets are equal iff they have corresponding
elements that are pairwise equal.
- Every other combination of things are not equal.
@ -154,6 +154,21 @@ func equalAny(v1, v2 reflect.Value) bool {
return v1.Float() == v2.Float()
case reflect.Int32, reflect.Int64:
return v1.Int() == v2.Int()
case reflect.Map:
if v1.Len() != v2.Len() {
return false
}
for _, key := range v1.MapKeys() {
val2 := v2.MapIndex(key)
if !val2.IsValid() {
// This key was not found in the second map.
return false
}
if !equalAny(v1.MapIndex(key), val2) {
return false
}
}
return true
case reflect.Ptr:
return equalAny(v1.Elem(), v2.Elem())
case reflect.Slice:

View File

@ -1,7 +1,7 @@
// Go support for Protocol Buffers - Google's data interchange format
//
// Copyright 2011 The Go Authors. All rights reserved.
// http://code.google.com/p/goprotobuf/
// https://github.com/golang/protobuf
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are
@ -35,7 +35,7 @@ import (
"testing"
pb "./testdata"
. "code.google.com/p/goprotobuf/proto"
. "github.com/golang/protobuf/proto"
)
// Four identical base messages.
@ -155,6 +155,31 @@ var EqualTests = []struct {
},
true,
},
{
"map same",
&pb.MessageWithMap{NameMapping: map[int32]string{1: "Ken"}},
&pb.MessageWithMap{NameMapping: map[int32]string{1: "Ken"}},
true,
},
{
"map different entry",
&pb.MessageWithMap{NameMapping: map[int32]string{1: "Ken"}},
&pb.MessageWithMap{NameMapping: map[int32]string{2: "Rob"}},
false,
},
{
"map different key only",
&pb.MessageWithMap{NameMapping: map[int32]string{1: "Ken"}},
&pb.MessageWithMap{NameMapping: map[int32]string{2: "Ken"}},
false,
},
{
"map different value only",
&pb.MessageWithMap{NameMapping: map[int32]string{1: "Ken"}},
&pb.MessageWithMap{NameMapping: map[int32]string{1: "Rob"}},
false,
},
}
func TestEqual(t *testing.T) {

View File

@ -1,7 +1,7 @@
// Go support for Protocol Buffers - Google's data interchange format
//
// Copyright 2010 The Go Authors. All rights reserved.
// http://code.google.com/p/goprotobuf/
// https://github.com/golang/protobuf
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are
@ -227,7 +227,8 @@ func GetExtension(pb extendableProto, extension *ExtensionDesc) (interface{}, er
return nil, err
}
e, ok := pb.ExtensionMap()[extension.Field]
emap := pb.ExtensionMap()
e, ok := emap[extension.Field]
if !ok {
return nil, ErrMissingExtension
}
@ -252,6 +253,7 @@ func GetExtension(pb extendableProto, extension *ExtensionDesc) (interface{}, er
e.value = v
e.desc = extension
e.enc = nil
emap[extension.Field] = e
return e.value, nil
}

View File

@ -0,0 +1,137 @@
// Go support for Protocol Buffers - Google's data interchange format
//
// Copyright 2014 The Go Authors. All rights reserved.
// https://github.com/golang/protobuf
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are
// met:
//
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above
// copyright notice, this list of conditions and the following disclaimer
// in the documentation and/or other materials provided with the
// distribution.
// * Neither the name of Google Inc. nor the names of its
// contributors may be used to endorse or promote products derived from
// this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
package proto_test
import (
"testing"
pb "./testdata"
"github.com/golang/protobuf/proto"
)
func TestGetExtensionsWithMissingExtensions(t *testing.T) {
msg := &pb.MyMessage{}
ext1 := &pb.Ext{}
if err := proto.SetExtension(msg, pb.E_Ext_More, ext1); err != nil {
t.Fatalf("Could not set ext1: %s", ext1)
}
exts, err := proto.GetExtensions(msg, []*proto.ExtensionDesc{
pb.E_Ext_More,
pb.E_Ext_Text,
})
if err != nil {
t.Fatalf("GetExtensions() failed: %s", err)
}
if exts[0] != ext1 {
t.Errorf("ext1 not in returned extensions: %T %v", exts[0], exts[0])
}
if exts[1] != nil {
t.Errorf("ext2 in returned extensions: %T %v", exts[1], exts[1])
}
}
func TestGetExtensionStability(t *testing.T) {
check := func(m *pb.MyMessage) bool {
ext1, err := proto.GetExtension(m, pb.E_Ext_More)
if err != nil {
t.Fatalf("GetExtension() failed: %s", err)
}
ext2, err := proto.GetExtension(m, pb.E_Ext_More)
if err != nil {
t.Fatalf("GetExtension() failed: %s", err)
}
return ext1 == ext2
}
msg := &pb.MyMessage{Count: proto.Int32(4)}
ext0 := &pb.Ext{}
if err := proto.SetExtension(msg, pb.E_Ext_More, ext0); err != nil {
t.Fatalf("Could not set ext1: %s", ext0)
}
if !check(msg) {
t.Errorf("GetExtension() not stable before marshaling")
}
bb, err := proto.Marshal(msg)
if err != nil {
t.Fatalf("Marshal() failed: %s", err)
}
msg1 := &pb.MyMessage{}
err = proto.Unmarshal(bb, msg1)
if err != nil {
t.Fatalf("Unmarshal() failed: %s", err)
}
if !check(msg1) {
t.Errorf("GetExtension() not stable after unmarshaling")
}
}
func TestExtensionsRoundTrip(t *testing.T) {
msg := &pb.MyMessage{}
ext1 := &pb.Ext{
Data: proto.String("hi"),
}
ext2 := &pb.Ext{
Data: proto.String("there"),
}
exists := proto.HasExtension(msg, pb.E_Ext_More)
if exists {
t.Error("Extension More present unexpectedly")
}
if err := proto.SetExtension(msg, pb.E_Ext_More, ext1); err != nil {
t.Error(err)
}
if err := proto.SetExtension(msg, pb.E_Ext_More, ext2); err != nil {
t.Error(err)
}
e, err := proto.GetExtension(msg, pb.E_Ext_More)
if err != nil {
t.Error(err)
}
x, ok := e.(*pb.Ext)
if !ok {
t.Errorf("e has type %T, expected testdata.Ext", e)
} else if *x.Data != "there" {
t.Errorf("SetExtension failed to overwrite, got %+v, not 'there'", x)
}
proto.ClearExtension(msg, pb.E_Ext_More)
if _, err = proto.GetExtension(msg, pb.E_Ext_More); err != proto.ErrMissingExtension {
t.Errorf("got %v, expected ErrMissingExtension", e)
}
if _, err := proto.GetExtension(msg, pb.E_X215); err == nil {
t.Error("expected bad extension error, got nil")
}
if err := proto.SetExtension(msg, pb.E_X215, 12); err == nil {
t.Error("expected extension err")
}
if err := proto.SetExtension(msg, pb.E_Ext_More, 12); err == nil {
t.Error("expected some sort of type mismatch error, got nil")
}
}

View File

@ -1,7 +1,7 @@
// Go support for Protocol Buffers - Google's data interchange format
//
// Copyright 2010 The Go Authors. All rights reserved.
// http://code.google.com/p/goprotobuf/
// https://github.com/golang/protobuf
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are
@ -50,17 +50,16 @@
That is, optional or required field int32 f becomes F *int32.
- Repeated fields are slices.
- Helper functions are available to aid the setting of fields.
Helpers for getting values are superseded by the
GetFoo methods and their use is deprecated.
msg.Foo = proto.String("hello") // set field
- Constants are defined to hold the default values of all fields that
have them. They have the form Default_StructName_FieldName.
Because the getter methods handle defaulted values,
direct use of these constants should be rare.
- Enums are given type names and maps from names to values.
Enum values are prefixed with the enum's type name. Enum types have
a String method, and a Enum method to assist in message construction.
- Nested groups and enums have type names prefixed with the name of
Enum values are prefixed by the enclosing message's name, or by the
enum's type name if it is a top-level enum. Enum types have a String
method, and a Enum method to assist in message construction.
- Nested messages, groups and enums have type names prefixed with the name of
the surrounding message type.
- Extensions are given descriptor names that start with E_,
followed by an underscore-delimited list of the nested messages
@ -74,7 +73,7 @@
package example;
enum FOO { X = 17; };
enum FOO { X = 17; }
message Test {
required string label = 1;
@ -89,7 +88,8 @@
package example
import "code.google.com/p/goprotobuf/proto"
import proto "github.com/golang/protobuf/proto"
import math "math"
type FOO int32
const (
@ -110,6 +110,14 @@
func (x FOO) String() string {
return proto.EnumName(FOO_name, int32(x))
}
func (x *FOO) UnmarshalJSON(data []byte) error {
value, err := proto.UnmarshalJSONEnum(FOO_value, data)
if err != nil {
return err
}
*x = FOO(value)
return nil
}
type Test struct {
Label *string `protobuf:"bytes,1,req,name=label" json:"label,omitempty"`
@ -118,41 +126,41 @@
Optionalgroup *Test_OptionalGroup `protobuf:"group,4,opt,name=OptionalGroup" json:"optionalgroup,omitempty"`
XXX_unrecognized []byte `json:"-"`
}
func (this *Test) Reset() { *this = Test{} }
func (this *Test) String() string { return proto.CompactTextString(this) }
func (m *Test) Reset() { *m = Test{} }
func (m *Test) String() string { return proto.CompactTextString(m) }
func (*Test) ProtoMessage() {}
const Default_Test_Type int32 = 77
func (this *Test) GetLabel() string {
if this != nil && this.Label != nil {
return *this.Label
func (m *Test) GetLabel() string {
if m != nil && m.Label != nil {
return *m.Label
}
return ""
}
func (this *Test) GetType() int32 {
if this != nil && this.Type != nil {
return *this.Type
func (m *Test) GetType() int32 {
if m != nil && m.Type != nil {
return *m.Type
}
return Default_Test_Type
}
func (this *Test) GetOptionalgroup() *Test_OptionalGroup {
if this != nil {
return this.Optionalgroup
func (m *Test) GetOptionalgroup() *Test_OptionalGroup {
if m != nil {
return m.Optionalgroup
}
return nil
}
type Test_OptionalGroup struct {
RequiredField *string `protobuf:"bytes,5,req" json:"RequiredField,omitempty"`
XXX_unrecognized []byte `json:"-"`
}
func (this *Test_OptionalGroup) Reset() { *this = Test_OptionalGroup{} }
func (this *Test_OptionalGroup) String() string { return proto.CompactTextString(this) }
func (m *Test_OptionalGroup) Reset() { *m = Test_OptionalGroup{} }
func (m *Test_OptionalGroup) String() string { return proto.CompactTextString(m) }
func (this *Test_OptionalGroup) GetRequiredField() string {
if this != nil && this.RequiredField != nil {
return *this.RequiredField
func (m *Test_OptionalGroup) GetRequiredField() string {
if m != nil && m.RequiredField != nil {
return *m.RequiredField
}
return ""
}
@ -168,15 +176,15 @@
import (
"log"
"code.google.com/p/goprotobuf/proto"
"./example.pb"
"github.com/golang/protobuf/proto"
pb "./example.pb"
)
func main() {
test := &example.Test{
test := &pb.Test{
Label: proto.String("hello"),
Type: proto.Int32(17),
Optionalgroup: &example.Test_OptionalGroup{
Optionalgroup: &pb.Test_OptionalGroup{
RequiredField: proto.String("good bye"),
},
}
@ -184,7 +192,7 @@
if err != nil {
log.Fatal("marshaling error: ", err)
}
newTest := new(example.Test)
newTest := &pb.Test{}
err = proto.Unmarshal(data, newTest)
if err != nil {
log.Fatal("unmarshaling error: ", err)
@ -323,9 +331,7 @@ func Float64(v float64) *float64 {
// Uint32 is a helper routine that allocates a new uint32 value
// to store v and returns a pointer to it.
func Uint32(v uint32) *uint32 {
p := new(uint32)
*p = v
return p
return &v
}
// Uint64 is a helper routine that allocates a new uint64 value
@ -738,3 +744,16 @@ func buildDefaultMessage(t reflect.Type) (dm defaultMessage) {
return dm
}
// Map fields may have key types of non-float scalars, strings and enums.
// The easiest way to sort them in some deterministic order is to use fmt.
// If this turns out to be inefficient we can always consider other options,
// such as doing a Schwartzian transform.
type mapKeys []reflect.Value
func (s mapKeys) Len() int { return len(s) }
func (s mapKeys) Swap(i, j int) { s[i], s[j] = s[j], s[i] }
func (s mapKeys) Less(i, j int) bool {
return fmt.Sprint(s[i].Interface()) < fmt.Sprint(s[j].Interface())
}

View File

@ -1,7 +1,7 @@
// Go support for Protocol Buffers - Google's data interchange format
//
// Copyright 2010 The Go Authors. All rights reserved.
// http://code.google.com/p/goprotobuf/
// https://github.com/golang/protobuf
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are
@ -36,7 +36,10 @@ package proto
*/
import (
"bytes"
"encoding/json"
"errors"
"fmt"
"reflect"
"sort"
)
@ -211,6 +214,61 @@ func UnmarshalMessageSet(buf []byte, m map[int32]Extension) error {
return nil
}
// MarshalMessageSetJSON encodes the extension map represented by m in JSON format.
// It is called by generated MarshalJSON methods on protocol buffer messages with the message_set_wire_format option.
func MarshalMessageSetJSON(m map[int32]Extension) ([]byte, error) {
var b bytes.Buffer
b.WriteByte('{')
// Process the map in key order for deterministic output.
ids := make([]int32, 0, len(m))
for id := range m {
ids = append(ids, id)
}
sort.Sort(int32Slice(ids)) // int32Slice defined in text.go
for i, id := range ids {
ext := m[id]
if i > 0 {
b.WriteByte(',')
}
msd, ok := messageSetMap[id]
if !ok {
// Unknown type; we can't render it, so skip it.
continue
}
fmt.Fprintf(&b, `"[%s]":`, msd.name)
x := ext.value
if x == nil {
x = reflect.New(msd.t.Elem()).Interface()
if err := Unmarshal(ext.enc, x.(Message)); err != nil {
return nil, err
}
}
d, err := json.Marshal(x)
if err != nil {
return nil, err
}
b.Write(d)
}
b.WriteByte('}')
return b.Bytes(), nil
}
// UnmarshalMessageSetJSON decodes the extension map encoded in buf in JSON format.
// It is called by generated UnmarshalJSON methods on protocol buffer messages with the message_set_wire_format option.
func UnmarshalMessageSetJSON(buf []byte, m map[int32]Extension) error {
// Common-case fast path.
if len(buf) == 0 || bytes.Equal(buf, []byte("{}")) {
return nil
}
// This is fairly tricky, and it's not clear that it is needed.
return errors.New("TODO: UnmarshalMessageSetJSON not yet implemented")
}
// A global registry of types that can be used in a MessageSet.
var messageSetMap = make(map[int32]messageSetDesc)
@ -221,9 +279,9 @@ type messageSetDesc struct {
}
// RegisterMessageSetType is called from the generated code.
func RegisterMessageSetType(i messageTypeIder, name string) {
messageSetMap[i.MessageTypeId()] = messageSetDesc{
t: reflect.TypeOf(i),
func RegisterMessageSetType(m Message, fieldNum int32, name string) {
messageSetMap[fieldNum] = messageSetDesc{
t: reflect.TypeOf(m),
name: name,
}
}

View File

@ -1,7 +1,7 @@
// Go support for Protocol Buffers - Google's data interchange format
//
// Copyright 2014 The Go Authors. All rights reserved.
// http://code.google.com/p/goprotobuf/
// https://github.com/golang/protobuf
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are

View File

@ -1,7 +1,7 @@
// Go support for Protocol Buffers - Google's data interchange format
//
// Copyright 2012 The Go Authors. All rights reserved.
// http://code.google.com/p/goprotobuf/
// https://github.com/golang/protobuf
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are
@ -29,7 +29,7 @@
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
// +build appengine,!appenginevm
// +build appengine
// This file contains an implementation of proto field accesses using package reflect.
// It is slower than the code in pointer_unsafe.go but it avoids package unsafe and can
@ -114,6 +114,11 @@ func structPointer_Bool(p structPointer, f field) **bool {
return structPointer_ifield(p, f).(**bool)
}
// BoolVal returns the address of a bool field in the struct.
func structPointer_BoolVal(p structPointer, f field) *bool {
return structPointer_ifield(p, f).(*bool)
}
// BoolSlice returns the address of a []bool field in the struct.
func structPointer_BoolSlice(p structPointer, f field) *[]bool {
return structPointer_ifield(p, f).(*[]bool)
@ -124,6 +129,11 @@ func structPointer_String(p structPointer, f field) **string {
return structPointer_ifield(p, f).(**string)
}
// StringVal returns the address of a string field in the struct.
func structPointer_StringVal(p structPointer, f field) *string {
return structPointer_ifield(p, f).(*string)
}
// StringSlice returns the address of a []string field in the struct.
func structPointer_StringSlice(p structPointer, f field) *[]string {
return structPointer_ifield(p, f).(*[]string)
@ -134,6 +144,11 @@ func structPointer_ExtMap(p structPointer, f field) *map[int32]Extension {
return structPointer_ifield(p, f).(*map[int32]Extension)
}
// Map returns the reflect.Value for the address of a map field in the struct.
func structPointer_Map(p structPointer, f field, typ reflect.Type) reflect.Value {
return structPointer_field(p, f).Addr()
}
// SetStructPointer writes a *struct field in the struct.
func structPointer_SetStructPointer(p structPointer, f field, q structPointer) {
structPointer_field(p, f).Set(q.v)
@ -235,6 +250,49 @@ func structPointer_Word32(p structPointer, f field) word32 {
return word32{structPointer_field(p, f)}
}
// A word32Val represents a field of type int32, uint32, float32, or enum.
// That is, v.Type() is int32, uint32, float32, or enum and v is assignable.
type word32Val struct {
v reflect.Value
}
// Set sets *p to x.
func word32Val_Set(p word32Val, x uint32) {
switch p.v.Type() {
case int32Type:
p.v.SetInt(int64(x))
return
case uint32Type:
p.v.SetUint(uint64(x))
return
case float32Type:
p.v.SetFloat(float64(math.Float32frombits(x)))
return
}
// must be enum
p.v.SetInt(int64(int32(x)))
}
// Get gets the bits pointed at by p, as a uint32.
func word32Val_Get(p word32Val) uint32 {
elem := p.v
switch elem.Kind() {
case reflect.Int32:
return uint32(elem.Int())
case reflect.Uint32:
return uint32(elem.Uint())
case reflect.Float32:
return math.Float32bits(float32(elem.Float()))
}
panic("unreachable")
}
// Word32Val returns a reference to a int32, uint32, float32, or enum field in the struct.
func structPointer_Word32Val(p structPointer, f field) word32Val {
return word32Val{structPointer_field(p, f)}
}
// A word32Slice is a slice of 32-bit values.
// That is, v.Type() is []int32, []uint32, []float32, or []enum.
type word32Slice struct {
@ -339,6 +397,43 @@ func structPointer_Word64(p structPointer, f field) word64 {
return word64{structPointer_field(p, f)}
}
// word64Val is like word32Val but for 64-bit values.
type word64Val struct {
v reflect.Value
}
func word64Val_Set(p word64Val, o *Buffer, x uint64) {
switch p.v.Type() {
case int64Type:
p.v.SetInt(int64(x))
return
case uint64Type:
p.v.SetUint(x)
return
case float64Type:
p.v.SetFloat(math.Float64frombits(x))
return
}
panic("unreachable")
}
func word64Val_Get(p word64Val) uint64 {
elem := p.v
switch elem.Kind() {
case reflect.Int64:
return uint64(elem.Int())
case reflect.Uint64:
return elem.Uint()
case reflect.Float64:
return math.Float64bits(elem.Float())
}
panic("unreachable")
}
func structPointer_Word64Val(p structPointer, f field) word64Val {
return word64Val{structPointer_field(p, f)}
}
type word64Slice struct {
v reflect.Value
}

View File

@ -1,7 +1,7 @@
// Go support for Protocol Buffers - Google's data interchange format
//
// Copyright 2012 The Go Authors. All rights reserved.
// http://code.google.com/p/goprotobuf/
// https://github.com/golang/protobuf
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are
@ -29,7 +29,7 @@
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
// +build !appengine appenginevm
// +build !appengine
// This file contains the implementation of the proto field accesses using package unsafe.
@ -100,6 +100,11 @@ func structPointer_Bool(p structPointer, f field) **bool {
return (**bool)(unsafe.Pointer(uintptr(p) + uintptr(f)))
}
// BoolVal returns the address of a bool field in the struct.
func structPointer_BoolVal(p structPointer, f field) *bool {
return (*bool)(unsafe.Pointer(uintptr(p) + uintptr(f)))
}
// BoolSlice returns the address of a []bool field in the struct.
func structPointer_BoolSlice(p structPointer, f field) *[]bool {
return (*[]bool)(unsafe.Pointer(uintptr(p) + uintptr(f)))
@ -110,6 +115,11 @@ func structPointer_String(p structPointer, f field) **string {
return (**string)(unsafe.Pointer(uintptr(p) + uintptr(f)))
}
// StringVal returns the address of a string field in the struct.
func structPointer_StringVal(p structPointer, f field) *string {
return (*string)(unsafe.Pointer(uintptr(p) + uintptr(f)))
}
// StringSlice returns the address of a []string field in the struct.
func structPointer_StringSlice(p structPointer, f field) *[]string {
return (*[]string)(unsafe.Pointer(uintptr(p) + uintptr(f)))
@ -120,6 +130,11 @@ func structPointer_ExtMap(p structPointer, f field) *map[int32]Extension {
return (*map[int32]Extension)(unsafe.Pointer(uintptr(p) + uintptr(f)))
}
// Map returns the reflect.Value for the address of a map field in the struct.
func structPointer_Map(p structPointer, f field, typ reflect.Type) reflect.Value {
return reflect.NewAt(typ, unsafe.Pointer(uintptr(p)+uintptr(f)))
}
// SetStructPointer writes a *struct field in the struct.
func structPointer_SetStructPointer(p structPointer, f field, q structPointer) {
*(*structPointer)(unsafe.Pointer(uintptr(p) + uintptr(f))) = q
@ -170,6 +185,24 @@ func structPointer_Word32(p structPointer, f field) word32 {
return word32((**uint32)(unsafe.Pointer(uintptr(p) + uintptr(f))))
}
// A word32Val is the address of a 32-bit value field.
type word32Val *uint32
// Set sets *p to x.
func word32Val_Set(p word32Val, x uint32) {
*p = x
}
// Get gets the value pointed at by p.
func word32Val_Get(p word32Val) uint32 {
return *p
}
// Word32Val returns the address of a *int32, *uint32, *float32, or *enum field in the struct.
func structPointer_Word32Val(p structPointer, f field) word32Val {
return word32Val((*uint32)(unsafe.Pointer(uintptr(p) + uintptr(f))))
}
// A word32Slice is a slice of 32-bit values.
type word32Slice []uint32
@ -206,6 +239,21 @@ func structPointer_Word64(p structPointer, f field) word64 {
return word64((**uint64)(unsafe.Pointer(uintptr(p) + uintptr(f))))
}
// word64Val is like word32Val but for 64-bit values.
type word64Val *uint64
func word64Val_Set(p word64Val, o *Buffer, x uint64) {
*p = x
}
func word64Val_Get(p word64Val) uint64 {
return *p
}
func structPointer_Word64Val(p structPointer, f field) word64Val {
return word64Val((*uint64)(unsafe.Pointer(uintptr(p) + uintptr(f))))
}
// word64Slice is like word32Slice but for 64-bit values.
type word64Slice []uint64

View File

@ -1,7 +1,7 @@
// Go support for Protocol Buffers - Google's data interchange format
//
// Copyright 2010 The Go Authors. All rights reserved.
// http://code.google.com/p/goprotobuf/
// https://github.com/golang/protobuf
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are
@ -155,6 +155,7 @@ type Properties struct {
Repeated bool
Packed bool // relevant for repeated primitives only
Enum string // set for enum types only
proto3 bool // whether this is known to be a proto3 field; set for []byte only
Default string // default value
HasDefault bool // whether an explicit default was provided
@ -170,6 +171,10 @@ type Properties struct {
isMarshaler bool
isUnmarshaler bool
mtype reflect.Type // set for map types only
mkeyprop *Properties // set for map types only
mvalprop *Properties // set for map types only
size sizer
valSize valueSizer // set for bool and numeric types only
@ -200,6 +205,9 @@ func (p *Properties) String() string {
if p.OrigName != p.Name {
s += ",name=" + p.OrigName
}
if p.proto3 {
s += ",proto3"
}
if len(p.Enum) > 0 {
s += ",enum=" + p.Enum
}
@ -274,6 +282,8 @@ func (p *Properties) Parse(s string) {
p.OrigName = f[5:]
case strings.HasPrefix(f, "enum="):
p.Enum = f[5:]
case f == "proto3":
p.proto3 = true
case strings.HasPrefix(f, "def="):
p.HasDefault = true
p.Default = f[4:] // rest of string
@ -293,19 +303,50 @@ func logNoSliceEnc(t1, t2 reflect.Type) {
var protoMessageType = reflect.TypeOf((*Message)(nil)).Elem()
// Initialize the fields for encoding and decoding.
func (p *Properties) setEncAndDec(typ reflect.Type, lockGetProp bool) {
func (p *Properties) setEncAndDec(typ reflect.Type, f *reflect.StructField, lockGetProp bool) {
p.enc = nil
p.dec = nil
p.size = nil
switch t1 := typ; t1.Kind() {
default:
fmt.Fprintf(os.Stderr, "proto: no coders for %T\n", t1)
fmt.Fprintf(os.Stderr, "proto: no coders for %v\n", t1)
// proto3 scalar types
case reflect.Bool:
p.enc = (*Buffer).enc_proto3_bool
p.dec = (*Buffer).dec_proto3_bool
p.size = size_proto3_bool
case reflect.Int32:
p.enc = (*Buffer).enc_proto3_int32
p.dec = (*Buffer).dec_proto3_int32
p.size = size_proto3_int32
case reflect.Uint32:
p.enc = (*Buffer).enc_proto3_uint32
p.dec = (*Buffer).dec_proto3_int32 // can reuse
p.size = size_proto3_uint32
case reflect.Int64, reflect.Uint64:
p.enc = (*Buffer).enc_proto3_int64
p.dec = (*Buffer).dec_proto3_int64
p.size = size_proto3_int64
case reflect.Float32:
p.enc = (*Buffer).enc_proto3_uint32 // can just treat them as bits
p.dec = (*Buffer).dec_proto3_int32
p.size = size_proto3_uint32
case reflect.Float64:
p.enc = (*Buffer).enc_proto3_int64 // can just treat them as bits
p.dec = (*Buffer).dec_proto3_int64
p.size = size_proto3_int64
case reflect.String:
p.enc = (*Buffer).enc_proto3_string
p.dec = (*Buffer).dec_proto3_string
p.size = size_proto3_string
case reflect.Ptr:
switch t2 := t1.Elem(); t2.Kind() {
default:
fmt.Fprintf(os.Stderr, "proto: no encoder function for %T -> %T\n", t1, t2)
fmt.Fprintf(os.Stderr, "proto: no encoder function for %v -> %v\n", t1, t2)
break
case reflect.Bool:
p.enc = (*Buffer).enc_bool
@ -399,6 +440,10 @@ func (p *Properties) setEncAndDec(typ reflect.Type, lockGetProp bool) {
p.enc = (*Buffer).enc_slice_byte
p.dec = (*Buffer).dec_slice_byte
p.size = size_slice_byte
if p.proto3 {
p.enc = (*Buffer).enc_proto3_slice_byte
p.size = size_proto3_slice_byte
}
case reflect.Float32, reflect.Float64:
switch t2.Bits() {
case 32:
@ -461,6 +506,23 @@ func (p *Properties) setEncAndDec(typ reflect.Type, lockGetProp bool) {
p.size = size_slice_slice_byte
}
}
case reflect.Map:
p.enc = (*Buffer).enc_new_map
p.dec = (*Buffer).dec_new_map
p.size = size_new_map
p.mtype = t1
p.mkeyprop = &Properties{}
p.mkeyprop.init(reflect.PtrTo(p.mtype.Key()), "Key", f.Tag.Get("protobuf_key"), nil, lockGetProp)
p.mvalprop = &Properties{}
vtype := p.mtype.Elem()
if vtype.Kind() != reflect.Ptr && vtype.Kind() != reflect.Slice {
// The value type is not a message (*T) or bytes ([]byte),
// so we need encoders for the pointer to this type.
vtype = reflect.PtrTo(vtype)
}
p.mvalprop.init(vtype, "Value", f.Tag.Get("protobuf_val"), nil, lockGetProp)
}
// precalculate tag code
@ -529,7 +591,7 @@ func (p *Properties) init(typ reflect.Type, name, tag string, f *reflect.StructF
return
}
p.Parse(tag)
p.setEncAndDec(typ, lockGetProp)
p.setEncAndDec(typ, f, lockGetProp)
}
var (
@ -538,7 +600,11 @@ var (
)
// GetProperties returns the list of properties for the type represented by t.
// t must represent a generated struct type of a protocol message.
func GetProperties(t reflect.Type) *StructProperties {
if t.Kind() != reflect.Struct {
panic("proto: type must have kind struct")
}
mutex.Lock()
sprop := getPropertiesLocked(t)
mutex.Unlock()

View File

@ -0,0 +1,44 @@
# Go support for Protocol Buffers - Google's data interchange format
#
# Copyright 2014 The Go Authors. All rights reserved.
# https://github.com/golang/protobuf
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are
# met:
#
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above
# copyright notice, this list of conditions and the following disclaimer
# in the documentation and/or other materials provided with the
# distribution.
# * Neither the name of Google Inc. nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
include ../../Make.protobuf
all: regenerate
regenerate:
rm -f proto3.pb.go
make proto3.pb.go
# The following rules are just aids to development. Not needed for typical testing.
diff: regenerate
git diff proto3.pb.go

View File

@ -1,7 +1,7 @@
// Go support for Protocol Buffers - Google's data interchange format
//
// Copyright 2014 The Go Authors. All rights reserved.
// http://code.google.com/p/goprotobuf/
// https://github.com/golang/protobuf
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are
@ -29,32 +29,30 @@
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
package proto_test
syntax = "proto3";
import (
"testing"
package proto3_proto;
pb "./testdata"
"code.google.com/p/goprotobuf/proto"
)
message Message {
enum Humour {
UNKNOWN = 0;
PUNS = 1;
SLAPSTICK = 2;
BILL_BAILEY = 3;
}
func TestGetExtensionsWithMissingExtensions(t *testing.T) {
msg := &pb.MyMessage{}
ext1 := &pb.Ext{}
if err := proto.SetExtension(msg, pb.E_Ext_More, ext1); err != nil {
t.Fatalf("Could not set ext1: %s", ext1)
}
exts, err := proto.GetExtensions(msg, []*proto.ExtensionDesc{
pb.E_Ext_More,
pb.E_Ext_Text,
})
if err != nil {
t.Fatalf("GetExtensions() failed: %s", err)
}
if exts[0] != ext1 {
t.Errorf("ext1 not in returned extensions: %T %v", exts[0], exts[0])
}
if exts[1] != nil {
t.Errorf("ext2 in returned extensions: %T %v", exts[1], exts[1])
string name = 1;
Humour hilarity = 2;
uint32 height_in_cm = 3;
bytes data = 4;
int64 result_count = 7;
bool true_scotsman = 8;
float score = 9;
repeated uint64 key = 5;
Nested nested = 6;
}
message Nested {
string bunny = 1;
}

View File

@ -0,0 +1,93 @@
// Go support for Protocol Buffers - Google's data interchange format
//
// Copyright 2014 The Go Authors. All rights reserved.
// https://github.com/golang/protobuf
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are
// met:
//
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above
// copyright notice, this list of conditions and the following disclaimer
// in the documentation and/or other materials provided with the
// distribution.
// * Neither the name of Google Inc. nor the names of its
// contributors may be used to endorse or promote products derived from
// this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
package proto_test
import (
"testing"
pb "./proto3_proto"
"github.com/golang/protobuf/proto"
)
func TestProto3ZeroValues(t *testing.T) {
tests := []struct {
desc string
m proto.Message
}{
{"zero message", &pb.Message{}},
{"empty bytes field", &pb.Message{Data: []byte{}}},
}
for _, test := range tests {
b, err := proto.Marshal(test.m)
if err != nil {
t.Errorf("%s: proto.Marshal: %v", test.desc, err)
continue
}
if len(b) > 0 {
t.Errorf("%s: Encoding is non-empty: %q", test.desc, b)
}
}
}
func TestRoundTripProto3(t *testing.T) {
m := &pb.Message{
Name: "David", // (2 | 1<<3): 0x0a 0x05 "David"
Hilarity: pb.Message_PUNS, // (0 | 2<<3): 0x10 0x01
HeightInCm: 178, // (0 | 3<<3): 0x18 0xb2 0x01
Data: []byte("roboto"), // (2 | 4<<3): 0x20 0x06 "roboto"
ResultCount: 47, // (0 | 7<<3): 0x38 0x2f
TrueScotsman: true, // (0 | 8<<3): 0x40 0x01
Score: 8.1, // (5 | 9<<3): 0x4d <8.1>
Key: []uint64{1, 0xdeadbeef},
Nested: &pb.Nested{
Bunny: "Monty",
},
}
t.Logf(" m: %v", m)
b, err := proto.Marshal(m)
if err != nil {
t.Fatalf("proto.Marshal: %v", err)
}
t.Logf(" b: %q", b)
m2 := new(pb.Message)
if err := proto.Unmarshal(b, m2); err != nil {
t.Fatalf("proto.Unmarshal: %v", err)
}
t.Logf("m2: %v", m2)
if !proto.Equal(m, m2) {
t.Errorf("proto.Equal returned false:\n m: %v\nm2: %v", m, m2)
}
}

View File

@ -1,7 +1,7 @@
// Go support for Protocol Buffers - Google's data interchange format
//
// Copyright 2012 The Go Authors. All rights reserved.
// http://code.google.com/p/goprotobuf/
// https://github.com/golang/protobuf
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are

View File

@ -1,7 +1,7 @@
// Go support for Protocol Buffers - Google's data interchange format
//
// Copyright 2012 The Go Authors. All rights reserved.
// http://code.google.com/p/goprotobuf/
// https://github.com/golang/protobuf
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are
@ -35,8 +35,9 @@ import (
"log"
"testing"
proto3pb "./proto3_proto"
pb "./testdata"
. "code.google.com/p/goprotobuf/proto"
. "github.com/golang/protobuf/proto"
)
var messageWithExtension1 = &pb.MyMessage{Count: Int32(7)}
@ -102,6 +103,20 @@ var SizeTests = []struct {
{"unrecognized", &pb.MoreRepeated{XXX_unrecognized: []byte{13<<3 | 0, 4}}},
{"extension (unencoded)", messageWithExtension1},
{"extension (encoded)", messageWithExtension3},
// proto3 message
{"proto3 empty", &proto3pb.Message{}},
{"proto3 bool", &proto3pb.Message{TrueScotsman: true}},
{"proto3 int64", &proto3pb.Message{ResultCount: 1}},
{"proto3 uint32", &proto3pb.Message{HeightInCm: 123}},
{"proto3 float", &proto3pb.Message{Score: 12.6}},
{"proto3 string", &proto3pb.Message{Name: "Snezana"}},
{"proto3 bytes", &proto3pb.Message{Data: []byte("wowsa")}},
{"proto3 bytes, empty", &proto3pb.Message{Data: []byte{}}},
{"proto3 enum", &proto3pb.Message{Hilarity: proto3pb.Message_PUNS}},
{"map field", &pb.MessageWithMap{NameMapping: map[int32]string{1: "Rob", 7: "Andrew"}}},
{"map field with message", &pb.MessageWithMap{MsgMapping: map[int64]*pb.FloatingPoint{0x7001: &pb.FloatingPoint{F: Float64(2.0)}}}},
{"map field with bytes", &pb.MessageWithMap{ByteMapping: map[bool][]byte{true: []byte("this time for sure")}}},
}
func TestSize(t *testing.T) {

View File

@ -1,7 +1,7 @@
# Go support for Protocol Buffers - Google's data interchange format
#
# Copyright 2010 The Go Authors. All rights reserved.
# http://code.google.com/p/goprotobuf/
# https://github.com/golang/protobuf
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are
@ -41,7 +41,7 @@ regenerate:
# The following rules are just aids to development. Not needed for typical testing.
diff: regenerate
hg diff test.pb.go
git diff test.pb.go
restore:
cp test.pb.go.golden test.pb.go

View File

@ -1,7 +1,7 @@
// Go support for Protocol Buffers - Google's data interchange format
//
// Copyright 2012 The Go Authors. All rights reserved.
// http://code.google.com/p/goprotobuf/
// https://github.com/golang/protobuf
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are

View File

@ -33,10 +33,11 @@ It has these top-level messages:
GroupOld
GroupNew
FloatingPoint
MessageWithMap
*/
package testdata
import proto "code.google.com/p/goprotobuf/proto"
import proto "github.com/golang/protobuf/proto"
import math "math"
// Reference imports to suppress errors if they are not otherwise used.
@ -1416,6 +1417,12 @@ func (m *MyMessageSet) Marshal() ([]byte, error) {
func (m *MyMessageSet) Unmarshal(buf []byte) error {
return proto.UnmarshalMessageSet(buf, m.ExtensionMap())
}
func (m *MyMessageSet) MarshalJSON() ([]byte, error) {
return proto.MarshalMessageSetJSON(m.XXX_extensions)
}
func (m *MyMessageSet) UnmarshalJSON(buf []byte) error {
return proto.UnmarshalMessageSetJSON(buf, m.XXX_extensions)
}
// ensure MyMessageSet satisfies proto.Marshaler and proto.Unmarshaler
var _ proto.Marshaler = (*MyMessageSet)(nil)
@ -1879,6 +1886,38 @@ func (m *FloatingPoint) GetF() float64 {
return 0
}
type MessageWithMap struct {
NameMapping map[int32]string `protobuf:"bytes,1,rep,name=name_mapping" json:"name_mapping,omitempty" protobuf_key:"varint,1,opt,name=key" protobuf_val:"bytes,2,opt,name=value"`
MsgMapping map[int64]*FloatingPoint `protobuf:"bytes,2,rep,name=msg_mapping" json:"msg_mapping,omitempty" protobuf_key:"zigzag64,1,opt,name=key" protobuf_val:"bytes,2,opt,name=value"`
ByteMapping map[bool][]byte `protobuf:"bytes,3,rep,name=byte_mapping" json:"byte_mapping,omitempty" protobuf_key:"varint,1,opt,name=key" protobuf_val:"bytes,2,opt,name=value"`
XXX_unrecognized []byte `json:"-"`
}
func (m *MessageWithMap) Reset() { *m = MessageWithMap{} }
func (m *MessageWithMap) String() string { return proto.CompactTextString(m) }
func (*MessageWithMap) ProtoMessage() {}
func (m *MessageWithMap) GetNameMapping() map[int32]string {
if m != nil {
return m.NameMapping
}
return nil
}
func (m *MessageWithMap) GetMsgMapping() map[int64]*FloatingPoint {
if m != nil {
return m.MsgMapping
}
return nil
}
func (m *MessageWithMap) GetByteMapping() map[bool][]byte {
if m != nil {
return m.ByteMapping
}
return nil
}
var E_Greeting = &proto.ExtensionDesc{
ExtendedType: (*MyMessage)(nil),
ExtensionType: ([]string)(nil),

View File

@ -1,7 +1,7 @@
// Go support for Protocol Buffers - Google's data interchange format
//
// Copyright 2010 The Go Authors. All rights reserved.
// http://code.google.com/p/goprotobuf/
// https://github.com/golang/protobuf
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are
@ -426,3 +426,9 @@ message GroupNew {
message FloatingPoint {
required double f = 1;
}
message MessageWithMap {
map<int32, string> name_mapping = 1;
map<sint64, FloatingPoint> msg_mapping = 2;
map<bool, bytes> byte_mapping = 3;
}

View File

@ -1,7 +1,7 @@
// Go support for Protocol Buffers - Google's data interchange format
//
// Copyright 2010 The Go Authors. All rights reserved.
// http://code.google.com/p/goprotobuf/
// https://github.com/golang/protobuf
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are
@ -36,6 +36,7 @@ package proto
import (
"bufio"
"bytes"
"encoding"
"fmt"
"io"
"log"
@ -74,13 +75,6 @@ type textWriter struct {
w writer
}
// textMarshaler is implemented by Messages that can marshal themsleves.
// It is identical to encoding.TextMarshaler, introduced in go 1.2,
// which will eventually replace it.
type textMarshaler interface {
MarshalText() (text []byte, err error)
}
func (w *textWriter) WriteString(s string) (n int, err error) {
if !strings.Contains(s, "\n") {
if !w.compact && w.complete {
@ -250,6 +244,100 @@ func writeStruct(w *textWriter, sv reflect.Value) error {
}
continue
}
if fv.Kind() == reflect.Map {
// Map fields are rendered as a repeated struct with key/value fields.
keys := fv.MapKeys() // TODO: should we sort these for deterministic output?
sort.Sort(mapKeys(keys))
for _, key := range keys {
val := fv.MapIndex(key)
if err := writeName(w, props); err != nil {
return err
}
if !w.compact {
if err := w.WriteByte(' '); err != nil {
return err
}
}
// open struct
if err := w.WriteByte('<'); err != nil {
return err
}
if !w.compact {
if err := w.WriteByte('\n'); err != nil {
return err
}
}
w.indent()
// key
if _, err := w.WriteString("key:"); err != nil {
return err
}
if !w.compact {
if err := w.WriteByte(' '); err != nil {
return err
}
}
if err := writeAny(w, key, props.mkeyprop); err != nil {
return err
}
if err := w.WriteByte('\n'); err != nil {
return err
}
// value
if _, err := w.WriteString("value:"); err != nil {
return err
}
if !w.compact {
if err := w.WriteByte(' '); err != nil {
return err
}
}
if err := writeAny(w, val, props.mvalprop); err != nil {
return err
}
if err := w.WriteByte('\n'); err != nil {
return err
}
// close struct
w.unindent()
if err := w.WriteByte('>'); err != nil {
return err
}
if err := w.WriteByte('\n'); err != nil {
return err
}
}
continue
}
if props.proto3 && fv.Kind() == reflect.Slice && fv.Len() == 0 {
// empty bytes field
continue
}
if fv.Kind() != reflect.Ptr && fv.Kind() != reflect.Slice {
// proto3 non-repeated scalar field; skip if zero value
switch fv.Kind() {
case reflect.Bool:
if !fv.Bool() {
continue
}
case reflect.Int32, reflect.Int64:
if fv.Int() == 0 {
continue
}
case reflect.Uint32, reflect.Uint64:
if fv.Uint() == 0 {
continue
}
case reflect.Float32, reflect.Float64:
if fv.Float() == 0 {
continue
}
case reflect.String:
if fv.String() == "" {
continue
}
}
}
if err := writeName(w, props); err != nil {
return err
@ -358,7 +446,7 @@ func writeAny(w *textWriter, v reflect.Value, props *Properties) error {
}
}
w.indent()
if tm, ok := v.Interface().(textMarshaler); ok {
if tm, ok := v.Interface().(encoding.TextMarshaler); ok {
text, err := tm.MarshalText()
if err != nil {
return err
@ -653,7 +741,7 @@ func marshalText(w io.Writer, pb Message, compact bool) error {
compact: compact,
}
if tm, ok := pb.(textMarshaler); ok {
if tm, ok := pb.(encoding.TextMarshaler); ok {
text, err := tm.MarshalText()
if err != nil {
return err

View File

@ -1,7 +1,7 @@
// Go support for Protocol Buffers - Google's data interchange format
//
// Copyright 2010 The Go Authors. All rights reserved.
// http://code.google.com/p/goprotobuf/
// https://github.com/golang/protobuf
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are
@ -35,6 +35,7 @@ package proto
// TODO: message sets.
import (
"encoding"
"errors"
"fmt"
"reflect"
@ -43,13 +44,6 @@ import (
"unicode/utf8"
)
// textUnmarshaler is implemented by Messages that can unmarshal themsleves.
// It is identical to encoding.TextUnmarshaler, introduced in go 1.2,
// which will eventually replace it.
type textUnmarshaler interface {
UnmarshalText(text []byte) error
}
type ParseError struct {
Message string
Line int // 1-based line number
@ -361,8 +355,20 @@ func (p *textParser) next() *token {
return &p.cur
}
// Return an error indicating which required field was not set.
func (p *textParser) missingRequiredFieldError(sv reflect.Value) *ParseError {
func (p *textParser) consumeToken(s string) error {
tok := p.next()
if tok.err != nil {
return tok.err
}
if tok.value != s {
p.back()
return p.errorf("expected %q, found %q", s, tok.value)
}
return nil
}
// Return a RequiredNotSetError indicating which required field was not set.
func (p *textParser) missingRequiredFieldError(sv reflect.Value) *RequiredNotSetError {
st := sv.Type()
sprops := GetProperties(st)
for i := 0; i < st.NumField(); i++ {
@ -372,10 +378,10 @@ func (p *textParser) missingRequiredFieldError(sv reflect.Value) *ParseError {
props := sprops.Prop[i]
if props.Required {
return p.errorf("message %v missing required field %q", st, props.OrigName)
return &RequiredNotSetError{fmt.Sprintf("%v.%v", st, props.OrigName)}
}
}
return p.errorf("message %v missing required field", st) // should not happen
return &RequiredNotSetError{fmt.Sprintf("%v.<unknown field name>", st)} // should not happen
}
// Returns the index in the struct for the named field, as well as the parsed tag properties.
@ -415,6 +421,10 @@ func (p *textParser) checkForColon(props *Properties, typ reflect.Type) *ParseEr
if typ.Elem().Kind() != reflect.Ptr {
break
}
} else if typ.Kind() == reflect.String {
// The proto3 exception is for a string field,
// which requires a colon.
break
}
needColon = false
}
@ -426,9 +436,11 @@ func (p *textParser) checkForColon(props *Properties, typ reflect.Type) *ParseEr
return nil
}
func (p *textParser) readStruct(sv reflect.Value, terminator string) *ParseError {
func (p *textParser) readStruct(sv reflect.Value, terminator string) error {
st := sv.Type()
reqCount := GetProperties(st).reqCount
var reqFieldErr error
fieldSet := make(map[string]bool)
// A struct is a sequence of "name: value", terminated by one of
// '>' or '}', or the end of the input. A name may also be
// "[extension]".
@ -489,8 +501,11 @@ func (p *textParser) readStruct(sv reflect.Value, terminator string) *ParseError
ext = reflect.New(typ.Elem()).Elem()
}
if err := p.readAny(ext, props); err != nil {
if _, ok := err.(*RequiredNotSetError); !ok {
return err
}
reqFieldErr = err
}
ep := sv.Addr().Interface().(extendableProto)
if !rep {
SetExtension(ep, desc, ext.Interface())
@ -507,17 +522,71 @@ func (p *textParser) readStruct(sv reflect.Value, terminator string) *ParseError
}
} else {
// This is a normal, non-extension field.
fi, props, ok := structFieldByName(st, tok.value)
name := tok.value
fi, props, ok := structFieldByName(st, name)
if !ok {
return p.errorf("unknown field name %q in %v", tok.value, st)
return p.errorf("unknown field name %q in %v", name, st)
}
dst := sv.Field(fi)
isDstNil := isNil(dst)
if dst.Kind() == reflect.Map {
// Consume any colon.
if err := p.checkForColon(props, dst.Type()); err != nil {
return err
}
// Construct the map if it doesn't already exist.
if dst.IsNil() {
dst.Set(reflect.MakeMap(dst.Type()))
}
key := reflect.New(dst.Type().Key()).Elem()
val := reflect.New(dst.Type().Elem()).Elem()
// The map entry should be this sequence of tokens:
// < key : KEY value : VALUE >
// Technically the "key" and "value" could come in any order,
// but in practice they won't.
tok := p.next()
var terminator string
switch tok.value {
case "<":
terminator = ">"
case "{":
terminator = "}"
default:
return p.errorf("expected '{' or '<', found %q", tok.value)
}
if err := p.consumeToken("key"); err != nil {
return err
}
if err := p.consumeToken(":"); err != nil {
return err
}
if err := p.readAny(key, props.mkeyprop); err != nil {
return err
}
if err := p.consumeToken("value"); err != nil {
return err
}
if err := p.checkForColon(props.mvalprop, dst.Type().Elem()); err != nil {
return err
}
if err := p.readAny(val, props.mvalprop); err != nil {
return err
}
if err := p.consumeToken(terminator); err != nil {
return err
}
dst.SetMapIndex(key, val)
continue
}
// Check that it's not already set if it's not a repeated field.
if !props.Repeated && !isDstNil {
return p.errorf("non-repeated field %q was repeated", tok.value)
if !props.Repeated && fieldSet[name] {
return p.errorf("non-repeated field %q was repeated", name)
}
if err := p.checkForColon(props, st.Field(fi).Type); err != nil {
@ -525,11 +594,13 @@ func (p *textParser) readStruct(sv reflect.Value, terminator string) *ParseError
}
// Parse into the field.
fieldSet[name] = true
if err := p.readAny(dst, props); err != nil {
if _, ok := err.(*RequiredNotSetError); !ok {
return err
}
if props.Required {
reqFieldErr = err
} else if props.Required {
reqCount--
}
}
@ -547,10 +618,10 @@ func (p *textParser) readStruct(sv reflect.Value, terminator string) *ParseError
if reqCount > 0 {
return p.missingRequiredFieldError(sv)
}
return nil
return reqFieldErr
}
func (p *textParser) readAny(v reflect.Value, props *Properties) *ParseError {
func (p *textParser) readAny(v reflect.Value, props *Properties) error {
tok := p.next()
if tok.err != nil {
return tok.err
@ -652,7 +723,7 @@ func (p *textParser) readAny(v reflect.Value, props *Properties) *ParseError {
default:
return p.errorf("expected '{' or '<', found %q", tok.value)
}
// TODO: Handle nested messages which implement textUnmarshaler.
// TODO: Handle nested messages which implement encoding.TextUnmarshaler.
return p.readStruct(fv, terminator)
case reflect.Uint32:
if x, err := strconv.ParseUint(tok.value, 0, 32); err == nil {
@ -670,8 +741,10 @@ func (p *textParser) readAny(v reflect.Value, props *Properties) *ParseError {
// UnmarshalText reads a protocol buffer in Text format. UnmarshalText resets pb
// before starting to unmarshal, so any existing data in pb is always removed.
// If a required field is not set and no other error occurs,
// UnmarshalText returns *RequiredNotSetError.
func UnmarshalText(s string, pb Message) error {
if um, ok := pb.(textUnmarshaler); ok {
if um, ok := pb.(encoding.TextUnmarshaler); ok {
err := um.UnmarshalText([]byte(s))
return err
}

View File

@ -1,7 +1,7 @@
// Go support for Protocol Buffers - Google's data interchange format
//
// Copyright 2010 The Go Authors. All rights reserved.
// http://code.google.com/p/goprotobuf/
// https://github.com/golang/protobuf
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are
@ -36,8 +36,9 @@ import (
"reflect"
"testing"
proto3pb "./proto3_proto"
. "./testdata"
. "code.google.com/p/goprotobuf/proto"
. "github.com/golang/protobuf/proto"
)
type UnmarshalTextTest struct {
@ -294,8 +295,11 @@ var unMarshalTextTests = []UnmarshalTextTest{
// Missing required field
{
in: ``,
err: `line 1.0: message testdata.MyMessage missing required field "count"`,
in: `name: "Pawel"`,
err: `proto: required field "testdata.MyMessage.count" not set`,
out: &MyMessage{
Name: String("Pawel"),
},
},
// Repeated non-repeated field
@ -408,6 +412,9 @@ func TestUnmarshalText(t *testing.T) {
} else if err.Error() != test.err {
t.Errorf("Test %d: Incorrect error.\nHave: %v\nWant: %v",
i, err.Error(), test.err)
} else if _, ok := err.(*RequiredNotSetError); ok && test.out != nil && !reflect.DeepEqual(pb, test.out) {
t.Errorf("Test %d: Incorrect populated \nHave: %v\nWant: %v",
i, pb, test.out)
}
}
}
@ -437,6 +444,48 @@ func TestRepeatedEnum(t *testing.T) {
}
}
func TestProto3TextParsing(t *testing.T) {
m := new(proto3pb.Message)
const in = `name: "Wallace" true_scotsman: true`
want := &proto3pb.Message{
Name: "Wallace",
TrueScotsman: true,
}
if err := UnmarshalText(in, m); err != nil {
t.Fatal(err)
}
if !Equal(m, want) {
t.Errorf("\n got %v\nwant %v", m, want)
}
}
func TestMapParsing(t *testing.T) {
m := new(MessageWithMap)
const in = `name_mapping:<key:1234 value:"Feist"> name_mapping:<key:1 value:"Beatles">` +
`msg_mapping:<key:-4 value:<f: 2.0>>` +
`msg_mapping<key:-2 value<f: 4.0>>` + // no colon after "value"
`byte_mapping:<key:true value:"so be it">`
want := &MessageWithMap{
NameMapping: map[int32]string{
1: "Beatles",
1234: "Feist",
},
MsgMapping: map[int64]*FloatingPoint{
-4: {F: Float64(2.0)},
-2: {F: Float64(4.0)},
},
ByteMapping: map[bool][]byte{
true: []byte("so be it"),
},
}
if err := UnmarshalText(in, m); err != nil {
t.Fatal(err)
}
if !Equal(m, want) {
t.Errorf("\n got %v\nwant %v", m, want)
}
}
var benchInput string
func init() {

View File

@ -1,7 +1,7 @@
// Go support for Protocol Buffers - Google's data interchange format
//
// Copyright 2010 The Go Authors. All rights reserved.
// http://code.google.com/p/goprotobuf/
// https://github.com/golang/protobuf
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are
@ -39,8 +39,9 @@ import (
"strings"
"testing"
"code.google.com/p/goprotobuf/proto"
"github.com/golang/protobuf/proto"
proto3pb "./proto3_proto"
pb "./testdata"
)
@ -406,3 +407,30 @@ Message <nil>
t.Errorf(" got: %s\nwant: %s", s, want)
}
}
func TestProto3Text(t *testing.T) {
tests := []struct {
m proto.Message
want string
}{
// zero message
{&proto3pb.Message{}, ``},
// zero message except for an empty byte slice
{&proto3pb.Message{Data: []byte{}}, ``},
// trivial case
{&proto3pb.Message{Name: "Rob", HeightInCm: 175}, `name:"Rob" height_in_cm:175`},
// empty map
{&pb.MessageWithMap{}, ``},
// non-empty map; current map format is the same as a repeated struct
{
&pb.MessageWithMap{NameMapping: map[int32]string{1234: "Feist"}},
`name_mapping:<key:1234 value:"Feist" >`,
},
}
for _, test := range tests {
got := strings.TrimSpace(test.m.String())
if got != test.want {
t.Errorf("\n got %s\nwant %s", got, test.want)
}
}
}

View File

@ -21,8 +21,8 @@ import (
"testing"
"testing/quick"
. "code.google.com/p/goprotobuf/proto"
. "code.google.com/p/goprotobuf/proto/testdata"
. "github.com/golang/protobuf/proto"
. "github.com/golang/protobuf/proto/testdata"
)
func TestWriteDelimited(t *testing.T) {

View File

@ -19,7 +19,7 @@ import (
"errors"
"io"
"code.google.com/p/goprotobuf/proto"
"github.com/golang/protobuf/proto"
)
var errInvalidVarint = errors.New("invalid varint32 encountered")

View File

@ -18,7 +18,7 @@ import (
"encoding/binary"
"io"
"code.google.com/p/goprotobuf/proto"
"github.com/golang/protobuf/proto"
)
// WriteDelimited encodes and dumps a message to the provided writer prefixed

View File

@ -1,5 +1,5 @@
// Copyright 2010 The Go Authors. All rights reserved.
// http://code.google.com/p/goprotobuf/
// http://github.com/golang/protobuf/
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are
@ -30,11 +30,11 @@
package ext
import (
. "code.google.com/p/goprotobuf/proto"
. "code.google.com/p/goprotobuf/proto/testdata"
. "github.com/golang/protobuf/proto"
. "github.com/golang/protobuf/proto/testdata"
)
// FROM https://code.google.com/p/goprotobuf/source/browse/proto/all_test.go.
// FROM https://github.com/golang/protobuf/blob/master/proto/all_test.go.
func initGoTestField() *GoTestField {
f := new(GoTestField)

View File

@ -16,4 +16,6 @@ script:
- GOARCH=$BUILD_GOARCH GOOS=$BUILD_GOOS go build
# only test on linux
- if [ $BUILD_GOOS == "linux" ]; then GOARCH=$BUILD_GOARCH GOOS=$BUILD_GOOS go test -bench=.; fi
# also specify -short; the crypto tests fail in weird ways *sometimes*
# See issue #151
- if [ $BUILD_GOOS == "linux" ]; then GOARCH=$BUILD_GOARCH GOOS=$BUILD_GOOS go test -short -bench=.; fi

View File

@ -34,6 +34,7 @@ A not-so-up-to-date-list-that-may-be-actually-current:
* https://github.com/skynetservices/skydns
* https://github.com/DevelopersPL/godnsagent
* https://github.com/duedil-ltd/discodns
* https://github.com/StalkR/dns-reverse-proxy
Send pull request if you want to be listed here.
@ -49,7 +50,7 @@ Send pull request if you want to be listed here.
* DNSSEC: signing, validating and key generation for DSA, RSA and ECDSA;
* EDNS0, NSID;
* AXFR/IXFR;
* TSIG;
* TSIG, SIG(0);
* DNS name compression;
* Depends only on the standard library.
@ -95,7 +96,8 @@ Example programs can be found in the `github.com/miekg/exdns` repository.
* 3225 - DO bit (DNSSEC OK)
* 340{1,2,3} - NAPTR record
* 3445 - Limiting the scope of (DNS)KEY
* 3597 - Unkown RRs
* 3597 - Unknown RRs
* 4025 - IPSECKEY
* 403{3,4,5} - DNSSEC + validation functions
* 4255 - SSHFP record
* 4343 - Case insensitivity
@ -112,6 +114,7 @@ Example programs can be found in the `github.com/miekg/exdns` repository.
* 5936 - AXFR
* 5966 - TCP implementation recommendations
* 6605 - ECDSA
* 6725 - IANA Registry Update
* 6742 - ILNP DNS
* 6891 - EDNS0 update
* 6895 - DNS IANA considerations
@ -131,10 +134,8 @@ Example programs can be found in the `github.com/miekg/exdns` repository.
## TODO
* privatekey.Precompute() when signing?
* Last remaining RRs: APL, ATMA, A6 and NXT;
* Last remaining RRs: APL, ATMA, A6 and NXT and IPSECKEY;
* Missing in parsing: ISDN, UNSPEC, ATMA;
* CAA parsing is broken;
* NSEC(3) cover/match/closest enclose;
* Replies with TC bit are not parsed to the end;
* SIG(0);
* Create IsMsg to validate a message before fully parsing it.

View File

@ -9,7 +9,7 @@ import (
"time"
)
const dnsTimeout time.Duration = 2 * 1e9
const dnsTimeout time.Duration = 2 * time.Second
const tcpIdleTimeout time.Duration = 8 * time.Second
// A Conn represents a connection to a DNS server.
@ -26,9 +26,9 @@ type Conn struct {
type Client struct {
Net string // if "tcp" a TCP query will be initiated, otherwise an UDP one (default is "" for UDP)
UDPSize uint16 // minimum receive buffer for UDP messages
DialTimeout time.Duration // net.DialTimeout (ns), defaults to 2 * 1e9
ReadTimeout time.Duration // net.Conn.SetReadTimeout value for connections (ns), defaults to 2 * 1e9
WriteTimeout time.Duration // net.Conn.SetWriteTimeout value for connections (ns), defaults to 2 * 1e9
DialTimeout time.Duration // net.DialTimeout, defaults to 2 seconds
ReadTimeout time.Duration // net.Conn.SetReadTimeout value for connections, defaults to 2 seconds
WriteTimeout time.Duration // net.Conn.SetWriteTimeout value for connections, defaults to 2 seconds
TsigSecret map[string]string // secret(s) for Tsig map[<zonename>]<base64 secret>, zonename must be fully qualified
SingleInflight bool // if true suppress multiple outstanding queries for the same Qname, Qtype and Qclass
group singleflight

View File

@ -118,54 +118,7 @@ Loop:
}
}
/*
func TestClientTsigAXFR(t *testing.T) {
m := new(Msg)
m.SetAxfr("example.nl.")
m.SetTsig("axfr.", HmacMD5, 300, time.Now().Unix())
tr := new(Transfer)
tr.TsigSecret = map[string]string{"axfr.": "so6ZGir4GPAqINNh9U5c3A=="}
if a, err := tr.In(m, "176.58.119.54:53"); err != nil {
t.Log("failed to setup axfr: " + err.Error())
t.Fatal()
} else {
for ex := range a {
if ex.Error != nil {
t.Logf("error %s\n", ex.Error.Error())
t.Fail()
break
}
for _, rr := range ex.RR {
t.Logf("%s\n", rr.String())
}
}
}
}
func TestClientAXFRMultipleEnvelopes(t *testing.T) {
m := new(Msg)
m.SetAxfr("nlnetlabs.nl.")
tr := new(Transfer)
if a, err := tr.In(m, "213.154.224.1:53"); err != nil {
t.Log("Failed to setup axfr" + err.Error())
t.Fail()
return
} else {
for ex := range a {
if ex.Error != nil {
t.Logf("Error %s\n", ex.Error.Error())
t.Fail()
break
}
}
}
}
*/
// ExapleUpdateLeaseTSIG shows how to update a lease signed with TSIG.
// ExampleUpdateLeaseTSIG shows how to update a lease signed with TSIG.
func ExampleUpdateLeaseTSIG(t *testing.T) {
m := new(Msg)
m.SetUpdate("t.local.ip6.io.")

View File

@ -26,14 +26,19 @@ func ClientConfigFromFile(resolvconf string) (*ClientConfig, error) {
}
defer file.Close()
c := new(ClientConfig)
b := bufio.NewReader(file)
scanner := bufio.NewScanner(file)
c.Servers = make([]string, 0)
c.Search = make([]string, 0)
c.Port = "53"
c.Ndots = 1
c.Timeout = 5
c.Attempts = 2
for line, ok := b.ReadString('\n'); ok == nil; line, ok = b.ReadString('\n') {
for scanner.Scan() {
if err := scanner.Err(); err != nil {
return nil, err
}
line := scanner.Text()
f := strings.Fields(line)
if len(f) < 1 {
continue

View File

@ -0,0 +1,55 @@
package dns
import (
"io/ioutil"
"os"
"path/filepath"
"testing"
)
const normal string = `
# Comment
domain somedomain.com
nameserver 10.28.10.2
nameserver 11.28.10.1
`
const missingNewline string = `
domain somedomain.com
nameserver 10.28.10.2
nameserver 11.28.10.1` // <- NOTE: NO newline.
func testConfig(t *testing.T, data string) {
tempDir, err := ioutil.TempDir("", "")
if err != nil {
t.Fatalf("TempDir: %v", err)
}
defer os.RemoveAll(tempDir)
path := filepath.Join(tempDir, "resolv.conf")
if err := ioutil.WriteFile(path, []byte(data), 0644); err != nil {
t.Fatalf("WriteFile: %v", err)
}
cc, err := ClientConfigFromFile(path)
if err != nil {
t.Errorf("error parsing resolv.conf: %s", err)
}
if l := len(cc.Servers); l != 2 {
t.Errorf("incorrect number of nameservers detected: %d", l)
}
if l := len(cc.Search); l != 1 {
t.Errorf("domain directive not parsed correctly: %v", cc.Search)
} else {
if cc.Search[0] != "somedomain.com" {
t.Errorf("domain is unexpected: %v", cc.Search[0])
}
}
}
func TestNameserver(t *testing.T) {
testConfig(t, normal)
}
func TestMissingFinalNewLine(t *testing.T) {
testConfig(t, missingNewline)
}

View File

@ -73,13 +73,15 @@ func (dns *Msg) SetUpdate(z string) *Msg {
}
// SetIxfr creates message for requesting an IXFR.
func (dns *Msg) SetIxfr(z string, serial uint32) *Msg {
func (dns *Msg) SetIxfr(z string, serial uint32, ns, mbox string) *Msg {
dns.Id = Id()
dns.Question = make([]Question, 1)
dns.Ns = make([]RR, 1)
s := new(SOA)
s.Hdr = RR_Header{z, TypeSOA, ClassINET, defaultTtl, 0}
s.Serial = serial
s.Ns = ns
s.Mbox = mbox
dns.Question[0] = Question{z, TypeIXFR, ClassINET}
dns.Ns[0] = s
return dns
@ -169,7 +171,7 @@ func IsMsg(buf []byte) error {
return errors.New("dns: bad message header")
}
// Header: Opcode
// TODO(miek): more checks here, e.g. check all header bits.
return nil
}

View File

@ -93,9 +93,7 @@
// spaces, semicolons and the at symbol are escaped.
package dns
import (
"strconv"
)
import "strconv"
const (
year68 = 1 << 31 // For RFC1982 (Serial Arithmetic) calculations in 32 bits.
@ -124,7 +122,7 @@ type RR interface {
String() string
// copy returns a copy of the RR
copy() RR
// len returns the length (in octects) of the uncompressed RR in wire format.
// len returns the length (in octets) of the uncompressed RR in wire format.
len() int
}

View File

@ -426,11 +426,21 @@ func BenchmarkUnpackDomainName(b *testing.B) {
}
}
func BenchmarkUnpackDomainNameUnprintable(b *testing.B) {
name1 := "\x02\x02\x02\x025\x02\x02\x02\x02.12345678.123."
buf := make([]byte, len(name1)+1)
_, _ = PackDomainName(name1, buf, 0, nil, false)
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, _, _ = UnpackDomainName(buf, 0)
}
}
func TestToRFC3597(t *testing.T) {
a, _ := NewRR("miek.nl. IN A 10.0.1.1")
x := new(RFC3597)
x.ToRFC3597(a)
if x.String() != `miek.nl. 3600 IN A \# 4 0a000101` {
if x.String() != `miek.nl. 3600 CLASS1 TYPE1 \# 4 0a000101` {
t.Fail()
}
}
@ -499,3 +509,72 @@ func TestCopy(t *testing.T) {
t.Fatalf("Copy() failed %s != %s", rr.String(), rr1.String())
}
}
func TestMsgCopy(t *testing.T) {
m := new(Msg)
m.SetQuestion("miek.nl.", TypeA)
rr, _ := NewRR("miek.nl. 2311 IN A 127.0.0.1")
m.Answer = []RR{rr}
rr, _ = NewRR("miek.nl. 2311 IN NS 127.0.0.1")
m.Ns = []RR{rr}
m1 := m.Copy()
if m.String() != m1.String() {
t.Fatalf("Msg.Copy() failed %s != %s", m.String(), m1.String())
}
m1.Answer[0], _ = NewRR("somethingelse.nl. 2311 IN A 127.0.0.1")
if m.String() == m1.String() {
t.Fatalf("Msg.Copy() failed; change to copy changed template %s", m.String())
}
rr, _ = NewRR("miek.nl. 2311 IN A 127.0.0.2")
m1.Answer = append(m1.Answer, rr)
if m1.Ns[0].String() == m1.Answer[1].String() {
t.Fatalf("Msg.Copy() failed; append changed underlying array %s", m1.Ns[0].String())
}
}
func BenchmarkCopy(b *testing.B) {
b.ReportAllocs()
m := new(Msg)
m.SetQuestion("miek.nl.", TypeA)
rr, _ := NewRR("miek.nl. 2311 IN A 127.0.0.1")
m.Answer = []RR{rr}
rr, _ = NewRR("miek.nl. 2311 IN NS 127.0.0.1")
m.Ns = []RR{rr}
rr, _ = NewRR("miek.nl. 2311 IN A 127.0.0.1")
m.Extra = []RR{rr}
b.ResetTimer()
for i := 0; i < b.N; i++ {
m.Copy()
}
}
func TestPackIPSECKEY(t *testing.T) {
tests := []string{
"38.2.0.192.in-addr.arpa. 7200 IN IPSECKEY ( 10 1 2 192.0.2.38 AQNRU3mG7TVTO2BkR47usntb102uFJtugbo6BSGvgqt4AQ== )",
"38.2.0.192.in-addr.arpa. 7200 IN IPSECKEY ( 10 0 2 . AQNRU3mG7TVTO2BkR47usntb102uFJtugbo6BSGvgqt4AQ== )",
"38.2.0.192.in-addr.arpa. 7200 IN IPSECKEY ( 10 1 2 192.0.2.3 AQNRU3mG7TVTO2BkR47usntb102uFJtugbo6BSGvgqt4AQ== )",
"38.1.0.192.in-addr.arpa. 7200 IN IPSECKEY ( 10 3 2 mygateway.example.com. AQNRU3mG7TVTO2BkR47usntb102uFJtugbo6BSGvgqt4AQ== )",
"0.d.4.0.3.0.e.f.f.f.3.f.0.1.2.0 7200 IN IPSECKEY ( 10 2 2 2001:0DB8:0:8002::2000:1 AQNRU3mG7TVTO2BkR47usntb102uFJtugbo6BSGvgqt4AQ== )",
}
buf := make([]byte, 1024)
for _, t1 := range tests {
rr, _ := NewRR(t1)
off, e := PackRR(rr, buf, 0, nil, false)
if e != nil {
t.Logf("failed to pack IPSECKEY %s: %s\n", e, t1)
t.Fail()
continue
}
rr, _, e = UnpackRR(buf[:off], 0)
if e != nil {
t.Logf("failed to unpack IPSECKEY %s: %s\n", e, t1)
t.Fail()
}
t.Logf("%s\n", rr)
}
}

View File

@ -20,7 +20,6 @@ import (
"crypto/ecdsa"
"crypto/elliptic"
"crypto/md5"
"crypto/rand"
"crypto/rsa"
"crypto/sha1"
"crypto/sha256"
@ -36,26 +35,29 @@ import (
// DNSSEC encryption algorithm codes.
const (
RSAMD5 = 1
DH = 2
DSA = 3
ECC = 4
RSASHA1 = 5
DSANSEC3SHA1 = 6
RSASHA1NSEC3SHA1 = 7
RSASHA256 = 8
RSASHA512 = 10
ECCGOST = 12
ECDSAP256SHA256 = 13
ECDSAP384SHA384 = 14
INDIRECT = 252
PRIVATEDNS = 253 // Private (experimental keys)
PRIVATEOID = 254
_ uint8 = iota
RSAMD5
DH
DSA
_ // Skip 4, RFC 6725, section 2.1
RSASHA1
DSANSEC3SHA1
RSASHA1NSEC3SHA1
RSASHA256
_ // Skip 9, RFC 6725, section 2.1
RSASHA512
_ // Skip 11, RFC 6725, section 2.1
ECCGOST
ECDSAP256SHA256
ECDSAP384SHA384
INDIRECT uint8 = 252
PRIVATEDNS uint8 = 253 // Private (experimental keys)
PRIVATEOID uint8 = 254
)
// DNSSEC hashing algorithm codes.
const (
_ = iota
_ uint8 = iota
SHA1 // RFC 4034
SHA256 // RFC 4509
GOST94 // RFC 5933
@ -94,6 +96,10 @@ type dnskeyWireFmt struct {
/* Nothing is left out */
}
func divRoundUp(a, b int) int {
return (a + b - 1) / b
}
// KeyTag calculates the keytag (or key-id) of the DNSKEY.
func (k *DNSKEY) KeyTag() uint16 {
if k == nil {
@ -105,7 +111,7 @@ func (k *DNSKEY) KeyTag() uint16 {
// Look at the bottom two bytes of the modules, which the last
// item in the pubkey. We could do this faster by looking directly
// at the base64 values. But I'm lazy.
modulus, _ := packBase64([]byte(k.PublicKey))
modulus, _ := fromBase64([]byte(k.PublicKey))
if len(modulus) > 1 {
x, _ := unpackUint16(modulus, len(modulus)-2)
keytag = int(x)
@ -136,7 +142,7 @@ func (k *DNSKEY) KeyTag() uint16 {
}
// ToDS converts a DNSKEY record to a DS record.
func (k *DNSKEY) ToDS(h int) *DS {
func (k *DNSKEY) ToDS(h uint8) *DS {
if k == nil {
return nil
}
@ -146,7 +152,7 @@ func (k *DNSKEY) ToDS(h int) *DS {
ds.Hdr.Rrtype = TypeDS
ds.Hdr.Ttl = k.Hdr.Ttl
ds.Algorithm = k.Algorithm
ds.DigestType = uint8(h)
ds.DigestType = h
ds.KeyTag = k.KeyTag()
keywire := new(dnskeyWireFmt)
@ -249,60 +255,36 @@ func (rr *RRSIG) Sign(k PrivateKey, rrset []RR) error {
}
signdata = append(signdata, wire...)
var sighash []byte
var h hash.Hash
var ch crypto.Hash // Only need for RSA
switch rr.Algorithm {
case DSA, DSANSEC3SHA1:
// Implicit in the ParameterSizes
// TODO: this seems bugged, will panic
case RSASHA1, RSASHA1NSEC3SHA1:
h = sha1.New()
ch = crypto.SHA1
case RSASHA256, ECDSAP256SHA256:
h = sha256.New()
ch = crypto.SHA256
case ECDSAP384SHA384:
h = sha512.New384()
case RSASHA512:
h = sha512.New()
ch = crypto.SHA512
case RSAMD5:
fallthrough // Deprecated in RFC 6725
default:
return ErrAlg
}
io.WriteString(h, string(signdata))
sighash = h.Sum(nil)
switch p := k.(type) {
case *dsa.PrivateKey:
r1, s1, err := dsa.Sign(rand.Reader, p, sighash)
_, err = h.Write(signdata)
if err != nil {
return err
}
signature := []byte{0x4D} // T value, here the ASCII M for Miek (not used in DNSSEC)
signature = append(signature, r1.Bytes()...)
signature = append(signature, s1.Bytes()...)
rr.Signature = unpackBase64(signature)
case *rsa.PrivateKey:
// We can use nil as rand.Reader here (says AGL)
signature, err := rsa.SignPKCS1v15(nil, p, ch, sighash)
sighash := h.Sum(nil)
signature, err := k.Sign(sighash, rr.Algorithm)
if err != nil {
return err
}
rr.Signature = unpackBase64(signature)
case *ecdsa.PrivateKey:
r1, s1, err := ecdsa.Sign(rand.Reader, p, sighash)
if err != nil {
return err
}
signature := r1.Bytes()
signature = append(signature, s1.Bytes()...)
rr.Signature = unpackBase64(signature)
default:
// Not given the correct key
return ErrKeyAlg
}
rr.Signature = toBase64(signature)
return nil
}
@ -395,7 +377,7 @@ func (rr *RRSIG) Verify(k *DNSKEY, rrset []RR) error {
sighash := h.Sum(nil)
return rsa.VerifyPKCS1v15(pubkey, ch, sighash, sigbuf)
case ECDSAP256SHA256, ECDSAP384SHA384:
pubkey := k.publicKeyCurve()
pubkey := k.publicKeyECDSA()
if pubkey == nil {
return ErrKey
}
@ -441,41 +423,16 @@ func (rr *RRSIG) ValidityPeriod(t time.Time) bool {
// Return the signatures base64 encodedig sigdata as a byte slice.
func (s *RRSIG) sigBuf() []byte {
sigbuf, err := packBase64([]byte(s.Signature))
sigbuf, err := fromBase64([]byte(s.Signature))
if err != nil {
return nil
}
return sigbuf
}
// setPublicKeyInPrivate sets the public key in the private key.
func (k *DNSKEY) setPublicKeyInPrivate(p PrivateKey) bool {
switch t := p.(type) {
case *dsa.PrivateKey:
x := k.publicKeyDSA()
if x == nil {
return false
}
t.PublicKey = *x
case *rsa.PrivateKey:
x := k.publicKeyRSA()
if x == nil {
return false
}
t.PublicKey = *x
case *ecdsa.PrivateKey:
x := k.publicKeyCurve()
if x == nil {
return false
}
t.PublicKey = *x
}
return true
}
// publicKeyRSA returns the RSA public key from a DNSKEY record.
func (k *DNSKEY) publicKeyRSA() *rsa.PublicKey {
keybuf, err := packBase64([]byte(k.PublicKey))
keybuf, err := fromBase64([]byte(k.PublicKey))
if err != nil {
return nil
}
@ -511,9 +468,9 @@ func (k *DNSKEY) publicKeyRSA() *rsa.PublicKey {
return pubkey
}
// publicKeyCurve returns the Curve public key from the DNSKEY record.
func (k *DNSKEY) publicKeyCurve() *ecdsa.PublicKey {
keybuf, err := packBase64([]byte(k.PublicKey))
// publicKeyECDSA returns the Curve public key from the DNSKEY record.
func (k *DNSKEY) publicKeyECDSA() *ecdsa.PublicKey {
keybuf, err := fromBase64([]byte(k.PublicKey))
if err != nil {
return nil
}
@ -540,97 +497,29 @@ func (k *DNSKEY) publicKeyCurve() *ecdsa.PublicKey {
}
func (k *DNSKEY) publicKeyDSA() *dsa.PublicKey {
keybuf, err := packBase64([]byte(k.PublicKey))
keybuf, err := fromBase64([]byte(k.PublicKey))
if err != nil {
return nil
}
if len(keybuf) < 22 { // TODO: check
if len(keybuf) < 22 {
return nil
}
t := int(keybuf[0])
t, keybuf := int(keybuf[0]), keybuf[1:]
size := 64 + t*8
q, keybuf := keybuf[:20], keybuf[20:]
if len(keybuf) != 3*size {
return nil
}
p, keybuf := keybuf[:size], keybuf[size:]
g, y := keybuf[:size], keybuf[size:]
pubkey := new(dsa.PublicKey)
pubkey.Parameters.Q = big.NewInt(0)
pubkey.Parameters.Q.SetBytes(keybuf[1:21]) // +/- 1 ?
pubkey.Parameters.P = big.NewInt(0)
pubkey.Parameters.P.SetBytes(keybuf[22 : 22+size])
pubkey.Parameters.G = big.NewInt(0)
pubkey.Parameters.G.SetBytes(keybuf[22+size+1 : 22+size*2])
pubkey.Y = big.NewInt(0)
pubkey.Y.SetBytes(keybuf[22+size*2+1 : 22+size*3])
pubkey.Parameters.Q = big.NewInt(0).SetBytes(q)
pubkey.Parameters.P = big.NewInt(0).SetBytes(p)
pubkey.Parameters.G = big.NewInt(0).SetBytes(g)
pubkey.Y = big.NewInt(0).SetBytes(y)
return pubkey
}
// Set the public key (the value E and N)
func (k *DNSKEY) setPublicKeyRSA(_E int, _N *big.Int) bool {
if _E == 0 || _N == nil {
return false
}
buf := exponentToBuf(_E)
buf = append(buf, _N.Bytes()...)
k.PublicKey = unpackBase64(buf)
return true
}
// Set the public key for Elliptic Curves
func (k *DNSKEY) setPublicKeyCurve(_X, _Y *big.Int) bool {
if _X == nil || _Y == nil {
return false
}
buf := curveToBuf(_X, _Y)
// Check the length of the buffer, either 64 or 92 bytes
k.PublicKey = unpackBase64(buf)
return true
}
// Set the public key for DSA
func (k *DNSKEY) setPublicKeyDSA(_Q, _P, _G, _Y *big.Int) bool {
if _Q == nil || _P == nil || _G == nil || _Y == nil {
return false
}
buf := dsaToBuf(_Q, _P, _G, _Y)
k.PublicKey = unpackBase64(buf)
return true
}
// Set the public key (the values E and N) for RSA
// RFC 3110: Section 2. RSA Public KEY Resource Records
func exponentToBuf(_E int) []byte {
var buf []byte
i := big.NewInt(int64(_E))
if len(i.Bytes()) < 256 {
buf = make([]byte, 1)
buf[0] = uint8(len(i.Bytes()))
} else {
buf = make([]byte, 3)
buf[0] = 0
buf[1] = uint8(len(i.Bytes()) >> 8)
buf[2] = uint8(len(i.Bytes()))
}
buf = append(buf, i.Bytes()...)
return buf
}
// Set the public key for X and Y for Curve. The two
// values are just concatenated.
func curveToBuf(_X, _Y *big.Int) []byte {
buf := _X.Bytes()
buf = append(buf, _Y.Bytes()...)
return buf
}
// Set the public key for X and Y for Curve. The two
// values are just concatenated.
func dsaToBuf(_Q, _P, _G, _Y *big.Int) []byte {
t := byte((len(_G.Bytes()) - 64) / 8)
buf := []byte{t}
buf = append(buf, _Q.Bytes()...)
buf = append(buf, _P.Bytes()...)
buf = append(buf, _G.Bytes()...)
buf = append(buf, _Y.Bytes()...)
return buf
}
type wireSlice [][]byte
func (p wireSlice) Len() int { return len(p) }

View File

@ -0,0 +1,155 @@
package dns
import (
"crypto/dsa"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/rsa"
"math/big"
)
// Generate generates a DNSKEY of the given bit size.
// The public part is put inside the DNSKEY record.
// The Algorithm in the key must be set as this will define
// what kind of DNSKEY will be generated.
// The ECDSA algorithms imply a fixed keysize, in that case
// bits should be set to the size of the algorithm.
func (r *DNSKEY) Generate(bits int) (PrivateKey, error) {
switch r.Algorithm {
case DSA, DSANSEC3SHA1:
if bits != 1024 {
return nil, ErrKeySize
}
case RSAMD5, RSASHA1, RSASHA256, RSASHA1NSEC3SHA1:
if bits < 512 || bits > 4096 {
return nil, ErrKeySize
}
case RSASHA512:
if bits < 1024 || bits > 4096 {
return nil, ErrKeySize
}
case ECDSAP256SHA256:
if bits != 256 {
return nil, ErrKeySize
}
case ECDSAP384SHA384:
if bits != 384 {
return nil, ErrKeySize
}
}
switch r.Algorithm {
case DSA, DSANSEC3SHA1:
params := new(dsa.Parameters)
if err := dsa.GenerateParameters(params, rand.Reader, dsa.L1024N160); err != nil {
return nil, err
}
priv := new(dsa.PrivateKey)
priv.PublicKey.Parameters = *params
err := dsa.GenerateKey(priv, rand.Reader)
if err != nil {
return nil, err
}
r.setPublicKeyDSA(params.Q, params.P, params.G, priv.PublicKey.Y)
return (*DSAPrivateKey)(priv), nil
case RSAMD5, RSASHA1, RSASHA256, RSASHA512, RSASHA1NSEC3SHA1:
priv, err := rsa.GenerateKey(rand.Reader, bits)
if err != nil {
return nil, err
}
r.setPublicKeyRSA(priv.PublicKey.E, priv.PublicKey.N)
return (*RSAPrivateKey)(priv), nil
case ECDSAP256SHA256, ECDSAP384SHA384:
var c elliptic.Curve
switch r.Algorithm {
case ECDSAP256SHA256:
c = elliptic.P256()
case ECDSAP384SHA384:
c = elliptic.P384()
}
priv, err := ecdsa.GenerateKey(c, rand.Reader)
if err != nil {
return nil, err
}
r.setPublicKeyECDSA(priv.PublicKey.X, priv.PublicKey.Y)
return (*ECDSAPrivateKey)(priv), nil
default:
return nil, ErrAlg
}
}
// Set the public key (the value E and N)
func (k *DNSKEY) setPublicKeyRSA(_E int, _N *big.Int) bool {
if _E == 0 || _N == nil {
return false
}
buf := exponentToBuf(_E)
buf = append(buf, _N.Bytes()...)
k.PublicKey = toBase64(buf)
return true
}
// Set the public key for Elliptic Curves
func (k *DNSKEY) setPublicKeyECDSA(_X, _Y *big.Int) bool {
if _X == nil || _Y == nil {
return false
}
var intlen int
switch k.Algorithm {
case ECDSAP256SHA256:
intlen = 32
case ECDSAP384SHA384:
intlen = 48
}
k.PublicKey = toBase64(curveToBuf(_X, _Y, intlen))
return true
}
// Set the public key for DSA
func (k *DNSKEY) setPublicKeyDSA(_Q, _P, _G, _Y *big.Int) bool {
if _Q == nil || _P == nil || _G == nil || _Y == nil {
return false
}
buf := dsaToBuf(_Q, _P, _G, _Y)
k.PublicKey = toBase64(buf)
return true
}
// Set the public key (the values E and N) for RSA
// RFC 3110: Section 2. RSA Public KEY Resource Records
func exponentToBuf(_E int) []byte {
var buf []byte
i := big.NewInt(int64(_E))
if len(i.Bytes()) < 256 {
buf = make([]byte, 1)
buf[0] = uint8(len(i.Bytes()))
} else {
buf = make([]byte, 3)
buf[0] = 0
buf[1] = uint8(len(i.Bytes()) >> 8)
buf[2] = uint8(len(i.Bytes()))
}
buf = append(buf, i.Bytes()...)
return buf
}
// Set the public key for X and Y for Curve. The two
// values are just concatenated.
func curveToBuf(_X, _Y *big.Int, intlen int) []byte {
buf := intToBytes(_X, intlen)
buf = append(buf, intToBytes(_Y, intlen)...)
return buf
}
// Set the public key for X and Y for Curve. The two
// values are just concatenated.
func dsaToBuf(_Q, _P, _G, _Y *big.Int) []byte {
t := divRoundUp(divRoundUp(_G.BitLen(), 8)-64, 8)
buf := []byte{byte(t)}
buf = append(buf, intToBytes(_Q, 20)...)
buf = append(buf, intToBytes(_P, 64+t*8)...)
buf = append(buf, intToBytes(_G, 64+t*8)...)
buf = append(buf, intToBytes(_Y, 64+t*8)...)
return buf
}

View File

@ -18,8 +18,8 @@ func (k *DNSKEY) NewPrivateKey(s string) (PrivateKey, error) {
// ReadPrivateKey reads a private key from the io.Reader q. The string file is
// only used in error reporting.
// The public key must be
// known, because some cryptographic algorithms embed the public inside the privatekey.
// The public key must be known, because some cryptographic algorithms embed
// the public inside the privatekey.
func (k *DNSKEY) ReadPrivateKey(q io.Reader, file string) (PrivateKey, error) {
m, e := parseKey(q, file)
if m == nil {
@ -34,14 +34,16 @@ func (k *DNSKEY) ReadPrivateKey(q io.Reader, file string) (PrivateKey, error) {
// TODO(mg): check if the pubkey matches the private key
switch m["algorithm"] {
case "3 (DSA)":
p, e := readPrivateKeyDSA(m)
priv, e := readPrivateKeyDSA(m)
if e != nil {
return nil, e
}
if !k.setPublicKeyInPrivate(p) {
return nil, ErrPrivKey
pub := k.publicKeyDSA()
if pub == nil {
return nil, ErrKey
}
return p, e
priv.PublicKey = *pub
return (*DSAPrivateKey)(priv), e
case "1 (RSAMD5)":
fallthrough
case "5 (RSASHA1)":
@ -51,44 +53,44 @@ func (k *DNSKEY) ReadPrivateKey(q io.Reader, file string) (PrivateKey, error) {
case "8 (RSASHA256)":
fallthrough
case "10 (RSASHA512)":
p, e := readPrivateKeyRSA(m)
priv, e := readPrivateKeyRSA(m)
if e != nil {
return nil, e
}
if !k.setPublicKeyInPrivate(p) {
return nil, ErrPrivKey
pub := k.publicKeyRSA()
if pub == nil {
return nil, ErrKey
}
return p, e
priv.PublicKey = *pub
return (*RSAPrivateKey)(priv), e
case "12 (ECC-GOST)":
p, e := readPrivateKeyGOST(m)
if e != nil {
return nil, e
}
// setPublicKeyInPrivate(p)
return p, e
return nil, ErrPrivKey
case "13 (ECDSAP256SHA256)":
fallthrough
case "14 (ECDSAP384SHA384)":
p, e := readPrivateKeyECDSA(m)
priv, e := readPrivateKeyECDSA(m)
if e != nil {
return nil, e
}
if !k.setPublicKeyInPrivate(p) {
pub := k.publicKeyECDSA()
if pub == nil {
return nil, ErrKey
}
priv.PublicKey = *pub
return (*ECDSAPrivateKey)(priv), e
default:
return nil, ErrPrivKey
}
return p, e
}
return nil, ErrPrivKey
}
// Read a private key (file) string and create a public key. Return the private key.
func readPrivateKeyRSA(m map[string]string) (PrivateKey, error) {
func readPrivateKeyRSA(m map[string]string) (*rsa.PrivateKey, error) {
p := new(rsa.PrivateKey)
p.Primes = []*big.Int{nil, nil}
for k, v := range m {
switch k {
case "modulus", "publicexponent", "privateexponent", "prime1", "prime2":
v1, err := packBase64([]byte(v))
v1, err := fromBase64([]byte(v))
if err != nil {
return nil, err
}
@ -119,13 +121,13 @@ func readPrivateKeyRSA(m map[string]string) (PrivateKey, error) {
return p, nil
}
func readPrivateKeyDSA(m map[string]string) (PrivateKey, error) {
func readPrivateKeyDSA(m map[string]string) (*dsa.PrivateKey, error) {
p := new(dsa.PrivateKey)
p.X = big.NewInt(0)
for k, v := range m {
switch k {
case "private_value(x)":
v1, err := packBase64([]byte(v))
v1, err := fromBase64([]byte(v))
if err != nil {
return nil, err
}
@ -137,14 +139,14 @@ func readPrivateKeyDSA(m map[string]string) (PrivateKey, error) {
return p, nil
}
func readPrivateKeyECDSA(m map[string]string) (PrivateKey, error) {
func readPrivateKeyECDSA(m map[string]string) (*ecdsa.PrivateKey, error) {
p := new(ecdsa.PrivateKey)
p.D = big.NewInt(0)
// TODO: validate that the required flags are present
for k, v := range m {
switch k {
case "privatekey":
v1, err := packBase64([]byte(v))
v1, err := fromBase64([]byte(v))
if err != nil {
return nil, err
}
@ -156,11 +158,6 @@ func readPrivateKeyECDSA(m map[string]string) (PrivateKey, error) {
return p, nil
}
func readPrivateKeyGOST(m map[string]string) (PrivateKey, error) {
// TODO(miek)
return nil, nil
}
// parseKey reads a private key from r. It returns a map[string]string,
// with the key-value pairs, or an error when the file is not correct.
func parseKey(r io.Reader, file string) (map[string]string, error) {

View File

@ -0,0 +1,143 @@
package dns
import (
"crypto"
"crypto/dsa"
"crypto/ecdsa"
"crypto/rand"
"crypto/rsa"
"math/big"
"strconv"
)
const _FORMAT = "Private-key-format: v1.3\n"
type PrivateKey interface {
Sign([]byte, uint8) ([]byte, error)
String(uint8) string
}
// PrivateKeyString converts a PrivateKey to a string. This string has the same
// format as the private-key-file of BIND9 (Private-key-format: v1.3).
// It needs some info from the key (the algorithm), so its a method of the
// DNSKEY and calls PrivateKey.String(alg).
func (r *DNSKEY) PrivateKeyString(p PrivateKey) string {
return p.String(r.Algorithm)
}
type RSAPrivateKey rsa.PrivateKey
func (p *RSAPrivateKey) Sign(hashed []byte, alg uint8) ([]byte, error) {
var hash crypto.Hash
switch alg {
case RSASHA1, RSASHA1NSEC3SHA1:
hash = crypto.SHA1
case RSASHA256:
hash = crypto.SHA256
case RSASHA512:
hash = crypto.SHA512
default:
return nil, ErrAlg
}
return rsa.SignPKCS1v15(nil, (*rsa.PrivateKey)(p), hash, hashed)
}
func (p *RSAPrivateKey) String(alg uint8) string {
algorithm := strconv.Itoa(int(alg)) + " (" + AlgorithmToString[alg] + ")"
modulus := toBase64(p.PublicKey.N.Bytes())
e := big.NewInt(int64(p.PublicKey.E))
publicExponent := toBase64(e.Bytes())
privateExponent := toBase64(p.D.Bytes())
prime1 := toBase64(p.Primes[0].Bytes())
prime2 := toBase64(p.Primes[1].Bytes())
// Calculate Exponent1/2 and Coefficient as per: http://en.wikipedia.org/wiki/RSA#Using_the_Chinese_remainder_algorithm
// and from: http://code.google.com/p/go/issues/detail?id=987
one := big.NewInt(1)
p_1 := big.NewInt(0).Sub(p.Primes[0], one)
q_1 := big.NewInt(0).Sub(p.Primes[1], one)
exp1 := big.NewInt(0).Mod(p.D, p_1)
exp2 := big.NewInt(0).Mod(p.D, q_1)
coeff := big.NewInt(0).ModInverse(p.Primes[1], p.Primes[0])
exponent1 := toBase64(exp1.Bytes())
exponent2 := toBase64(exp2.Bytes())
coefficient := toBase64(coeff.Bytes())
return _FORMAT +
"Algorithm: " + algorithm + "\n" +
"Modulus: " + modulus + "\n" +
"PublicExponent: " + publicExponent + "\n" +
"PrivateExponent: " + privateExponent + "\n" +
"Prime1: " + prime1 + "\n" +
"Prime2: " + prime2 + "\n" +
"Exponent1: " + exponent1 + "\n" +
"Exponent2: " + exponent2 + "\n" +
"Coefficient: " + coefficient + "\n"
}
type ECDSAPrivateKey ecdsa.PrivateKey
func (p *ECDSAPrivateKey) Sign(hashed []byte, alg uint8) ([]byte, error) {
var intlen int
switch alg {
case ECDSAP256SHA256:
intlen = 32
case ECDSAP384SHA384:
intlen = 48
default:
return nil, ErrAlg
}
r1, s1, err := ecdsa.Sign(rand.Reader, (*ecdsa.PrivateKey)(p), hashed)
if err != nil {
return nil, err
}
signature := intToBytes(r1, intlen)
signature = append(signature, intToBytes(s1, intlen)...)
return signature, nil
}
func (p *ECDSAPrivateKey) String(alg uint8) string {
algorithm := strconv.Itoa(int(alg)) + " (" + AlgorithmToString[alg] + ")"
var intlen int
switch alg {
case ECDSAP256SHA256:
intlen = 32
case ECDSAP384SHA384:
intlen = 48
}
private := toBase64(intToBytes(p.D, intlen))
return _FORMAT +
"Algorithm: " + algorithm + "\n" +
"PrivateKey: " + private + "\n"
}
type DSAPrivateKey dsa.PrivateKey
func (p *DSAPrivateKey) Sign(hashed []byte, alg uint8) ([]byte, error) {
r1, s1, err := dsa.Sign(rand.Reader, (*dsa.PrivateKey)(p), hashed)
if err != nil {
return nil, err
}
t := divRoundUp(divRoundUp(p.PublicKey.Y.BitLen(), 8)-64, 8)
signature := []byte{byte(t)}
signature = append(signature, intToBytes(r1, 20)...)
signature = append(signature, intToBytes(s1, 20)...)
return signature, nil
}
func (p *DSAPrivateKey) String(alg uint8) string {
algorithm := strconv.Itoa(int(alg)) + " (" + AlgorithmToString[alg] + ")"
T := divRoundUp(divRoundUp(p.PublicKey.Parameters.G.BitLen(), 8)-64, 8)
prime := toBase64(intToBytes(p.PublicKey.Parameters.P, 64+T*8))
subprime := toBase64(intToBytes(p.PublicKey.Parameters.Q, 20))
base := toBase64(intToBytes(p.PublicKey.Parameters.G, 64+T*8))
priv := toBase64(intToBytes(p.X, 20))
pub := toBase64(intToBytes(p.PublicKey.Y, 64+T*8))
return _FORMAT +
"Algorithm: " + algorithm + "\n" +
"Prime(p): " + prime + "\n" +
"Subprime(q): " + subprime + "\n" +
"Base(g): " + base + "\n" +
"Private_value(x): " + priv + "\n" +
"Public_value(y): " + pub + "\n"
}

View File

@ -1,7 +1,6 @@
package dns
import (
"crypto/rsa"
"reflect"
"strings"
"testing"
@ -34,6 +33,9 @@ func getSoa() *SOA {
}
func TestGenerateEC(t *testing.T) {
if testing.Short() {
t.Skip("skipping test in short mode.")
}
key := new(DNSKEY)
key.Hdr.Rrtype = TypeDNSKEY
key.Hdr.Name = "miek.nl."
@ -48,6 +50,9 @@ func TestGenerateEC(t *testing.T) {
}
func TestGenerateDSA(t *testing.T) {
if testing.Short() {
t.Skip("skipping test in short mode.")
}
key := new(DNSKEY)
key.Hdr.Rrtype = TypeDNSKEY
key.Hdr.Name = "miek.nl."
@ -62,6 +67,9 @@ func TestGenerateDSA(t *testing.T) {
}
func TestGenerateRSA(t *testing.T) {
if testing.Short() {
t.Skip("skipping test in short mode.")
}
key := new(DNSKEY)
key.Hdr.Rrtype = TypeDNSKEY
key.Hdr.Name = "miek.nl."
@ -240,12 +248,13 @@ func Test65534(t *testing.T) {
}
func TestDnskey(t *testing.T) {
// f, _ := os.Open("t/Kmiek.nl.+010+05240.key")
pubkey, _ := ReadRR(strings.NewReader(`
pubkey, err := ReadRR(strings.NewReader(`
miek.nl. IN DNSKEY 256 3 10 AwEAAZuMCu2FdugHkTrXYgl5qixvcDw1aDDlvL46/xJKbHBAHY16fNUb2b65cwko2Js/aJxUYJbZk5dwCDZxYfrfbZVtDPQuc3o8QaChVxC7/JYz2AHc9qHvqQ1j4VrH71RWINlQo6VYjzN/BGpMhOZoZOEwzp1HfsOE3lNYcoWU1smL ;{id = 5240 (zsk), size = 1024b}
`), "Kmiek.nl.+010+05240.key")
privkey, _ := pubkey.(*DNSKEY).ReadPrivateKey(strings.NewReader(`
Private-key-format: v1.2
if err != nil {
t.Fatal(err)
}
privStr := `Private-key-format: v1.3
Algorithm: 10 (RSASHA512)
Modulus: m4wK7YV26AeROtdiCXmqLG9wPDVoMOW8vjr/EkpscEAdjXp81RvZvrlzCSjYmz9onFRgltmTl3AINnFh+t9tlW0M9C5zejxBoKFXELv8ljPYAdz2oe+pDWPhWsfvVFYg2VCjpViPM38EakyE5mhk4TDOnUd+w4TeU1hyhZTWyYs=
PublicExponent: AQAB
@ -255,13 +264,21 @@ Prime2: xA1bF8M0RTIQ6+A11AoVG6GIR/aPGg5sogRkIZ7ID/sF6g9HMVU/CM2TqVEBJLRPp73cv6Ze
Exponent1: xzkblyZ96bGYxTVZm2/vHMOXswod4KWIyMoOepK6B/ZPcZoIT6omLCgtypWtwHLfqyCz3MK51Nc0G2EGzg8rFQ==
Exponent2: Pu5+mCEb7T5F+kFNZhQadHUklt0JUHbi3hsEvVoHpEGSw3BGDQrtIflDde0/rbWHgDPM4WQY+hscd8UuTXrvLw==
Coefficient: UuRoNqe7YHnKmQzE6iDWKTMIWTuoqqrFAmXPmKQnC+Y+BQzOVEHUo9bXdDnoI9hzXP1gf8zENMYwYLeWpuYlFQ==
`), "Kmiek.nl.+010+05240.private")
`
privkey, err := pubkey.(*DNSKEY).ReadPrivateKey(strings.NewReader(privStr),
"Kmiek.nl.+010+05240.private")
if err != nil {
t.Fatal(err)
}
if pubkey.(*DNSKEY).PublicKey != "AwEAAZuMCu2FdugHkTrXYgl5qixvcDw1aDDlvL46/xJKbHBAHY16fNUb2b65cwko2Js/aJxUYJbZk5dwCDZxYfrfbZVtDPQuc3o8QaChVxC7/JYz2AHc9qHvqQ1j4VrH71RWINlQo6VYjzN/BGpMhOZoZOEwzp1HfsOE3lNYcoWU1smL" {
t.Log("pubkey is not what we've read")
t.Fail()
}
// Coefficient looks fishy...
t.Logf("%s", pubkey.(*DNSKEY).PrivateKeyString(privkey))
if pubkey.(*DNSKEY).PrivateKeyString(privkey) != privStr {
t.Log("privkey is not what we've read")
t.Logf("%v", pubkey.(*DNSKEY).PrivateKeyString(privkey))
t.Fail()
}
}
func TestTag(t *testing.T) {
@ -283,6 +300,9 @@ func TestTag(t *testing.T) {
}
func TestKeyRSA(t *testing.T) {
if testing.Short() {
t.Skip("skipping test in short mode.")
}
key := new(DNSKEY)
key.Hdr.Name = "miek.nl."
key.Hdr.Rrtype = TypeDNSKEY
@ -368,7 +388,7 @@ Activate: 20110302104537`
t.Fail()
}
switch priv := p.(type) {
case *rsa.PrivateKey:
case *RSAPrivateKey:
if 65537 != priv.PublicKey.E {
t.Log("exponenent should be 65537")
t.Fail()
@ -443,12 +463,19 @@ PrivateKey: WURgWHCcYIYUPWgeLmiPY2DJJk02vgrmTfitxgqcL4vwW7BOrbawVmVe0d9V94SR`
sig.SignerName = eckey.(*DNSKEY).Hdr.Name
sig.Algorithm = eckey.(*DNSKEY).Algorithm
sig.Sign(privkey, []RR{a})
if sig.Sign(privkey, []RR{a}) != nil {
t.Fatal("failure to sign the record")
}
t.Logf("%s", sig.String())
if e := sig.Verify(eckey.(*DNSKEY), []RR{a}); e != nil {
t.Logf("failure to validate: %s", e.Error())
t.Fail()
t.Logf("\n%s\n%s\n%s\n\n%s\n\n",
eckey.(*DNSKEY).String(),
a.String(),
sig.String(),
eckey.(*DNSKEY).PrivateKeyString(privkey),
)
t.Fatalf("failure to validate: %s", e.Error())
}
}
@ -491,6 +518,13 @@ func TestSignVerifyECDSA2(t *testing.T) {
err = sig.Verify(key, []RR{srv})
if err != nil {
t.Logf("\n%s\n%s\n%s\n\n%s\n\n",
key.String(),
srv.String(),
sig.String(),
key.PrivateKeyString(privkey),
)
t.Fatal("Failure to validate:", err)
}
}

View File

@ -287,7 +287,7 @@ func (e *EDNS0_SUBNET) unpack(b []byte) error {
case 1:
addr := make([]byte, 4)
for i := 0; i < int(e.SourceNetmask/8); i++ {
if 4+i > len(b) {
if i >= len(addr) || 4+i >= len(b) {
return ErrBuf
}
addr[i] = b[4+i]
@ -296,7 +296,7 @@ func (e *EDNS0_SUBNET) unpack(b []byte) error {
case 2:
addr := make([]byte, 16)
for i := 0; i < int(e.SourceNetmask/8); i++ {
if 4+i > len(b) {
if i >= len(addr) || 4+i >= len(b) {
return ErrBuf
}
addr[i] = b[4+i]

View File

@ -49,7 +49,7 @@ func ExampleDS(zone string) {
}
for _, k := range r.Answer {
if key, ok := k.(*dns.DNSKEY); ok {
for _, alg := range []int{dns.SHA1, dns.SHA256, dns.SHA384} {
for _, alg := range []uint8{dns.SHA1, dns.SHA256, dns.SHA384} {
fmt.Printf("%s; %d\n", key.ToDS(alg).String(), key.Flags)
}
}

96
Godeps/_workspace/src/github.com/miekg/dns/format.go generated vendored Normal file
View File

@ -0,0 +1,96 @@
package dns
import (
"net"
"reflect"
"strconv"
)
// NumField returns the number of rdata fields r has.
func NumField(r RR) int {
return reflect.ValueOf(r).Elem().NumField() - 1 // Remove RR_Header
}
// Field returns the rdata field i as a string. Fields are indexed starting from 1.
// RR types that holds slice data, for instance the NSEC type bitmap will return a single
// string where the types are concatenated using a space.
// Accessing non existing fields will cause a panic.
func Field(r RR, i int) string {
if i == 0 {
return ""
}
d := reflect.ValueOf(r).Elem().Field(i)
switch k := d.Kind(); k {
case reflect.String:
return d.String()
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return strconv.FormatInt(d.Int(), 10)
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
return strconv.FormatUint(d.Uint(), 10)
case reflect.Slice:
switch reflect.ValueOf(r).Elem().Type().Field(i).Tag {
case `dns:"a"`:
// TODO(miek): Hmm store this as 16 bytes
if d.Len() < net.IPv6len {
return net.IPv4(byte(d.Index(0).Uint()),
byte(d.Index(1).Uint()),
byte(d.Index(2).Uint()),
byte(d.Index(3).Uint())).String()
}
return net.IPv4(byte(d.Index(12).Uint()),
byte(d.Index(13).Uint()),
byte(d.Index(14).Uint()),
byte(d.Index(15).Uint())).String()
case `dns:"aaaa"`:
return net.IP{
byte(d.Index(0).Uint()),
byte(d.Index(1).Uint()),
byte(d.Index(2).Uint()),
byte(d.Index(3).Uint()),
byte(d.Index(4).Uint()),
byte(d.Index(5).Uint()),
byte(d.Index(6).Uint()),
byte(d.Index(7).Uint()),
byte(d.Index(8).Uint()),
byte(d.Index(9).Uint()),
byte(d.Index(10).Uint()),
byte(d.Index(11).Uint()),
byte(d.Index(12).Uint()),
byte(d.Index(13).Uint()),
byte(d.Index(14).Uint()),
byte(d.Index(15).Uint()),
}.String()
case `dns:"nsec"`:
if d.Len() == 0 {
return ""
}
s := Type(d.Index(0).Uint()).String()
for i := 1; i < d.Len(); i++ {
s += " " + Type(d.Index(i).Uint()).String()
}
return s
case `dns:"wks"`:
if d.Len() == 0 {
return ""
}
s := strconv.Itoa(int(d.Index(0).Uint()))
for i := 0; i < d.Len(); i++ {
s += " " + strconv.Itoa(int(d.Index(i).Uint()))
}
return s
default:
// if it does not have a tag its a string slice
fallthrough
case `dns:"txt"`:
if d.Len() == 0 {
return ""
}
s := d.Index(0).String()
for i := 1; i < d.Len(); i++ {
s += " " + d.Index(i).String()
}
return s
}
}
return ""
}

View File

@ -1,149 +0,0 @@
package dns
import (
"crypto/dsa"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/rsa"
"math/big"
"strconv"
)
const _FORMAT = "Private-key-format: v1.3\n"
// Empty interface that is used as a wrapper around all possible
// private key implementations from the crypto package.
type PrivateKey interface{}
// Generate generates a DNSKEY of the given bit size.
// The public part is put inside the DNSKEY record.
// The Algorithm in the key must be set as this will define
// what kind of DNSKEY will be generated.
// The ECDSA algorithms imply a fixed keysize, in that case
// bits should be set to the size of the algorithm.
func (r *DNSKEY) Generate(bits int) (PrivateKey, error) {
switch r.Algorithm {
case DSA, DSANSEC3SHA1:
if bits != 1024 {
return nil, ErrKeySize
}
case RSAMD5, RSASHA1, RSASHA256, RSASHA1NSEC3SHA1:
if bits < 512 || bits > 4096 {
return nil, ErrKeySize
}
case RSASHA512:
if bits < 1024 || bits > 4096 {
return nil, ErrKeySize
}
case ECDSAP256SHA256:
if bits != 256 {
return nil, ErrKeySize
}
case ECDSAP384SHA384:
if bits != 384 {
return nil, ErrKeySize
}
}
switch r.Algorithm {
case DSA, DSANSEC3SHA1:
params := new(dsa.Parameters)
if err := dsa.GenerateParameters(params, rand.Reader, dsa.L1024N160); err != nil {
return nil, err
}
priv := new(dsa.PrivateKey)
priv.PublicKey.Parameters = *params
err := dsa.GenerateKey(priv, rand.Reader)
if err != nil {
return nil, err
}
r.setPublicKeyDSA(params.Q, params.P, params.G, priv.PublicKey.Y)
return priv, nil
case RSAMD5, RSASHA1, RSASHA256, RSASHA512, RSASHA1NSEC3SHA1:
priv, err := rsa.GenerateKey(rand.Reader, bits)
if err != nil {
return nil, err
}
r.setPublicKeyRSA(priv.PublicKey.E, priv.PublicKey.N)
return priv, nil
case ECDSAP256SHA256, ECDSAP384SHA384:
var c elliptic.Curve
switch r.Algorithm {
case ECDSAP256SHA256:
c = elliptic.P256()
case ECDSAP384SHA384:
c = elliptic.P384()
}
priv, err := ecdsa.GenerateKey(c, rand.Reader)
if err != nil {
return nil, err
}
r.setPublicKeyCurve(priv.PublicKey.X, priv.PublicKey.Y)
return priv, nil
default:
return nil, ErrAlg
}
return nil, nil // Dummy return
}
// PrivateKeyString converts a PrivateKey to a string. This
// string has the same format as the private-key-file of BIND9 (Private-key-format: v1.3).
// It needs some info from the key (hashing, keytag), so its a method of the DNSKEY.
func (r *DNSKEY) PrivateKeyString(p PrivateKey) (s string) {
switch t := p.(type) {
case *rsa.PrivateKey:
algorithm := strconv.Itoa(int(r.Algorithm)) + " (" + AlgorithmToString[r.Algorithm] + ")"
modulus := unpackBase64(t.PublicKey.N.Bytes())
e := big.NewInt(int64(t.PublicKey.E))
publicExponent := unpackBase64(e.Bytes())
privateExponent := unpackBase64(t.D.Bytes())
prime1 := unpackBase64(t.Primes[0].Bytes())
prime2 := unpackBase64(t.Primes[1].Bytes())
// Calculate Exponent1/2 and Coefficient as per: http://en.wikipedia.org/wiki/RSA#Using_the_Chinese_remainder_algorithm
// and from: http://code.google.com/p/go/issues/detail?id=987
one := big.NewInt(1)
minusone := big.NewInt(-1)
p_1 := big.NewInt(0).Sub(t.Primes[0], one)
q_1 := big.NewInt(0).Sub(t.Primes[1], one)
exp1 := big.NewInt(0).Mod(t.D, p_1)
exp2 := big.NewInt(0).Mod(t.D, q_1)
coeff := big.NewInt(0).Exp(t.Primes[1], minusone, t.Primes[0])
exponent1 := unpackBase64(exp1.Bytes())
exponent2 := unpackBase64(exp2.Bytes())
coefficient := unpackBase64(coeff.Bytes())
s = _FORMAT +
"Algorithm: " + algorithm + "\n" +
"Modules: " + modulus + "\n" +
"PublicExponent: " + publicExponent + "\n" +
"PrivateExponent: " + privateExponent + "\n" +
"Prime1: " + prime1 + "\n" +
"Prime2: " + prime2 + "\n" +
"Exponent1: " + exponent1 + "\n" +
"Exponent2: " + exponent2 + "\n" +
"Coefficient: " + coefficient + "\n"
case *ecdsa.PrivateKey:
algorithm := strconv.Itoa(int(r.Algorithm)) + " (" + AlgorithmToString[r.Algorithm] + ")"
private := unpackBase64(t.D.Bytes())
s = _FORMAT +
"Algorithm: " + algorithm + "\n" +
"PrivateKey: " + private + "\n"
case *dsa.PrivateKey:
algorithm := strconv.Itoa(int(r.Algorithm)) + " (" + AlgorithmToString[r.Algorithm] + ")"
prime := unpackBase64(t.PublicKey.Parameters.P.Bytes())
subprime := unpackBase64(t.PublicKey.Parameters.Q.Bytes())
base := unpackBase64(t.PublicKey.Parameters.G.Bytes())
priv := unpackBase64(t.X.Bytes())
pub := unpackBase64(t.PublicKey.Y.Bytes())
s = _FORMAT +
"Algorithm: " + algorithm + "\n" +
"Prime(p): " + prime + "\n" +
"Subprime(q): " + subprime + "\n" +
"Base(g): " + base + "\n" +
"Private_value(x): " + priv + "\n" +
"Public_value(y): " + pub + "\n"
}
return
}

View File

@ -12,7 +12,7 @@ import (
"encoding/base32"
"encoding/base64"
"encoding/hex"
"fmt"
"math/big"
"math/rand"
"net"
"reflect"
@ -91,6 +91,7 @@ var TypeToString = map[uint16]string{
TypeATMA: "ATMA",
TypeAXFR: "AXFR", // Meta RR
TypeCAA: "CAA",
TypeCDNSKEY: "CDNSKEY",
TypeCDS: "CDS",
TypeCERT: "CERT",
TypeCNAME: "CNAME",
@ -109,6 +110,7 @@ var TypeToString = map[uint16]string{
TypeIPSECKEY: "IPSECKEY",
TypeISDN: "ISDN",
TypeIXFR: "IXFR", // Meta RR
TypeKEY: "KEY",
TypeKX: "KX",
TypeL32: "L32",
TypeL64: "L64",
@ -139,6 +141,7 @@ var TypeToString = map[uint16]string{
TypeRP: "RP",
TypeRRSIG: "RRSIG",
TypeRT: "RT",
TypeSIG: "SIG",
TypeSOA: "SOA",
TypeSPF: "SPF",
TypeSRV: "SRV",
@ -419,7 +422,15 @@ Loop:
s = append(s, '\\', 'r')
default:
if b < 32 || b >= 127 { // unprintable use \DDD
s = append(s, fmt.Sprintf("\\%03d", b)...)
var buf [3]byte
bufs := strconv.AppendInt(buf[:0], int64(b), 10)
s = append(s, '\\')
for i := 0; i < 3-len(bufs); i++ {
s = append(s, '0')
}
for _, r := range bufs {
s = append(s, r)
}
} else {
s = append(s, b)
}
@ -554,7 +565,15 @@ func unpackTxtString(msg []byte, offset int) (string, int, error) {
s = append(s, `\n`...)
default:
if b < 32 || b > 127 { // unprintable
s = append(s, fmt.Sprintf("\\%03d", b)...)
var buf [3]byte
bufs := strconv.AppendInt(buf[:0], int64(b), 10)
s = append(s, '\\')
for i := 0; i < 3-len(bufs); i++ {
s = append(s, '0')
}
for _, r := range bufs {
s = append(s, r)
}
} else {
s = append(s, b)
}
@ -633,6 +652,12 @@ func packStructValue(val reflect.Value, msg []byte, off int, compression map[str
off += len(b)
}
case `dns:"a"`:
if val.Type().String() == "dns.IPSECKEY" {
// Field(2) is GatewayType, must be 1
if val.Field(2).Uint() != 1 {
continue
}
}
// It must be a slice of 4, even if it is 16, we encode
// only the first 4
if off+net.IPv4len > lenmsg {
@ -657,6 +682,12 @@ func packStructValue(val reflect.Value, msg []byte, off int, compression map[str
return lenmsg, &Error{err: "overflow packing a"}
}
case `dns:"aaaa"`:
if val.Type().String() == "dns.IPSECKEY" {
// Field(2) is GatewayType, must be 2
if val.Field(2).Uint() != 2 {
continue
}
}
if fv.Len() == 0 {
break
}
@ -668,6 +699,7 @@ func packStructValue(val reflect.Value, msg []byte, off int, compression map[str
off++
}
case `dns:"wks"`:
// TODO(miek): this is wrong should be lenrd
if off == lenmsg {
break // dyn. updates
}
@ -794,13 +826,20 @@ func packStructValue(val reflect.Value, msg []byte, off int, compression map[str
default:
return lenmsg, &Error{"bad tag packing string: " + typefield.Tag.Get("dns")}
case `dns:"base64"`:
b64, e := packBase64([]byte(s))
b64, e := fromBase64([]byte(s))
if e != nil {
return lenmsg, e
}
copy(msg[off:off+len(b64)], b64)
off += len(b64)
case `dns:"domain-name"`:
if val.Type().String() == "dns.IPSECKEY" {
// Field(2) is GatewayType, 1 and 2 or used for addresses
x := val.Field(2).Uint()
if x == 1 || x == 2 {
continue
}
}
if off, err = PackDomainName(s, msg, off, compression, false && compress); err != nil {
return lenmsg, err
}
@ -815,7 +854,7 @@ func packStructValue(val reflect.Value, msg []byte, off int, compression map[str
msg[off-1] = 20
fallthrough
case `dns:"base32"`:
b32, e := packBase32([]byte(s))
b32, e := fromBase32([]byte(s))
if e != nil {
return lenmsg, e
}
@ -875,9 +914,12 @@ func packStructCompress(any interface{}, msg []byte, off int, compression map[st
// Unpack a reflect.StructValue from msg.
// Same restrictions as packStructValue.
func unpackStructValue(val reflect.Value, msg []byte, off int) (off1 int, err error) {
var rdend int
var lenrd int
lenmsg := len(msg)
for i := 0; i < val.NumField(); i++ {
if lenrd != 0 && lenrd == off {
break
}
if off > lenmsg {
return lenmsg, &Error{"bad offset unpacking"}
}
@ -889,7 +931,7 @@ func unpackStructValue(val reflect.Value, msg []byte, off int) (off1 int, err er
// therefore it's expected that this interface would be PrivateRdata
switch data := fv.Interface().(type) {
case PrivateRdata:
n, err := data.Unpack(msg[off:rdend])
n, err := data.Unpack(msg[off:lenrd])
if err != nil {
return lenmsg, err
}
@ -905,7 +947,7 @@ func unpackStructValue(val reflect.Value, msg []byte, off int) (off1 int, err er
// HIP record slice of name (or none)
servers := make([]string, 0)
var s string
for off < rdend {
for off < lenrd {
s, off, err = UnpackDomainName(msg, off)
if err != nil {
return lenmsg, err
@ -914,17 +956,17 @@ func unpackStructValue(val reflect.Value, msg []byte, off int) (off1 int, err er
}
fv.Set(reflect.ValueOf(servers))
case `dns:"txt"`:
if off == lenmsg || rdend == off {
if off == lenmsg || lenrd == off {
break
}
var txt []string
txt, off, err = unpackTxt(msg, off, rdend)
txt, off, err = unpackTxt(msg, off, lenrd)
if err != nil {
return lenmsg, err
}
fv.Set(reflect.ValueOf(txt))
case `dns:"opt"`: // edns0
if off == rdend {
if off == lenrd {
// This is an EDNS0 (OPT Record) with no rdata
// We can safely return here.
break
@ -937,7 +979,7 @@ func unpackStructValue(val reflect.Value, msg []byte, off int) (off1 int, err er
}
code, off = unpackUint16(msg, off)
optlen, off1 := unpackUint16(msg, off)
if off1+int(optlen) > rdend {
if off1+int(optlen) > lenrd {
return lenmsg, &Error{err: "overflow unpacking opt"}
}
switch code {
@ -997,24 +1039,36 @@ func unpackStructValue(val reflect.Value, msg []byte, off int) (off1 int, err er
// do nothing?
off = off1 + int(optlen)
}
if off < rdend {
if off < lenrd {
goto Option
}
fv.Set(reflect.ValueOf(edns))
case `dns:"a"`:
if off == lenmsg {
if val.Type().String() == "dns.IPSECKEY" {
// Field(2) is GatewayType, must be 1
if val.Field(2).Uint() != 1 {
continue
}
}
if off == lenrd {
break // dyn. update
}
if off+net.IPv4len > rdend {
if off+net.IPv4len > lenrd || off+net.IPv4len > lenmsg {
return lenmsg, &Error{err: "overflow unpacking a"}
}
fv.Set(reflect.ValueOf(net.IPv4(msg[off], msg[off+1], msg[off+2], msg[off+3])))
off += net.IPv4len
case `dns:"aaaa"`:
if off == rdend {
if val.Type().String() == "dns.IPSECKEY" {
// Field(2) is GatewayType, must be 2
if val.Field(2).Uint() != 2 {
continue
}
}
if off == lenrd {
break
}
if off+net.IPv6len > rdend || off+net.IPv6len > lenmsg {
if off+net.IPv6len > lenrd || off+net.IPv6len > lenmsg {
return lenmsg, &Error{err: "overflow unpacking aaaa"}
}
fv.Set(reflect.ValueOf(net.IP{msg[off], msg[off+1], msg[off+2], msg[off+3], msg[off+4],
@ -1025,7 +1079,7 @@ func unpackStructValue(val reflect.Value, msg []byte, off int) (off1 int, err er
// Rest of the record is the bitmap
serv := make([]uint16, 0)
j := 0
for off < rdend {
for off < lenrd {
if off+1 > lenmsg {
return lenmsg, &Error{err: "overflow unpacking wks"}
}
@ -1060,17 +1114,17 @@ func unpackStructValue(val reflect.Value, msg []byte, off int) (off1 int, err er
}
fv.Set(reflect.ValueOf(serv))
case `dns:"nsec"`: // NSEC/NSEC3
if off == rdend {
if off == lenrd {
break
}
// Rest of the record is the type bitmap
if off+2 > rdend || off+2 > lenmsg {
if off+2 > lenrd || off+2 > lenmsg {
return lenmsg, &Error{err: "overflow unpacking nsecx"}
}
nsec := make([]uint16, 0)
length := 0
window := 0
for off+2 < rdend {
for off+2 < lenrd {
window = int(msg[off])
length = int(msg[off+1])
//println("off, windows, length, end", off, window, length, endrr)
@ -1127,7 +1181,7 @@ func unpackStructValue(val reflect.Value, msg []byte, off int) (off1 int, err er
return lenmsg, err
}
if val.Type().Field(i).Name == "Hdr" {
rdend = off + int(val.FieldByName("Hdr").FieldByName("Rdlength").Uint())
lenrd = off + int(val.FieldByName("Hdr").FieldByName("Rdlength").Uint())
}
case reflect.Uint8:
if off == lenmsg {
@ -1184,29 +1238,36 @@ func unpackStructValue(val reflect.Value, msg []byte, off int) (off1 int, err er
default:
return lenmsg, &Error{"bad tag unpacking string: " + val.Type().Field(i).Tag.Get("dns")}
case `dns:"hex"`:
hexend := rdend
hexend := lenrd
if val.FieldByName("Hdr").FieldByName("Rrtype").Uint() == uint64(TypeHIP) {
hexend = off + int(val.FieldByName("HitLength").Uint())
}
if hexend > rdend || hexend > lenmsg {
if hexend > lenrd || hexend > lenmsg {
return lenmsg, &Error{err: "overflow unpacking hex"}
}
s = hex.EncodeToString(msg[off:hexend])
off = hexend
case `dns:"base64"`:
// Rest of the RR is base64 encoded value
b64end := rdend
b64end := lenrd
if val.FieldByName("Hdr").FieldByName("Rrtype").Uint() == uint64(TypeHIP) {
b64end = off + int(val.FieldByName("PublicKeyLength").Uint())
}
if b64end > rdend || b64end > lenmsg {
if b64end > lenrd || b64end > lenmsg {
return lenmsg, &Error{err: "overflow unpacking base64"}
}
s = unpackBase64(msg[off:b64end])
s = toBase64(msg[off:b64end])
off = b64end
case `dns:"cdomain-name"`:
fallthrough
case `dns:"domain-name"`:
if val.Type().String() == "dns.IPSECKEY" {
// Field(2) is GatewayType, 1 and 2 or used for addresses
x := val.Field(2).Uint()
if x == 1 || x == 2 {
continue
}
}
if off == lenmsg {
// zero rdata foo, OK for dyn. updates
break
@ -1228,7 +1289,7 @@ func unpackStructValue(val reflect.Value, msg []byte, off int) (off1 int, err er
if off+size > lenmsg {
return lenmsg, &Error{err: "overflow unpacking base32"}
}
s = unpackBase32(msg[off : off+size])
s = toBase32(msg[off : off+size])
off += size
case `dns:"size-hex"`:
// a "size" string, but it must be encoded in hex in the string
@ -1276,58 +1337,53 @@ func dddToByte(s []byte) byte {
return byte((s[0]-'0')*100 + (s[1]-'0')*10 + (s[2] - '0'))
}
// Helper function for unpacking
func unpackUint16(msg []byte, off int) (v uint16, off1 int) {
v = uint16(msg[off])<<8 | uint16(msg[off+1])
off1 = off + 2
return
}
// UnpackStruct unpacks a binary message from offset off to the interface
// value given.
func UnpackStruct(any interface{}, msg []byte, off int) (off1 int, err error) {
off, err = unpackStructValue(structValue(any), msg, off)
return off, err
func UnpackStruct(any interface{}, msg []byte, off int) (int, error) {
return unpackStructValue(structValue(any), msg, off)
}
func unpackBase32(b []byte) string {
b32 := make([]byte, base32.HexEncoding.EncodedLen(len(b)))
base32.HexEncoding.Encode(b32, b)
return string(b32)
// Helper function for packing and unpacking
func intToBytes(i *big.Int, length int) []byte {
buf := i.Bytes()
if len(buf) < length {
b := make([]byte, length)
copy(b[length-len(buf):], buf)
return b
}
return buf
}
func unpackBase64(b []byte) string {
b64 := make([]byte, base64.StdEncoding.EncodedLen(len(b)))
base64.StdEncoding.Encode(b64, b)
return string(b64)
func unpackUint16(msg []byte, off int) (uint16, int) {
return uint16(msg[off])<<8 | uint16(msg[off+1]), off + 2
}
// Helper function for packing
func packUint16(i uint16) (byte, byte) {
return byte(i >> 8), byte(i)
}
func packBase64(s []byte) ([]byte, error) {
b64len := base64.StdEncoding.DecodedLen(len(s))
buf := make([]byte, b64len)
n, err := base64.StdEncoding.Decode(buf, []byte(s))
if err != nil {
return nil, err
}
buf = buf[:n]
return buf, nil
func toBase32(b []byte) string {
return base32.HexEncoding.EncodeToString(b)
}
// Helper function for packing, mostly used in dnssec.go
func packBase32(s []byte) ([]byte, error) {
b32len := base32.HexEncoding.DecodedLen(len(s))
buf := make([]byte, b32len)
n, err := base32.HexEncoding.Decode(buf, []byte(s))
if err != nil {
return nil, err
}
func fromBase32(s []byte) (buf []byte, err error) {
buflen := base32.HexEncoding.DecodedLen(len(s))
buf = make([]byte, buflen)
n, err := base32.HexEncoding.Decode(buf, s)
buf = buf[:n]
return buf, nil
return
}
func toBase64(b []byte) string {
return base64.StdEncoding.EncodeToString(b)
}
func fromBase64(s []byte) (buf []byte, err error) {
buflen := base64.StdEncoding.DecodedLen(len(s))
buf = make([]byte, buflen)
n, err := base64.StdEncoding.Decode(buf, s)
buf = buf[:n]
return
}
// PackRR packs a resource record rr into msg[off:].
@ -1856,25 +1912,34 @@ func (dns *Msg) Copy() *Msg {
copy(r1.Question, dns.Question) // TODO(miek): Question is an immutable value, ok to do a shallow-copy
}
rrArr := make([]RR, len(dns.Answer)+len(dns.Ns)+len(dns.Extra))
var rri int
if len(dns.Answer) > 0 {
r1.Answer = make([]RR, len(dns.Answer))
rrbegin := rri
for i := 0; i < len(dns.Answer); i++ {
r1.Answer[i] = dns.Answer[i].copy()
rrArr[rri] = dns.Answer[i].copy()
rri++
}
r1.Answer = rrArr[rrbegin:rri:rri]
}
if len(dns.Ns) > 0 {
r1.Ns = make([]RR, len(dns.Ns))
rrbegin := rri
for i := 0; i < len(dns.Ns); i++ {
r1.Ns[i] = dns.Ns[i].copy()
rrArr[rri] = dns.Ns[i].copy()
rri++
}
r1.Ns = rrArr[rrbegin:rri:rri]
}
if len(dns.Extra) > 0 {
r1.Extra = make([]RR, len(dns.Extra))
rrbegin := rri
for i := 0; i < len(dns.Extra); i++ {
r1.Extra[i] = dns.Extra[i].copy()
rrArr[rri] = dns.Extra[i].copy()
rri++
}
r1.Extra = rrArr[rrbegin:rri:rri]
}
return r1

View File

@ -47,7 +47,7 @@ func HashName(label string, ha uint8, iter uint16, salt string) string {
io.WriteString(s, string(nsec3))
nsec3 = s.Sum(nil)
}
return unpackBase32(nsec3)
return toBase32(nsec3)
}
type Denialer interface {

View File

@ -386,8 +386,8 @@ func TestNSEC(t *testing.T) {
func TestParseLOC(t *testing.T) {
lt := map[string]string{
"SW1A2AA.find.me.uk. LOC 51 30 12.748 N 00 07 39.611 W 0.00m 0.00m 0.00m 0.00m": "SW1A2AA.find.me.uk.\t3600\tIN\tLOC\t51 30 12.748 N 00 07 39.611 W 0.00m 0.00m 0.00m 0.00m",
"SW1A2AA.find.me.uk. LOC 51 0 0.0 N 00 07 39.611 W 0.00m 0.00m 0.00m 0.00m": "SW1A2AA.find.me.uk.\t3600\tIN\tLOC\t51 00 0.000 N 00 07 39.611 W 0.00m 0.00m 0.00m 0.00m",
"SW1A2AA.find.me.uk. LOC 51 30 12.748 N 00 07 39.611 W 0.00m 0.00m 0.00m 0.00m": "SW1A2AA.find.me.uk.\t3600\tIN\tLOC\t51 30 12.748 N 00 07 39.611 W 0m 0.00m 0.00m 0.00m",
"SW1A2AA.find.me.uk. LOC 51 0 0.0 N 00 07 39.611 W 0.00m 0.00m 0.00m 0.00m": "SW1A2AA.find.me.uk.\t3600\tIN\tLOC\t51 00 0.000 N 00 07 39.611 W 0m 0.00m 0.00m 0.00m",
}
for i, o := range lt {
rr, e := NewRR(i)
@ -1225,35 +1225,21 @@ func TestMalformedPackets(t *testing.T) {
}
}
func TestDynamicUpdateParsing(t *testing.T) {
prefix := "example.com. IN "
for _, typ := range TypeToString {
if typ == "CAA" || typ == "OPT" || typ == "AXFR" || typ == "IXFR" || typ == "ANY" || typ == "TKEY" ||
typ == "TSIG" || typ == "ISDN" || typ == "UNSPEC" || typ == "NULL" || typ == "ATMA" {
continue
}
r, e := NewRR(prefix + typ)
if e != nil {
t.Log("failure to parse: " + prefix + typ)
t.Fail()
} else {
t.Logf("parsed: %s", r.String())
}
}
}
type algorithm struct {
name uint8
bits int
}
func TestNewPrivateKeyECDSA(t *testing.T) {
func TestNewPrivateKey(t *testing.T) {
if testing.Short() {
t.Skip("skipping test in short mode.")
}
algorithms := []algorithm{
algorithm{ECDSAP256SHA256, 256},
algorithm{ECDSAP384SHA384, 384},
algorithm{RSASHA1, 1024},
algorithm{RSASHA256, 2048},
// algorithm{DSA, 1024}, // TODO: STILL BROKEN!
algorithm{DSA, 1024},
}
for _, algo := range algorithms {
@ -1272,12 +1258,15 @@ func TestNewPrivateKeyECDSA(t *testing.T) {
newPrivKey, err := key.NewPrivateKey(key.PrivateKeyString(privkey))
if err != nil {
t.Log(key.String())
t.Log(key.PrivateKeyString(privkey))
t.Fatal(err.Error())
}
switch newPrivKey := newPrivKey.(type) {
case *rsa.PrivateKey:
newPrivKey.Precompute()
case *RSAPrivateKey:
(*rsa.PrivateKey)(newPrivKey).Precompute()
}
if !reflect.DeepEqual(privkey, newPrivKey) {
@ -1285,3 +1274,136 @@ func TestNewPrivateKeyECDSA(t *testing.T) {
}
}
}
// special input test
func TestNewRRSpecial(t *testing.T) {
var (
rr RR
err error
expect string
)
rr, err = NewRR("; comment")
expect = ""
if err != nil {
t.Errorf("unexpected err: %s", err)
}
if rr != nil {
t.Errorf("unexpected result: [%s] != [%s]", rr, expect)
}
rr, err = NewRR("")
expect = ""
if err != nil {
t.Errorf("unexpected err: %s", err)
}
if rr != nil {
t.Errorf("unexpected result: [%s] != [%s]", rr, expect)
}
rr, err = NewRR("$ORIGIN foo.")
expect = ""
if err != nil {
t.Errorf("unexpected err: %s", err)
}
if rr != nil {
t.Errorf("unexpected result: [%s] != [%s]", rr, expect)
}
rr, err = NewRR(" ")
expect = ""
if err != nil {
t.Errorf("unexpected err: %s", err)
}
if rr != nil {
t.Errorf("unexpected result: [%s] != [%s]", rr, expect)
}
rr, err = NewRR("\n")
expect = ""
if err != nil {
t.Errorf("unexpected err: %s", err)
}
if rr != nil {
t.Errorf("unexpected result: [%s] != [%s]", rr, expect)
}
rr, err = NewRR("foo. A 1.1.1.1\nbar. A 2.2.2.2")
expect = "foo.\t3600\tIN\tA\t1.1.1.1"
if err != nil {
t.Errorf("unexpected err: %s", err)
}
if rr == nil || rr.String() != expect {
t.Errorf("unexpected result: [%s] != [%s]", rr, expect)
}
}
func TestPrintfVerbsRdata(t *testing.T) {
x, _ := NewRR("www.miek.nl. IN MX 20 mx.miek.nl.")
if Field(x, 1) != "20" {
t.Errorf("should be 20")
}
if Field(x, 2) != "mx.miek.nl." {
t.Errorf("should be mx.miek.nl.")
}
x, _ = NewRR("www.miek.nl. IN A 127.0.0.1")
if Field(x, 1) != "127.0.0.1" {
t.Errorf("should be 127.0.0.1")
}
x, _ = NewRR("www.miek.nl. IN AAAA ::1")
if Field(x, 1) != "::1" {
t.Errorf("should be ::1")
}
x, _ = NewRR("www.miek.nl. IN NSEC a.miek.nl. A NS SOA MX AAAA")
if Field(x, 1) != "a.miek.nl." {
t.Errorf("should be a.miek.nl.")
}
if Field(x, 2) != "A NS SOA MX AAAA" {
t.Errorf("should be A NS SOA MX AAAA")
}
x, _ = NewRR("www.miek.nl. IN TXT \"first\" \"second\"")
if Field(x, 1) != "first second" {
t.Errorf("should be first second")
}
if Field(x, 0) != "" {
t.Errorf("should be empty")
}
}
func TestParseIPSECKEY(t *testing.T) {
tests := []string{
"38.2.0.192.in-addr.arpa. 7200 IN IPSECKEY ( 10 1 2 192.0.2.38 AQNRU3mG7TVTO2BkR47usntb102uFJtugbo6BSGvgqt4AQ== )",
"38.2.0.192.in-addr.arpa.\t7200\tIN\tIPSECKEY\t10 1 2 192.0.2.38 AQNRU3mG7TVTO2BkR47usntb102uFJtugbo6BSGvgqt4AQ==",
"38.2.0.192.in-addr.arpa. 7200 IN IPSECKEY ( 10 0 2 . AQNRU3mG7TVTO2BkR47usntb102uFJtugbo6BSGvgqt4AQ== )",
"38.2.0.192.in-addr.arpa.\t7200\tIN\tIPSECKEY\t10 0 2 . AQNRU3mG7TVTO2BkR47usntb102uFJtugbo6BSGvgqt4AQ==",
"38.2.0.192.in-addr.arpa. 7200 IN IPSECKEY ( 10 1 2 192.0.2.3 AQNRU3mG7TVTO2BkR47usntb102uFJtugbo6BSGvgqt4AQ== )",
"38.2.0.192.in-addr.arpa.\t7200\tIN\tIPSECKEY\t10 1 2 192.0.2.3 AQNRU3mG7TVTO2BkR47usntb102uFJtugbo6BSGvgqt4AQ==",
"38.1.0.192.in-addr.arpa. 7200 IN IPSECKEY ( 10 3 2 mygateway.example.com. AQNRU3mG7TVTO2BkR47usntb102uFJtugbo6BSGvgqt4AQ== )",
"38.1.0.192.in-addr.arpa.\t7200\tIN\tIPSECKEY\t10 3 2 mygateway.example.com. AQNRU3mG7TVTO2BkR47usntb102uFJtugbo6BSGvgqt4AQ==",
"0.d.4.0.3.0.e.f.f.f.3.f.0.1.2.0 7200 IN IPSECKEY ( 10 2 2 2001:0DB8:0:8002::2000:1 AQNRU3mG7TVTO2BkR47usntb102uFJtugbo6BSGvgqt4AQ== )",
"0.d.4.0.3.0.e.f.f.f.3.f.0.1.2.0.\t7200\tIN\tIPSECKEY\t10 2 2 2001:db8:0:8002::2000:1 AQNRU3mG7TVTO2BkR47usntb102uFJtugbo6BSGvgqt4AQ==",
}
for i := 0; i < len(tests)-1; i++ {
t1 := tests[i]
e1 := tests[i+1]
r, e := NewRR(t1)
if e != nil {
t.Logf("failed to parse IPSECKEY %s", e)
continue
}
if r.String() != e1 {
t.Logf("these two IPSECKEY records should match")
t.Logf("\n%s\n%s\n", r.String(), e1)
t.Fail()
}
i++
}
}

View File

@ -77,8 +77,7 @@ func (f HandlerFunc) ServeDNS(w ResponseWriter, r *Msg) {
f(w, r)
}
// FailedHandler returns a HandlerFunc
// returns SERVFAIL for every request it gets.
// FailedHandler returns a HandlerFunc that returns SERVFAIL for every request it gets.
func HandleFailed(w ResponseWriter, r *Msg) {
m := new(Msg)
m.SetRcode(r, RcodeServerFailure)
@ -170,10 +169,10 @@ func (mux *ServeMux) HandleRemove(pattern string) {
// is sought.
// If no handler is found a standard SERVFAIL message is returned
// If the request message does not have exactly one question in the
// question section a SERVFAIL is returned.
// question section a SERVFAIL is returned, unlesss Unsafe is true.
func (mux *ServeMux) ServeDNS(w ResponseWriter, request *Msg) {
var h Handler
if len(request.Question) != 1 {
if len(request.Question) < 1 { // allow more than one question
h = failedHandler()
} else {
if h = mux.match(request.Question[0].Name, request.Question[0].Qtype); h == nil {
@ -221,6 +220,11 @@ type Server struct {
IdleTimeout func() time.Duration
// Secret(s) for Tsig map[<zonename>]<base64 secret>.
TsigSecret map[string]string
// Unsafe instructs the server to disregard any sanity checks and directly hand the message to
// the handler. It will specfically not check if the query has the QR bit not set.
Unsafe bool
// If NotifyStartedFunc is set is is called, once the server has started listening.
NotifyStartedFunc func()
// For graceful shutdown.
stopUDP chan bool
@ -237,6 +241,7 @@ type Server struct {
func (srv *Server) ListenAndServe() error {
srv.lock.Lock()
if srv.started {
srv.lock.Unlock()
return &Error{err: "server already started"}
}
srv.stopUDP, srv.stopTCP = make(chan bool), make(chan bool)
@ -282,14 +287,12 @@ func (srv *Server) ListenAndServe() error {
func (srv *Server) ActivateAndServe() error {
srv.lock.Lock()
if srv.started {
srv.lock.Unlock()
return &Error{err: "server already started"}
}
srv.stopUDP, srv.stopTCP = make(chan bool), make(chan bool)
srv.started = true
srv.lock.Unlock()
if srv.UDPSize == 0 {
srv.UDPSize = MinMsgSize
}
if srv.PacketConn != nil {
if srv.UDPSize == 0 {
srv.UDPSize = MinMsgSize
@ -316,6 +319,7 @@ func (srv *Server) ActivateAndServe() error {
func (srv *Server) Shutdown() error {
srv.lock.Lock()
if !srv.started {
srv.lock.Unlock()
return &Error{err: "server not started"}
}
srv.started = false
@ -371,6 +375,11 @@ func (srv *Server) getReadTimeout() time.Duration {
// Each request is handled in a seperate goroutine.
func (srv *Server) serveTCP(l *net.TCPListener) error {
defer l.Close()
if srv.NotifyStartedFunc != nil {
srv.NotifyStartedFunc()
}
handler := srv.Handler
if handler == nil {
handler = DefaultServeMux
@ -402,6 +411,10 @@ func (srv *Server) serveTCP(l *net.TCPListener) error {
func (srv *Server) serveUDP(l *net.UDPConn) error {
defer l.Close()
if srv.NotifyStartedFunc != nil {
srv.NotifyStartedFunc()
}
handler := srv.Handler
if handler == nil {
handler = DefaultServeMux
@ -445,6 +458,9 @@ Redo:
w.WriteMsg(x)
goto Exit
}
if !srv.Unsafe && req.Response {
goto Exit
}
w.tsigStatus = nil
if w.tsigSecret != nil {

View File

@ -32,10 +32,37 @@ func RunLocalUDPServer(laddr string) (*Server, string, error) {
return nil, "", err
}
server := &Server{PacketConn: pc}
waitLock := sync.Mutex{}
waitLock.Lock()
server.NotifyStartedFunc = waitLock.Unlock
go func() {
server.ActivateAndServe()
pc.Close()
}()
waitLock.Lock()
return server, pc.LocalAddr().String(), nil
}
func RunLocalUDPServerUnsafe(laddr string) (*Server, string, error) {
pc, err := net.ListenPacket("udp", laddr)
if err != nil {
return nil, "", err
}
server := &Server{PacketConn: pc, Unsafe: true}
waitLock := sync.Mutex{}
waitLock.Lock()
server.NotifyStartedFunc = waitLock.Unlock
go func() {
server.ActivateAndServe()
pc.Close()
}()
waitLock.Lock()
return server, pc.LocalAddr().String(), nil
}
@ -44,11 +71,19 @@ func RunLocalTCPServer(laddr string) (*Server, string, error) {
if err != nil {
return nil, "", err
}
server := &Server{Listener: l}
waitLock := sync.Mutex{}
waitLock.Lock()
server.NotifyStartedFunc = waitLock.Unlock
go func() {
server.ActivateAndServe()
l.Close()
}()
waitLock.Lock()
return server, l.Addr().String(), nil
}
@ -68,7 +103,7 @@ func TestServing(t *testing.T) {
m := new(Msg)
m.SetQuestion("miek.nl.", TypeTXT)
r, _, err := c.Exchange(m, addrstr)
if err != nil {
if err != nil || len(r.Extra) == 0 {
t.Log("failed to exchange miek.nl", err)
t.Fatal()
}
@ -302,14 +337,52 @@ func TestServingLargeResponses(t *testing.T) {
}
}
func TestServingResponse(t *testing.T) {
if testing.Short() {
t.Skip("skipping test in short mode.")
}
HandleFunc("miek.nl.", HelloServer)
s, addrstr, err := RunLocalUDPServer("127.0.0.1:0")
if err != nil {
t.Fatalf("Unable to run test server: %s", err)
}
c := new(Client)
m := new(Msg)
m.SetQuestion("miek.nl.", TypeTXT)
m.Response = false
_, _, err = c.Exchange(m, addrstr)
if err != nil {
t.Log("failed to exchange", err)
t.Fatal()
}
m.Response = true
_, _, err = c.Exchange(m, addrstr)
if err == nil {
t.Log("exchanged response message")
t.Fatal()
}
s.Shutdown()
s, addrstr, err = RunLocalUDPServerUnsafe("127.0.0.1:0")
if err != nil {
t.Fatalf("Unable to run test server: %s", err)
}
defer s.Shutdown()
m.Response = true
_, _, err = c.Exchange(m, addrstr)
if err != nil {
t.Log("could exchanged response message in Unsafe mode")
t.Fatal()
}
}
func TestShutdownTCP(t *testing.T) {
s, _, err := RunLocalTCPServer("127.0.0.1:0")
if err != nil {
t.Fatalf("Unable to run test server: %s", err)
}
// it normally is too early to shutting down because server
// activates in goroutine.
runtime.Gosched()
err = s.Shutdown()
if err != nil {
t.Errorf("Could not shutdown test TCP server, %s", err)
@ -321,9 +394,6 @@ func TestShutdownUDP(t *testing.T) {
if err != nil {
t.Fatalf("Unable to run test server: %s", err)
}
// it normally is too early to shutting down because server
// activates in goroutine.
runtime.Gosched()
err = s.Shutdown()
if err != nil {
t.Errorf("Could not shutdown test UDP server, %s", err)

236
Godeps/_workspace/src/github.com/miekg/dns/sig0.go generated vendored Normal file
View File

@ -0,0 +1,236 @@
// SIG(0)
//
// From RFC 2931:
//
// SIG(0) provides protection for DNS transactions and requests ....
// ... protection for glue records, DNS requests, protection for message headers
// on requests and responses, and protection of the overall integrity of a response.
//
// It works like TSIG, except that SIG(0) uses public key cryptography, instead of the shared
// secret approach in TSIG.
// Supported algorithms: DSA, ECDSAP256SHA256, ECDSAP384SHA384, RSASHA1, RSASHA256 and
// RSASHA512.
//
// Signing subsequent messages in multi-message sessions is not implemented.
//
package dns
import (
"crypto"
"crypto/dsa"
"crypto/ecdsa"
"crypto/rsa"
"math/big"
"strings"
"time"
)
// Sign signs a dns.Msg. It fills the signature with the appropriate data.
// The SIG record should have the SignerName, KeyTag, Algorithm, Inception
// and Expiration set.
func (rr *SIG) Sign(k PrivateKey, m *Msg) ([]byte, error) {
if k == nil {
return nil, ErrPrivKey
}
if rr.KeyTag == 0 || len(rr.SignerName) == 0 || rr.Algorithm == 0 {
return nil, ErrKey
}
rr.Header().Rrtype = TypeSIG
rr.Header().Class = ClassANY
rr.Header().Ttl = 0
rr.Header().Name = "."
rr.OrigTtl = 0
rr.TypeCovered = 0
rr.Labels = 0
buf := make([]byte, m.Len()+rr.len())
mbuf, err := m.PackBuffer(buf)
if err != nil {
return nil, err
}
if &buf[0] != &mbuf[0] {
return nil, ErrBuf
}
off, err := PackRR(rr, buf, len(mbuf), nil, false)
if err != nil {
return nil, err
}
buf = buf[:off:cap(buf)]
var hash crypto.Hash
switch rr.Algorithm {
case DSA, RSASHA1:
hash = crypto.SHA1
case RSASHA256, ECDSAP256SHA256:
hash = crypto.SHA256
case ECDSAP384SHA384:
hash = crypto.SHA384
case RSASHA512:
hash = crypto.SHA512
default:
return nil, ErrAlg
}
hasher := hash.New()
// Write SIG rdata
hasher.Write(buf[len(mbuf)+1+2+2+4+2:])
// Write message
hasher.Write(buf[:len(mbuf)])
hashed := hasher.Sum(nil)
sig, err := k.Sign(hashed, rr.Algorithm)
if err != nil {
return nil, err
}
rr.Signature = toBase64(sig)
buf = append(buf, sig...)
if len(buf) > int(^uint16(0)) {
return nil, ErrBuf
}
// Adjust sig data length
rdoff := len(mbuf) + 1 + 2 + 2 + 4
rdlen, _ := unpackUint16(buf, rdoff)
rdlen += uint16(len(sig))
buf[rdoff], buf[rdoff+1] = packUint16(rdlen)
// Adjust additional count
adc, _ := unpackUint16(buf, 10)
adc += 1
buf[10], buf[11] = packUint16(adc)
return buf, nil
}
// Verify validates the message buf using the key k.
// It's assumed that buf is a valid message from which rr was unpacked.
func (rr *SIG) Verify(k *KEY, buf []byte) error {
if k == nil {
return ErrKey
}
if rr.KeyTag == 0 || len(rr.SignerName) == 0 || rr.Algorithm == 0 {
return ErrKey
}
var hash crypto.Hash
switch rr.Algorithm {
case DSA, RSASHA1:
hash = crypto.SHA1
case RSASHA256, ECDSAP256SHA256:
hash = crypto.SHA256
case ECDSAP384SHA384:
hash = crypto.SHA384
case RSASHA512:
hash = crypto.SHA512
default:
return ErrAlg
}
hasher := hash.New()
buflen := len(buf)
qdc, _ := unpackUint16(buf, 4)
anc, _ := unpackUint16(buf, 6)
auc, _ := unpackUint16(buf, 8)
adc, offset := unpackUint16(buf, 10)
var err error
for i := uint16(0); i < qdc && offset < buflen; i++ {
_, offset, err = UnpackDomainName(buf, offset)
if err != nil {
return err
}
// Skip past Type and Class
offset += 2 + 2
}
for i := uint16(1); i < anc+auc+adc && offset < buflen; i++ {
_, offset, err = UnpackDomainName(buf, offset)
if err != nil {
return err
}
// Skip past Type, Class and TTL
offset += 2 + 2 + 4
if offset+1 >= buflen {
continue
}
var rdlen uint16
rdlen, offset = unpackUint16(buf, offset)
offset += int(rdlen)
}
if offset >= buflen {
return &Error{err: "overflowing unpacking signed message"}
}
// offset should be just prior to SIG
bodyend := offset
// owner name SHOULD be root
_, offset, err = UnpackDomainName(buf, offset)
if err != nil {
return err
}
// Skip Type, Class, TTL, RDLen
offset += 2 + 2 + 4 + 2
sigstart := offset
// Skip Type Covered, Algorithm, Labels, Original TTL
offset += 2 + 1 + 1 + 4
if offset+4+4 >= buflen {
return &Error{err: "overflow unpacking signed message"}
}
expire := uint32(buf[offset])<<24 | uint32(buf[offset+1])<<16 | uint32(buf[offset+2])<<8 | uint32(buf[offset+3])
offset += 4
incept := uint32(buf[offset])<<24 | uint32(buf[offset+1])<<16 | uint32(buf[offset+2])<<8 | uint32(buf[offset+3])
offset += 4
now := uint32(time.Now().Unix())
if now < incept || now > expire {
return ErrTime
}
// Skip key tag
offset += 2
var signername string
signername, offset, err = UnpackDomainName(buf, offset)
if err != nil {
return err
}
// If key has come from the DNS name compression might
// have mangled the case of the name
if strings.ToLower(signername) != strings.ToLower(k.Header().Name) {
return &Error{err: "signer name doesn't match key name"}
}
sigend := offset
hasher.Write(buf[sigstart:sigend])
hasher.Write(buf[:10])
hasher.Write([]byte{
byte((adc - 1) << 8),
byte(adc - 1),
})
hasher.Write(buf[12:bodyend])
hashed := hasher.Sum(nil)
sig := buf[sigend:]
switch k.Algorithm {
case DSA:
pk := k.publicKeyDSA()
sig = sig[1:]
r := big.NewInt(0)
r.SetBytes(sig[:len(sig)/2])
s := big.NewInt(0)
s.SetBytes(sig[len(sig)/2:])
if pk != nil {
if dsa.Verify(pk, hashed, r, s) {
return nil
}
return ErrSig
}
case RSASHA1, RSASHA256, RSASHA512:
pk := k.publicKeyRSA()
if pk != nil {
return rsa.VerifyPKCS1v15(pk, hash, hashed, sig)
}
case ECDSAP256SHA256, ECDSAP384SHA384:
pk := k.publicKeyECDSA()
r := big.NewInt(0)
r.SetBytes(sig[:len(sig)/2])
s := big.NewInt(0)
s.SetBytes(sig[len(sig)/2:])
if pk != nil {
if ecdsa.Verify(pk, hashed, r, s) {
return nil
}
return ErrSig
}
}
return ErrKeyAlg
}

View File

@ -0,0 +1,96 @@
package dns
import (
"testing"
"time"
)
func TestSIG0(t *testing.T) {
if testing.Short() {
t.Skip("skipping test in short mode.")
}
m := new(Msg)
m.SetQuestion("example.org.", TypeSOA)
for _, alg := range []uint8{DSA, ECDSAP256SHA256, ECDSAP384SHA384, RSASHA1, RSASHA256, RSASHA512} {
algstr := AlgorithmToString[alg]
keyrr := new(KEY)
keyrr.Hdr.Name = algstr + "."
keyrr.Hdr.Rrtype = TypeKEY
keyrr.Hdr.Class = ClassINET
keyrr.Algorithm = alg
keysize := 1024
switch alg {
case ECDSAP256SHA256:
keysize = 256
case ECDSAP384SHA384:
keysize = 384
}
pk, err := keyrr.Generate(keysize)
if err != nil {
t.Logf("Failed to generate key for “%s”: %v", algstr, err)
t.Fail()
continue
}
now := uint32(time.Now().Unix())
sigrr := new(SIG)
sigrr.Hdr.Name = "."
sigrr.Hdr.Rrtype = TypeSIG
sigrr.Hdr.Class = ClassANY
sigrr.Algorithm = alg
sigrr.Expiration = now + 300
sigrr.Inception = now - 300
sigrr.KeyTag = keyrr.KeyTag()
sigrr.SignerName = keyrr.Hdr.Name
mb, err := sigrr.Sign(pk, m)
if err != nil {
t.Logf("Failed to sign message using “%s”: %v", algstr, err)
t.Fail()
continue
}
m := new(Msg)
if err := m.Unpack(mb); err != nil {
t.Logf("Failed to unpack message signed using “%s”: %v", algstr, err)
t.Fail()
continue
}
if len(m.Extra) != 1 {
t.Logf("Missing SIG for message signed using “%s”", algstr)
t.Fail()
continue
}
var sigrrwire *SIG
switch rr := m.Extra[0].(type) {
case *SIG:
sigrrwire = rr
default:
t.Logf("Expected SIG RR, instead: %v", rr)
t.Fail()
continue
}
for _, rr := range []*SIG{sigrr, sigrrwire} {
id := "sigrr"
if rr == sigrrwire {
id = "sigrrwire"
}
if err := rr.Verify(keyrr, mb); err != nil {
t.Logf("Failed to verify “%s” signed SIG(%s): %v", algstr, id, err)
t.Fail()
continue
}
}
mb[13]++
if err := sigrr.Verify(keyrr, mb); err == nil {
t.Logf("Verify succeeded on an altered message using “%s”", algstr)
t.Fail()
continue
}
sigrr.Expiration = 2
sigrr.Inception = 1
mb, _ = sigrr.Sign(pk, m)
if err := sigrr.Verify(keyrr, mb); err == nil {
t.Logf("Verify succeeded on an expired message using “%s”", algstr)
t.Fail()
continue
}
}
}

View File

@ -1,7 +1,7 @@
// TRANSACTION SIGNATURE
//
// An TSIG or transaction signature adds a HMAC TSIG record to each message sent.
// The supported algorithms include: HmacMD5, HmacSHA1 and HmacSHA256.
// The supported algorithms include: HmacMD5, HmacSHA1, HmacSHA256 and HmacSHA512.
//
// Basic use pattern when querying with a TSIG name "axfr." (note that these key names
// must be fully qualified - as they are domain names) and the base64 secret
@ -58,6 +58,7 @@ import (
"crypto/md5"
"crypto/sha1"
"crypto/sha256"
"crypto/sha512"
"encoding/hex"
"hash"
"io"
@ -71,6 +72,7 @@ const (
HmacMD5 = "hmac-md5.sig-alg.reg.int."
HmacSHA1 = "hmac-sha1."
HmacSHA256 = "hmac-sha256."
HmacSHA512 = "hmac-sha512."
)
type TSIG struct {
@ -159,7 +161,7 @@ func TsigGenerate(m *Msg, secret, requestMAC string, timersOnly bool) ([]byte, s
panic("dns: TSIG not last RR in additional")
}
// If we barf here, the caller is to blame
rawsecret, err := packBase64([]byte(secret))
rawsecret, err := fromBase64([]byte(secret))
if err != nil {
return nil, "", err
}
@ -181,6 +183,8 @@ func TsigGenerate(m *Msg, secret, requestMAC string, timersOnly bool) ([]byte, s
h = hmac.New(sha1.New, []byte(rawsecret))
case HmacSHA256:
h = hmac.New(sha256.New, []byte(rawsecret))
case HmacSHA512:
h = hmac.New(sha512.New, []byte(rawsecret))
default:
return nil, "", ErrKeyAlg
}
@ -209,7 +213,7 @@ func TsigGenerate(m *Msg, secret, requestMAC string, timersOnly bool) ([]byte, s
// If the signature does not validate err contains the
// error, otherwise it is nil.
func TsigVerify(msg []byte, secret, requestMAC string, timersOnly bool) error {
rawsecret, err := packBase64([]byte(secret))
rawsecret, err := fromBase64([]byte(secret))
if err != nil {
return err
}
@ -225,7 +229,14 @@ func TsigVerify(msg []byte, secret, requestMAC string, timersOnly bool) error {
}
buf := tsigBuffer(stripped, tsig, requestMAC, timersOnly)
ti := uint64(time.Now().Unix()) - tsig.TimeSigned
// Fudge factor works both ways. A message can arrive before it was signed because
// of clock skew.
now := uint64(time.Now().Unix())
ti := now - tsig.TimeSigned
if now < tsig.TimeSigned {
ti = tsig.TimeSigned - now
}
if uint64(tsig.Fudge) < ti {
return ErrTime
}
@ -238,6 +249,8 @@ func TsigVerify(msg []byte, secret, requestMAC string, timersOnly bool) error {
h = hmac.New(sha1.New, rawsecret)
case HmacSHA256:
h = hmac.New(sha256.New, rawsecret)
case HmacSHA512:
h = hmac.New(sha512.New, rawsecret)
default:
return ErrKeyAlg
}

View File

@ -75,6 +75,7 @@ const (
TypeRKEY uint16 = 57
TypeTALINK uint16 = 58
TypeCDS uint16 = 59
TypeCDNSKEY uint16 = 60
TypeOPENPGPKEY uint16 = 61
TypeSPF uint16 = 99
TypeUINFO uint16 = 100
@ -159,10 +160,14 @@ const (
_AD = 1 << 5 // authticated data
_CD = 1 << 4 // checking disabled
)
LOC_EQUATOR = 1 << 31 // RFC 1876, Section 2.
LOC_PRIMEMERIDIAN = 1 << 31 // RFC 1876, Section 2.
// RFC 1876, Section 2
const _LOC_EQUATOR = 1 << 31
LOC_HOURS = 60 * 1000
LOC_DEGREES = 60 * LOC_HOURS
LOC_ALTITUDEBASE = 100000
)
// RFC 4398, Section 2.1
const (
@ -307,7 +312,7 @@ func (rr *MF) copy() RR { return &MF{*rr.Hdr.copyHeader(), rr.Mf} }
func (rr *MF) len() int { return rr.Hdr.len() + len(rr.Mf) + 1 }
func (rr *MF) String() string {
return rr.Hdr.String() + " " + sprintName(rr.Mf)
return rr.Hdr.String() + sprintName(rr.Mf)
}
type MD struct {
@ -320,7 +325,7 @@ func (rr *MD) copy() RR { return &MD{*rr.Hdr.copyHeader(), rr.Md} }
func (rr *MD) len() int { return rr.Hdr.len() + len(rr.Md) + 1 }
func (rr *MD) String() string {
return rr.Hdr.String() + " " + sprintName(rr.Md)
return rr.Hdr.String() + sprintName(rr.Md)
}
type MX struct {
@ -527,7 +532,17 @@ func appendTXTStringByte(s []byte, b byte) []byte {
return append(s, '\\', b)
}
if b < ' ' || b > '~' {
return append(s, fmt.Sprintf("\\%03d", b)...)
var buf [3]byte
bufs := strconv.AppendInt(buf[:0], int64(b), 10)
s = append(s, '\\')
for i := 0; i < 3-len(bufs); i++ {
s = append(s, '0')
}
for _, r := range bufs {
s = append(s, r)
}
return s
}
return append(s, b)
}
@ -772,51 +787,77 @@ func (rr *LOC) copy() RR {
return &LOC{*rr.Hdr.copyHeader(), rr.Version, rr.Size, rr.HorizPre, rr.VertPre, rr.Latitude, rr.Longitude, rr.Altitude}
}
// cmToM takes a cm value expressed in RFC1876 SIZE mantissa/exponent
// format and returns a string in m (two decimals for the cm)
func cmToM(m, e uint8) string {
if e < 2 {
if e == 1 {
m *= 10
}
return fmt.Sprintf("0.%02d", m)
}
s := fmt.Sprintf("%d", m)
for e > 2 {
s += "0"
e -= 1
}
return s
}
// String returns a string version of a LOC
func (rr *LOC) String() string {
s := rr.Hdr.String()
// Copied from ldns
// Latitude
lat := rr.Latitude
north := "N"
if lat > _LOC_EQUATOR {
lat = lat - _LOC_EQUATOR
} else {
north = "S"
lat = _LOC_EQUATOR - lat
}
h := lat / (1000 * 60 * 60)
lat = lat % (1000 * 60 * 60)
m := lat / (1000 * 60)
lat = lat % (1000 * 60)
s += fmt.Sprintf("%02d %02d %0.3f %s ", h, m, (float32(lat) / 1000), north)
// Longitude
lon := rr.Longitude
east := "E"
if lon > _LOC_EQUATOR {
lon = lon - _LOC_EQUATOR
} else {
east = "W"
lon = _LOC_EQUATOR - lon
}
h = lon / (1000 * 60 * 60)
lon = lon % (1000 * 60 * 60)
m = lon / (1000 * 60)
lon = lon % (1000 * 60)
s += fmt.Sprintf("%02d %02d %0.3f %s ", h, m, (float32(lon) / 1000), east)
s1 := rr.Altitude / 100.00
s1 -= 100000
if rr.Altitude%100 == 0 {
s += fmt.Sprintf("%.2fm ", float32(s1))
lat := rr.Latitude
ns := "N"
if lat > LOC_EQUATOR {
lat = lat - LOC_EQUATOR
} else {
s += fmt.Sprintf("%.0fm ", float32(s1))
ns = "S"
lat = LOC_EQUATOR - lat
}
s += cmToString((rr.Size&0xf0)>>4, rr.Size&0x0f) + "m "
s += cmToString((rr.HorizPre&0xf0)>>4, rr.HorizPre&0x0f) + "m "
s += cmToString((rr.VertPre&0xf0)>>4, rr.VertPre&0x0f) + "m"
h := lat / LOC_DEGREES
lat = lat % LOC_DEGREES
m := lat / LOC_HOURS
lat = lat % LOC_HOURS
s += fmt.Sprintf("%02d %02d %0.3f %s ", h, m, (float64(lat) / 1000), ns)
lon := rr.Longitude
ew := "E"
if lon > LOC_PRIMEMERIDIAN {
lon = lon - LOC_PRIMEMERIDIAN
} else {
ew = "W"
lon = LOC_PRIMEMERIDIAN - lon
}
h = lon / LOC_DEGREES
lon = lon % LOC_DEGREES
m = lon / LOC_HOURS
lon = lon % LOC_HOURS
s += fmt.Sprintf("%02d %02d %0.3f %s ", h, m, (float64(lon) / 1000), ew)
var alt float64 = float64(rr.Altitude) / 100
alt -= LOC_ALTITUDEBASE
if rr.Altitude%100 != 0 {
s += fmt.Sprintf("%.2fm ", alt)
} else {
s += fmt.Sprintf("%.0fm ", alt)
}
s += cmToM((rr.Size&0xf0)>>4, rr.Size&0x0f) + "m "
s += cmToM((rr.HorizPre&0xf0)>>4, rr.HorizPre&0x0f) + "m "
s += cmToM((rr.VertPre&0xf0)>>4, rr.VertPre&0x0f) + "m"
return s
}
// SIG is identical to RRSIG and nowadays only used for SIG(0), RFC2931.
type SIG struct {
RRSIG
}
type RRSIG struct {
Hdr RR_Header
TypeCovered uint16
@ -888,6 +929,14 @@ func (rr *NSEC) len() int {
return l
}
type DLV struct {
DS
}
type CDS struct {
DS
}
type DS struct {
Hdr RR_Header
KeyTag uint16
@ -909,48 +958,6 @@ func (rr *DS) String() string {
" " + strings.ToUpper(rr.Digest)
}
type CDS struct {
Hdr RR_Header
KeyTag uint16
Algorithm uint8
DigestType uint8
Digest string `dns:"hex"`
}
func (rr *CDS) Header() *RR_Header { return &rr.Hdr }
func (rr *CDS) len() int { return rr.Hdr.len() + 4 + len(rr.Digest)/2 }
func (rr *CDS) copy() RR {
return &CDS{*rr.Hdr.copyHeader(), rr.KeyTag, rr.Algorithm, rr.DigestType, rr.Digest}
}
func (rr *CDS) String() string {
return rr.Hdr.String() + strconv.Itoa(int(rr.KeyTag)) +
" " + strconv.Itoa(int(rr.Algorithm)) +
" " + strconv.Itoa(int(rr.DigestType)) +
" " + strings.ToUpper(rr.Digest)
}
type DLV struct {
Hdr RR_Header
KeyTag uint16
Algorithm uint8
DigestType uint8
Digest string `dns:"hex"`
}
func (rr *DLV) Header() *RR_Header { return &rr.Hdr }
func (rr *DLV) len() int { return rr.Hdr.len() + 4 + len(rr.Digest)/2 }
func (rr *DLV) copy() RR {
return &DLV{*rr.Hdr.copyHeader(), rr.KeyTag, rr.Algorithm, rr.DigestType, rr.Digest}
}
func (rr *DLV) String() string {
return rr.Hdr.String() + strconv.Itoa(int(rr.KeyTag)) +
" " + strconv.Itoa(int(rr.Algorithm)) +
" " + strconv.Itoa(int(rr.DigestType)) +
" " + strings.ToUpper(rr.Digest)
}
type KX struct {
Hdr RR_Header
Preference uint16
@ -999,7 +1006,7 @@ func (rr *TALINK) len() int { return rr.Hdr.len() + len(rr.PreviousNam
func (rr *TALINK) String() string {
return rr.Hdr.String() +
" " + sprintName(rr.PreviousName) + " " + sprintName(rr.NextName)
sprintName(rr.PreviousName) + " " + sprintName(rr.NextName)
}
type SSHFP struct {
@ -1024,28 +1031,65 @@ func (rr *SSHFP) String() string {
type IPSECKEY struct {
Hdr RR_Header
Precedence uint8
// GatewayType: 1: A record, 2: AAAA record, 3: domainname.
// 0 is use for no type and GatewayName should be "." then.
GatewayType uint8
Algorithm uint8
Gateway string `dns:"ipseckey"`
// Gateway can be an A record, AAAA record or a domain name.
GatewayA net.IP `dns:"a"`
GatewayAAAA net.IP `dns:"aaaa"`
GatewayName string `dns:"domain-name"`
PublicKey string `dns:"base64"`
}
func (rr *IPSECKEY) Header() *RR_Header { return &rr.Hdr }
func (rr *IPSECKEY) copy() RR {
return &IPSECKEY{*rr.Hdr.copyHeader(), rr.Precedence, rr.GatewayType, rr.Algorithm, rr.Gateway, rr.PublicKey}
return &IPSECKEY{*rr.Hdr.copyHeader(), rr.Precedence, rr.GatewayType, rr.Algorithm, rr.GatewayA, rr.GatewayAAAA, rr.GatewayName, rr.PublicKey}
}
func (rr *IPSECKEY) String() string {
return rr.Hdr.String() + strconv.Itoa(int(rr.Precedence)) +
s := rr.Hdr.String() + strconv.Itoa(int(rr.Precedence)) +
" " + strconv.Itoa(int(rr.GatewayType)) +
" " + strconv.Itoa(int(rr.Algorithm)) +
" " + rr.Gateway +
" " + rr.PublicKey
" " + strconv.Itoa(int(rr.Algorithm))
switch rr.GatewayType {
case 0:
fallthrough
case 3:
s += " " + rr.GatewayName
case 1:
s += " " + rr.GatewayA.String()
case 2:
s += " " + rr.GatewayAAAA.String()
default:
s += " ."
}
s += " " + rr.PublicKey
return s
}
func (rr *IPSECKEY) len() int {
return rr.Hdr.len() + 3 + len(rr.Gateway) + 1 +
base64.StdEncoding.DecodedLen(len(rr.PublicKey))
l := rr.Hdr.len() + 3 + 1
switch rr.GatewayType {
default:
fallthrough
case 0:
fallthrough
case 3:
l += len(rr.GatewayName)
case 1:
l += 4
case 2:
l += 16
}
return l + base64.StdEncoding.DecodedLen(len(rr.PublicKey))
}
type KEY struct {
DNSKEY
}
type CDNSKEY struct {
DNSKEY
}
type DNSKEY struct {
@ -1221,11 +1265,23 @@ func (rr *RFC3597) copy() RR { return &RFC3597{*rr.Hdr.copyHeader(), r
func (rr *RFC3597) len() int { return rr.Hdr.len() + len(rr.Rdata)/2 + 2 }
func (rr *RFC3597) String() string {
s := rr.Hdr.String()
// Let's call it a hack
s := rfc3597Header(rr.Hdr)
s += "\\# " + strconv.Itoa(len(rr.Rdata)/2) + " " + rr.Rdata
return s
}
func rfc3597Header(h RR_Header) string {
var s string
s += sprintName(h.Name) + "\t"
s += strconv.FormatInt(int64(h.Ttl), 10) + "\t"
s += "CLASS" + strconv.Itoa(int(h.Class)) + "\t"
s += "TYPE" + strconv.Itoa(int(h.Rrtype)) + "\t"
return s
}
type URI struct {
Hdr RR_Header
Priority uint16
@ -1280,7 +1336,7 @@ func (rr *TLSA) copy() RR {
func (rr *TLSA) String() string {
return rr.Hdr.String() +
" " + strconv.Itoa(int(rr.Usage)) +
strconv.Itoa(int(rr.Usage)) +
" " + strconv.Itoa(int(rr.Selector)) +
" " + strconv.Itoa(int(rr.MatchingType)) +
" " + rr.Certificate
@ -1305,7 +1361,7 @@ func (rr *HIP) copy() RR {
func (rr *HIP) String() string {
s := rr.Hdr.String() +
" " + strconv.Itoa(int(rr.PublicKeyAlgorithm)) +
strconv.Itoa(int(rr.PublicKeyAlgorithm)) +
" " + rr.Hit +
" " + rr.PublicKey
for _, d := range rr.RendezvousServers {
@ -1367,6 +1423,7 @@ func (rr *WKS) String() (s string) {
if rr.Address != nil {
s += rr.Address.String()
}
// TODO(miek): missing protocol here, see /etc/protocols
for i := 0; i < len(rr.BitMap); i++ {
// should lookup the port
s += " " + strconv.Itoa(int(rr.BitMap[i]))
@ -1578,23 +1635,6 @@ func saltToString(s string) string {
return strings.ToUpper(s)
}
func cmToString(mantissa, exponent uint8) string {
switch exponent {
case 0, 1:
if exponent == 1 {
mantissa *= 10
}
return fmt.Sprintf("%.02f", float32(mantissa))
default:
s := fmt.Sprintf("%d", mantissa)
for i := uint8(0); i < exponent-2; i++ {
s += "0"
}
return s
}
panic("dns: not reached")
}
func euiToString(eui uint64, bits int) (hex string) {
switch bits {
case 64:
@ -1628,6 +1668,7 @@ var typeToRR = map[uint16]func() RR{
TypeDHCID: func() RR { return new(DHCID) },
TypeDLV: func() RR { return new(DLV) },
TypeDNAME: func() RR { return new(DNAME) },
TypeKEY: func() RR { return new(KEY) },
TypeDNSKEY: func() RR { return new(DNSKEY) },
TypeDS: func() RR { return new(DS) },
TypeEUI48: func() RR { return new(EUI48) },
@ -1637,6 +1678,7 @@ var typeToRR = map[uint16]func() RR{
TypeEID: func() RR { return new(EID) },
TypeHINFO: func() RR { return new(HINFO) },
TypeHIP: func() RR { return new(HIP) },
TypeIPSECKEY: func() RR { return new(IPSECKEY) },
TypeKX: func() RR { return new(KX) },
TypeL32: func() RR { return new(L32) },
TypeL64: func() RR { return new(L64) },
@ -1665,6 +1707,7 @@ var typeToRR = map[uint16]func() RR{
TypeRKEY: func() RR { return new(RKEY) },
TypeRP: func() RR { return new(RP) },
TypePX: func() RR { return new(PX) },
TypeSIG: func() RR { return new(SIG) },
TypeRRSIG: func() RR { return new(RRSIG) },
TypeRT: func() RR { return new(RT) },
TypeSOA: func() RR { return new(SOA) },

View File

@ -0,0 +1,42 @@
package dns
import (
"testing"
)
func TestCmToM(t *testing.T) {
s := cmToM(0, 0)
if s != "0.00" {
t.Error("0, 0")
}
s = cmToM(1, 0)
if s != "0.01" {
t.Error("1, 0")
}
s = cmToM(3, 1)
if s != "0.30" {
t.Error("3, 1")
}
s = cmToM(4, 2)
if s != "4" {
t.Error("4, 2")
}
s = cmToM(5, 3)
if s != "50" {
t.Error("5, 3")
}
s = cmToM(7, 5)
if s != "7000" {
t.Error("7, 5")
}
s = cmToM(9, 9)
if s != "90000000" {
t.Error("9, 9")
}
}

View File

@ -106,10 +106,7 @@ func (u *Msg) Insert(rr []RR) {
func (u *Msg) RemoveRRset(rr []RR) {
u.Ns = make([]RR, len(rr))
for i, r := range rr {
u.Ns[i] = r
u.Ns[i].Header().Class = ClassANY
u.Ns[i].Header().Rdlength = 0
u.Ns[i].Header().Ttl = 0
u.Ns[i] = &ANY{Hdr: RR_Header{Name: r.Header().Name, Ttl: 0, Rrtype: r.Header().Rrtype, Class: ClassANY}}
}
}

View File

@ -0,0 +1,89 @@
package dns
import (
"bytes"
"testing"
)
func TestDynamicUpdateParsing(t *testing.T) {
prefix := "example.com. IN "
for _, typ := range TypeToString {
if typ == "CAA" || typ == "OPT" || typ == "AXFR" || typ == "IXFR" || typ == "ANY" || typ == "TKEY" ||
typ == "TSIG" || typ == "ISDN" || typ == "UNSPEC" || typ == "NULL" || typ == "ATMA" {
continue
}
r, e := NewRR(prefix + typ)
if e != nil {
t.Log("failure to parse: " + prefix + typ)
t.Fail()
} else {
t.Logf("parsed: %s", r.String())
}
}
}
func TestDynamicUpdateUnpack(t *testing.T) {
// From https://github.com/miekg/dns/issues/150#issuecomment-62296803
// It should be an update message for the zone "example.",
// deleting the A RRset "example." and then adding an A record at "example.".
// class ANY, TYPE A
buf := []byte{171, 68, 40, 0, 0, 1, 0, 0, 0, 2, 0, 0, 7, 101, 120, 97, 109, 112, 108, 101, 0, 0, 6, 0, 1, 192, 12, 0, 1, 0, 255, 0, 0, 0, 0, 0, 0, 192, 12, 0, 1, 0, 1, 0, 0, 0, 0, 0, 4, 127, 0, 0, 1}
msg := new(Msg)
err := msg.Unpack(buf)
if err != nil {
t.Log("failed to unpack: " + err.Error() + "\n" + msg.String())
t.Fail()
}
}
func TestDynamicUpdateZeroRdataUnpack(t *testing.T) {
m := new(Msg)
rr := &RR_Header{Name: ".", Rrtype: 0, Class: 1, Ttl: ^uint32(0), Rdlength: 0}
m.Answer = []RR{rr, rr, rr, rr, rr}
m.Ns = m.Answer
for n, s := range TypeToString {
rr.Rrtype = n
bytes, err := m.Pack()
if err != nil {
t.Logf("failed to pack %s: %v", s, err)
t.Fail()
continue
}
if err := new(Msg).Unpack(bytes); err != nil {
t.Logf("failed to unpack %s: %v", s, err)
t.Fail()
}
}
}
func TestRemoveRRset(t *testing.T) {
// Should add a zero data RR in Class ANY with a TTL of 0
// for each set mentioned in the RRs provided to it.
rr, err := NewRR(". 100 IN A 127.0.0.1")
if err != nil {
t.Fatalf("Error constructing RR: %v", err)
}
m := new(Msg)
m.Ns = []RR{&RR_Header{Name: ".", Rrtype: TypeA, Class: ClassANY, Ttl: 0, Rdlength: 0}}
expectstr := m.String()
expect, err := m.Pack()
if err != nil {
t.Fatalf("Error packing expected msg: %v", err)
}
m.Ns = nil
m.RemoveRRset([]RR{rr})
actual, err := m.Pack()
if err != nil {
t.Fatalf("Error packing actual msg: %v", err)
}
if !bytes.Equal(actual, expect) {
tmp := new(Msg)
if err := tmp.Unpack(actual); err != nil {
t.Fatalf("Error unpacking actual msg: %v", err)
}
t.Logf("Expected msg:\n%s", expectstr)
t.Logf("Actual msg:\n%v", tmp)
t.Fail()
}
}

View File

@ -13,9 +13,9 @@ type Envelope struct {
// A Transfer defines parameters that are used during a zone transfer.
type Transfer struct {
*Conn
DialTimeout time.Duration // net.DialTimeout (ns), defaults to 2 * 1e9
ReadTimeout time.Duration // net.Conn.SetReadTimeout value for connections (ns), defaults to 2 * 1e9
WriteTimeout time.Duration // net.Conn.SetWriteTimeout value for connections (ns), defaults to 2 * 1e9
DialTimeout time.Duration // net.DialTimeout, defaults to 2 seconds
ReadTimeout time.Duration // net.Conn.SetReadTimeout value for connections, defaults to 2 seconds
WriteTimeout time.Duration // net.Conn.SetWriteTimeout value for connections, defaults to 2 seconds
TsigSecret map[string]string // Secret(s) for Tsig map[<zonename>]<base64 secret>, zonename must be fully qualified
tsigTimersOnly bool
}
@ -160,22 +160,18 @@ func (t *Transfer) inIxfr(id uint16, c chan *Envelope) {
// The server is responsible for sending the correct sequence of RRs through the
// channel ch.
func (t *Transfer) Out(w ResponseWriter, q *Msg, ch chan *Envelope) error {
for x := range ch {
r := new(Msg)
// Compress?
r.SetReply(q)
r.Authoritative = true
go func() {
for x := range ch {
// assume it fits TODO(miek): fix
r.Answer = append(r.Answer, x.RR...)
if err := w.WriteMsg(r); err != nil {
return
return err
}
}
w.TsigTimersOnly(true)
r.Answer = nil
}()
return nil
}

99
Godeps/_workspace/src/github.com/miekg/dns/xfr_test.go generated vendored Normal file
View File

@ -0,0 +1,99 @@
package dns
import (
"net"
"testing"
"time"
)
func getIP(s string) string {
a, e := net.LookupAddr(s)
if e != nil {
return ""
}
return a[0]
}
// flaky, need to setup local server and test from
// that.
func testClientAXFR(t *testing.T) {
if testing.Short() {
return
}
m := new(Msg)
m.SetAxfr("miek.nl.")
server := getIP("linode.atoom.net")
tr := new(Transfer)
if a, err := tr.In(m, net.JoinHostPort(server, "53")); err != nil {
t.Log("failed to setup axfr: " + err.Error())
t.Fatal()
} else {
for ex := range a {
if ex.Error != nil {
t.Logf("error %s\n", ex.Error.Error())
t.Fail()
break
}
for _, rr := range ex.RR {
t.Logf("%s\n", rr.String())
}
}
}
}
// fails.
func testClientAXFRMultipleEnvelopes(t *testing.T) {
if testing.Short() {
return
}
m := new(Msg)
m.SetAxfr("nlnetlabs.nl.")
server := getIP("open.nlnetlabs.nl.")
tr := new(Transfer)
if a, err := tr.In(m, net.JoinHostPort(server, "53")); err != nil {
t.Log("Failed to setup axfr" + err.Error() + "for server: " + server)
t.Fail()
return
} else {
for ex := range a {
if ex.Error != nil {
t.Logf("Error %s\n", ex.Error.Error())
t.Fail()
break
}
}
}
}
func testClientTsigAXFR(t *testing.T) {
if testing.Short() {
return
}
m := new(Msg)
m.SetAxfr("example.nl.")
m.SetTsig("axfr.", HmacMD5, 300, time.Now().Unix())
tr := new(Transfer)
tr.TsigSecret = map[string]string{"axfr.": "so6ZGir4GPAqINNh9U5c3A=="}
if a, err := tr.In(m, "176.58.119.54:53"); err != nil {
t.Log("failed to setup axfr: " + err.Error())
t.Fatal()
} else {
for ex := range a {
if ex.Error != nil {
t.Logf("error %s\n", ex.Error.Error())
t.Fail()
break
}
for _, rr := range ex.RR {
t.Logf("%s\n", rr.String())
}
}
}
}

View File

@ -102,12 +102,13 @@ type Token struct {
Comment string // a potential comment positioned after the RR and on the same line
}
// NewRR reads the RR contained in the string s. Only the first RR is returned.
// The class defaults to IN and TTL defaults to 3600. The full zone file
// syntax like $TTL, $ORIGIN, etc. is supported.
// All fields of the returned RR are set, except RR.Header().Rdlength which is set to 0.
// NewRR reads the RR contained in the string s. Only the first RR is
// returned. If s contains no RR, return nil with no error. The class
// defaults to IN and TTL defaults to 3600. The full zone file syntax
// like $TTL, $ORIGIN, etc. is supported. All fields of the returned
// RR are set, except RR.Header().Rdlength which is set to 0.
func NewRR(s string) (RR, error) {
if s[len(s)-1] != '\n' { // We need a closing newline
if len(s) > 0 && s[len(s)-1] != '\n' { // We need a closing newline
return ReadRR(strings.NewReader(s+"\n"), "")
}
return ReadRR(strings.NewReader(s), "")
@ -117,6 +118,10 @@ func NewRR(s string) (RR, error) {
// See NewRR for more documentation.
func ReadRR(q io.Reader, filename string) (RR, error) {
r := <-parseZoneHelper(q, ".", filename, 1)
if r == nil {
return nil, nil
}
if r.Error != nil {
return nil, r.Error
}
@ -899,9 +904,9 @@ func appendOrigin(name, origin string) string {
func locCheckNorth(token string, latitude uint32) (uint32, bool) {
switch token {
case "n", "N":
return _LOC_EQUATOR + latitude, true
return LOC_EQUATOR + latitude, true
case "s", "S":
return _LOC_EQUATOR - latitude, true
return LOC_EQUATOR - latitude, true
}
return latitude, false
}
@ -910,9 +915,9 @@ func locCheckNorth(token string, latitude uint32) (uint32, bool) {
func locCheckEast(token string, longitude uint32) (uint32, bool) {
switch token {
case "e", "E":
return _LOC_EQUATOR + longitude, true
return LOC_EQUATOR + longitude, true
case "w", "W":
return _LOC_EQUATOR - longitude, true
return LOC_EQUATOR - longitude, true
}
return longitude, false
}

View File

@ -1065,6 +1065,14 @@ func setOPENPGPKEY(h RR_Header, c chan lex, o, f string) (RR, *ParseError, strin
return rr, nil, c1
}
func setSIG(h RR_Header, c chan lex, o, f string) (RR, *ParseError, string) {
r, e, s := setRRSIG(h, c, o, f)
if r != nil {
return &SIG{*r.(*RRSIG)}, e, s
}
return nil, e, s
}
func setRRSIG(h RR_Header, c chan lex, o, f string) (RR, *ParseError, string) {
rr := new(RRSIG)
rr.Hdr = h
@ -1452,7 +1460,7 @@ func setSSHFP(h RR_Header, c chan lex, o, f string) (RR, *ParseError, string) {
return rr, nil, ""
}
func setDNSKEY(h RR_Header, c chan lex, o, f string) (RR, *ParseError, string) {
func setDNSKEYs(h RR_Header, c chan lex, o, f, typ string) (RR, *ParseError, string) {
rr := new(DNSKEY)
rr.Hdr = h
@ -1461,25 +1469,25 @@ func setDNSKEY(h RR_Header, c chan lex, o, f string) (RR, *ParseError, string) {
return rr, nil, l.comment
}
if i, e := strconv.Atoi(l.token); e != nil {
return nil, &ParseError{f, "bad DNSKEY Flags", l}, ""
return nil, &ParseError{f, "bad " + typ + " Flags", l}, ""
} else {
rr.Flags = uint16(i)
}
<-c // _BLANK
l = <-c // _STRING
if i, e := strconv.Atoi(l.token); e != nil {
return nil, &ParseError{f, "bad DNSKEY Protocol", l}, ""
return nil, &ParseError{f, "bad " + typ + " Protocol", l}, ""
} else {
rr.Protocol = uint8(i)
}
<-c // _BLANK
l = <-c // _STRING
if i, e := strconv.Atoi(l.token); e != nil {
return nil, &ParseError{f, "bad DNSKEY Algorithm", l}, ""
return nil, &ParseError{f, "bad " + typ + " Algorithm", l}, ""
} else {
rr.Algorithm = uint8(i)
}
s, e, c1 := endingToString(c, "bad DNSKEY PublicKey", f)
s, e, c1 := endingToString(c, "bad "+typ+" PublicKey", f)
if e != nil {
return nil, e, c1
}
@ -1487,6 +1495,27 @@ func setDNSKEY(h RR_Header, c chan lex, o, f string) (RR, *ParseError, string) {
return rr, nil, c1
}
func setKEY(h RR_Header, c chan lex, o, f string) (RR, *ParseError, string) {
r, e, s := setDNSKEYs(h, c, o, f, "KEY")
if r != nil {
return &KEY{*r.(*DNSKEY)}, e, s
}
return nil, e, s
}
func setDNSKEY(h RR_Header, c chan lex, o, f string) (RR, *ParseError, string) {
r, e, s := setDNSKEYs(h, c, o, f, "DNSKEY")
return r, e, s
}
func setCDNSKEY(h RR_Header, c chan lex, o, f string) (RR, *ParseError, string) {
r, e, s := setDNSKEYs(h, c, o, f, "CDNSKEY")
if r != nil {
return &CDNSKEY{*r.(*DNSKEY)}, e, s
}
return nil, e, s
}
func setRKEY(h RR_Header, c chan lex, o, f string) (RR, *ParseError, string) {
rr := new(RKEY)
rr.Hdr = h
@ -1522,44 +1551,6 @@ func setRKEY(h RR_Header, c chan lex, o, f string) (RR, *ParseError, string) {
return rr, nil, c1
}
func setDS(h RR_Header, c chan lex, o, f string) (RR, *ParseError, string) {
rr := new(DS)
rr.Hdr = h
l := <-c
if l.length == 0 {
return rr, nil, l.comment
}
if i, e := strconv.Atoi(l.token); e != nil {
return nil, &ParseError{f, "bad DS KeyTag", l}, ""
} else {
rr.KeyTag = uint16(i)
}
<-c // _BLANK
l = <-c
if i, e := strconv.Atoi(l.token); e != nil {
if i, ok := StringToAlgorithm[l.tokenUpper]; !ok {
return nil, &ParseError{f, "bad DS Algorithm", l}, ""
} else {
rr.Algorithm = i
}
} else {
rr.Algorithm = uint8(i)
}
<-c // _BLANK
l = <-c
if i, e := strconv.Atoi(l.token); e != nil {
return nil, &ParseError{f, "bad DS DigestType", l}, ""
} else {
rr.DigestType = uint8(i)
}
s, e, c1 := endingToString(c, "bad DS Digest", f)
if e != nil {
return nil, e, c1
}
rr.Digest = s
return rr, nil, c1
}
func setEID(h RR_Header, c chan lex, o, f string) (RR, *ParseError, string) {
rr := new(EID)
rr.Hdr = h
@ -1632,15 +1623,15 @@ func setGPOS(h RR_Header, c chan lex, o, f string) (RR, *ParseError, string) {
return rr, nil, ""
}
func setCDS(h RR_Header, c chan lex, o, f string) (RR, *ParseError, string) {
rr := new(CDS)
func setDSs(h RR_Header, c chan lex, o, f, typ string) (RR, *ParseError, string) {
rr := new(DS)
rr.Hdr = h
l := <-c
if l.length == 0 {
return rr, nil, l.comment
}
if i, e := strconv.Atoi(l.token); e != nil {
return nil, &ParseError{f, "bad CDS KeyTag", l}, ""
return nil, &ParseError{f, "bad " + typ + " KeyTag", l}, ""
} else {
rr.KeyTag = uint16(i)
}
@ -1648,7 +1639,7 @@ func setCDS(h RR_Header, c chan lex, o, f string) (RR, *ParseError, string) {
l = <-c
if i, e := strconv.Atoi(l.token); e != nil {
if i, ok := StringToAlgorithm[l.tokenUpper]; !ok {
return nil, &ParseError{f, "bad CDS Algorithm", l}, ""
return nil, &ParseError{f, "bad " + typ + " Algorithm", l}, ""
} else {
rr.Algorithm = i
}
@ -1658,11 +1649,11 @@ func setCDS(h RR_Header, c chan lex, o, f string) (RR, *ParseError, string) {
<-c // _BLANK
l = <-c
if i, e := strconv.Atoi(l.token); e != nil {
return nil, &ParseError{f, "bad CDS DigestType", l}, ""
return nil, &ParseError{f, "bad " + typ + " DigestType", l}, ""
} else {
rr.DigestType = uint8(i)
}
s, e, c1 := endingToString(c, "bad CDS Digest", f)
s, e, c1 := endingToString(c, "bad "+typ+" Digest", f)
if e != nil {
return nil, e, c1
}
@ -1670,42 +1661,25 @@ func setCDS(h RR_Header, c chan lex, o, f string) (RR, *ParseError, string) {
return rr, nil, c1
}
func setDS(h RR_Header, c chan lex, o, f string) (RR, *ParseError, string) {
r, e, s := setDSs(h, c, o, f, "DS")
return r, e, s
}
func setDLV(h RR_Header, c chan lex, o, f string) (RR, *ParseError, string) {
rr := new(DLV)
rr.Hdr = h
l := <-c
if l.length == 0 {
return rr, nil, l.comment
r, e, s := setDSs(h, c, o, f, "DLV")
if r != nil {
return &DLV{*r.(*DS)}, e, s
}
if i, e := strconv.Atoi(l.token); e != nil {
return nil, &ParseError{f, "bad DLV KeyTag", l}, ""
} else {
rr.KeyTag = uint16(i)
return nil, e, s
}
<-c // _BLANK
l = <-c
if i, e := strconv.Atoi(l.token); e != nil {
if i, ok := StringToAlgorithm[l.tokenUpper]; !ok {
return nil, &ParseError{f, "bad DLV Algorithm", l}, ""
} else {
rr.Algorithm = i
func setCDS(h RR_Header, c chan lex, o, f string) (RR, *ParseError, string) {
r, e, s := setDSs(h, c, o, f, "DLV")
if r != nil {
return &CDS{*r.(*DS)}, e, s
}
} else {
rr.Algorithm = uint8(i)
}
<-c // _BLANK
l = <-c
if i, e := strconv.Atoi(l.token); e != nil {
return nil, &ParseError{f, "bad DLV DigestType", l}, ""
} else {
rr.DigestType = uint8(i)
}
s, e, c1 := endingToString(c, "bad DLV Digest", f)
if e != nil {
return nil, e, c1
}
rr.Digest = s
return rr, nil, c1
return nil, e, s
}
func setTA(h RR_Header, c chan lex, o, f string) (RR, *ParseError, string) {
@ -1873,44 +1847,6 @@ func setURI(h RR_Header, c chan lex, o, f string) (RR, *ParseError, string) {
return rr, nil, c1
}
func setIPSECKEY(h RR_Header, c chan lex, o, f string) (RR, *ParseError, string) {
rr := new(IPSECKEY)
rr.Hdr = h
l := <-c
if l.length == 0 {
return rr, nil, l.comment
}
if i, e := strconv.Atoi(l.token); e != nil {
return nil, &ParseError{f, "bad IPSECKEY Precedence", l}, ""
} else {
rr.Precedence = uint8(i)
}
<-c // _BLANK
l = <-c
if i, e := strconv.Atoi(l.token); e != nil {
return nil, &ParseError{f, "bad IPSECKEY GatewayType", l}, ""
} else {
rr.GatewayType = uint8(i)
}
<-c // _BLANK
l = <-c
if i, e := strconv.Atoi(l.token); e != nil {
return nil, &ParseError{f, "bad IPSECKEY Algorithm", l}, ""
} else {
rr.Algorithm = uint8(i)
}
<-c
l = <-c
rr.Gateway = l.token
s, e, c1 := endingToString(c, "bad IPSECKEY PublicKey", f)
if e != nil {
return nil, e, c1
}
rr.PublicKey = s
return rr, nil, c1
}
func setDHCID(h RR_Header, c chan lex, o, f string) (RR, *ParseError, string) {
// awesome record to parse!
rr := new(DHCID)
@ -2113,16 +2049,85 @@ func setPX(h RR_Header, c chan lex, o, f string) (RR, *ParseError, string) {
return rr, nil, ""
}
func setIPSECKEY(h RR_Header, c chan lex, o, f string) (RR, *ParseError, string) {
rr := new(IPSECKEY)
rr.Hdr = h
l := <-c
if l.length == 0 {
return rr, nil, l.comment
}
if i, err := strconv.Atoi(l.token); err != nil {
return nil, &ParseError{f, "bad IPSECKEY Precedence", l}, ""
} else {
rr.Precedence = uint8(i)
}
<-c // _BLANK
l = <-c
if i, err := strconv.Atoi(l.token); err != nil {
return nil, &ParseError{f, "bad IPSECKEY GatewayType", l}, ""
} else {
rr.GatewayType = uint8(i)
}
<-c // _BLANK
l = <-c
if i, err := strconv.Atoi(l.token); err != nil {
return nil, &ParseError{f, "bad IPSECKEY Algorithm", l}, ""
} else {
rr.Algorithm = uint8(i)
}
// Now according to GatewayType we can have different elements here
<-c // _BLANK
l = <-c
switch rr.GatewayType {
case 0:
fallthrough
case 3:
rr.GatewayName = l.token
if l.token == "@" {
rr.GatewayName = o
}
_, ok := IsDomainName(l.token)
if !ok {
return nil, &ParseError{f, "bad IPSECKEY GatewayName", l}, ""
}
if rr.GatewayName[l.length-1] != '.' {
rr.GatewayName = appendOrigin(rr.GatewayName, o)
}
case 1:
rr.GatewayA = net.ParseIP(l.token)
if rr.GatewayA == nil {
return nil, &ParseError{f, "bad IPSECKEY GatewayA", l}, ""
}
case 2:
rr.GatewayAAAA = net.ParseIP(l.token)
if rr.GatewayAAAA == nil {
return nil, &ParseError{f, "bad IPSECKEY GatewayAAAA", l}, ""
}
default:
return nil, &ParseError{f, "bad IPSECKEY GatewayType", l}, ""
}
s, e, c1 := endingToString(c, "bad IPSECKEY PublicKey", f)
if e != nil {
return nil, e, c1
}
rr.PublicKey = s
return rr, nil, c1
}
var typeToparserFunc = map[uint16]parserFunc{
TypeAAAA: parserFunc{setAAAA, false},
TypeAFSDB: parserFunc{setAFSDB, false},
TypeA: parserFunc{setA, false},
TypeCDS: parserFunc{setCDS, true},
TypeCDNSKEY: parserFunc{setCDNSKEY, true},
TypeCERT: parserFunc{setCERT, true},
TypeCNAME: parserFunc{setCNAME, false},
TypeDHCID: parserFunc{setDHCID, true},
TypeDLV: parserFunc{setDLV, true},
TypeDNAME: parserFunc{setDNAME, false},
TypeKEY: parserFunc{setKEY, true},
TypeDNSKEY: parserFunc{setDNSKEY, true},
TypeDS: parserFunc{setDS, true},
TypeEID: parserFunc{setEID, true},
@ -2158,6 +2163,7 @@ var typeToparserFunc = map[uint16]parserFunc{
TypeOPENPGPKEY: parserFunc{setOPENPGPKEY, true},
TypePTR: parserFunc{setPTR, false},
TypePX: parserFunc{setPX, false},
TypeSIG: parserFunc{setSIG, true},
TypeRKEY: parserFunc{setRKEY, true},
TypeRP: parserFunc{setRP, false},
TypeRRSIG: parserFunc{setRRSIG, true},

View File

@ -0,0 +1 @@
Imported at 75cd24fc2f2c from https://bitbucket.org/ww/goautoneg.

View File

@ -0,0 +1,13 @@
include $(GOROOT)/src/Make.inc
TARG=bitbucket.org/ww/goautoneg
GOFILES=autoneg.go
include $(GOROOT)/src/Make.pkg
format:
gofmt -w *.go
docs:
gomake clean
godoc ${TARG} > README.txt

View File

@ -0,0 +1,67 @@
PACKAGE
package goautoneg
import "bitbucket.org/ww/goautoneg"
HTTP Content-Type Autonegotiation.
The functions in this package implement the behaviour specified in
http://www.w3.org/Protocols/rfc2616/rfc2616-sec14.html
Copyright (c) 2011, Open Knowledge Foundation Ltd.
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:
Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in
the documentation and/or other materials provided with the
distribution.
Neither the name of the Open Knowledge Foundation Ltd. nor the
names of its contributors may be used to endorse or promote
products derived from this software without specific prior written
permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
FUNCTIONS
func Negotiate(header string, alternatives []string) (content_type string)
Negotiate the most appropriate content_type given the accept header
and a list of alternatives.
func ParseAccept(header string) (accept []Accept)
Parse an Accept Header string returning a sorted list
of clauses
TYPES
type Accept struct {
Type, SubType string
Q float32
Params map[string]string
}
Structure to represent a clause in an HTTP Accept Header
SUBDIRECTORIES
.hg

View File

@ -0,0 +1,162 @@
/*
HTTP Content-Type Autonegotiation.
The functions in this package implement the behaviour specified in
http://www.w3.org/Protocols/rfc2616/rfc2616-sec14.html
Copyright (c) 2011, Open Knowledge Foundation Ltd.
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:
Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in
the documentation and/or other materials provided with the
distribution.
Neither the name of the Open Knowledge Foundation Ltd. nor the
names of its contributors may be used to endorse or promote
products derived from this software without specific prior written
permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
package goautoneg
import (
"sort"
"strconv"
"strings"
)
// Structure to represent a clause in an HTTP Accept Header
type Accept struct {
Type, SubType string
Q float64
Params map[string]string
}
// For internal use, so that we can use the sort interface
type accept_slice []Accept
func (accept accept_slice) Len() int {
slice := []Accept(accept)
return len(slice)
}
func (accept accept_slice) Less(i, j int) bool {
slice := []Accept(accept)
ai, aj := slice[i], slice[j]
if ai.Q > aj.Q {
return true
}
if ai.Type != "*" && aj.Type == "*" {
return true
}
if ai.SubType != "*" && aj.SubType == "*" {
return true
}
return false
}
func (accept accept_slice) Swap(i, j int) {
slice := []Accept(accept)
slice[i], slice[j] = slice[j], slice[i]
}
// Parse an Accept Header string returning a sorted list
// of clauses
func ParseAccept(header string) (accept []Accept) {
parts := strings.Split(header, ",")
accept = make([]Accept, 0, len(parts))
for _, part := range parts {
part := strings.Trim(part, " ")
a := Accept{}
a.Params = make(map[string]string)
a.Q = 1.0
mrp := strings.Split(part, ";")
media_range := mrp[0]
sp := strings.Split(media_range, "/")
a.Type = strings.Trim(sp[0], " ")
switch {
case len(sp) == 1 && a.Type == "*":
a.SubType = "*"
case len(sp) == 2:
a.SubType = strings.Trim(sp[1], " ")
default:
continue
}
if len(mrp) == 1 {
accept = append(accept, a)
continue
}
for _, param := range mrp[1:] {
sp := strings.SplitN(param, "=", 2)
if len(sp) != 2 {
continue
}
token := strings.Trim(sp[0], " ")
if token == "q" {
a.Q, _ = strconv.ParseFloat(sp[1], 32)
} else {
a.Params[token] = strings.Trim(sp[1], " ")
}
}
accept = append(accept, a)
}
slice := accept_slice(accept)
sort.Sort(slice)
return
}
// Negotiate the most appropriate content_type given the accept header
// and a list of alternatives.
func Negotiate(header string, alternatives []string) (content_type string) {
asp := make([][]string, 0, len(alternatives))
for _, ctype := range alternatives {
asp = append(asp, strings.SplitN(ctype, "/", 2))
}
for _, clause := range ParseAccept(header) {
for i, ctsp := range asp {
if clause.Type == ctsp[0] && clause.SubType == ctsp[1] {
content_type = alternatives[i]
return
}
if clause.Type == ctsp[0] && clause.SubType == "*" {
content_type = alternatives[i]
return
}
if clause.Type == "*" && clause.SubType == "*" {
content_type = alternatives[i]
return
}
}
}
return
}

View File

@ -0,0 +1,33 @@
package goautoneg
import (
"testing"
)
var chrome = "application/xml,application/xhtml+xml,text/html;q=0.9,text/plain;q=0.8,image/png,*/*;q=0.5"
func TestParseAccept(t *testing.T) {
alternatives := []string{"text/html", "image/png"}
content_type := Negotiate(chrome, alternatives)
if content_type != "image/png" {
t.Errorf("got %s expected image/png", content_type)
}
alternatives = []string{"text/html", "text/plain", "text/n3"}
content_type = Negotiate(chrome, alternatives)
if content_type != "text/html" {
t.Errorf("got %s expected text/html", content_type)
}
alternatives = []string{"text/n3", "text/plain"}
content_type = Negotiate(chrome, alternatives)
if content_type != "text/plain" {
t.Errorf("got %s expected text/plain", content_type)
}
alternatives = []string{"text/n3", "application/rdf+xml"}
content_type = Negotiate(chrome, alternatives)
if content_type != "text/n3" {
t.Errorf("got %s expected text/n3", content_type)
}
}

View File

@ -0,0 +1,63 @@
package quantile
import (
"testing"
)
func BenchmarkInsertTargeted(b *testing.B) {
b.ReportAllocs()
s := NewTargeted(Targets)
b.ResetTimer()
for i := float64(0); i < float64(b.N); i++ {
s.Insert(i)
}
}
func BenchmarkInsertTargetedSmallEpsilon(b *testing.B) {
s := NewTargeted(TargetsSmallEpsilon)
b.ResetTimer()
for i := float64(0); i < float64(b.N); i++ {
s.Insert(i)
}
}
func BenchmarkInsertBiased(b *testing.B) {
s := NewLowBiased(0.01)
b.ResetTimer()
for i := float64(0); i < float64(b.N); i++ {
s.Insert(i)
}
}
func BenchmarkInsertBiasedSmallEpsilon(b *testing.B) {
s := NewLowBiased(0.0001)
b.ResetTimer()
for i := float64(0); i < float64(b.N); i++ {
s.Insert(i)
}
}
func BenchmarkQuery(b *testing.B) {
s := NewTargeted(Targets)
for i := float64(0); i < 1e6; i++ {
s.Insert(i)
}
b.ResetTimer()
n := float64(b.N)
for i := float64(0); i < n; i++ {
s.Query(i / n)
}
}
func BenchmarkQuerySmallEpsilon(b *testing.B) {
s := NewTargeted(TargetsSmallEpsilon)
for i := float64(0); i < 1e6; i++ {
s.Insert(i)
}
b.ResetTimer()
n := float64(b.N)
for i := float64(0); i < n; i++ {
s.Query(i / n)
}
}

View File

@ -0,0 +1,112 @@
// +build go1.1
package quantile_test
import (
"bufio"
"fmt"
"github.com/bmizerany/perks/quantile"
"log"
"os"
"strconv"
"time"
)
func Example_simple() {
ch := make(chan float64)
go sendFloats(ch)
// Compute the 50th, 90th, and 99th percentile.
q := quantile.NewTargeted(0.50, 0.90, 0.99)
for v := range ch {
q.Insert(v)
}
fmt.Println("perc50:", q.Query(0.50))
fmt.Println("perc90:", q.Query(0.90))
fmt.Println("perc99:", q.Query(0.99))
fmt.Println("count:", q.Count())
// Output:
// perc50: 5
// perc90: 14
// perc99: 40
// count: 2388
}
func Example_mergeMultipleStreams() {
// Scenario:
// We have multiple database shards. On each shard, there is a process
// collecting query response times from the database logs and inserting
// them into a Stream (created via NewTargeted(0.90)), much like the
// Simple example. These processes expose a network interface for us to
// ask them to serialize and send us the results of their
// Stream.Samples so we may Merge and Query them.
//
// NOTES:
// * These sample sets are small, allowing us to get them
// across the network much faster than sending the entire list of data
// points.
//
// * For this to work correctly, we must supply the same quantiles
// a priori the process collecting the samples supplied to NewTargeted,
// even if we do not plan to query them all here.
ch := make(chan quantile.Samples)
getDBQuerySamples(ch)
q := quantile.NewTargeted(0.90)
for samples := range ch {
q.Merge(samples)
}
fmt.Println("perc90:", q.Query(0.90))
}
func Example_window() {
// Scenario: We want the 90th, 95th, and 99th percentiles for each
// minute.
ch := make(chan float64)
go sendStreamValues(ch)
tick := time.NewTicker(1 * time.Minute)
q := quantile.NewTargeted(0.90, 0.95, 0.99)
for {
select {
case t := <-tick.C:
flushToDB(t, q.Samples())
q.Reset()
case v := <-ch:
q.Insert(v)
}
}
}
func sendStreamValues(ch chan float64) {
// Use your imagination
}
func flushToDB(t time.Time, samples quantile.Samples) {
// Use your imagination
}
// This is a stub for the above example. In reality this would hit the remote
// servers via http or something like it.
func getDBQuerySamples(ch chan quantile.Samples) {}
func sendFloats(ch chan<- float64) {
f, err := os.Open("exampledata.txt")
if err != nil {
log.Fatal(err)
}
sc := bufio.NewScanner(f)
for sc.Scan() {
b := sc.Bytes()
v, err := strconv.ParseFloat(string(b), 64)
if err != nil {
log.Fatal(err)
}
ch <- v
}
if sc.Err() != nil {
log.Fatal(sc.Err())
}
close(ch)
}

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,292 @@
// Package quantile computes approximate quantiles over an unbounded data
// stream within low memory and CPU bounds.
//
// A small amount of accuracy is traded to achieve the above properties.
//
// Multiple streams can be merged before calling Query to generate a single set
// of results. This is meaningful when the streams represent the same type of
// data. See Merge and Samples.
//
// For more detailed information about the algorithm used, see:
//
// Effective Computation of Biased Quantiles over Data Streams
//
// http://www.cs.rutgers.edu/~muthu/bquant.pdf
package quantile
import (
"math"
"sort"
)
// Sample holds an observed value and meta information for compression. JSON
// tags have been added for convenience.
type Sample struct {
Value float64 `json:",string"`
Width float64 `json:",string"`
Delta float64 `json:",string"`
}
// Samples represents a slice of samples. It implements sort.Interface.
type Samples []Sample
func (a Samples) Len() int { return len(a) }
func (a Samples) Less(i, j int) bool { return a[i].Value < a[j].Value }
func (a Samples) Swap(i, j int) { a[i], a[j] = a[j], a[i] }
type invariant func(s *stream, r float64) float64
// NewLowBiased returns an initialized Stream for low-biased quantiles
// (e.g. 0.01, 0.1, 0.5) where the needed quantiles are not known a priori, but
// error guarantees can still be given even for the lower ranks of the data
// distribution.
//
// The provided epsilon is a relative error, i.e. the true quantile of a value
// returned by a query is guaranteed to be within (1±Epsilon)*Quantile.
//
// See http://www.cs.rutgers.edu/~muthu/bquant.pdf for time, space, and error
// properties.
func NewLowBiased(epsilon float64) *Stream {
ƒ := func(s *stream, r float64) float64 {
return 2 * epsilon * r
}
return newStream(ƒ)
}
// NewHighBiased returns an initialized Stream for high-biased quantiles
// (e.g. 0.01, 0.1, 0.5) where the needed quantiles are not known a priori, but
// error guarantees can still be given even for the higher ranks of the data
// distribution.
//
// The provided epsilon is a relative error, i.e. the true quantile of a value
// returned by a query is guaranteed to be within 1-(1±Epsilon)*(1-Quantile).
//
// See http://www.cs.rutgers.edu/~muthu/bquant.pdf for time, space, and error
// properties.
func NewHighBiased(epsilon float64) *Stream {
ƒ := func(s *stream, r float64) float64 {
return 2 * epsilon * (s.n - r)
}
return newStream(ƒ)
}
// NewTargeted returns an initialized Stream concerned with a particular set of
// quantile values that are supplied a priori. Knowing these a priori reduces
// space and computation time. The targets map maps the desired quantiles to
// their absolute errors, i.e. the true quantile of a value returned by a query
// is guaranteed to be within (Quantile±Epsilon).
//
// See http://www.cs.rutgers.edu/~muthu/bquant.pdf for time, space, and error properties.
func NewTargeted(targets map[float64]float64) *Stream {
ƒ := func(s *stream, r float64) float64 {
var m = math.MaxFloat64
var f float64
for quantile, epsilon := range targets {
if quantile*s.n <= r {
f = (2 * epsilon * r) / quantile
} else {
f = (2 * epsilon * (s.n - r)) / (1 - quantile)
}
if f < m {
m = f
}
}
return m
}
return newStream(ƒ)
}
// Stream computes quantiles for a stream of float64s. It is not thread-safe by
// design. Take care when using across multiple goroutines.
type Stream struct {
*stream
b Samples
sorted bool
}
func newStream(ƒ invariant) *Stream {
x := &stream{ƒ: ƒ}
return &Stream{x, make(Samples, 0, 500), true}
}
// Insert inserts v into the stream.
func (s *Stream) Insert(v float64) {
s.insert(Sample{Value: v, Width: 1})
}
func (s *Stream) insert(sample Sample) {
s.b = append(s.b, sample)
s.sorted = false
if len(s.b) == cap(s.b) {
s.flush()
}
}
// Query returns the computed qth percentiles value. If s was created with
// NewTargeted, and q is not in the set of quantiles provided a priori, Query
// will return an unspecified result.
func (s *Stream) Query(q float64) float64 {
if !s.flushed() {
// Fast path when there hasn't been enough data for a flush;
// this also yields better accuracy for small sets of data.
l := len(s.b)
if l == 0 {
return 0
}
i := int(float64(l) * q)
if i > 0 {
i -= 1
}
s.maybeSort()
return s.b[i].Value
}
s.flush()
return s.stream.query(q)
}
// Merge merges samples into the underlying streams samples. This is handy when
// merging multiple streams from separate threads, database shards, etc.
//
// ATTENTION: This method is broken and does not yield correct results. The
// underlying algorithm is not capable of merging streams correctly.
func (s *Stream) Merge(samples Samples) {
sort.Sort(samples)
s.stream.merge(samples)
}
// Reset reinitializes and clears the list reusing the samples buffer memory.
func (s *Stream) Reset() {
s.stream.reset()
s.b = s.b[:0]
}
// Samples returns stream samples held by s.
func (s *Stream) Samples() Samples {
if !s.flushed() {
return s.b
}
s.flush()
return s.stream.samples()
}
// Count returns the total number of samples observed in the stream
// since initialization.
func (s *Stream) Count() int {
return len(s.b) + s.stream.count()
}
func (s *Stream) flush() {
s.maybeSort()
s.stream.merge(s.b)
s.b = s.b[:0]
}
func (s *Stream) maybeSort() {
if !s.sorted {
s.sorted = true
sort.Sort(s.b)
}
}
func (s *Stream) flushed() bool {
return len(s.stream.l) > 0
}
type stream struct {
n float64
l []Sample
ƒ invariant
}
func (s *stream) reset() {
s.l = s.l[:0]
s.n = 0
}
func (s *stream) insert(v float64) {
s.merge(Samples{{v, 1, 0}})
}
func (s *stream) merge(samples Samples) {
// TODO(beorn7): This tries to merge not only individual samples, but
// whole summaries. The paper doesn't mention merging summaries at
// all. Unittests show that the merging is inaccurate. Find out how to
// do merges properly.
var r float64
i := 0
for _, sample := range samples {
for ; i < len(s.l); i++ {
c := s.l[i]
if c.Value > sample.Value {
// Insert at position i.
s.l = append(s.l, Sample{})
copy(s.l[i+1:], s.l[i:])
s.l[i] = Sample{
sample.Value,
sample.Width,
math.Max(sample.Delta, math.Floor(s.ƒ(s, r))-1),
// TODO(beorn7): How to calculate delta correctly?
}
i++
goto inserted
}
r += c.Width
}
s.l = append(s.l, Sample{sample.Value, sample.Width, 0})
i++
inserted:
s.n += sample.Width
r += sample.Width
}
s.compress()
}
func (s *stream) count() int {
return int(s.n)
}
func (s *stream) query(q float64) float64 {
t := math.Ceil(q * s.n)
t += math.Ceil(s.ƒ(s, t) / 2)
p := s.l[0]
var r float64
for _, c := range s.l[1:] {
r += p.Width
if r+c.Width+c.Delta > t {
return p.Value
}
p = c
}
return p.Value
}
func (s *stream) compress() {
if len(s.l) < 2 {
return
}
x := s.l[len(s.l)-1]
xi := len(s.l) - 1
r := s.n - 1 - x.Width
for i := len(s.l) - 2; i >= 0; i-- {
c := s.l[i]
if c.Width+x.Width+x.Delta <= s.ƒ(s, r) {
x.Width += c.Width
s.l[xi] = x
// Remove element at i.
copy(s.l[i:], s.l[i+1:])
s.l = s.l[:len(s.l)-1]
xi -= 1
} else {
x = c
xi = i
}
r -= c.Width
}
}
func (s *stream) samples() Samples {
samples := make(Samples, len(s.l))
copy(samples, s.l)
return samples
}

View File

@ -0,0 +1,185 @@
package quantile
import (
"math"
"math/rand"
"sort"
"testing"
)
var (
Targets = map[float64]float64{
0.01: 0.001,
0.10: 0.01,
0.50: 0.05,
0.90: 0.01,
0.99: 0.001,
}
TargetsSmallEpsilon = map[float64]float64{
0.01: 0.0001,
0.10: 0.001,
0.50: 0.005,
0.90: 0.001,
0.99: 0.0001,
}
LowQuantiles = []float64{0.01, 0.1, 0.5}
HighQuantiles = []float64{0.99, 0.9, 0.5}
)
const RelativeEpsilon = 0.01
func verifyPercsWithAbsoluteEpsilon(t *testing.T, a []float64, s *Stream) {
sort.Float64s(a)
for quantile, epsilon := range Targets {
n := float64(len(a))
k := int(quantile * n)
lower := int((quantile - epsilon) * n)
if lower < 1 {
lower = 1
}
upper := int(math.Ceil((quantile + epsilon) * n))
if upper > len(a) {
upper = len(a)
}
w, min, max := a[k-1], a[lower-1], a[upper-1]
if g := s.Query(quantile); g < min || g > max {
t.Errorf("q=%f: want %v [%f,%f], got %v", quantile, w, min, max, g)
}
}
}
func verifyLowPercsWithRelativeEpsilon(t *testing.T, a []float64, s *Stream) {
sort.Float64s(a)
for _, qu := range LowQuantiles {
n := float64(len(a))
k := int(qu * n)
lowerRank := int((1 - RelativeEpsilon) * qu * n)
upperRank := int(math.Ceil((1 + RelativeEpsilon) * qu * n))
w, min, max := a[k-1], a[lowerRank-1], a[upperRank-1]
if g := s.Query(qu); g < min || g > max {
t.Errorf("q=%f: want %v [%f,%f], got %v", qu, w, min, max, g)
}
}
}
func verifyHighPercsWithRelativeEpsilon(t *testing.T, a []float64, s *Stream) {
sort.Float64s(a)
for _, qu := range HighQuantiles {
n := float64(len(a))
k := int(qu * n)
lowerRank := int((1 - (1+RelativeEpsilon)*(1-qu)) * n)
upperRank := int(math.Ceil((1 - (1-RelativeEpsilon)*(1-qu)) * n))
w, min, max := a[k-1], a[lowerRank-1], a[upperRank-1]
if g := s.Query(qu); g < min || g > max {
t.Errorf("q=%f: want %v [%f,%f], got %v", qu, w, min, max, g)
}
}
}
func populateStream(s *Stream) []float64 {
a := make([]float64, 0, 1e5+100)
for i := 0; i < cap(a); i++ {
v := rand.NormFloat64()
// Add 5% asymmetric outliers.
if i%20 == 0 {
v = v*v + 1
}
s.Insert(v)
a = append(a, v)
}
return a
}
func TestTargetedQuery(t *testing.T) {
rand.Seed(42)
s := NewTargeted(Targets)
a := populateStream(s)
verifyPercsWithAbsoluteEpsilon(t, a, s)
}
func TestLowBiasedQuery(t *testing.T) {
rand.Seed(42)
s := NewLowBiased(RelativeEpsilon)
a := populateStream(s)
verifyLowPercsWithRelativeEpsilon(t, a, s)
}
func TestHighBiasedQuery(t *testing.T) {
rand.Seed(42)
s := NewHighBiased(RelativeEpsilon)
a := populateStream(s)
verifyHighPercsWithRelativeEpsilon(t, a, s)
}
func TestTargetedMerge(t *testing.T) {
rand.Seed(42)
s1 := NewTargeted(Targets)
s2 := NewTargeted(Targets)
a := populateStream(s1)
a = append(a, populateStream(s2)...)
s1.Merge(s2.Samples())
verifyPercsWithAbsoluteEpsilon(t, a, s1)
}
func TestLowBiasedMerge(t *testing.T) {
rand.Seed(42)
s1 := NewLowBiased(RelativeEpsilon)
s2 := NewLowBiased(RelativeEpsilon)
a := populateStream(s1)
a = append(a, populateStream(s2)...)
s1.Merge(s2.Samples())
verifyLowPercsWithRelativeEpsilon(t, a, s2)
}
func TestHighBiasedMerge(t *testing.T) {
rand.Seed(42)
s1 := NewHighBiased(RelativeEpsilon)
s2 := NewHighBiased(RelativeEpsilon)
a := populateStream(s1)
a = append(a, populateStream(s2)...)
s1.Merge(s2.Samples())
verifyHighPercsWithRelativeEpsilon(t, a, s2)
}
func TestUncompressed(t *testing.T) {
q := NewTargeted(Targets)
for i := 100; i > 0; i-- {
q.Insert(float64(i))
}
if g := q.Count(); g != 100 {
t.Errorf("want count 100, got %d", g)
}
// Before compression, Query should have 100% accuracy.
for quantile := range Targets {
w := quantile * 100
if g := q.Query(quantile); g != w {
t.Errorf("want %f, got %f", w, g)
}
}
}
func TestUncompressedSamples(t *testing.T) {
q := NewTargeted(map[float64]float64{0.99: 0.001})
for i := 1; i <= 100; i++ {
q.Insert(float64(i))
}
if g := q.Samples().Len(); g != 100 {
t.Errorf("want count 100, got %d", g)
}
}
func TestUncompressedOne(t *testing.T) {
q := NewTargeted(map[float64]float64{0.99: 0.01})
q.Insert(3.14)
if g := q.Query(0.90); g != 3.14 {
t.Error("want PI, got", g)
}
}
func TestDefaults(t *testing.T) {
if g := NewTargeted(map[float64]float64{0.99: 0.001}).Query(0.99); g != 0 {
t.Errorf("want 0, got %f", g)
}
}

View File

@ -0,0 +1,74 @@
// Copyright 2013 The Prometheus Authors
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package extraction
import (
"errors"
"fmt"
"mime"
"net/http"
)
// ProcessorForRequestHeader interprets a HTTP request header to determine
// what Processor should be used for the given input. If no acceptable
// Processor can be found, an error is returned.
func ProcessorForRequestHeader(header http.Header) (Processor, error) {
if header == nil {
return nil, errors.New("received illegal and nil header")
}
mediatype, params, err := mime.ParseMediaType(header.Get("Content-Type"))
if err != nil {
return nil, fmt.Errorf("invalid Content-Type header %q: %s", header.Get("Content-Type"), err)
}
switch mediatype {
case "application/vnd.google.protobuf":
if params["proto"] != "io.prometheus.client.MetricFamily" {
return nil, fmt.Errorf("unrecognized protocol message %s", params["proto"])
}
if params["encoding"] != "delimited" {
return nil, fmt.Errorf("unsupported encoding %s", params["encoding"])
}
return MetricFamilyProcessor, nil
case "text/plain":
switch params["version"] {
case "0.0.4":
return Processor004, nil
case "":
// Fallback: most recent version.
return Processor004, nil
default:
return nil, fmt.Errorf("unrecognized API version %s", params["version"])
}
case "application/json":
var prometheusAPIVersion string
if params["schema"] == "prometheus/telemetry" && params["version"] != "" {
prometheusAPIVersion = params["version"]
} else {
prometheusAPIVersion = header.Get("X-Prometheus-API-Version")
}
switch prometheusAPIVersion {
case "0.0.2":
return Processor002, nil
case "0.0.1":
return Processor001, nil
default:
return nil, fmt.Errorf("unrecognized API version %s", prometheusAPIVersion)
}
default:
return nil, fmt.Errorf("unsupported media type %q, expected %q", mediatype, "application/json")
}
}

View File

@ -0,0 +1,126 @@
// Copyright 2013 The Prometheus Authors
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package extraction
import (
"errors"
"net/http"
"testing"
)
func testDiscriminatorHTTPHeader(t testing.TB) {
var scenarios = []struct {
input map[string]string
output Processor
err error
}{
{
output: nil,
err: errors.New("received illegal and nil header"),
},
{
input: map[string]string{"Content-Type": "application/json", "X-Prometheus-API-Version": "0.0.0"},
output: nil,
err: errors.New("unrecognized API version 0.0.0"),
},
{
input: map[string]string{"Content-Type": "application/json", "X-Prometheus-API-Version": "0.0.1"},
output: Processor001,
err: nil,
},
{
input: map[string]string{"Content-Type": `application/json; schema="prometheus/telemetry"; version=0.0.0`},
output: nil,
err: errors.New("unrecognized API version 0.0.0"),
},
{
input: map[string]string{"Content-Type": `application/json; schema="prometheus/telemetry"; version=0.0.1`},
output: Processor001,
err: nil,
},
{
input: map[string]string{"Content-Type": `application/json; schema="prometheus/telemetry"; version=0.0.2`},
output: Processor002,
err: nil,
},
{
input: map[string]string{"Content-Type": `application/vnd.google.protobuf; proto="io.prometheus.client.MetricFamily"; encoding="delimited"`},
output: MetricFamilyProcessor,
err: nil,
},
{
input: map[string]string{"Content-Type": `application/vnd.google.protobuf; proto="illegal"; encoding="delimited"`},
output: nil,
err: errors.New("unrecognized protocol message illegal"),
},
{
input: map[string]string{"Content-Type": `application/vnd.google.protobuf; proto="io.prometheus.client.MetricFamily"; encoding="illegal"`},
output: nil,
err: errors.New("unsupported encoding illegal"),
},
{
input: map[string]string{"Content-Type": `text/plain; version=0.0.4`},
output: Processor004,
err: nil,
},
{
input: map[string]string{"Content-Type": `text/plain`},
output: Processor004,
err: nil,
},
{
input: map[string]string{"Content-Type": `text/plain; version=0.0.3`},
output: nil,
err: errors.New("unrecognized API version 0.0.3"),
},
}
for i, scenario := range scenarios {
var header http.Header
if len(scenario.input) > 0 {
header = http.Header{}
}
for key, value := range scenario.input {
header.Add(key, value)
}
actual, err := ProcessorForRequestHeader(header)
if scenario.err != err {
if scenario.err != nil && err != nil {
if scenario.err.Error() != err.Error() {
t.Errorf("%d. expected %s, got %s", i, scenario.err, err)
}
} else if scenario.err != nil || err != nil {
t.Errorf("%d. expected %s, got %s", i, scenario.err, err)
}
}
if scenario.output != actual {
t.Errorf("%d. expected %s, got %s", i, scenario.output, actual)
}
}
}
func TestDiscriminatorHTTPHeader(t *testing.T) {
testDiscriminatorHTTPHeader(t)
}
func BenchmarkDiscriminatorHTTPHeader(b *testing.B) {
for i := 0; i < b.N; i++ {
testDiscriminatorHTTPHeader(b)
}
}

View File

@ -0,0 +1,15 @@
// Copyright 2013 The Prometheus Authors
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package extraction decodes Prometheus clients' data streams for consumers.
package extraction

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,79 @@
[
{
"baseLabels": {
"__name__": "rpc_calls_total",
"job": "batch_job"
},
"docstring": "RPC calls.",
"metric": {
"type": "counter",
"value": [
{
"labels": {
"service": "zed"
},
"value": 25
},
{
"labels": {
"service": "bar"
},
"value": 25
},
{
"labels": {
"service": "foo"
},
"value": 25
}
]
}
},
{
"baseLabels": {
"__name__": "rpc_latency_microseconds"
},
"docstring": "RPC latency.",
"metric": {
"type": "histogram",
"value": [
{
"labels": {
"service": "foo"
},
"value": {
"0.010000": 15.890724674774395,
"0.050000": 15.890724674774395,
"0.500000": 84.63044031436561,
"0.900000": 160.21100853053224,
"0.990000": 172.49828748957728
}
},
{
"labels": {
"service": "zed"
},
"value": {
"0.010000": 0.0459814091918713,
"0.050000": 0.0459814091918713,
"0.500000": 0.6120456642749681,
"0.900000": 1.355915069887731,
"0.990000": 1.772733213161236
}
},
{
"labels": {
"service": "bar"
},
"value": {
"0.010000": 78.48563317257356,
"0.050000": 78.48563317257356,
"0.500000": 97.31798360385088,
"0.900000": 109.89202084295582,
"0.990000": 109.99626121011262
}
}
]
}
}
]

View File

@ -0,0 +1,295 @@
// Copyright 2013 The Prometheus Authors
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package extraction
import (
"fmt"
"io"
dto "github.com/prometheus/client_model/go"
"github.com/matttproud/golang_protobuf_extensions/ext"
"github.com/prometheus/client_golang/model"
)
type metricFamilyProcessor struct{}
// MetricFamilyProcessor decodes varint encoded record length-delimited streams
// of io.prometheus.client.MetricFamily.
//
// See http://godoc.org/github.com/matttproud/golang_protobuf_extensions/ext for
// more details.
var MetricFamilyProcessor = &metricFamilyProcessor{}
func (m *metricFamilyProcessor) ProcessSingle(i io.Reader, out Ingester, o *ProcessOptions) error {
family := &dto.MetricFamily{}
for {
family.Reset()
if _, err := ext.ReadDelimited(i, family); err != nil {
if err == io.EOF {
return nil
}
return err
}
if err := extractMetricFamily(out, o, family); err != nil {
return err
}
}
}
func extractMetricFamily(out Ingester, o *ProcessOptions, family *dto.MetricFamily) error {
switch family.GetType() {
case dto.MetricType_COUNTER:
if err := extractCounter(out, o, family); err != nil {
return err
}
case dto.MetricType_GAUGE:
if err := extractGauge(out, o, family); err != nil {
return err
}
case dto.MetricType_SUMMARY:
if err := extractSummary(out, o, family); err != nil {
return err
}
case dto.MetricType_UNTYPED:
if err := extractUntyped(out, o, family); err != nil {
return err
}
case dto.MetricType_HISTOGRAM:
if err := extractHistogram(out, o, family); err != nil {
return err
}
}
return nil
}
func extractCounter(out Ingester, o *ProcessOptions, f *dto.MetricFamily) error {
samples := make(model.Samples, 0, len(f.Metric))
for _, m := range f.Metric {
if m.Counter == nil {
continue
}
sample := new(model.Sample)
samples = append(samples, sample)
if m.TimestampMs != nil {
sample.Timestamp = model.TimestampFromUnixNano(*m.TimestampMs * 1000000)
} else {
sample.Timestamp = o.Timestamp
}
sample.Metric = model.Metric{}
metric := sample.Metric
for _, p := range m.Label {
metric[model.LabelName(p.GetName())] = model.LabelValue(p.GetValue())
}
metric[model.MetricNameLabel] = model.LabelValue(f.GetName())
sample.Value = model.SampleValue(m.Counter.GetValue())
}
return out.Ingest(samples)
}
func extractGauge(out Ingester, o *ProcessOptions, f *dto.MetricFamily) error {
samples := make(model.Samples, 0, len(f.Metric))
for _, m := range f.Metric {
if m.Gauge == nil {
continue
}
sample := new(model.Sample)
samples = append(samples, sample)
if m.TimestampMs != nil {
sample.Timestamp = model.TimestampFromUnixNano(*m.TimestampMs * 1000000)
} else {
sample.Timestamp = o.Timestamp
}
sample.Metric = model.Metric{}
metric := sample.Metric
for _, p := range m.Label {
metric[model.LabelName(p.GetName())] = model.LabelValue(p.GetValue())
}
metric[model.MetricNameLabel] = model.LabelValue(f.GetName())
sample.Value = model.SampleValue(m.Gauge.GetValue())
}
return out.Ingest(samples)
}
func extractSummary(out Ingester, o *ProcessOptions, f *dto.MetricFamily) error {
samples := make(model.Samples, 0, len(f.Metric))
for _, m := range f.Metric {
if m.Summary == nil {
continue
}
timestamp := o.Timestamp
if m.TimestampMs != nil {
timestamp = model.TimestampFromUnixNano(*m.TimestampMs * 1000000)
}
for _, q := range m.Summary.Quantile {
sample := new(model.Sample)
samples = append(samples, sample)
sample.Timestamp = timestamp
sample.Metric = model.Metric{}
metric := sample.Metric
for _, p := range m.Label {
metric[model.LabelName(p.GetName())] = model.LabelValue(p.GetValue())
}
// BUG(matt): Update other names to "quantile".
metric[model.LabelName("quantile")] = model.LabelValue(fmt.Sprint(q.GetQuantile()))
metric[model.MetricNameLabel] = model.LabelValue(f.GetName())
sample.Value = model.SampleValue(q.GetValue())
}
if m.Summary.SampleSum != nil {
sum := new(model.Sample)
sum.Timestamp = timestamp
metric := model.Metric{}
for _, p := range m.Label {
metric[model.LabelName(p.GetName())] = model.LabelValue(p.GetValue())
}
metric[model.MetricNameLabel] = model.LabelValue(f.GetName() + "_sum")
sum.Metric = metric
sum.Value = model.SampleValue(m.Summary.GetSampleSum())
samples = append(samples, sum)
}
if m.Summary.SampleCount != nil {
count := new(model.Sample)
count.Timestamp = timestamp
metric := model.Metric{}
for _, p := range m.Label {
metric[model.LabelName(p.GetName())] = model.LabelValue(p.GetValue())
}
metric[model.MetricNameLabel] = model.LabelValue(f.GetName() + "_count")
count.Metric = metric
count.Value = model.SampleValue(m.Summary.GetSampleCount())
samples = append(samples, count)
}
}
return out.Ingest(samples)
}
func extractUntyped(out Ingester, o *ProcessOptions, f *dto.MetricFamily) error {
samples := make(model.Samples, 0, len(f.Metric))
for _, m := range f.Metric {
if m.Untyped == nil {
continue
}
sample := new(model.Sample)
samples = append(samples, sample)
if m.TimestampMs != nil {
sample.Timestamp = model.TimestampFromUnixNano(*m.TimestampMs * 1000000)
} else {
sample.Timestamp = o.Timestamp
}
sample.Metric = model.Metric{}
metric := sample.Metric
for _, p := range m.Label {
metric[model.LabelName(p.GetName())] = model.LabelValue(p.GetValue())
}
metric[model.MetricNameLabel] = model.LabelValue(f.GetName())
sample.Value = model.SampleValue(m.Untyped.GetValue())
}
return out.Ingest(samples)
}
func extractHistogram(out Ingester, o *ProcessOptions, f *dto.MetricFamily) error {
samples := make(model.Samples, 0, len(f.Metric))
for _, m := range f.Metric {
if m.Histogram == nil {
continue
}
timestamp := o.Timestamp
if m.TimestampMs != nil {
timestamp = model.TimestampFromUnixNano(*m.TimestampMs * 1000000)
}
for _, q := range m.Histogram.Bucket {
sample := new(model.Sample)
samples = append(samples, sample)
sample.Timestamp = timestamp
sample.Metric = model.Metric{}
metric := sample.Metric
for _, p := range m.Label {
metric[model.LabelName(p.GetName())] = model.LabelValue(p.GetValue())
}
metric[model.LabelName("le")] = model.LabelValue(fmt.Sprint(q.GetUpperBound()))
metric[model.MetricNameLabel] = model.LabelValue(f.GetName() + "_bucket")
sample.Value = model.SampleValue(q.GetCumulativeCount())
}
// TODO: If +Inf bucket is missing, add it.
if m.Histogram.SampleSum != nil {
sum := new(model.Sample)
sum.Timestamp = timestamp
metric := model.Metric{}
for _, p := range m.Label {
metric[model.LabelName(p.GetName())] = model.LabelValue(p.GetValue())
}
metric[model.MetricNameLabel] = model.LabelValue(f.GetName() + "_sum")
sum.Metric = metric
sum.Value = model.SampleValue(m.Histogram.GetSampleSum())
samples = append(samples, sum)
}
if m.Histogram.SampleCount != nil {
count := new(model.Sample)
count.Timestamp = timestamp
metric := model.Metric{}
for _, p := range m.Label {
metric[model.LabelName(p.GetName())] = model.LabelValue(p.GetValue())
}
metric[model.MetricNameLabel] = model.LabelValue(f.GetName() + "_count")
count.Metric = metric
count.Value = model.SampleValue(m.Histogram.GetSampleCount())
samples = append(samples, count)
}
}
return out.Ingest(samples)
}

View File

@ -0,0 +1,153 @@
// Copyright 2013 The Prometheus Authors
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package extraction
import (
"sort"
"strings"
"testing"
"github.com/prometheus/client_golang/model"
)
var testTime = model.Now()
type metricFamilyProcessorScenario struct {
in string
expected, actual []model.Samples
}
func (s *metricFamilyProcessorScenario) Ingest(samples model.Samples) error {
s.actual = append(s.actual, samples)
return nil
}
func (s *metricFamilyProcessorScenario) test(t *testing.T, set int) {
i := strings.NewReader(s.in)
o := &ProcessOptions{
Timestamp: testTime,
}
err := MetricFamilyProcessor.ProcessSingle(i, s, o)
if err != nil {
t.Fatalf("%d. got error: %s", set, err)
}
if len(s.expected) != len(s.actual) {
t.Fatalf("%d. expected length %d, got %d", set, len(s.expected), len(s.actual))
}
for i, expected := range s.expected {
sort.Sort(s.actual[i])
sort.Sort(expected)
if !expected.Equal(s.actual[i]) {
t.Errorf("%d.%d. expected %s, got %s", set, i, expected, s.actual[i])
}
}
}
func TestMetricFamilyProcessor(t *testing.T) {
scenarios := []metricFamilyProcessorScenario{
{
in: "",
},
{
in: "\x8f\x01\n\rrequest_count\x12\x12Number of requests\x18\x00\"0\n#\n\x0fsome_label_name\x12\x10some_label_value\x1a\t\t\x00\x00\x00\x00\x00\x00E\xc0\"6\n)\n\x12another_label_name\x12\x13another_label_value\x1a\t\t\x00\x00\x00\x00\x00\x00U@",
expected: []model.Samples{
model.Samples{
&model.Sample{
Metric: model.Metric{model.MetricNameLabel: "request_count", "some_label_name": "some_label_value"},
Value: -42,
Timestamp: testTime,
},
&model.Sample{
Metric: model.Metric{model.MetricNameLabel: "request_count", "another_label_name": "another_label_value"},
Value: 84,
Timestamp: testTime,
},
},
},
},
{
in: "\xb9\x01\n\rrequest_count\x12\x12Number of requests\x18\x02\"O\n#\n\x0fsome_label_name\x12\x10some_label_value\"(\x1a\x12\t\xaeG\xe1z\x14\xae\xef?\x11\x00\x00\x00\x00\x00\x00E\xc0\x1a\x12\t+\x87\x16\xd9\xce\xf7\xef?\x11\x00\x00\x00\x00\x00\x00U\xc0\"A\n)\n\x12another_label_name\x12\x13another_label_value\"\x14\x1a\x12\t\x00\x00\x00\x00\x00\x00\xe0?\x11\x00\x00\x00\x00\x00\x00$@",
expected: []model.Samples{
model.Samples{
&model.Sample{
Metric: model.Metric{model.MetricNameLabel: "request_count", "some_label_name": "some_label_value", "quantile": "0.99"},
Value: -42,
Timestamp: testTime,
},
&model.Sample{
Metric: model.Metric{model.MetricNameLabel: "request_count", "some_label_name": "some_label_value", "quantile": "0.999"},
Value: -84,
Timestamp: testTime,
},
&model.Sample{
Metric: model.Metric{model.MetricNameLabel: "request_count", "another_label_name": "another_label_value", "quantile": "0.5"},
Value: 10,
Timestamp: testTime,
},
},
},
},
{
in: "\x8d\x01\n\x1drequest_duration_microseconds\x12\x15The response latency.\x18\x04\"S:Q\b\x85\x15\x11\xcd\xcc\xccL\x8f\xcb:A\x1a\v\b{\x11\x00\x00\x00\x00\x00\x00Y@\x1a\f\b\x9c\x03\x11\x00\x00\x00\x00\x00\x00^@\x1a\f\b\xd0\x04\x11\x00\x00\x00\x00\x00\x00b@\x1a\f\b\xf4\v\x11\x9a\x99\x99\x99\x99\x99e@\x1a\f\b\x85\x15\x11\x00\x00\x00\x00\x00\x00\xf0\u007f",
expected: []model.Samples{
model.Samples{
&model.Sample{
Metric: model.Metric{model.MetricNameLabel: "request_duration_microseconds_bucket", "le": "100"},
Value: 123,
Timestamp: testTime,
},
&model.Sample{
Metric: model.Metric{model.MetricNameLabel: "request_duration_microseconds_bucket", "le": "120"},
Value: 412,
Timestamp: testTime,
},
&model.Sample{
Metric: model.Metric{model.MetricNameLabel: "request_duration_microseconds_bucket", "le": "144"},
Value: 592,
Timestamp: testTime,
},
&model.Sample{
Metric: model.Metric{model.MetricNameLabel: "request_duration_microseconds_bucket", "le": "172.8"},
Value: 1524,
Timestamp: testTime,
},
&model.Sample{
Metric: model.Metric{model.MetricNameLabel: "request_duration_microseconds_bucket", "le": "+Inf"},
Value: 2693,
Timestamp: testTime,
},
&model.Sample{
Metric: model.Metric{model.MetricNameLabel: "request_duration_microseconds_sum"},
Value: 1756047.3,
Timestamp: testTime,
},
&model.Sample{
Metric: model.Metric{model.MetricNameLabel: "request_duration_microseconds_count"},
Value: 2693,
Timestamp: testTime,
},
},
},
},
}
for i, scenario := range scenarios {
scenario.test(t, i)
}
}

View File

@ -0,0 +1,84 @@
// Copyright 2013 The Prometheus Authors
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package extraction
import (
"io"
"time"
"github.com/prometheus/client_golang/model"
)
// ProcessOptions dictates how the interpreted stream should be rendered for
// consumption.
type ProcessOptions struct {
// Timestamp is added to each value from the stream that has no explicit
// timestamp set.
Timestamp model.Timestamp
}
// Ingester consumes result streams in whatever way is desired by the user.
type Ingester interface {
Ingest(model.Samples) error
}
// Processor is responsible for decoding the actual message responses from
// stream into a format that can be consumed with the end result written
// to the results channel.
type Processor interface {
// ProcessSingle treats the input as a single self-contained message body and
// transforms it accordingly. It has no support for streaming.
ProcessSingle(in io.Reader, out Ingester, o *ProcessOptions) error
}
// Helper function to convert map[string]string into LabelSet.
//
// NOTE: This should be deleted when support for go 1.0.3 is removed; 1.1 is
// smart enough to unmarshal JSON objects into LabelSet directly.
func labelSet(labels map[string]string) model.LabelSet {
labelset := make(model.LabelSet, len(labels))
for k, v := range labels {
labelset[model.LabelName(k)] = model.LabelValue(v)
}
return labelset
}
// A basic interface only useful in testing contexts for dispensing the time
// in a controlled manner.
type instantProvider interface {
// The current instant.
Now() time.Time
}
// Clock is a simple means for fluently wrapping around standard Go timekeeping
// mechanisms to enhance testability without compromising code readability.
//
// It is sufficient for use on bare initialization. A provider should be
// set only for test contexts. When not provided, it emits the current
// system time.
type clock struct {
// The underlying means through which time is provided, if supplied.
Provider instantProvider
}
// Emit the current instant.
func (t *clock) Now() time.Time {
if t.Provider == nil {
return time.Now()
}
return t.Provider.Now()
}

View File

@ -0,0 +1,127 @@
// Copyright 2013 The Prometheus Authors
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package extraction
import (
"encoding/json"
"fmt"
"io"
"io/ioutil"
"github.com/prometheus/client_golang/model"
)
const (
baseLabels001 = "baseLabels"
counter001 = "counter"
docstring001 = "docstring"
gauge001 = "gauge"
histogram001 = "histogram"
labels001 = "labels"
metric001 = "metric"
type001 = "type"
value001 = "value"
percentile001 = "percentile"
)
// Processor001 is responsible for decoding payloads from protocol version
// 0.0.1.
var Processor001 = &processor001{}
// processor001 is responsible for handling API version 0.0.1.
type processor001 struct{}
// entity001 represents a the JSON structure that 0.0.1 uses.
type entity001 []struct {
BaseLabels map[string]string `json:"baseLabels"`
Docstring string `json:"docstring"`
Metric struct {
MetricType string `json:"type"`
Value []struct {
Labels map[string]string `json:"labels"`
Value interface{} `json:"value"`
} `json:"value"`
} `json:"metric"`
}
func (p *processor001) ProcessSingle(in io.Reader, out Ingester, o *ProcessOptions) error {
// TODO(matt): Replace with plain-jane JSON unmarshalling.
buffer, err := ioutil.ReadAll(in)
if err != nil {
return err
}
entities := entity001{}
if err = json.Unmarshal(buffer, &entities); err != nil {
return err
}
// TODO(matt): This outer loop is a great basis for parallelization.
pendingSamples := model.Samples{}
for _, entity := range entities {
for _, value := range entity.Metric.Value {
labels := labelSet(entity.BaseLabels).Merge(labelSet(value.Labels))
switch entity.Metric.MetricType {
case gauge001, counter001:
sampleValue, ok := value.Value.(float64)
if !ok {
return fmt.Errorf("could not convert value from %s %s to float64", entity, value)
}
pendingSamples = append(pendingSamples, &model.Sample{
Metric: model.Metric(labels),
Timestamp: o.Timestamp,
Value: model.SampleValue(sampleValue),
})
break
case histogram001:
sampleValue, ok := value.Value.(map[string]interface{})
if !ok {
return fmt.Errorf("could not convert value from %q to a map[string]interface{}", value.Value)
}
for percentile, percentileValue := range sampleValue {
individualValue, ok := percentileValue.(float64)
if !ok {
return fmt.Errorf("could not convert value from %q to a float64", percentileValue)
}
childMetric := make(map[model.LabelName]model.LabelValue, len(labels)+1)
for k, v := range labels {
childMetric[k] = v
}
childMetric[model.LabelName(percentile001)] = model.LabelValue(percentile)
pendingSamples = append(pendingSamples, &model.Sample{
Metric: model.Metric(childMetric),
Timestamp: o.Timestamp,
Value: model.SampleValue(individualValue),
})
}
break
}
}
}
if len(pendingSamples) > 0 {
return out.Ingest(pendingSamples)
}
return nil
}

View File

@ -0,0 +1,186 @@
// Copyright 2013 The Prometheus Authors
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package extraction
import (
"errors"
"os"
"path"
"sort"
"testing"
"github.com/prometheus/prometheus/utility/test"
"github.com/prometheus/client_golang/model"
)
var test001Time = model.Now()
type testProcessor001ProcessScenario struct {
in string
expected, actual []model.Samples
err error
}
func (s *testProcessor001ProcessScenario) Ingest(samples model.Samples) error {
s.actual = append(s.actual, samples)
return nil
}
func (s *testProcessor001ProcessScenario) test(t testing.TB, set int) {
reader, err := os.Open(path.Join("fixtures", s.in))
if err != nil {
t.Fatalf("%d. couldn't open scenario input file %s: %s", set, s.in, err)
}
options := &ProcessOptions{
Timestamp: test001Time,
}
if err := Processor001.ProcessSingle(reader, s, options); !test.ErrorEqual(s.err, err) {
t.Fatalf("%d. expected err of %s, got %s", set, s.err, err)
}
if len(s.actual) != len(s.expected) {
t.Fatalf("%d. expected output length of %d, got %d", set, len(s.expected), len(s.actual))
}
for i, expected := range s.expected {
sort.Sort(s.actual[i])
sort.Sort(expected)
if !expected.Equal(s.actual[i]) {
t.Errorf("%d.%d. expected %s, got %s", set, i, expected, s.actual[i])
}
}
}
func testProcessor001Process(t testing.TB) {
var scenarios = []testProcessor001ProcessScenario{
{
in: "empty.json",
err: errors.New("unexpected end of JSON input"),
},
{
in: "test0_0_1-0_0_2.json",
expected: []model.Samples{
model.Samples{
&model.Sample{
Metric: model.Metric{"service": "zed", model.MetricNameLabel: "rpc_calls_total", "job": "batch_job"},
Value: 25,
Timestamp: test001Time,
},
&model.Sample{
Metric: model.Metric{"service": "bar", model.MetricNameLabel: "rpc_calls_total", "job": "batch_job"},
Value: 25,
Timestamp: test001Time,
},
&model.Sample{
Metric: model.Metric{"service": "foo", model.MetricNameLabel: "rpc_calls_total", "job": "batch_job"},
Value: 25,
Timestamp: test001Time,
},
&model.Sample{
Metric: model.Metric{"percentile": "0.010000", model.MetricNameLabel: "rpc_latency_microseconds", "service": "zed"},
Value: 0.0459814091918713,
Timestamp: test001Time,
},
&model.Sample{
Metric: model.Metric{"percentile": "0.010000", model.MetricNameLabel: "rpc_latency_microseconds", "service": "bar"},
Value: 78.48563317257356,
Timestamp: test001Time,
},
&model.Sample{
Metric: model.Metric{"percentile": "0.010000", model.MetricNameLabel: "rpc_latency_microseconds", "service": "foo"},
Value: 15.890724674774395,
Timestamp: test001Time,
},
&model.Sample{
Metric: model.Metric{"percentile": "0.050000", model.MetricNameLabel: "rpc_latency_microseconds", "service": "zed"},
Value: 0.0459814091918713,
Timestamp: test001Time,
},
&model.Sample{
Metric: model.Metric{"percentile": "0.050000", model.MetricNameLabel: "rpc_latency_microseconds", "service": "bar"},
Value: 78.48563317257356,
Timestamp: test001Time,
},
&model.Sample{
Metric: model.Metric{"percentile": "0.050000", model.MetricNameLabel: "rpc_latency_microseconds", "service": "foo"},
Value: 15.890724674774395,
Timestamp: test001Time,
},
&model.Sample{
Metric: model.Metric{"percentile": "0.500000", model.MetricNameLabel: "rpc_latency_microseconds", "service": "zed"},
Value: 0.6120456642749681,
Timestamp: test001Time,
},
&model.Sample{
Metric: model.Metric{"percentile": "0.500000", model.MetricNameLabel: "rpc_latency_microseconds", "service": "bar"},
Value: 97.31798360385088,
Timestamp: test001Time,
},
&model.Sample{
Metric: model.Metric{"percentile": "0.500000", model.MetricNameLabel: "rpc_latency_microseconds", "service": "foo"},
Value: 84.63044031436561,
Timestamp: test001Time,
},
&model.Sample{
Metric: model.Metric{"percentile": "0.900000", model.MetricNameLabel: "rpc_latency_microseconds", "service": "zed"},
Value: 1.355915069887731,
Timestamp: test001Time,
},
&model.Sample{
Metric: model.Metric{"percentile": "0.900000", model.MetricNameLabel: "rpc_latency_microseconds", "service": "bar"},
Value: 109.89202084295582,
Timestamp: test001Time,
},
&model.Sample{
Metric: model.Metric{"percentile": "0.900000", model.MetricNameLabel: "rpc_latency_microseconds", "service": "foo"},
Value: 160.21100853053224,
Timestamp: test001Time,
},
&model.Sample{
Metric: model.Metric{"percentile": "0.990000", model.MetricNameLabel: "rpc_latency_microseconds", "service": "zed"},
Value: 1.772733213161236,
Timestamp: test001Time,
},
&model.Sample{
Metric: model.Metric{"percentile": "0.990000", model.MetricNameLabel: "rpc_latency_microseconds", "service": "bar"},
Value: 109.99626121011262,
Timestamp: test001Time,
},
&model.Sample{
Metric: model.Metric{"percentile": "0.990000", model.MetricNameLabel: "rpc_latency_microseconds", "service": "foo"},
Value: 172.49828748957728,
Timestamp: test001Time,
},
},
},
},
}
for i, scenario := range scenarios {
scenario.test(t, i)
}
}
func TestProcessor001Process(t *testing.T) {
testProcessor001Process(t)
}
func BenchmarkProcessor001Process(b *testing.B) {
for i := 0; i < b.N; i++ {
testProcessor001Process(b)
}
}

View File

@ -0,0 +1,106 @@
// Copyright 2013 The Prometheus Authors
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package extraction
import (
"encoding/json"
"fmt"
"io"
"github.com/prometheus/client_golang/model"
)
// Processor002 is responsible for decoding payloads from protocol version
// 0.0.2.
var Processor002 = &processor002{}
type histogram002 struct {
Labels map[string]string `json:"labels"`
Values map[string]model.SampleValue `json:"value"`
}
type counter002 struct {
Labels map[string]string `json:"labels"`
Value model.SampleValue `json:"value"`
}
type processor002 struct{}
func (p *processor002) ProcessSingle(in io.Reader, out Ingester, o *ProcessOptions) error {
// Processor for telemetry schema version 0.0.2.
// container for telemetry data
var entities []struct {
BaseLabels map[string]string `json:"baseLabels"`
Docstring string `json:"docstring"`
Metric struct {
Type string `json:"type"`
Values json.RawMessage `json:"value"`
} `json:"metric"`
}
if err := json.NewDecoder(in).Decode(&entities); err != nil {
return err
}
pendingSamples := model.Samples{}
for _, entity := range entities {
switch entity.Metric.Type {
case "counter", "gauge":
var values []counter002
if err := json.Unmarshal(entity.Metric.Values, &values); err != nil {
return fmt.Errorf("could not extract %s value: %s", entity.Metric.Type, err)
}
for _, counter := range values {
labels := labelSet(entity.BaseLabels).Merge(labelSet(counter.Labels))
pendingSamples = append(pendingSamples, &model.Sample{
Metric: model.Metric(labels),
Timestamp: o.Timestamp,
Value: counter.Value,
})
}
case "histogram":
var values []histogram002
if err := json.Unmarshal(entity.Metric.Values, &values); err != nil {
return fmt.Errorf("could not extract %s value: %s", entity.Metric.Type, err)
}
for _, histogram := range values {
for percentile, value := range histogram.Values {
labels := labelSet(entity.BaseLabels).Merge(labelSet(histogram.Labels))
labels[model.LabelName("percentile")] = model.LabelValue(percentile)
pendingSamples = append(pendingSamples, &model.Sample{
Metric: model.Metric(labels),
Timestamp: o.Timestamp,
Value: value,
})
}
}
default:
return fmt.Errorf("unknown metric type %q", entity.Metric.Type)
}
}
if len(pendingSamples) > 0 {
return out.Ingest(pendingSamples)
}
return nil
}

View File

@ -0,0 +1,226 @@
// Copyright 2013 The Prometheus Authors
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package extraction
import (
"bytes"
"errors"
"io/ioutil"
"os"
"path"
"runtime"
"sort"
"testing"
"github.com/prometheus/prometheus/utility/test"
"github.com/prometheus/client_golang/model"
)
var test002Time = model.Now()
type testProcessor002ProcessScenario struct {
in string
expected, actual []model.Samples
err error
}
func (s *testProcessor002ProcessScenario) Ingest(samples model.Samples) error {
s.actual = append(s.actual, samples)
return nil
}
func (s *testProcessor002ProcessScenario) test(t testing.TB, set int) {
reader, err := os.Open(path.Join("fixtures", s.in))
if err != nil {
t.Fatalf("%d. couldn't open scenario input file %s: %s", set, s.in, err)
}
options := &ProcessOptions{
Timestamp: test002Time,
}
if err := Processor002.ProcessSingle(reader, s, options); !test.ErrorEqual(s.err, err) {
t.Fatalf("%d. expected err of %s, got %s", set, s.err, err)
}
if len(s.actual) != len(s.expected) {
t.Fatalf("%d. expected output length of %d, got %d", set, len(s.expected), len(s.actual))
}
for i, expected := range s.expected {
sort.Sort(s.actual[i])
sort.Sort(expected)
if !expected.Equal(s.actual[i]) {
t.Fatalf("%d.%d. expected %s, got %s", set, i, expected, s.actual[i])
}
}
}
func testProcessor002Process(t testing.TB) {
var scenarios = []testProcessor002ProcessScenario{
{
in: "empty.json",
err: errors.New("EOF"),
},
{
in: "test0_0_1-0_0_2.json",
expected: []model.Samples{
model.Samples{
&model.Sample{
Metric: model.Metric{"service": "zed", model.MetricNameLabel: "rpc_calls_total", "job": "batch_job"},
Value: 25,
Timestamp: test002Time,
},
&model.Sample{
Metric: model.Metric{"service": "bar", model.MetricNameLabel: "rpc_calls_total", "job": "batch_job"},
Value: 25,
Timestamp: test002Time,
},
&model.Sample{
Metric: model.Metric{"service": "foo", model.MetricNameLabel: "rpc_calls_total", "job": "batch_job"},
Value: 25,
Timestamp: test002Time,
},
&model.Sample{
Metric: model.Metric{"percentile": "0.010000", model.MetricNameLabel: "rpc_latency_microseconds", "service": "zed"},
Value: 0.0459814091918713,
Timestamp: test002Time,
},
&model.Sample{
Metric: model.Metric{"percentile": "0.010000", model.MetricNameLabel: "rpc_latency_microseconds", "service": "bar"},
Value: 78.48563317257356,
Timestamp: test002Time,
},
&model.Sample{
Metric: model.Metric{"percentile": "0.010000", model.MetricNameLabel: "rpc_latency_microseconds", "service": "foo"},
Value: 15.890724674774395,
Timestamp: test002Time,
},
&model.Sample{
Metric: model.Metric{"percentile": "0.050000", model.MetricNameLabel: "rpc_latency_microseconds", "service": "zed"},
Value: 0.0459814091918713,
Timestamp: test002Time,
},
&model.Sample{
Metric: model.Metric{"percentile": "0.050000", model.MetricNameLabel: "rpc_latency_microseconds", "service": "bar"},
Value: 78.48563317257356,
Timestamp: test002Time,
},
&model.Sample{
Metric: model.Metric{"percentile": "0.050000", model.MetricNameLabel: "rpc_latency_microseconds", "service": "foo"},
Value: 15.890724674774395,
Timestamp: test002Time,
},
&model.Sample{
Metric: model.Metric{"percentile": "0.500000", model.MetricNameLabel: "rpc_latency_microseconds", "service": "zed"},
Value: 0.6120456642749681,
Timestamp: test002Time,
},
&model.Sample{
Metric: model.Metric{"percentile": "0.500000", model.MetricNameLabel: "rpc_latency_microseconds", "service": "bar"},
Value: 97.31798360385088,
Timestamp: test002Time,
},
&model.Sample{
Metric: model.Metric{"percentile": "0.500000", model.MetricNameLabel: "rpc_latency_microseconds", "service": "foo"},
Value: 84.63044031436561,
Timestamp: test002Time,
},
&model.Sample{
Metric: model.Metric{"percentile": "0.900000", model.MetricNameLabel: "rpc_latency_microseconds", "service": "zed"},
Value: 1.355915069887731,
Timestamp: test002Time,
},
&model.Sample{
Metric: model.Metric{"percentile": "0.900000", model.MetricNameLabel: "rpc_latency_microseconds", "service": "bar"},
Value: 109.89202084295582,
Timestamp: test002Time,
},
&model.Sample{
Metric: model.Metric{"percentile": "0.900000", model.MetricNameLabel: "rpc_latency_microseconds", "service": "foo"},
Value: 160.21100853053224,
Timestamp: test002Time,
},
&model.Sample{
Metric: model.Metric{"percentile": "0.990000", model.MetricNameLabel: "rpc_latency_microseconds", "service": "zed"},
Value: 1.772733213161236,
Timestamp: test002Time,
},
&model.Sample{
Metric: model.Metric{"percentile": "0.990000", model.MetricNameLabel: "rpc_latency_microseconds", "service": "bar"},
Value: 109.99626121011262,
Timestamp: test002Time,
},
&model.Sample{
Metric: model.Metric{"percentile": "0.990000", model.MetricNameLabel: "rpc_latency_microseconds", "service": "foo"},
Value: 172.49828748957728,
Timestamp: test002Time,
},
},
},
},
}
for i, scenario := range scenarios {
scenario.test(t, i)
}
}
func TestProcessor002Process(t *testing.T) {
testProcessor002Process(t)
}
func BenchmarkProcessor002Process(b *testing.B) {
b.StopTimer()
pre := runtime.MemStats{}
runtime.ReadMemStats(&pre)
b.StartTimer()
for i := 0; i < b.N; i++ {
testProcessor002Process(b)
}
post := runtime.MemStats{}
runtime.ReadMemStats(&post)
allocated := post.TotalAlloc - pre.TotalAlloc
b.Logf("Allocated %d at %f per cycle with %d cycles.", allocated, float64(allocated)/float64(b.N), b.N)
}
func BenchmarkProcessor002ParseOnly(b *testing.B) {
b.StopTimer()
data, err := ioutil.ReadFile("fixtures/test0_0_1-0_0_2-large.json")
if err != nil {
b.Fatal(err)
}
ing := fakeIngester{}
b.StartTimer()
for i := 0; i < b.N; i++ {
if err := Processor002.ProcessSingle(bytes.NewReader(data), ing, &ProcessOptions{}); err != nil {
b.Fatal(err)
}
}
}
type fakeIngester struct{}
func (i fakeIngester) Ingest(model.Samples) error {
return nil
}

View File

@ -0,0 +1,40 @@
// Copyright 2014 The Prometheus Authors
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package extraction
import (
"io"
"github.com/prometheus/client_golang/text"
)
type processor004 struct{}
// Processor004 s responsible for decoding payloads from the text based variety
// of protocol version 0.0.4.
var Processor004 = &processor004{}
func (t *processor004) ProcessSingle(i io.Reader, out Ingester, o *ProcessOptions) error {
var parser text.Parser
metricFamilies, err := parser.TextToMetricFamilies(i)
if err != nil {
return err
}
for _, metricFamily := range metricFamilies {
if err := extractMetricFamily(out, o, metricFamily); err != nil {
return err
}
}
return nil
}

View File

@ -0,0 +1,100 @@
// Copyright 2014 The Prometheus Authors
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package extraction
import (
"sort"
"strings"
"testing"
"github.com/prometheus/client_golang/model"
)
var (
ts = model.Now()
in = `
# Only a quite simple scenario with two metric families.
# More complicated tests of the parser itself can be found in the text package.
# TYPE mf2 counter
mf2 3
mf1{label="value1"} -3.14 123456
mf1{label="value2"} 42
mf2 4
`
out = map[model.LabelValue]model.Samples{
"mf1": model.Samples{
&model.Sample{
Metric: model.Metric{model.MetricNameLabel: "mf1", "label": "value1"},
Value: -3.14,
Timestamp: 123456,
},
&model.Sample{
Metric: model.Metric{model.MetricNameLabel: "mf1", "label": "value2"},
Value: 42,
Timestamp: ts,
},
},
"mf2": model.Samples{
&model.Sample{
Metric: model.Metric{model.MetricNameLabel: "mf2"},
Value: 3,
Timestamp: ts,
},
&model.Sample{
Metric: model.Metric{model.MetricNameLabel: "mf2"},
Value: 4,
Timestamp: ts,
},
},
}
)
type testIngester struct {
results []model.Samples
}
func (i *testIngester) Ingest(s model.Samples) error {
i.results = append(i.results, s)
return nil
}
func TestTextProcessor(t *testing.T) {
var ingester testIngester
i := strings.NewReader(in)
o := &ProcessOptions{
Timestamp: ts,
}
err := Processor004.ProcessSingle(i, &ingester, o)
if err != nil {
t.Fatal(err)
}
if expected, got := len(out), len(ingester.results); expected != got {
t.Fatalf("Expected length %d, got %d", expected, got)
}
for _, r := range ingester.results {
expected, ok := out[r[0].Metric[model.MetricNameLabel]]
if !ok {
t.Fatalf(
"Unexpected metric name %q",
r[0].Metric[model.MetricNameLabel],
)
}
sort.Sort(expected)
sort.Sort(r)
if !expected.Equal(r) {
t.Errorf("expected %s, got %s", expected, r)
}
}
}

View File

@ -0,0 +1,110 @@
// Copyright 2013 The Prometheus Authors
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package model
import (
"fmt"
"strconv"
)
// Fingerprint provides a hash-capable representation of a Metric.
// For our purposes, FNV-1A 64-bit is used.
type Fingerprint uint64
func (f Fingerprint) String() string {
return fmt.Sprintf("%016x", uint64(f))
}
// Less implements sort.Interface.
func (f Fingerprint) Less(o Fingerprint) bool {
return f < o
}
// Equal implements sort.Interface.
func (f Fingerprint) Equal(o Fingerprint) bool {
return f == o
}
// LoadFromString transforms a string representation into a Fingerprint.
func (f *Fingerprint) LoadFromString(s string) error {
num, err := strconv.ParseUint(s, 16, 64)
if err != nil {
return err
}
*f = Fingerprint(num)
return nil
}
// Fingerprints represents a collection of Fingerprint subject to a given
// natural sorting scheme. It implements sort.Interface.
type Fingerprints []Fingerprint
// Len implements sort.Interface.
func (f Fingerprints) Len() int {
return len(f)
}
// Less implements sort.Interface.
func (f Fingerprints) Less(i, j int) bool {
return f[i] < f[j]
}
// Swap implements sort.Interface.
func (f Fingerprints) Swap(i, j int) {
f[i], f[j] = f[j], f[i]
}
// FingerprintSet is a set of Fingerprints.
type FingerprintSet map[Fingerprint]struct{}
// Equal returns true if both sets contain the same elements (and not more).
func (s FingerprintSet) Equal(o FingerprintSet) bool {
if len(s) != len(o) {
return false
}
for k := range s {
if _, ok := o[k]; !ok {
return false
}
}
return true
}
// Intersection returns the elements contained in both sets.
func (s FingerprintSet) Intersection(o FingerprintSet) FingerprintSet {
myLength, otherLength := len(s), len(o)
if myLength == 0 || otherLength == 0 {
return FingerprintSet{}
}
subSet := s
superSet := o
if otherLength < myLength {
subSet = o
superSet = s
}
out := FingerprintSet{}
for k := range subSet {
if _, ok := superSet[k]; ok {
out[k] = struct{}{}
}
}
return out
}

View File

@ -0,0 +1,63 @@
// Copyright 2013 The Prometheus Authors
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package model
import (
"strings"
)
const (
// ExporterLabelPrefix is the label name prefix to prepend if a
// synthetic label is already present in the exported metrics.
ExporterLabelPrefix LabelName = "exporter_"
// MetricNameLabel is the label name indicating the metric name of a
// timeseries.
MetricNameLabel LabelName = "__name__"
// ReservedLabelPrefix is a prefix which is not legal in user-supplied
// label names.
ReservedLabelPrefix = "__"
// JobLabel is the label name indicating the job from which a timeseries
// was scraped.
JobLabel LabelName = "job"
)
// A LabelName is a key for a LabelSet or Metric. It has a value associated
// therewith.
type LabelName string
// LabelNames is a sortable LabelName slice. In implements sort.Interface.
type LabelNames []LabelName
func (l LabelNames) Len() int {
return len(l)
}
func (l LabelNames) Less(i, j int) bool {
return l[i] < l[j]
}
func (l LabelNames) Swap(i, j int) {
l[i], l[j] = l[j], l[i]
}
func (l LabelNames) String() string {
labelStrings := make([]string, 0, len(l))
for _, label := range l {
labelStrings = append(labelStrings, string(label))
}
return strings.Join(labelStrings, ", ")
}

View File

@ -0,0 +1,55 @@
// Copyright 2013 The Prometheus Authors
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package model
import (
"sort"
"testing"
)
func testLabelNames(t testing.TB) {
var scenarios = []struct {
in LabelNames
out LabelNames
}{
{
in: LabelNames{"ZZZ", "zzz"},
out: LabelNames{"ZZZ", "zzz"},
},
{
in: LabelNames{"aaa", "AAA"},
out: LabelNames{"AAA", "aaa"},
},
}
for i, scenario := range scenarios {
sort.Sort(scenario.in)
for j, expected := range scenario.out {
if expected != scenario.in[j] {
t.Errorf("%d.%d expected %s, got %s", i, j, expected, scenario.in[j])
}
}
}
}
func TestLabelNames(t *testing.T) {
testLabelNames(t)
}
func BenchmarkLabelNames(b *testing.B) {
for i := 0; i < b.N; i++ {
testLabelNames(b)
}
}

View File

@ -0,0 +1,64 @@
// Copyright 2013 The Prometheus Authors
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package model
import (
"fmt"
"sort"
"strings"
)
// A LabelSet is a collection of LabelName and LabelValue pairs. The LabelSet
// may be fully-qualified down to the point where it may resolve to a single
// Metric in the data store or not. All operations that occur within the realm
// of a LabelSet can emit a vector of Metric entities to which the LabelSet may
// match.
type LabelSet map[LabelName]LabelValue
// Merge is a helper function to non-destructively merge two label sets.
func (l LabelSet) Merge(other LabelSet) LabelSet {
result := make(LabelSet, len(l))
for k, v := range l {
result[k] = v
}
for k, v := range other {
result[k] = v
}
return result
}
func (l LabelSet) String() string {
labelStrings := make([]string, 0, len(l))
for label, value := range l {
labelStrings = append(labelStrings, fmt.Sprintf("%s=%q", label, value))
}
switch len(labelStrings) {
case 0:
return ""
default:
sort.Strings(labelStrings)
return fmt.Sprintf("{%s}", strings.Join(labelStrings, ", "))
}
}
// MergeFromMetric merges Metric into this LabelSet.
func (l LabelSet) MergeFromMetric(m Metric) {
for k, v := range m {
l[k] = v
}
}

View File

@ -0,0 +1,36 @@
// Copyright 2013 The Prometheus Authors
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package model
import (
"sort"
)
// A LabelValue is an associated value for a LabelName.
type LabelValue string
// LabelValues is a sortable LabelValue slice. It implements sort.Interface.
type LabelValues []LabelValue
func (l LabelValues) Len() int {
return len(l)
}
func (l LabelValues) Less(i, j int) bool {
return sort.StringsAreSorted([]string{string(l[i]), string(l[j])})
}
func (l LabelValues) Swap(i, j int) {
l[i], l[j] = l[j], l[i]
}

View File

@ -0,0 +1,55 @@
// Copyright 2013 The Prometheus Authors
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package model
import (
"sort"
"testing"
)
func testLabelValues(t testing.TB) {
var scenarios = []struct {
in LabelValues
out LabelValues
}{
{
in: LabelValues{"ZZZ", "zzz"},
out: LabelValues{"ZZZ", "zzz"},
},
{
in: LabelValues{"aaa", "AAA"},
out: LabelValues{"AAA", "aaa"},
},
}
for i, scenario := range scenarios {
sort.Sort(scenario.in)
for j, expected := range scenario.out {
if expected != scenario.in[j] {
t.Errorf("%d.%d expected %s, got %s", i, j, expected, scenario.in[j])
}
}
}
}
func TestLabelValues(t *testing.T) {
testLabelValues(t)
}
func BenchmarkLabelValues(b *testing.B) {
for i := 0; i < b.N; i++ {
testLabelValues(b)
}
}

Some files were not shown because too many files have changed in this diff Show More