Merge pull request #527 from prometheus/nogodep-build

Copy vendored deps manually instead of using Godeps.
pull/537/head
Julius Volz 2015-02-17 17:17:45 +01:00
commit b0c8b56603
192 changed files with 20695 additions and 1064 deletions

2
.gitignore vendored
View File

@ -26,7 +26,7 @@ _cgo_*
core core
*-stamp *-stamp
prometheus /prometheus
benchmark.txt benchmark.txt
.#* .#*

56
Godeps/Godeps.json generated
View File

@ -1,27 +1,65 @@
{ {
"ImportPath": "github.com/prometheus/prometheus", "ImportPath": "github.com/prometheus/prometheus",
"GoVersion": "go1.4", "GoVersion": "go1.4.1",
"Deps": [ "Deps": [
{
"ImportPath": "code.google.com/p/goprotobuf/proto",
"Comment": "go.r60-152",
"Rev": "36be16571e14f67e114bb0af619e5de2c1591679"
},
{ {
"ImportPath": "github.com/golang/glog", "ImportPath": "github.com/golang/glog",
"Rev": "44145f04b68cf362d9c4df2182967c2275eaefed" "Rev": "44145f04b68cf362d9c4df2182967c2275eaefed"
}, },
{
"ImportPath": "github.com/golang/protobuf/proto",
"Rev": "5677a0e3d5e89854c9974e1256839ee23f8233ca"
},
{ {
"ImportPath": "github.com/matttproud/golang_protobuf_extensions/ext", "ImportPath": "github.com/matttproud/golang_protobuf_extensions/ext",
"Rev": "7a864a042e844af638df17ebbabf8183dace556a" "Rev": "ba7d65ac66e9da93a714ca18f6d1bc7a0c09100c"
}, },
{ {
"ImportPath": "github.com/miekg/dns", "ImportPath": "github.com/miekg/dns",
"Rev": "6b75215519f9916839204d80413bb178b94ef769" "Rev": "b65f52f3f0dd1afa25cbbf63f8e7eb15fb5c0641"
},
{
"ImportPath": "github.com/prometheus/client_golang/_vendor/goautoneg",
"Comment": "0.1.0-24-g4627d59",
"Rev": "4627d59e8a09c330c5ccfe7414baca28d8df847d"
},
{
"ImportPath": "github.com/prometheus/client_golang/_vendor/perks/quantile",
"Comment": "0.1.0-24-g4627d59",
"Rev": "4627d59e8a09c330c5ccfe7414baca28d8df847d"
},
{
"ImportPath": "github.com/prometheus/client_golang/extraction",
"Comment": "0.1.0-24-g4627d59",
"Rev": "4627d59e8a09c330c5ccfe7414baca28d8df847d"
},
{
"ImportPath": "github.com/prometheus/client_golang/model",
"Comment": "0.1.0-24-g4627d59",
"Rev": "4627d59e8a09c330c5ccfe7414baca28d8df847d"
},
{
"ImportPath": "github.com/prometheus/client_golang/prometheus",
"Comment": "0.1.0-24-g4627d59",
"Rev": "4627d59e8a09c330c5ccfe7414baca28d8df847d"
},
{
"ImportPath": "github.com/prometheus/client_golang/text",
"Comment": "0.1.0-24-g4627d59",
"Rev": "4627d59e8a09c330c5ccfe7414baca28d8df847d"
},
{
"ImportPath": "github.com/prometheus/client_model/go",
"Comment": "model-0.0.2-12-gfa8ad6f",
"Rev": "fa8ad6fec33561be4280a8f0514318c79d7f6cb6"
},
{
"ImportPath": "github.com/prometheus/procfs",
"Rev": "92faa308558161acab0ada1db048e9996ecec160"
}, },
{ {
"ImportPath": "github.com/syndtr/goleveldb/leveldb", "ImportPath": "github.com/syndtr/goleveldb/leveldb",
"Rev": "63c9e642efad852f49e20a6f90194cae112fd2ac" "Rev": "e9e2c8f6d3b9c313fb4acaac5ab06285bcf30b04"
}, },
{ {
"ImportPath": "github.com/syndtr/gosnappy/snappy", "ImportPath": "github.com/syndtr/gosnappy/snappy",

View File

@ -1,7 +1,7 @@
# Go support for Protocol Buffers - Google's data interchange format # Go support for Protocol Buffers - Google's data interchange format
# #
# Copyright 2010 The Go Authors. All rights reserved. # Copyright 2010 The Go Authors. All rights reserved.
# http://code.google.com/p/goprotobuf/ # https://github.com/golang/protobuf
# #
# Redistribution and use in source and binary forms, with or without # Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are # modification, are permitted provided that the following conditions are
@ -37,4 +37,7 @@ test: install generate-test-pbs
generate-test-pbs: generate-test-pbs:
make install && cd testdata && make make install
make -C testdata
make -C proto3_proto
make

View File

@ -1,7 +1,7 @@
// Go support for Protocol Buffers - Google's data interchange format // Go support for Protocol Buffers - Google's data interchange format
// //
// Copyright 2010 The Go Authors. All rights reserved. // Copyright 2010 The Go Authors. All rights reserved.
// http://code.google.com/p/goprotobuf/ // https://github.com/golang/protobuf
// //
// Redistribution and use in source and binary forms, with or without // Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are // modification, are permitted provided that the following conditions are
@ -45,7 +45,7 @@ import (
"time" "time"
. "./testdata" . "./testdata"
. "code.google.com/p/goprotobuf/proto" . "github.com/golang/protobuf/proto"
) )
var globalO *Buffer var globalO *Buffer
@ -1833,6 +1833,86 @@ func fuzzUnmarshal(t *testing.T, data []byte) {
Unmarshal(data, pb) Unmarshal(data, pb)
} }
func TestMapFieldMarshal(t *testing.T) {
m := &MessageWithMap{
NameMapping: map[int32]string{
1: "Rob",
4: "Ian",
8: "Dave",
},
}
b, err := Marshal(m)
if err != nil {
t.Fatalf("Marshal: %v", err)
}
// b should be the concatenation of these three byte sequences in some order.
parts := []string{
"\n\a\b\x01\x12\x03Rob",
"\n\a\b\x04\x12\x03Ian",
"\n\b\b\x08\x12\x04Dave",
}
ok := false
for i := range parts {
for j := range parts {
if j == i {
continue
}
for k := range parts {
if k == i || k == j {
continue
}
try := parts[i] + parts[j] + parts[k]
if bytes.Equal(b, []byte(try)) {
ok = true
break
}
}
}
}
if !ok {
t.Fatalf("Incorrect Marshal output.\n got %q\nwant %q (or a permutation of that)", b, parts[0]+parts[1]+parts[2])
}
t.Logf("FYI b: %q", b)
(new(Buffer)).DebugPrint("Dump of b", b)
}
func TestMapFieldRoundTrips(t *testing.T) {
m := &MessageWithMap{
NameMapping: map[int32]string{
1: "Rob",
4: "Ian",
8: "Dave",
},
MsgMapping: map[int64]*FloatingPoint{
0x7001: &FloatingPoint{F: Float64(2.0)},
},
ByteMapping: map[bool][]byte{
false: []byte("that's not right!"),
true: []byte("aye, 'tis true!"),
},
}
b, err := Marshal(m)
if err != nil {
t.Fatalf("Marshal: %v", err)
}
t.Logf("FYI b: %q", b)
m2 := new(MessageWithMap)
if err := Unmarshal(b, m2); err != nil {
t.Fatalf("Unmarshal: %v", err)
}
for _, pair := range [][2]interface{}{
{m.NameMapping, m2.NameMapping},
{m.MsgMapping, m2.MsgMapping},
{m.ByteMapping, m2.ByteMapping},
} {
if !reflect.DeepEqual(pair[0], pair[1]) {
t.Errorf("Map did not survive a round trip.\ninitial: %v\n final: %v", pair[0], pair[1])
}
}
}
// Benchmarks // Benchmarks
func testMsg() *GoTest { func testMsg() *GoTest {

View File

@ -1,7 +1,7 @@
// Go support for Protocol Buffers - Google's data interchange format // Go support for Protocol Buffers - Google's data interchange format
// //
// Copyright 2011 The Go Authors. All rights reserved. // Copyright 2011 The Go Authors. All rights reserved.
// http://code.google.com/p/goprotobuf/ // https://github.com/golang/protobuf
// //
// Redistribution and use in source and binary forms, with or without // Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are // modification, are permitted provided that the following conditions are
@ -29,7 +29,7 @@
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
// Protocol buffer deep copy. // Protocol buffer deep copy and merge.
// TODO: MessageSet and RawMessage. // TODO: MessageSet and RawMessage.
package proto package proto
@ -113,6 +113,29 @@ func mergeAny(out, in reflect.Value) {
case reflect.Bool, reflect.Float32, reflect.Float64, reflect.Int32, reflect.Int64, case reflect.Bool, reflect.Float32, reflect.Float64, reflect.Int32, reflect.Int64,
reflect.String, reflect.Uint32, reflect.Uint64: reflect.String, reflect.Uint32, reflect.Uint64:
out.Set(in) out.Set(in)
case reflect.Map:
if in.Len() == 0 {
return
}
if out.IsNil() {
out.Set(reflect.MakeMap(in.Type()))
}
// For maps with value types of *T or []byte we need to deep copy each value.
elemKind := in.Type().Elem().Kind()
for _, key := range in.MapKeys() {
var val reflect.Value
switch elemKind {
case reflect.Ptr:
val = reflect.New(in.Type().Elem().Elem())
mergeAny(val, in.MapIndex(key))
case reflect.Slice:
val = in.MapIndex(key)
val = reflect.ValueOf(append([]byte{}, val.Bytes()...))
default:
val = in.MapIndex(key)
}
out.SetMapIndex(key, val)
}
case reflect.Ptr: case reflect.Ptr:
if in.IsNil() { if in.IsNil() {
return return
@ -125,6 +148,14 @@ func mergeAny(out, in reflect.Value) {
if in.IsNil() { if in.IsNil() {
return return
} }
if in.Type().Elem().Kind() == reflect.Uint8 {
// []byte is a scalar bytes field, not a repeated field.
// Make a deep copy.
// Append to []byte{} instead of []byte(nil) so that we never end up
// with a nil result.
out.SetBytes(append([]byte{}, in.Bytes()...))
return
}
n := in.Len() n := in.Len()
if out.IsNil() { if out.IsNil() {
out.Set(reflect.MakeSlice(in.Type(), 0, n)) out.Set(reflect.MakeSlice(in.Type(), 0, n))
@ -133,9 +164,6 @@ func mergeAny(out, in reflect.Value) {
case reflect.Bool, reflect.Float32, reflect.Float64, reflect.Int32, reflect.Int64, case reflect.Bool, reflect.Float32, reflect.Float64, reflect.Int32, reflect.Int64,
reflect.String, reflect.Uint32, reflect.Uint64: reflect.String, reflect.Uint32, reflect.Uint64:
out.Set(reflect.AppendSlice(out, in)) out.Set(reflect.AppendSlice(out, in))
case reflect.Uint8:
// []byte is a scalar bytes field.
out.Set(in)
default: default:
for i := 0; i < n; i++ { for i := 0; i < n; i++ {
x := reflect.Indirect(reflect.New(in.Type().Elem())) x := reflect.Indirect(reflect.New(in.Type().Elem()))

View File

@ -1,7 +1,7 @@
// Go support for Protocol Buffers - Google's data interchange format // Go support for Protocol Buffers - Google's data interchange format
// //
// Copyright 2011 The Go Authors. All rights reserved. // Copyright 2011 The Go Authors. All rights reserved.
// http://code.google.com/p/goprotobuf/ // https://github.com/golang/protobuf
// //
// Redistribution and use in source and binary forms, with or without // Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are // modification, are permitted provided that the following conditions are
@ -34,7 +34,7 @@ package proto_test
import ( import (
"testing" "testing"
"code.google.com/p/goprotobuf/proto" "github.com/golang/protobuf/proto"
pb "./testdata" pb "./testdata"
) )
@ -79,6 +79,22 @@ func TestClone(t *testing.T) {
if proto.Equal(m, cloneTestMessage) { if proto.Equal(m, cloneTestMessage) {
t.Error("Mutating clone changed the original") t.Error("Mutating clone changed the original")
} }
// Byte fields and repeated fields should be copied.
if &m.Pet[0] == &cloneTestMessage.Pet[0] {
t.Error("Pet: repeated field not copied")
}
if &m.Others[0] == &cloneTestMessage.Others[0] {
t.Error("Others: repeated field not copied")
}
if &m.Others[0].Value[0] == &cloneTestMessage.Others[0].Value[0] {
t.Error("Others[0].Value: bytes field not copied")
}
if &m.RepBytes[0] == &cloneTestMessage.RepBytes[0] {
t.Error("RepBytes: repeated field not copied")
}
if &m.RepBytes[0][0] == &cloneTestMessage.RepBytes[0][0] {
t.Error("RepBytes[0]: bytes field not copied")
}
} }
func TestCloneNil(t *testing.T) { func TestCloneNil(t *testing.T) {
@ -173,6 +189,31 @@ var mergeTests = []struct {
dst: &pb.OtherMessage{Value: []byte("bar")}, dst: &pb.OtherMessage{Value: []byte("bar")},
want: &pb.OtherMessage{Value: []byte("foo")}, want: &pb.OtherMessage{Value: []byte("foo")},
}, },
{
src: &pb.MessageWithMap{
NameMapping: map[int32]string{6: "Nigel"},
MsgMapping: map[int64]*pb.FloatingPoint{
0x4001: &pb.FloatingPoint{F: proto.Float64(2.0)},
},
ByteMapping: map[bool][]byte{true: []byte("wowsa")},
},
dst: &pb.MessageWithMap{
NameMapping: map[int32]string{
6: "Bruce", // should be overwritten
7: "Andrew",
},
},
want: &pb.MessageWithMap{
NameMapping: map[int32]string{
6: "Nigel",
7: "Andrew",
},
MsgMapping: map[int64]*pb.FloatingPoint{
0x4001: &pb.FloatingPoint{F: proto.Float64(2.0)},
},
ByteMapping: map[bool][]byte{true: []byte("wowsa")},
},
},
} }
func TestMerge(t *testing.T) { func TestMerge(t *testing.T) {

View File

@ -1,7 +1,7 @@
// Go support for Protocol Buffers - Google's data interchange format // Go support for Protocol Buffers - Google's data interchange format
// //
// Copyright 2010 The Go Authors. All rights reserved. // Copyright 2010 The Go Authors. All rights reserved.
// http://code.google.com/p/goprotobuf/ // https://github.com/golang/protobuf
// //
// Redistribution and use in source and binary forms, with or without // Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are // modification, are permitted provided that the following conditions are
@ -178,7 +178,7 @@ func (p *Buffer) DecodeZigzag32() (x uint64, err error) {
func (p *Buffer) DecodeRawBytes(alloc bool) (buf []byte, err error) { func (p *Buffer) DecodeRawBytes(alloc bool) (buf []byte, err error) {
n, err := p.DecodeVarint() n, err := p.DecodeVarint()
if err != nil { if err != nil {
return return nil, err
} }
nb := int(n) nb := int(n)
@ -362,7 +362,7 @@ func (o *Buffer) unmarshalType(st reflect.Type, prop *StructProperties, is_group
} }
tag := int(u >> 3) tag := int(u >> 3)
if tag <= 0 { if tag <= 0 {
return fmt.Errorf("proto: %s: illegal tag %d", st, tag) return fmt.Errorf("proto: %s: illegal tag %d (wire type %d)", st, tag, wire)
} }
fieldnum, ok := prop.decoderTags.get(tag) fieldnum, ok := prop.decoderTags.get(tag)
if !ok { if !ok {
@ -465,6 +465,15 @@ func (o *Buffer) dec_bool(p *Properties, base structPointer) error {
return nil return nil
} }
func (o *Buffer) dec_proto3_bool(p *Properties, base structPointer) error {
u, err := p.valDec(o)
if err != nil {
return err
}
*structPointer_BoolVal(base, p.field) = u != 0
return nil
}
// Decode an int32. // Decode an int32.
func (o *Buffer) dec_int32(p *Properties, base structPointer) error { func (o *Buffer) dec_int32(p *Properties, base structPointer) error {
u, err := p.valDec(o) u, err := p.valDec(o)
@ -475,6 +484,15 @@ func (o *Buffer) dec_int32(p *Properties, base structPointer) error {
return nil return nil
} }
func (o *Buffer) dec_proto3_int32(p *Properties, base structPointer) error {
u, err := p.valDec(o)
if err != nil {
return err
}
word32Val_Set(structPointer_Word32Val(base, p.field), uint32(u))
return nil
}
// Decode an int64. // Decode an int64.
func (o *Buffer) dec_int64(p *Properties, base structPointer) error { func (o *Buffer) dec_int64(p *Properties, base structPointer) error {
u, err := p.valDec(o) u, err := p.valDec(o)
@ -485,15 +503,31 @@ func (o *Buffer) dec_int64(p *Properties, base structPointer) error {
return nil return nil
} }
func (o *Buffer) dec_proto3_int64(p *Properties, base structPointer) error {
u, err := p.valDec(o)
if err != nil {
return err
}
word64Val_Set(structPointer_Word64Val(base, p.field), o, u)
return nil
}
// Decode a string. // Decode a string.
func (o *Buffer) dec_string(p *Properties, base structPointer) error { func (o *Buffer) dec_string(p *Properties, base structPointer) error {
s, err := o.DecodeStringBytes() s, err := o.DecodeStringBytes()
if err != nil { if err != nil {
return err return err
} }
sp := new(string) *structPointer_String(base, p.field) = &s
*sp = s return nil
*structPointer_String(base, p.field) = sp }
func (o *Buffer) dec_proto3_string(p *Properties, base structPointer) error {
s, err := o.DecodeStringBytes()
if err != nil {
return err
}
*structPointer_StringVal(base, p.field) = s
return nil return nil
} }
@ -632,6 +666,72 @@ func (o *Buffer) dec_slice_slice_byte(p *Properties, base structPointer) error {
return nil return nil
} }
// Decode a map field.
func (o *Buffer) dec_new_map(p *Properties, base structPointer) error {
raw, err := o.DecodeRawBytes(false)
if err != nil {
return err
}
oi := o.index // index at the end of this map entry
o.index -= len(raw) // move buffer back to start of map entry
mptr := structPointer_Map(base, p.field, p.mtype) // *map[K]V
if mptr.Elem().IsNil() {
mptr.Elem().Set(reflect.MakeMap(mptr.Type().Elem()))
}
v := mptr.Elem() // map[K]V
// Prepare addressable doubly-indirect placeholders for the key and value types.
// See enc_new_map for why.
keyptr := reflect.New(reflect.PtrTo(p.mtype.Key())).Elem() // addressable *K
keybase := toStructPointer(keyptr.Addr()) // **K
var valbase structPointer
var valptr reflect.Value
switch p.mtype.Elem().Kind() {
case reflect.Slice:
// []byte
var dummy []byte
valptr = reflect.ValueOf(&dummy) // *[]byte
valbase = toStructPointer(valptr) // *[]byte
case reflect.Ptr:
// message; valptr is **Msg; need to allocate the intermediate pointer
valptr = reflect.New(reflect.PtrTo(p.mtype.Elem())).Elem() // addressable *V
valptr.Set(reflect.New(valptr.Type().Elem()))
valbase = toStructPointer(valptr)
default:
// everything else
valptr = reflect.New(reflect.PtrTo(p.mtype.Elem())).Elem() // addressable *V
valbase = toStructPointer(valptr.Addr()) // **V
}
// Decode.
// This parses a restricted wire format, namely the encoding of a message
// with two fields. See enc_new_map for the format.
for o.index < oi {
// tagcode for key and value properties are always a single byte
// because they have tags 1 and 2.
tagcode := o.buf[o.index]
o.index++
switch tagcode {
case p.mkeyprop.tagcode[0]:
if err := p.mkeyprop.dec(o, p.mkeyprop, keybase); err != nil {
return err
}
case p.mvalprop.tagcode[0]:
if err := p.mvalprop.dec(o, p.mvalprop, valbase); err != nil {
return err
}
default:
// TODO: Should we silently skip this instead?
return fmt.Errorf("proto: bad map data tag %d", raw[0])
}
}
v.SetMapIndex(keyptr.Elem(), valptr.Elem())
return nil
}
// Decode a group. // Decode a group.
func (o *Buffer) dec_struct_group(p *Properties, base structPointer) error { func (o *Buffer) dec_struct_group(p *Properties, base structPointer) error {
bas := structPointer_GetStructPointer(base, p.field) bas := structPointer_GetStructPointer(base, p.field)

View File

@ -1,7 +1,7 @@
// Go support for Protocol Buffers - Google's data interchange format // Go support for Protocol Buffers - Google's data interchange format
// //
// Copyright 2010 The Go Authors. All rights reserved. // Copyright 2010 The Go Authors. All rights reserved.
// http://code.google.com/p/goprotobuf/ // https://github.com/golang/protobuf
// //
// Redistribution and use in source and binary forms, with or without // Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are // modification, are permitted provided that the following conditions are
@ -247,7 +247,7 @@ func (p *Buffer) Marshal(pb Message) error {
return ErrNil return ErrNil
} }
if err == nil { if err == nil {
err = p.enc_struct(t.Elem(), GetProperties(t.Elem()), base) err = p.enc_struct(GetProperties(t.Elem()), base)
} }
if collectStats { if collectStats {
@ -271,7 +271,7 @@ func Size(pb Message) (n int) {
return 0 return 0
} }
if err == nil { if err == nil {
n = size_struct(t.Elem(), GetProperties(t.Elem()), base) n = size_struct(GetProperties(t.Elem()), base)
} }
if collectStats { if collectStats {
@ -298,6 +298,16 @@ func (o *Buffer) enc_bool(p *Properties, base structPointer) error {
return nil return nil
} }
func (o *Buffer) enc_proto3_bool(p *Properties, base structPointer) error {
v := *structPointer_BoolVal(base, p.field)
if !v {
return ErrNil
}
o.buf = append(o.buf, p.tagcode...)
p.valEnc(o, 1)
return nil
}
func size_bool(p *Properties, base structPointer) int { func size_bool(p *Properties, base structPointer) int {
v := *structPointer_Bool(base, p.field) v := *structPointer_Bool(base, p.field)
if v == nil { if v == nil {
@ -306,6 +316,14 @@ func size_bool(p *Properties, base structPointer) int {
return len(p.tagcode) + 1 // each bool takes exactly one byte return len(p.tagcode) + 1 // each bool takes exactly one byte
} }
func size_proto3_bool(p *Properties, base structPointer) int {
v := *structPointer_BoolVal(base, p.field)
if !v {
return 0
}
return len(p.tagcode) + 1 // each bool takes exactly one byte
}
// Encode an int32. // Encode an int32.
func (o *Buffer) enc_int32(p *Properties, base structPointer) error { func (o *Buffer) enc_int32(p *Properties, base structPointer) error {
v := structPointer_Word32(base, p.field) v := structPointer_Word32(base, p.field)
@ -318,6 +336,17 @@ func (o *Buffer) enc_int32(p *Properties, base structPointer) error {
return nil return nil
} }
func (o *Buffer) enc_proto3_int32(p *Properties, base structPointer) error {
v := structPointer_Word32Val(base, p.field)
x := int32(word32Val_Get(v)) // permit sign extension to use full 64-bit range
if x == 0 {
return ErrNil
}
o.buf = append(o.buf, p.tagcode...)
p.valEnc(o, uint64(x))
return nil
}
func size_int32(p *Properties, base structPointer) (n int) { func size_int32(p *Properties, base structPointer) (n int) {
v := structPointer_Word32(base, p.field) v := structPointer_Word32(base, p.field)
if word32_IsNil(v) { if word32_IsNil(v) {
@ -329,6 +358,17 @@ func size_int32(p *Properties, base structPointer) (n int) {
return return
} }
func size_proto3_int32(p *Properties, base structPointer) (n int) {
v := structPointer_Word32Val(base, p.field)
x := int32(word32Val_Get(v)) // permit sign extension to use full 64-bit range
if x == 0 {
return 0
}
n += len(p.tagcode)
n += p.valSize(uint64(x))
return
}
// Encode a uint32. // Encode a uint32.
// Exactly the same as int32, except for no sign extension. // Exactly the same as int32, except for no sign extension.
func (o *Buffer) enc_uint32(p *Properties, base structPointer) error { func (o *Buffer) enc_uint32(p *Properties, base structPointer) error {
@ -342,6 +382,17 @@ func (o *Buffer) enc_uint32(p *Properties, base structPointer) error {
return nil return nil
} }
func (o *Buffer) enc_proto3_uint32(p *Properties, base structPointer) error {
v := structPointer_Word32Val(base, p.field)
x := word32Val_Get(v)
if x == 0 {
return ErrNil
}
o.buf = append(o.buf, p.tagcode...)
p.valEnc(o, uint64(x))
return nil
}
func size_uint32(p *Properties, base structPointer) (n int) { func size_uint32(p *Properties, base structPointer) (n int) {
v := structPointer_Word32(base, p.field) v := structPointer_Word32(base, p.field)
if word32_IsNil(v) { if word32_IsNil(v) {
@ -353,6 +404,17 @@ func size_uint32(p *Properties, base structPointer) (n int) {
return return
} }
func size_proto3_uint32(p *Properties, base structPointer) (n int) {
v := structPointer_Word32Val(base, p.field)
x := word32Val_Get(v)
if x == 0 {
return 0
}
n += len(p.tagcode)
n += p.valSize(uint64(x))
return
}
// Encode an int64. // Encode an int64.
func (o *Buffer) enc_int64(p *Properties, base structPointer) error { func (o *Buffer) enc_int64(p *Properties, base structPointer) error {
v := structPointer_Word64(base, p.field) v := structPointer_Word64(base, p.field)
@ -365,6 +427,17 @@ func (o *Buffer) enc_int64(p *Properties, base structPointer) error {
return nil return nil
} }
func (o *Buffer) enc_proto3_int64(p *Properties, base structPointer) error {
v := structPointer_Word64Val(base, p.field)
x := word64Val_Get(v)
if x == 0 {
return ErrNil
}
o.buf = append(o.buf, p.tagcode...)
p.valEnc(o, x)
return nil
}
func size_int64(p *Properties, base structPointer) (n int) { func size_int64(p *Properties, base structPointer) (n int) {
v := structPointer_Word64(base, p.field) v := structPointer_Word64(base, p.field)
if word64_IsNil(v) { if word64_IsNil(v) {
@ -376,6 +449,17 @@ func size_int64(p *Properties, base structPointer) (n int) {
return return
} }
func size_proto3_int64(p *Properties, base structPointer) (n int) {
v := structPointer_Word64Val(base, p.field)
x := word64Val_Get(v)
if x == 0 {
return 0
}
n += len(p.tagcode)
n += p.valSize(x)
return
}
// Encode a string. // Encode a string.
func (o *Buffer) enc_string(p *Properties, base structPointer) error { func (o *Buffer) enc_string(p *Properties, base structPointer) error {
v := *structPointer_String(base, p.field) v := *structPointer_String(base, p.field)
@ -388,6 +472,16 @@ func (o *Buffer) enc_string(p *Properties, base structPointer) error {
return nil return nil
} }
func (o *Buffer) enc_proto3_string(p *Properties, base structPointer) error {
v := *structPointer_StringVal(base, p.field)
if v == "" {
return ErrNil
}
o.buf = append(o.buf, p.tagcode...)
o.EncodeStringBytes(v)
return nil
}
func size_string(p *Properties, base structPointer) (n int) { func size_string(p *Properties, base structPointer) (n int) {
v := *structPointer_String(base, p.field) v := *structPointer_String(base, p.field)
if v == nil { if v == nil {
@ -399,6 +493,16 @@ func size_string(p *Properties, base structPointer) (n int) {
return return
} }
func size_proto3_string(p *Properties, base structPointer) (n int) {
v := *structPointer_StringVal(base, p.field)
if v == "" {
return 0
}
n += len(p.tagcode)
n += sizeStringBytes(v)
return
}
// All protocol buffer fields are nillable, but be careful. // All protocol buffer fields are nillable, but be careful.
func isNil(v reflect.Value) bool { func isNil(v reflect.Value) bool {
switch v.Kind() { switch v.Kind() {
@ -429,7 +533,7 @@ func (o *Buffer) enc_struct_message(p *Properties, base structPointer) error {
} }
o.buf = append(o.buf, p.tagcode...) o.buf = append(o.buf, p.tagcode...)
return o.enc_len_struct(p.stype, p.sprop, structp, &state) return o.enc_len_struct(p.sprop, structp, &state)
} }
func size_struct_message(p *Properties, base structPointer) int { func size_struct_message(p *Properties, base structPointer) int {
@ -448,7 +552,7 @@ func size_struct_message(p *Properties, base structPointer) int {
} }
n0 := len(p.tagcode) n0 := len(p.tagcode)
n1 := size_struct(p.stype, p.sprop, structp) n1 := size_struct(p.sprop, structp)
n2 := sizeVarint(uint64(n1)) // size of encoded length n2 := sizeVarint(uint64(n1)) // size of encoded length
return n0 + n1 + n2 return n0 + n1 + n2
} }
@ -462,7 +566,7 @@ func (o *Buffer) enc_struct_group(p *Properties, base structPointer) error {
} }
o.EncodeVarint(uint64((p.Tag << 3) | WireStartGroup)) o.EncodeVarint(uint64((p.Tag << 3) | WireStartGroup))
err := o.enc_struct(p.stype, p.sprop, b) err := o.enc_struct(p.sprop, b)
if err != nil && !state.shouldContinue(err, nil) { if err != nil && !state.shouldContinue(err, nil) {
return err return err
} }
@ -477,7 +581,7 @@ func size_struct_group(p *Properties, base structPointer) (n int) {
} }
n += sizeVarint(uint64((p.Tag << 3) | WireStartGroup)) n += sizeVarint(uint64((p.Tag << 3) | WireStartGroup))
n += size_struct(p.stype, p.sprop, b) n += size_struct(p.sprop, b)
n += sizeVarint(uint64((p.Tag << 3) | WireEndGroup)) n += sizeVarint(uint64((p.Tag << 3) | WireEndGroup))
return return
} }
@ -551,6 +655,16 @@ func (o *Buffer) enc_slice_byte(p *Properties, base structPointer) error {
return nil return nil
} }
func (o *Buffer) enc_proto3_slice_byte(p *Properties, base structPointer) error {
s := *structPointer_Bytes(base, p.field)
if len(s) == 0 {
return ErrNil
}
o.buf = append(o.buf, p.tagcode...)
o.EncodeRawBytes(s)
return nil
}
func size_slice_byte(p *Properties, base structPointer) (n int) { func size_slice_byte(p *Properties, base structPointer) (n int) {
s := *structPointer_Bytes(base, p.field) s := *structPointer_Bytes(base, p.field)
if s == nil { if s == nil {
@ -561,6 +675,16 @@ func size_slice_byte(p *Properties, base structPointer) (n int) {
return return
} }
func size_proto3_slice_byte(p *Properties, base structPointer) (n int) {
s := *structPointer_Bytes(base, p.field)
if len(s) == 0 {
return 0
}
n += len(p.tagcode)
n += sizeRawBytes(s)
return
}
// Encode a slice of int32s ([]int32). // Encode a slice of int32s ([]int32).
func (o *Buffer) enc_slice_int32(p *Properties, base structPointer) error { func (o *Buffer) enc_slice_int32(p *Properties, base structPointer) error {
s := structPointer_Word32Slice(base, p.field) s := structPointer_Word32Slice(base, p.field)
@ -831,7 +955,7 @@ func (o *Buffer) enc_slice_struct_message(p *Properties, base structPointer) err
} }
o.buf = append(o.buf, p.tagcode...) o.buf = append(o.buf, p.tagcode...)
err := o.enc_len_struct(p.stype, p.sprop, structp, &state) err := o.enc_len_struct(p.sprop, structp, &state)
if err != nil && !state.shouldContinue(err, nil) { if err != nil && !state.shouldContinue(err, nil) {
if err == ErrNil { if err == ErrNil {
return ErrRepeatedHasNil return ErrRepeatedHasNil
@ -861,7 +985,7 @@ func size_slice_struct_message(p *Properties, base structPointer) (n int) {
continue continue
} }
n0 := size_struct(p.stype, p.sprop, structp) n0 := size_struct(p.sprop, structp)
n1 := sizeVarint(uint64(n0)) // size of encoded length n1 := sizeVarint(uint64(n0)) // size of encoded length
n += n0 + n1 n += n0 + n1
} }
@ -882,7 +1006,7 @@ func (o *Buffer) enc_slice_struct_group(p *Properties, base structPointer) error
o.EncodeVarint(uint64((p.Tag << 3) | WireStartGroup)) o.EncodeVarint(uint64((p.Tag << 3) | WireStartGroup))
err := o.enc_struct(p.stype, p.sprop, b) err := o.enc_struct(p.sprop, b)
if err != nil && !state.shouldContinue(err, nil) { if err != nil && !state.shouldContinue(err, nil) {
if err == ErrNil { if err == ErrNil {
@ -908,7 +1032,7 @@ func size_slice_struct_group(p *Properties, base structPointer) (n int) {
return // return size up to this point return // return size up to this point
} }
n += size_struct(p.stype, p.sprop, b) n += size_struct(p.sprop, b)
} }
return return
} }
@ -945,12 +1069,112 @@ func size_map(p *Properties, base structPointer) int {
return sizeExtensionMap(v) return sizeExtensionMap(v)
} }
// Encode a map field.
func (o *Buffer) enc_new_map(p *Properties, base structPointer) error {
var state errorState // XXX: or do we need to plumb this through?
/*
A map defined as
map<key_type, value_type> map_field = N;
is encoded in the same way as
message MapFieldEntry {
key_type key = 1;
value_type value = 2;
}
repeated MapFieldEntry map_field = N;
*/
v := structPointer_Map(base, p.field, p.mtype).Elem() // map[K]V
if v.Len() == 0 {
return nil
}
keycopy, valcopy, keybase, valbase := mapEncodeScratch(p.mtype)
enc := func() error {
if err := p.mkeyprop.enc(o, p.mkeyprop, keybase); err != nil {
return err
}
if err := p.mvalprop.enc(o, p.mvalprop, valbase); err != nil {
return err
}
return nil
}
keys := v.MapKeys()
sort.Sort(mapKeys(keys))
for _, key := range keys {
val := v.MapIndex(key)
keycopy.Set(key)
valcopy.Set(val)
o.buf = append(o.buf, p.tagcode...)
if err := o.enc_len_thing(enc, &state); err != nil {
return err
}
}
return nil
}
func size_new_map(p *Properties, base structPointer) int {
v := structPointer_Map(base, p.field, p.mtype).Elem() // map[K]V
keycopy, valcopy, keybase, valbase := mapEncodeScratch(p.mtype)
n := 0
for _, key := range v.MapKeys() {
val := v.MapIndex(key)
keycopy.Set(key)
valcopy.Set(val)
// Tag codes are two bytes per map entry.
n += 2
n += p.mkeyprop.size(p.mkeyprop, keybase)
n += p.mvalprop.size(p.mvalprop, valbase)
}
return n
}
// mapEncodeScratch returns a new reflect.Value matching the map's value type,
// and a structPointer suitable for passing to an encoder or sizer.
func mapEncodeScratch(mapType reflect.Type) (keycopy, valcopy reflect.Value, keybase, valbase structPointer) {
// Prepare addressable doubly-indirect placeholders for the key and value types.
// This is needed because the element-type encoders expect **T, but the map iteration produces T.
keycopy = reflect.New(mapType.Key()).Elem() // addressable K
keyptr := reflect.New(reflect.PtrTo(keycopy.Type())).Elem() // addressable *K
keyptr.Set(keycopy.Addr()) //
keybase = toStructPointer(keyptr.Addr()) // **K
// Value types are more varied and require special handling.
switch mapType.Elem().Kind() {
case reflect.Slice:
// []byte
var dummy []byte
valcopy = reflect.ValueOf(&dummy).Elem() // addressable []byte
valbase = toStructPointer(valcopy.Addr())
case reflect.Ptr:
// message; the generated field type is map[K]*Msg (so V is *Msg),
// so we only need one level of indirection.
valcopy = reflect.New(mapType.Elem()).Elem() // addressable V
valbase = toStructPointer(valcopy.Addr())
default:
// everything else
valcopy = reflect.New(mapType.Elem()).Elem() // addressable V
valptr := reflect.New(reflect.PtrTo(valcopy.Type())).Elem() // addressable *V
valptr.Set(valcopy.Addr()) //
valbase = toStructPointer(valptr.Addr()) // **V
}
return
}
// Encode a struct. // Encode a struct.
func (o *Buffer) enc_struct(t reflect.Type, prop *StructProperties, base structPointer) error { func (o *Buffer) enc_struct(prop *StructProperties, base structPointer) error {
var state errorState var state errorState
// Encode fields in tag order so that decoders may use optimizations // Encode fields in tag order so that decoders may use optimizations
// that depend on the ordering. // that depend on the ordering.
// http://code.google.com/apis/protocolbuffers/docs/encoding.html#order // https://developers.google.com/protocol-buffers/docs/encoding#order
for _, i := range prop.order { for _, i := range prop.order {
p := prop.Prop[i] p := prop.Prop[i]
if p.enc != nil { if p.enc != nil {
@ -978,7 +1202,7 @@ func (o *Buffer) enc_struct(t reflect.Type, prop *StructProperties, base structP
return state.err return state.err
} }
func size_struct(t reflect.Type, prop *StructProperties, base structPointer) (n int) { func size_struct(prop *StructProperties, base structPointer) (n int) {
for _, i := range prop.order { for _, i := range prop.order {
p := prop.Prop[i] p := prop.Prop[i]
if p.size != nil { if p.size != nil {
@ -998,11 +1222,16 @@ func size_struct(t reflect.Type, prop *StructProperties, base structPointer) (n
var zeroes [20]byte // longer than any conceivable sizeVarint var zeroes [20]byte // longer than any conceivable sizeVarint
// Encode a struct, preceded by its encoded length (as a varint). // Encode a struct, preceded by its encoded length (as a varint).
func (o *Buffer) enc_len_struct(t reflect.Type, prop *StructProperties, base structPointer, state *errorState) error { func (o *Buffer) enc_len_struct(prop *StructProperties, base structPointer, state *errorState) error {
return o.enc_len_thing(func() error { return o.enc_struct(prop, base) }, state)
}
// Encode something, preceded by its encoded length (as a varint).
func (o *Buffer) enc_len_thing(enc func() error, state *errorState) error {
iLen := len(o.buf) iLen := len(o.buf)
o.buf = append(o.buf, 0, 0, 0, 0) // reserve four bytes for length o.buf = append(o.buf, 0, 0, 0, 0) // reserve four bytes for length
iMsg := len(o.buf) iMsg := len(o.buf)
err := o.enc_struct(t, prop, base) err := enc()
if err != nil && !state.shouldContinue(err, nil) { if err != nil && !state.shouldContinue(err, nil) {
return err return err
} }

View File

@ -1,7 +1,7 @@
// Go support for Protocol Buffers - Google's data interchange format // Go support for Protocol Buffers - Google's data interchange format
// //
// Copyright 2011 The Go Authors. All rights reserved. // Copyright 2011 The Go Authors. All rights reserved.
// http://code.google.com/p/goprotobuf/ // https://github.com/golang/protobuf
// //
// Redistribution and use in source and binary forms, with or without // Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are // modification, are permitted provided that the following conditions are
@ -57,7 +57,7 @@ Equality is defined in this way:
although represented by []byte, is not a repeated field) although represented by []byte, is not a repeated field)
- Two unset fields are equal. - Two unset fields are equal.
- Two unknown field sets are equal if their current - Two unknown field sets are equal if their current
encoded state is equal. (TODO) encoded state is equal.
- Two extension sets are equal iff they have corresponding - Two extension sets are equal iff they have corresponding
elements that are pairwise equal. elements that are pairwise equal.
- Every other combination of things are not equal. - Every other combination of things are not equal.
@ -154,6 +154,21 @@ func equalAny(v1, v2 reflect.Value) bool {
return v1.Float() == v2.Float() return v1.Float() == v2.Float()
case reflect.Int32, reflect.Int64: case reflect.Int32, reflect.Int64:
return v1.Int() == v2.Int() return v1.Int() == v2.Int()
case reflect.Map:
if v1.Len() != v2.Len() {
return false
}
for _, key := range v1.MapKeys() {
val2 := v2.MapIndex(key)
if !val2.IsValid() {
// This key was not found in the second map.
return false
}
if !equalAny(v1.MapIndex(key), val2) {
return false
}
}
return true
case reflect.Ptr: case reflect.Ptr:
return equalAny(v1.Elem(), v2.Elem()) return equalAny(v1.Elem(), v2.Elem())
case reflect.Slice: case reflect.Slice:

View File

@ -1,7 +1,7 @@
// Go support for Protocol Buffers - Google's data interchange format // Go support for Protocol Buffers - Google's data interchange format
// //
// Copyright 2011 The Go Authors. All rights reserved. // Copyright 2011 The Go Authors. All rights reserved.
// http://code.google.com/p/goprotobuf/ // https://github.com/golang/protobuf
// //
// Redistribution and use in source and binary forms, with or without // Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are // modification, are permitted provided that the following conditions are
@ -35,7 +35,7 @@ import (
"testing" "testing"
pb "./testdata" pb "./testdata"
. "code.google.com/p/goprotobuf/proto" . "github.com/golang/protobuf/proto"
) )
// Four identical base messages. // Four identical base messages.
@ -155,6 +155,31 @@ var EqualTests = []struct {
}, },
true, true,
}, },
{
"map same",
&pb.MessageWithMap{NameMapping: map[int32]string{1: "Ken"}},
&pb.MessageWithMap{NameMapping: map[int32]string{1: "Ken"}},
true,
},
{
"map different entry",
&pb.MessageWithMap{NameMapping: map[int32]string{1: "Ken"}},
&pb.MessageWithMap{NameMapping: map[int32]string{2: "Rob"}},
false,
},
{
"map different key only",
&pb.MessageWithMap{NameMapping: map[int32]string{1: "Ken"}},
&pb.MessageWithMap{NameMapping: map[int32]string{2: "Ken"}},
false,
},
{
"map different value only",
&pb.MessageWithMap{NameMapping: map[int32]string{1: "Ken"}},
&pb.MessageWithMap{NameMapping: map[int32]string{1: "Rob"}},
false,
},
} }
func TestEqual(t *testing.T) { func TestEqual(t *testing.T) {

View File

@ -1,7 +1,7 @@
// Go support for Protocol Buffers - Google's data interchange format // Go support for Protocol Buffers - Google's data interchange format
// //
// Copyright 2010 The Go Authors. All rights reserved. // Copyright 2010 The Go Authors. All rights reserved.
// http://code.google.com/p/goprotobuf/ // https://github.com/golang/protobuf
// //
// Redistribution and use in source and binary forms, with or without // Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are // modification, are permitted provided that the following conditions are
@ -227,7 +227,8 @@ func GetExtension(pb extendableProto, extension *ExtensionDesc) (interface{}, er
return nil, err return nil, err
} }
e, ok := pb.ExtensionMap()[extension.Field] emap := pb.ExtensionMap()
e, ok := emap[extension.Field]
if !ok { if !ok {
return nil, ErrMissingExtension return nil, ErrMissingExtension
} }
@ -252,6 +253,7 @@ func GetExtension(pb extendableProto, extension *ExtensionDesc) (interface{}, er
e.value = v e.value = v
e.desc = extension e.desc = extension
e.enc = nil e.enc = nil
emap[extension.Field] = e
return e.value, nil return e.value, nil
} }

View File

@ -0,0 +1,137 @@
// Go support for Protocol Buffers - Google's data interchange format
//
// Copyright 2014 The Go Authors. All rights reserved.
// https://github.com/golang/protobuf
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are
// met:
//
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above
// copyright notice, this list of conditions and the following disclaimer
// in the documentation and/or other materials provided with the
// distribution.
// * Neither the name of Google Inc. nor the names of its
// contributors may be used to endorse or promote products derived from
// this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
package proto_test
import (
"testing"
pb "./testdata"
"github.com/golang/protobuf/proto"
)
func TestGetExtensionsWithMissingExtensions(t *testing.T) {
msg := &pb.MyMessage{}
ext1 := &pb.Ext{}
if err := proto.SetExtension(msg, pb.E_Ext_More, ext1); err != nil {
t.Fatalf("Could not set ext1: %s", ext1)
}
exts, err := proto.GetExtensions(msg, []*proto.ExtensionDesc{
pb.E_Ext_More,
pb.E_Ext_Text,
})
if err != nil {
t.Fatalf("GetExtensions() failed: %s", err)
}
if exts[0] != ext1 {
t.Errorf("ext1 not in returned extensions: %T %v", exts[0], exts[0])
}
if exts[1] != nil {
t.Errorf("ext2 in returned extensions: %T %v", exts[1], exts[1])
}
}
func TestGetExtensionStability(t *testing.T) {
check := func(m *pb.MyMessage) bool {
ext1, err := proto.GetExtension(m, pb.E_Ext_More)
if err != nil {
t.Fatalf("GetExtension() failed: %s", err)
}
ext2, err := proto.GetExtension(m, pb.E_Ext_More)
if err != nil {
t.Fatalf("GetExtension() failed: %s", err)
}
return ext1 == ext2
}
msg := &pb.MyMessage{Count: proto.Int32(4)}
ext0 := &pb.Ext{}
if err := proto.SetExtension(msg, pb.E_Ext_More, ext0); err != nil {
t.Fatalf("Could not set ext1: %s", ext0)
}
if !check(msg) {
t.Errorf("GetExtension() not stable before marshaling")
}
bb, err := proto.Marshal(msg)
if err != nil {
t.Fatalf("Marshal() failed: %s", err)
}
msg1 := &pb.MyMessage{}
err = proto.Unmarshal(bb, msg1)
if err != nil {
t.Fatalf("Unmarshal() failed: %s", err)
}
if !check(msg1) {
t.Errorf("GetExtension() not stable after unmarshaling")
}
}
func TestExtensionsRoundTrip(t *testing.T) {
msg := &pb.MyMessage{}
ext1 := &pb.Ext{
Data: proto.String("hi"),
}
ext2 := &pb.Ext{
Data: proto.String("there"),
}
exists := proto.HasExtension(msg, pb.E_Ext_More)
if exists {
t.Error("Extension More present unexpectedly")
}
if err := proto.SetExtension(msg, pb.E_Ext_More, ext1); err != nil {
t.Error(err)
}
if err := proto.SetExtension(msg, pb.E_Ext_More, ext2); err != nil {
t.Error(err)
}
e, err := proto.GetExtension(msg, pb.E_Ext_More)
if err != nil {
t.Error(err)
}
x, ok := e.(*pb.Ext)
if !ok {
t.Errorf("e has type %T, expected testdata.Ext", e)
} else if *x.Data != "there" {
t.Errorf("SetExtension failed to overwrite, got %+v, not 'there'", x)
}
proto.ClearExtension(msg, pb.E_Ext_More)
if _, err = proto.GetExtension(msg, pb.E_Ext_More); err != proto.ErrMissingExtension {
t.Errorf("got %v, expected ErrMissingExtension", e)
}
if _, err := proto.GetExtension(msg, pb.E_X215); err == nil {
t.Error("expected bad extension error, got nil")
}
if err := proto.SetExtension(msg, pb.E_X215, 12); err == nil {
t.Error("expected extension err")
}
if err := proto.SetExtension(msg, pb.E_Ext_More, 12); err == nil {
t.Error("expected some sort of type mismatch error, got nil")
}
}

View File

@ -1,7 +1,7 @@
// Go support for Protocol Buffers - Google's data interchange format // Go support for Protocol Buffers - Google's data interchange format
// //
// Copyright 2010 The Go Authors. All rights reserved. // Copyright 2010 The Go Authors. All rights reserved.
// http://code.google.com/p/goprotobuf/ // https://github.com/golang/protobuf
// //
// Redistribution and use in source and binary forms, with or without // Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are // modification, are permitted provided that the following conditions are
@ -39,7 +39,7 @@
- Names are turned from camel_case to CamelCase for export. - Names are turned from camel_case to CamelCase for export.
- There are no methods on v to set fields; just treat - There are no methods on v to set fields; just treat
them as structure fields. them as structure fields.
- There are getters that return a field's value if set, - There are getters that return a field's value if set,
and return the field's default value if unset. and return the field's default value if unset.
The getters work even if the receiver is a nil message. The getters work even if the receiver is a nil message.
@ -50,17 +50,16 @@
That is, optional or required field int32 f becomes F *int32. That is, optional or required field int32 f becomes F *int32.
- Repeated fields are slices. - Repeated fields are slices.
- Helper functions are available to aid the setting of fields. - Helper functions are available to aid the setting of fields.
Helpers for getting values are superseded by the msg.Foo = proto.String("hello") // set field
GetFoo methods and their use is deprecated.
msg.Foo = proto.String("hello") // set field
- Constants are defined to hold the default values of all fields that - Constants are defined to hold the default values of all fields that
have them. They have the form Default_StructName_FieldName. have them. They have the form Default_StructName_FieldName.
Because the getter methods handle defaulted values, Because the getter methods handle defaulted values,
direct use of these constants should be rare. direct use of these constants should be rare.
- Enums are given type names and maps from names to values. - Enums are given type names and maps from names to values.
Enum values are prefixed with the enum's type name. Enum types have Enum values are prefixed by the enclosing message's name, or by the
a String method, and a Enum method to assist in message construction. enum's type name if it is a top-level enum. Enum types have a String
- Nested groups and enums have type names prefixed with the name of method, and a Enum method to assist in message construction.
- Nested messages, groups and enums have type names prefixed with the name of
the surrounding message type. the surrounding message type.
- Extensions are given descriptor names that start with E_, - Extensions are given descriptor names that start with E_,
followed by an underscore-delimited list of the nested messages followed by an underscore-delimited list of the nested messages
@ -74,7 +73,7 @@
package example; package example;
enum FOO { X = 17; }; enum FOO { X = 17; }
message Test { message Test {
required string label = 1; required string label = 1;
@ -89,7 +88,8 @@
package example package example
import "code.google.com/p/goprotobuf/proto" import proto "github.com/golang/protobuf/proto"
import math "math"
type FOO int32 type FOO int32
const ( const (
@ -110,6 +110,14 @@
func (x FOO) String() string { func (x FOO) String() string {
return proto.EnumName(FOO_name, int32(x)) return proto.EnumName(FOO_name, int32(x))
} }
func (x *FOO) UnmarshalJSON(data []byte) error {
value, err := proto.UnmarshalJSONEnum(FOO_value, data)
if err != nil {
return err
}
*x = FOO(value)
return nil
}
type Test struct { type Test struct {
Label *string `protobuf:"bytes,1,req,name=label" json:"label,omitempty"` Label *string `protobuf:"bytes,1,req,name=label" json:"label,omitempty"`
@ -118,41 +126,41 @@
Optionalgroup *Test_OptionalGroup `protobuf:"group,4,opt,name=OptionalGroup" json:"optionalgroup,omitempty"` Optionalgroup *Test_OptionalGroup `protobuf:"group,4,opt,name=OptionalGroup" json:"optionalgroup,omitempty"`
XXX_unrecognized []byte `json:"-"` XXX_unrecognized []byte `json:"-"`
} }
func (this *Test) Reset() { *this = Test{} } func (m *Test) Reset() { *m = Test{} }
func (this *Test) String() string { return proto.CompactTextString(this) } func (m *Test) String() string { return proto.CompactTextString(m) }
func (*Test) ProtoMessage() {}
const Default_Test_Type int32 = 77 const Default_Test_Type int32 = 77
func (this *Test) GetLabel() string { func (m *Test) GetLabel() string {
if this != nil && this.Label != nil { if m != nil && m.Label != nil {
return *this.Label return *m.Label
} }
return "" return ""
} }
func (this *Test) GetType() int32 { func (m *Test) GetType() int32 {
if this != nil && this.Type != nil { if m != nil && m.Type != nil {
return *this.Type return *m.Type
} }
return Default_Test_Type return Default_Test_Type
} }
func (this *Test) GetOptionalgroup() *Test_OptionalGroup { func (m *Test) GetOptionalgroup() *Test_OptionalGroup {
if this != nil { if m != nil {
return this.Optionalgroup return m.Optionalgroup
} }
return nil return nil
} }
type Test_OptionalGroup struct { type Test_OptionalGroup struct {
RequiredField *string `protobuf:"bytes,5,req" json:"RequiredField,omitempty"` RequiredField *string `protobuf:"bytes,5,req" json:"RequiredField,omitempty"`
XXX_unrecognized []byte `json:"-"`
} }
func (this *Test_OptionalGroup) Reset() { *this = Test_OptionalGroup{} } func (m *Test_OptionalGroup) Reset() { *m = Test_OptionalGroup{} }
func (this *Test_OptionalGroup) String() string { return proto.CompactTextString(this) } func (m *Test_OptionalGroup) String() string { return proto.CompactTextString(m) }
func (this *Test_OptionalGroup) GetRequiredField() string { func (m *Test_OptionalGroup) GetRequiredField() string {
if this != nil && this.RequiredField != nil { if m != nil && m.RequiredField != nil {
return *this.RequiredField return *m.RequiredField
} }
return "" return ""
} }
@ -168,15 +176,15 @@
import ( import (
"log" "log"
"code.google.com/p/goprotobuf/proto" "github.com/golang/protobuf/proto"
"./example.pb" pb "./example.pb"
) )
func main() { func main() {
test := &example.Test{ test := &pb.Test{
Label: proto.String("hello"), Label: proto.String("hello"),
Type: proto.Int32(17), Type: proto.Int32(17),
Optionalgroup: &example.Test_OptionalGroup{ Optionalgroup: &pb.Test_OptionalGroup{
RequiredField: proto.String("good bye"), RequiredField: proto.String("good bye"),
}, },
} }
@ -184,7 +192,7 @@
if err != nil { if err != nil {
log.Fatal("marshaling error: ", err) log.Fatal("marshaling error: ", err)
} }
newTest := new(example.Test) newTest := &pb.Test{}
err = proto.Unmarshal(data, newTest) err = proto.Unmarshal(data, newTest)
if err != nil { if err != nil {
log.Fatal("unmarshaling error: ", err) log.Fatal("unmarshaling error: ", err)
@ -323,9 +331,7 @@ func Float64(v float64) *float64 {
// Uint32 is a helper routine that allocates a new uint32 value // Uint32 is a helper routine that allocates a new uint32 value
// to store v and returns a pointer to it. // to store v and returns a pointer to it.
func Uint32(v uint32) *uint32 { func Uint32(v uint32) *uint32 {
p := new(uint32) return &v
*p = v
return p
} }
// Uint64 is a helper routine that allocates a new uint64 value // Uint64 is a helper routine that allocates a new uint64 value
@ -738,3 +744,16 @@ func buildDefaultMessage(t reflect.Type) (dm defaultMessage) {
return dm return dm
} }
// Map fields may have key types of non-float scalars, strings and enums.
// The easiest way to sort them in some deterministic order is to use fmt.
// If this turns out to be inefficient we can always consider other options,
// such as doing a Schwartzian transform.
type mapKeys []reflect.Value
func (s mapKeys) Len() int { return len(s) }
func (s mapKeys) Swap(i, j int) { s[i], s[j] = s[j], s[i] }
func (s mapKeys) Less(i, j int) bool {
return fmt.Sprint(s[i].Interface()) < fmt.Sprint(s[j].Interface())
}

View File

@ -1,7 +1,7 @@
// Go support for Protocol Buffers - Google's data interchange format // Go support for Protocol Buffers - Google's data interchange format
// //
// Copyright 2010 The Go Authors. All rights reserved. // Copyright 2010 The Go Authors. All rights reserved.
// http://code.google.com/p/goprotobuf/ // https://github.com/golang/protobuf
// //
// Redistribution and use in source and binary forms, with or without // Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are // modification, are permitted provided that the following conditions are
@ -36,7 +36,10 @@ package proto
*/ */
import ( import (
"bytes"
"encoding/json"
"errors" "errors"
"fmt"
"reflect" "reflect"
"sort" "sort"
) )
@ -211,6 +214,61 @@ func UnmarshalMessageSet(buf []byte, m map[int32]Extension) error {
return nil return nil
} }
// MarshalMessageSetJSON encodes the extension map represented by m in JSON format.
// It is called by generated MarshalJSON methods on protocol buffer messages with the message_set_wire_format option.
func MarshalMessageSetJSON(m map[int32]Extension) ([]byte, error) {
var b bytes.Buffer
b.WriteByte('{')
// Process the map in key order for deterministic output.
ids := make([]int32, 0, len(m))
for id := range m {
ids = append(ids, id)
}
sort.Sort(int32Slice(ids)) // int32Slice defined in text.go
for i, id := range ids {
ext := m[id]
if i > 0 {
b.WriteByte(',')
}
msd, ok := messageSetMap[id]
if !ok {
// Unknown type; we can't render it, so skip it.
continue
}
fmt.Fprintf(&b, `"[%s]":`, msd.name)
x := ext.value
if x == nil {
x = reflect.New(msd.t.Elem()).Interface()
if err := Unmarshal(ext.enc, x.(Message)); err != nil {
return nil, err
}
}
d, err := json.Marshal(x)
if err != nil {
return nil, err
}
b.Write(d)
}
b.WriteByte('}')
return b.Bytes(), nil
}
// UnmarshalMessageSetJSON decodes the extension map encoded in buf in JSON format.
// It is called by generated UnmarshalJSON methods on protocol buffer messages with the message_set_wire_format option.
func UnmarshalMessageSetJSON(buf []byte, m map[int32]Extension) error {
// Common-case fast path.
if len(buf) == 0 || bytes.Equal(buf, []byte("{}")) {
return nil
}
// This is fairly tricky, and it's not clear that it is needed.
return errors.New("TODO: UnmarshalMessageSetJSON not yet implemented")
}
// A global registry of types that can be used in a MessageSet. // A global registry of types that can be used in a MessageSet.
var messageSetMap = make(map[int32]messageSetDesc) var messageSetMap = make(map[int32]messageSetDesc)
@ -221,9 +279,9 @@ type messageSetDesc struct {
} }
// RegisterMessageSetType is called from the generated code. // RegisterMessageSetType is called from the generated code.
func RegisterMessageSetType(i messageTypeIder, name string) { func RegisterMessageSetType(m Message, fieldNum int32, name string) {
messageSetMap[i.MessageTypeId()] = messageSetDesc{ messageSetMap[fieldNum] = messageSetDesc{
t: reflect.TypeOf(i), t: reflect.TypeOf(m),
name: name, name: name,
} }
} }

View File

@ -1,7 +1,7 @@
// Go support for Protocol Buffers - Google's data interchange format // Go support for Protocol Buffers - Google's data interchange format
// //
// Copyright 2014 The Go Authors. All rights reserved. // Copyright 2014 The Go Authors. All rights reserved.
// http://code.google.com/p/goprotobuf/ // https://github.com/golang/protobuf
// //
// Redistribution and use in source and binary forms, with or without // Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are // modification, are permitted provided that the following conditions are

View File

@ -1,7 +1,7 @@
// Go support for Protocol Buffers - Google's data interchange format // Go support for Protocol Buffers - Google's data interchange format
// //
// Copyright 2012 The Go Authors. All rights reserved. // Copyright 2012 The Go Authors. All rights reserved.
// http://code.google.com/p/goprotobuf/ // https://github.com/golang/protobuf
// //
// Redistribution and use in source and binary forms, with or without // Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are // modification, are permitted provided that the following conditions are
@ -29,7 +29,7 @@
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
// +build appengine,!appenginevm // +build appengine
// This file contains an implementation of proto field accesses using package reflect. // This file contains an implementation of proto field accesses using package reflect.
// It is slower than the code in pointer_unsafe.go but it avoids package unsafe and can // It is slower than the code in pointer_unsafe.go but it avoids package unsafe and can
@ -114,6 +114,11 @@ func structPointer_Bool(p structPointer, f field) **bool {
return structPointer_ifield(p, f).(**bool) return structPointer_ifield(p, f).(**bool)
} }
// BoolVal returns the address of a bool field in the struct.
func structPointer_BoolVal(p structPointer, f field) *bool {
return structPointer_ifield(p, f).(*bool)
}
// BoolSlice returns the address of a []bool field in the struct. // BoolSlice returns the address of a []bool field in the struct.
func structPointer_BoolSlice(p structPointer, f field) *[]bool { func structPointer_BoolSlice(p structPointer, f field) *[]bool {
return structPointer_ifield(p, f).(*[]bool) return structPointer_ifield(p, f).(*[]bool)
@ -124,6 +129,11 @@ func structPointer_String(p structPointer, f field) **string {
return structPointer_ifield(p, f).(**string) return structPointer_ifield(p, f).(**string)
} }
// StringVal returns the address of a string field in the struct.
func structPointer_StringVal(p structPointer, f field) *string {
return structPointer_ifield(p, f).(*string)
}
// StringSlice returns the address of a []string field in the struct. // StringSlice returns the address of a []string field in the struct.
func structPointer_StringSlice(p structPointer, f field) *[]string { func structPointer_StringSlice(p structPointer, f field) *[]string {
return structPointer_ifield(p, f).(*[]string) return structPointer_ifield(p, f).(*[]string)
@ -134,6 +144,11 @@ func structPointer_ExtMap(p structPointer, f field) *map[int32]Extension {
return structPointer_ifield(p, f).(*map[int32]Extension) return structPointer_ifield(p, f).(*map[int32]Extension)
} }
// Map returns the reflect.Value for the address of a map field in the struct.
func structPointer_Map(p structPointer, f field, typ reflect.Type) reflect.Value {
return structPointer_field(p, f).Addr()
}
// SetStructPointer writes a *struct field in the struct. // SetStructPointer writes a *struct field in the struct.
func structPointer_SetStructPointer(p structPointer, f field, q structPointer) { func structPointer_SetStructPointer(p structPointer, f field, q structPointer) {
structPointer_field(p, f).Set(q.v) structPointer_field(p, f).Set(q.v)
@ -235,6 +250,49 @@ func structPointer_Word32(p structPointer, f field) word32 {
return word32{structPointer_field(p, f)} return word32{structPointer_field(p, f)}
} }
// A word32Val represents a field of type int32, uint32, float32, or enum.
// That is, v.Type() is int32, uint32, float32, or enum and v is assignable.
type word32Val struct {
v reflect.Value
}
// Set sets *p to x.
func word32Val_Set(p word32Val, x uint32) {
switch p.v.Type() {
case int32Type:
p.v.SetInt(int64(x))
return
case uint32Type:
p.v.SetUint(uint64(x))
return
case float32Type:
p.v.SetFloat(float64(math.Float32frombits(x)))
return
}
// must be enum
p.v.SetInt(int64(int32(x)))
}
// Get gets the bits pointed at by p, as a uint32.
func word32Val_Get(p word32Val) uint32 {
elem := p.v
switch elem.Kind() {
case reflect.Int32:
return uint32(elem.Int())
case reflect.Uint32:
return uint32(elem.Uint())
case reflect.Float32:
return math.Float32bits(float32(elem.Float()))
}
panic("unreachable")
}
// Word32Val returns a reference to a int32, uint32, float32, or enum field in the struct.
func structPointer_Word32Val(p structPointer, f field) word32Val {
return word32Val{structPointer_field(p, f)}
}
// A word32Slice is a slice of 32-bit values. // A word32Slice is a slice of 32-bit values.
// That is, v.Type() is []int32, []uint32, []float32, or []enum. // That is, v.Type() is []int32, []uint32, []float32, or []enum.
type word32Slice struct { type word32Slice struct {
@ -339,6 +397,43 @@ func structPointer_Word64(p structPointer, f field) word64 {
return word64{structPointer_field(p, f)} return word64{structPointer_field(p, f)}
} }
// word64Val is like word32Val but for 64-bit values.
type word64Val struct {
v reflect.Value
}
func word64Val_Set(p word64Val, o *Buffer, x uint64) {
switch p.v.Type() {
case int64Type:
p.v.SetInt(int64(x))
return
case uint64Type:
p.v.SetUint(x)
return
case float64Type:
p.v.SetFloat(math.Float64frombits(x))
return
}
panic("unreachable")
}
func word64Val_Get(p word64Val) uint64 {
elem := p.v
switch elem.Kind() {
case reflect.Int64:
return uint64(elem.Int())
case reflect.Uint64:
return elem.Uint()
case reflect.Float64:
return math.Float64bits(elem.Float())
}
panic("unreachable")
}
func structPointer_Word64Val(p structPointer, f field) word64Val {
return word64Val{structPointer_field(p, f)}
}
type word64Slice struct { type word64Slice struct {
v reflect.Value v reflect.Value
} }

View File

@ -1,7 +1,7 @@
// Go support for Protocol Buffers - Google's data interchange format // Go support for Protocol Buffers - Google's data interchange format
// //
// Copyright 2012 The Go Authors. All rights reserved. // Copyright 2012 The Go Authors. All rights reserved.
// http://code.google.com/p/goprotobuf/ // https://github.com/golang/protobuf
// //
// Redistribution and use in source and binary forms, with or without // Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are // modification, are permitted provided that the following conditions are
@ -29,7 +29,7 @@
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
// +build !appengine appenginevm // +build !appengine
// This file contains the implementation of the proto field accesses using package unsafe. // This file contains the implementation of the proto field accesses using package unsafe.
@ -100,6 +100,11 @@ func structPointer_Bool(p structPointer, f field) **bool {
return (**bool)(unsafe.Pointer(uintptr(p) + uintptr(f))) return (**bool)(unsafe.Pointer(uintptr(p) + uintptr(f)))
} }
// BoolVal returns the address of a bool field in the struct.
func structPointer_BoolVal(p structPointer, f field) *bool {
return (*bool)(unsafe.Pointer(uintptr(p) + uintptr(f)))
}
// BoolSlice returns the address of a []bool field in the struct. // BoolSlice returns the address of a []bool field in the struct.
func structPointer_BoolSlice(p structPointer, f field) *[]bool { func structPointer_BoolSlice(p structPointer, f field) *[]bool {
return (*[]bool)(unsafe.Pointer(uintptr(p) + uintptr(f))) return (*[]bool)(unsafe.Pointer(uintptr(p) + uintptr(f)))
@ -110,6 +115,11 @@ func structPointer_String(p structPointer, f field) **string {
return (**string)(unsafe.Pointer(uintptr(p) + uintptr(f))) return (**string)(unsafe.Pointer(uintptr(p) + uintptr(f)))
} }
// StringVal returns the address of a string field in the struct.
func structPointer_StringVal(p structPointer, f field) *string {
return (*string)(unsafe.Pointer(uintptr(p) + uintptr(f)))
}
// StringSlice returns the address of a []string field in the struct. // StringSlice returns the address of a []string field in the struct.
func structPointer_StringSlice(p structPointer, f field) *[]string { func structPointer_StringSlice(p structPointer, f field) *[]string {
return (*[]string)(unsafe.Pointer(uintptr(p) + uintptr(f))) return (*[]string)(unsafe.Pointer(uintptr(p) + uintptr(f)))
@ -120,6 +130,11 @@ func structPointer_ExtMap(p structPointer, f field) *map[int32]Extension {
return (*map[int32]Extension)(unsafe.Pointer(uintptr(p) + uintptr(f))) return (*map[int32]Extension)(unsafe.Pointer(uintptr(p) + uintptr(f)))
} }
// Map returns the reflect.Value for the address of a map field in the struct.
func structPointer_Map(p structPointer, f field, typ reflect.Type) reflect.Value {
return reflect.NewAt(typ, unsafe.Pointer(uintptr(p)+uintptr(f)))
}
// SetStructPointer writes a *struct field in the struct. // SetStructPointer writes a *struct field in the struct.
func structPointer_SetStructPointer(p structPointer, f field, q structPointer) { func structPointer_SetStructPointer(p structPointer, f field, q structPointer) {
*(*structPointer)(unsafe.Pointer(uintptr(p) + uintptr(f))) = q *(*structPointer)(unsafe.Pointer(uintptr(p) + uintptr(f))) = q
@ -170,6 +185,24 @@ func structPointer_Word32(p structPointer, f field) word32 {
return word32((**uint32)(unsafe.Pointer(uintptr(p) + uintptr(f)))) return word32((**uint32)(unsafe.Pointer(uintptr(p) + uintptr(f))))
} }
// A word32Val is the address of a 32-bit value field.
type word32Val *uint32
// Set sets *p to x.
func word32Val_Set(p word32Val, x uint32) {
*p = x
}
// Get gets the value pointed at by p.
func word32Val_Get(p word32Val) uint32 {
return *p
}
// Word32Val returns the address of a *int32, *uint32, *float32, or *enum field in the struct.
func structPointer_Word32Val(p structPointer, f field) word32Val {
return word32Val((*uint32)(unsafe.Pointer(uintptr(p) + uintptr(f))))
}
// A word32Slice is a slice of 32-bit values. // A word32Slice is a slice of 32-bit values.
type word32Slice []uint32 type word32Slice []uint32
@ -206,6 +239,21 @@ func structPointer_Word64(p structPointer, f field) word64 {
return word64((**uint64)(unsafe.Pointer(uintptr(p) + uintptr(f)))) return word64((**uint64)(unsafe.Pointer(uintptr(p) + uintptr(f))))
} }
// word64Val is like word32Val but for 64-bit values.
type word64Val *uint64
func word64Val_Set(p word64Val, o *Buffer, x uint64) {
*p = x
}
func word64Val_Get(p word64Val) uint64 {
return *p
}
func structPointer_Word64Val(p structPointer, f field) word64Val {
return word64Val((*uint64)(unsafe.Pointer(uintptr(p) + uintptr(f))))
}
// word64Slice is like word32Slice but for 64-bit values. // word64Slice is like word32Slice but for 64-bit values.
type word64Slice []uint64 type word64Slice []uint64

View File

@ -1,7 +1,7 @@
// Go support for Protocol Buffers - Google's data interchange format // Go support for Protocol Buffers - Google's data interchange format
// //
// Copyright 2010 The Go Authors. All rights reserved. // Copyright 2010 The Go Authors. All rights reserved.
// http://code.google.com/p/goprotobuf/ // https://github.com/golang/protobuf
// //
// Redistribution and use in source and binary forms, with or without // Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are // modification, are permitted provided that the following conditions are
@ -155,6 +155,7 @@ type Properties struct {
Repeated bool Repeated bool
Packed bool // relevant for repeated primitives only Packed bool // relevant for repeated primitives only
Enum string // set for enum types only Enum string // set for enum types only
proto3 bool // whether this is known to be a proto3 field; set for []byte only
Default string // default value Default string // default value
HasDefault bool // whether an explicit default was provided HasDefault bool // whether an explicit default was provided
@ -170,6 +171,10 @@ type Properties struct {
isMarshaler bool isMarshaler bool
isUnmarshaler bool isUnmarshaler bool
mtype reflect.Type // set for map types only
mkeyprop *Properties // set for map types only
mvalprop *Properties // set for map types only
size sizer size sizer
valSize valueSizer // set for bool and numeric types only valSize valueSizer // set for bool and numeric types only
@ -200,6 +205,9 @@ func (p *Properties) String() string {
if p.OrigName != p.Name { if p.OrigName != p.Name {
s += ",name=" + p.OrigName s += ",name=" + p.OrigName
} }
if p.proto3 {
s += ",proto3"
}
if len(p.Enum) > 0 { if len(p.Enum) > 0 {
s += ",enum=" + p.Enum s += ",enum=" + p.Enum
} }
@ -274,6 +282,8 @@ func (p *Properties) Parse(s string) {
p.OrigName = f[5:] p.OrigName = f[5:]
case strings.HasPrefix(f, "enum="): case strings.HasPrefix(f, "enum="):
p.Enum = f[5:] p.Enum = f[5:]
case f == "proto3":
p.proto3 = true
case strings.HasPrefix(f, "def="): case strings.HasPrefix(f, "def="):
p.HasDefault = true p.HasDefault = true
p.Default = f[4:] // rest of string p.Default = f[4:] // rest of string
@ -293,19 +303,50 @@ func logNoSliceEnc(t1, t2 reflect.Type) {
var protoMessageType = reflect.TypeOf((*Message)(nil)).Elem() var protoMessageType = reflect.TypeOf((*Message)(nil)).Elem()
// Initialize the fields for encoding and decoding. // Initialize the fields for encoding and decoding.
func (p *Properties) setEncAndDec(typ reflect.Type, lockGetProp bool) { func (p *Properties) setEncAndDec(typ reflect.Type, f *reflect.StructField, lockGetProp bool) {
p.enc = nil p.enc = nil
p.dec = nil p.dec = nil
p.size = nil p.size = nil
switch t1 := typ; t1.Kind() { switch t1 := typ; t1.Kind() {
default: default:
fmt.Fprintf(os.Stderr, "proto: no coders for %T\n", t1) fmt.Fprintf(os.Stderr, "proto: no coders for %v\n", t1)
// proto3 scalar types
case reflect.Bool:
p.enc = (*Buffer).enc_proto3_bool
p.dec = (*Buffer).dec_proto3_bool
p.size = size_proto3_bool
case reflect.Int32:
p.enc = (*Buffer).enc_proto3_int32
p.dec = (*Buffer).dec_proto3_int32
p.size = size_proto3_int32
case reflect.Uint32:
p.enc = (*Buffer).enc_proto3_uint32
p.dec = (*Buffer).dec_proto3_int32 // can reuse
p.size = size_proto3_uint32
case reflect.Int64, reflect.Uint64:
p.enc = (*Buffer).enc_proto3_int64
p.dec = (*Buffer).dec_proto3_int64
p.size = size_proto3_int64
case reflect.Float32:
p.enc = (*Buffer).enc_proto3_uint32 // can just treat them as bits
p.dec = (*Buffer).dec_proto3_int32
p.size = size_proto3_uint32
case reflect.Float64:
p.enc = (*Buffer).enc_proto3_int64 // can just treat them as bits
p.dec = (*Buffer).dec_proto3_int64
p.size = size_proto3_int64
case reflect.String:
p.enc = (*Buffer).enc_proto3_string
p.dec = (*Buffer).dec_proto3_string
p.size = size_proto3_string
case reflect.Ptr: case reflect.Ptr:
switch t2 := t1.Elem(); t2.Kind() { switch t2 := t1.Elem(); t2.Kind() {
default: default:
fmt.Fprintf(os.Stderr, "proto: no encoder function for %T -> %T\n", t1, t2) fmt.Fprintf(os.Stderr, "proto: no encoder function for %v -> %v\n", t1, t2)
break break
case reflect.Bool: case reflect.Bool:
p.enc = (*Buffer).enc_bool p.enc = (*Buffer).enc_bool
@ -399,6 +440,10 @@ func (p *Properties) setEncAndDec(typ reflect.Type, lockGetProp bool) {
p.enc = (*Buffer).enc_slice_byte p.enc = (*Buffer).enc_slice_byte
p.dec = (*Buffer).dec_slice_byte p.dec = (*Buffer).dec_slice_byte
p.size = size_slice_byte p.size = size_slice_byte
if p.proto3 {
p.enc = (*Buffer).enc_proto3_slice_byte
p.size = size_proto3_slice_byte
}
case reflect.Float32, reflect.Float64: case reflect.Float32, reflect.Float64:
switch t2.Bits() { switch t2.Bits() {
case 32: case 32:
@ -461,6 +506,23 @@ func (p *Properties) setEncAndDec(typ reflect.Type, lockGetProp bool) {
p.size = size_slice_slice_byte p.size = size_slice_slice_byte
} }
} }
case reflect.Map:
p.enc = (*Buffer).enc_new_map
p.dec = (*Buffer).dec_new_map
p.size = size_new_map
p.mtype = t1
p.mkeyprop = &Properties{}
p.mkeyprop.init(reflect.PtrTo(p.mtype.Key()), "Key", f.Tag.Get("protobuf_key"), nil, lockGetProp)
p.mvalprop = &Properties{}
vtype := p.mtype.Elem()
if vtype.Kind() != reflect.Ptr && vtype.Kind() != reflect.Slice {
// The value type is not a message (*T) or bytes ([]byte),
// so we need encoders for the pointer to this type.
vtype = reflect.PtrTo(vtype)
}
p.mvalprop.init(vtype, "Value", f.Tag.Get("protobuf_val"), nil, lockGetProp)
} }
// precalculate tag code // precalculate tag code
@ -529,7 +591,7 @@ func (p *Properties) init(typ reflect.Type, name, tag string, f *reflect.StructF
return return
} }
p.Parse(tag) p.Parse(tag)
p.setEncAndDec(typ, lockGetProp) p.setEncAndDec(typ, f, lockGetProp)
} }
var ( var (
@ -538,7 +600,11 @@ var (
) )
// GetProperties returns the list of properties for the type represented by t. // GetProperties returns the list of properties for the type represented by t.
// t must represent a generated struct type of a protocol message.
func GetProperties(t reflect.Type) *StructProperties { func GetProperties(t reflect.Type) *StructProperties {
if t.Kind() != reflect.Struct {
panic("proto: type must have kind struct")
}
mutex.Lock() mutex.Lock()
sprop := getPropertiesLocked(t) sprop := getPropertiesLocked(t)
mutex.Unlock() mutex.Unlock()

View File

@ -0,0 +1,44 @@
# Go support for Protocol Buffers - Google's data interchange format
#
# Copyright 2014 The Go Authors. All rights reserved.
# https://github.com/golang/protobuf
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are
# met:
#
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above
# copyright notice, this list of conditions and the following disclaimer
# in the documentation and/or other materials provided with the
# distribution.
# * Neither the name of Google Inc. nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
include ../../Make.protobuf
all: regenerate
regenerate:
rm -f proto3.pb.go
make proto3.pb.go
# The following rules are just aids to development. Not needed for typical testing.
diff: regenerate
git diff proto3.pb.go

View File

@ -1,7 +1,7 @@
// Go support for Protocol Buffers - Google's data interchange format // Go support for Protocol Buffers - Google's data interchange format
// //
// Copyright 2014 The Go Authors. All rights reserved. // Copyright 2014 The Go Authors. All rights reserved.
// http://code.google.com/p/goprotobuf/ // https://github.com/golang/protobuf
// //
// Redistribution and use in source and binary forms, with or without // Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are // modification, are permitted provided that the following conditions are
@ -29,32 +29,30 @@
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
package proto_test syntax = "proto3";
import ( package proto3_proto;
"testing"
pb "./testdata" message Message {
"code.google.com/p/goprotobuf/proto" enum Humour {
) UNKNOWN = 0;
PUNS = 1;
SLAPSTICK = 2;
BILL_BAILEY = 3;
}
func TestGetExtensionsWithMissingExtensions(t *testing.T) { string name = 1;
msg := &pb.MyMessage{} Humour hilarity = 2;
ext1 := &pb.Ext{} uint32 height_in_cm = 3;
if err := proto.SetExtension(msg, pb.E_Ext_More, ext1); err != nil { bytes data = 4;
t.Fatalf("Could not set ext1: %s", ext1) int64 result_count = 7;
} bool true_scotsman = 8;
exts, err := proto.GetExtensions(msg, []*proto.ExtensionDesc{ float score = 9;
pb.E_Ext_More,
pb.E_Ext_Text, repeated uint64 key = 5;
}) Nested nested = 6;
if err != nil { }
t.Fatalf("GetExtensions() failed: %s", err)
} message Nested {
if exts[0] != ext1 { string bunny = 1;
t.Errorf("ext1 not in returned extensions: %T %v", exts[0], exts[0])
}
if exts[1] != nil {
t.Errorf("ext2 in returned extensions: %T %v", exts[1], exts[1])
}
} }

View File

@ -0,0 +1,93 @@
// Go support for Protocol Buffers - Google's data interchange format
//
// Copyright 2014 The Go Authors. All rights reserved.
// https://github.com/golang/protobuf
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are
// met:
//
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above
// copyright notice, this list of conditions and the following disclaimer
// in the documentation and/or other materials provided with the
// distribution.
// * Neither the name of Google Inc. nor the names of its
// contributors may be used to endorse or promote products derived from
// this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
package proto_test
import (
"testing"
pb "./proto3_proto"
"github.com/golang/protobuf/proto"
)
func TestProto3ZeroValues(t *testing.T) {
tests := []struct {
desc string
m proto.Message
}{
{"zero message", &pb.Message{}},
{"empty bytes field", &pb.Message{Data: []byte{}}},
}
for _, test := range tests {
b, err := proto.Marshal(test.m)
if err != nil {
t.Errorf("%s: proto.Marshal: %v", test.desc, err)
continue
}
if len(b) > 0 {
t.Errorf("%s: Encoding is non-empty: %q", test.desc, b)
}
}
}
func TestRoundTripProto3(t *testing.T) {
m := &pb.Message{
Name: "David", // (2 | 1<<3): 0x0a 0x05 "David"
Hilarity: pb.Message_PUNS, // (0 | 2<<3): 0x10 0x01
HeightInCm: 178, // (0 | 3<<3): 0x18 0xb2 0x01
Data: []byte("roboto"), // (2 | 4<<3): 0x20 0x06 "roboto"
ResultCount: 47, // (0 | 7<<3): 0x38 0x2f
TrueScotsman: true, // (0 | 8<<3): 0x40 0x01
Score: 8.1, // (5 | 9<<3): 0x4d <8.1>
Key: []uint64{1, 0xdeadbeef},
Nested: &pb.Nested{
Bunny: "Monty",
},
}
t.Logf(" m: %v", m)
b, err := proto.Marshal(m)
if err != nil {
t.Fatalf("proto.Marshal: %v", err)
}
t.Logf(" b: %q", b)
m2 := new(pb.Message)
if err := proto.Unmarshal(b, m2); err != nil {
t.Fatalf("proto.Unmarshal: %v", err)
}
t.Logf("m2: %v", m2)
if !proto.Equal(m, m2) {
t.Errorf("proto.Equal returned false:\n m: %v\nm2: %v", m, m2)
}
}

View File

@ -1,7 +1,7 @@
// Go support for Protocol Buffers - Google's data interchange format // Go support for Protocol Buffers - Google's data interchange format
// //
// Copyright 2012 The Go Authors. All rights reserved. // Copyright 2012 The Go Authors. All rights reserved.
// http://code.google.com/p/goprotobuf/ // https://github.com/golang/protobuf
// //
// Redistribution and use in source and binary forms, with or without // Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are // modification, are permitted provided that the following conditions are

View File

@ -1,7 +1,7 @@
// Go support for Protocol Buffers - Google's data interchange format // Go support for Protocol Buffers - Google's data interchange format
// //
// Copyright 2012 The Go Authors. All rights reserved. // Copyright 2012 The Go Authors. All rights reserved.
// http://code.google.com/p/goprotobuf/ // https://github.com/golang/protobuf
// //
// Redistribution and use in source and binary forms, with or without // Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are // modification, are permitted provided that the following conditions are
@ -35,8 +35,9 @@ import (
"log" "log"
"testing" "testing"
proto3pb "./proto3_proto"
pb "./testdata" pb "./testdata"
. "code.google.com/p/goprotobuf/proto" . "github.com/golang/protobuf/proto"
) )
var messageWithExtension1 = &pb.MyMessage{Count: Int32(7)} var messageWithExtension1 = &pb.MyMessage{Count: Int32(7)}
@ -102,6 +103,20 @@ var SizeTests = []struct {
{"unrecognized", &pb.MoreRepeated{XXX_unrecognized: []byte{13<<3 | 0, 4}}}, {"unrecognized", &pb.MoreRepeated{XXX_unrecognized: []byte{13<<3 | 0, 4}}},
{"extension (unencoded)", messageWithExtension1}, {"extension (unencoded)", messageWithExtension1},
{"extension (encoded)", messageWithExtension3}, {"extension (encoded)", messageWithExtension3},
// proto3 message
{"proto3 empty", &proto3pb.Message{}},
{"proto3 bool", &proto3pb.Message{TrueScotsman: true}},
{"proto3 int64", &proto3pb.Message{ResultCount: 1}},
{"proto3 uint32", &proto3pb.Message{HeightInCm: 123}},
{"proto3 float", &proto3pb.Message{Score: 12.6}},
{"proto3 string", &proto3pb.Message{Name: "Snezana"}},
{"proto3 bytes", &proto3pb.Message{Data: []byte("wowsa")}},
{"proto3 bytes, empty", &proto3pb.Message{Data: []byte{}}},
{"proto3 enum", &proto3pb.Message{Hilarity: proto3pb.Message_PUNS}},
{"map field", &pb.MessageWithMap{NameMapping: map[int32]string{1: "Rob", 7: "Andrew"}}},
{"map field with message", &pb.MessageWithMap{MsgMapping: map[int64]*pb.FloatingPoint{0x7001: &pb.FloatingPoint{F: Float64(2.0)}}}},
{"map field with bytes", &pb.MessageWithMap{ByteMapping: map[bool][]byte{true: []byte("this time for sure")}}},
} }
func TestSize(t *testing.T) { func TestSize(t *testing.T) {

View File

@ -1,7 +1,7 @@
# Go support for Protocol Buffers - Google's data interchange format # Go support for Protocol Buffers - Google's data interchange format
# #
# Copyright 2010 The Go Authors. All rights reserved. # Copyright 2010 The Go Authors. All rights reserved.
# http://code.google.com/p/goprotobuf/ # https://github.com/golang/protobuf
# #
# Redistribution and use in source and binary forms, with or without # Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are # modification, are permitted provided that the following conditions are
@ -37,11 +37,11 @@ all: regenerate
regenerate: regenerate:
rm -f test.pb.go rm -f test.pb.go
make test.pb.go make test.pb.go
# The following rules are just aids to development. Not needed for typical testing. # The following rules are just aids to development. Not needed for typical testing.
diff: regenerate diff: regenerate
hg diff test.pb.go git diff test.pb.go
restore: restore:
cp test.pb.go.golden test.pb.go cp test.pb.go.golden test.pb.go

View File

@ -1,7 +1,7 @@
// Go support for Protocol Buffers - Google's data interchange format // Go support for Protocol Buffers - Google's data interchange format
// //
// Copyright 2012 The Go Authors. All rights reserved. // Copyright 2012 The Go Authors. All rights reserved.
// http://code.google.com/p/goprotobuf/ // https://github.com/golang/protobuf
// //
// Redistribution and use in source and binary forms, with or without // Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are // modification, are permitted provided that the following conditions are

View File

@ -33,10 +33,11 @@ It has these top-level messages:
GroupOld GroupOld
GroupNew GroupNew
FloatingPoint FloatingPoint
MessageWithMap
*/ */
package testdata package testdata
import proto "code.google.com/p/goprotobuf/proto" import proto "github.com/golang/protobuf/proto"
import math "math" import math "math"
// Reference imports to suppress errors if they are not otherwise used. // Reference imports to suppress errors if they are not otherwise used.
@ -1416,6 +1417,12 @@ func (m *MyMessageSet) Marshal() ([]byte, error) {
func (m *MyMessageSet) Unmarshal(buf []byte) error { func (m *MyMessageSet) Unmarshal(buf []byte) error {
return proto.UnmarshalMessageSet(buf, m.ExtensionMap()) return proto.UnmarshalMessageSet(buf, m.ExtensionMap())
} }
func (m *MyMessageSet) MarshalJSON() ([]byte, error) {
return proto.MarshalMessageSetJSON(m.XXX_extensions)
}
func (m *MyMessageSet) UnmarshalJSON(buf []byte) error {
return proto.UnmarshalMessageSetJSON(buf, m.XXX_extensions)
}
// ensure MyMessageSet satisfies proto.Marshaler and proto.Unmarshaler // ensure MyMessageSet satisfies proto.Marshaler and proto.Unmarshaler
var _ proto.Marshaler = (*MyMessageSet)(nil) var _ proto.Marshaler = (*MyMessageSet)(nil)
@ -1879,6 +1886,38 @@ func (m *FloatingPoint) GetF() float64 {
return 0 return 0
} }
type MessageWithMap struct {
NameMapping map[int32]string `protobuf:"bytes,1,rep,name=name_mapping" json:"name_mapping,omitempty" protobuf_key:"varint,1,opt,name=key" protobuf_val:"bytes,2,opt,name=value"`
MsgMapping map[int64]*FloatingPoint `protobuf:"bytes,2,rep,name=msg_mapping" json:"msg_mapping,omitempty" protobuf_key:"zigzag64,1,opt,name=key" protobuf_val:"bytes,2,opt,name=value"`
ByteMapping map[bool][]byte `protobuf:"bytes,3,rep,name=byte_mapping" json:"byte_mapping,omitempty" protobuf_key:"varint,1,opt,name=key" protobuf_val:"bytes,2,opt,name=value"`
XXX_unrecognized []byte `json:"-"`
}
func (m *MessageWithMap) Reset() { *m = MessageWithMap{} }
func (m *MessageWithMap) String() string { return proto.CompactTextString(m) }
func (*MessageWithMap) ProtoMessage() {}
func (m *MessageWithMap) GetNameMapping() map[int32]string {
if m != nil {
return m.NameMapping
}
return nil
}
func (m *MessageWithMap) GetMsgMapping() map[int64]*FloatingPoint {
if m != nil {
return m.MsgMapping
}
return nil
}
func (m *MessageWithMap) GetByteMapping() map[bool][]byte {
if m != nil {
return m.ByteMapping
}
return nil
}
var E_Greeting = &proto.ExtensionDesc{ var E_Greeting = &proto.ExtensionDesc{
ExtendedType: (*MyMessage)(nil), ExtendedType: (*MyMessage)(nil),
ExtensionType: ([]string)(nil), ExtensionType: ([]string)(nil),

View File

@ -1,7 +1,7 @@
// Go support for Protocol Buffers - Google's data interchange format // Go support for Protocol Buffers - Google's data interchange format
// //
// Copyright 2010 The Go Authors. All rights reserved. // Copyright 2010 The Go Authors. All rights reserved.
// http://code.google.com/p/goprotobuf/ // https://github.com/golang/protobuf
// //
// Redistribution and use in source and binary forms, with or without // Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are // modification, are permitted provided that the following conditions are
@ -426,3 +426,9 @@ message GroupNew {
message FloatingPoint { message FloatingPoint {
required double f = 1; required double f = 1;
} }
message MessageWithMap {
map<int32, string> name_mapping = 1;
map<sint64, FloatingPoint> msg_mapping = 2;
map<bool, bytes> byte_mapping = 3;
}

View File

@ -1,7 +1,7 @@
// Go support for Protocol Buffers - Google's data interchange format // Go support for Protocol Buffers - Google's data interchange format
// //
// Copyright 2010 The Go Authors. All rights reserved. // Copyright 2010 The Go Authors. All rights reserved.
// http://code.google.com/p/goprotobuf/ // https://github.com/golang/protobuf
// //
// Redistribution and use in source and binary forms, with or without // Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are // modification, are permitted provided that the following conditions are
@ -36,6 +36,7 @@ package proto
import ( import (
"bufio" "bufio"
"bytes" "bytes"
"encoding"
"fmt" "fmt"
"io" "io"
"log" "log"
@ -74,13 +75,6 @@ type textWriter struct {
w writer w writer
} }
// textMarshaler is implemented by Messages that can marshal themsleves.
// It is identical to encoding.TextMarshaler, introduced in go 1.2,
// which will eventually replace it.
type textMarshaler interface {
MarshalText() (text []byte, err error)
}
func (w *textWriter) WriteString(s string) (n int, err error) { func (w *textWriter) WriteString(s string) (n int, err error) {
if !strings.Contains(s, "\n") { if !strings.Contains(s, "\n") {
if !w.compact && w.complete { if !w.compact && w.complete {
@ -250,6 +244,100 @@ func writeStruct(w *textWriter, sv reflect.Value) error {
} }
continue continue
} }
if fv.Kind() == reflect.Map {
// Map fields are rendered as a repeated struct with key/value fields.
keys := fv.MapKeys() // TODO: should we sort these for deterministic output?
sort.Sort(mapKeys(keys))
for _, key := range keys {
val := fv.MapIndex(key)
if err := writeName(w, props); err != nil {
return err
}
if !w.compact {
if err := w.WriteByte(' '); err != nil {
return err
}
}
// open struct
if err := w.WriteByte('<'); err != nil {
return err
}
if !w.compact {
if err := w.WriteByte('\n'); err != nil {
return err
}
}
w.indent()
// key
if _, err := w.WriteString("key:"); err != nil {
return err
}
if !w.compact {
if err := w.WriteByte(' '); err != nil {
return err
}
}
if err := writeAny(w, key, props.mkeyprop); err != nil {
return err
}
if err := w.WriteByte('\n'); err != nil {
return err
}
// value
if _, err := w.WriteString("value:"); err != nil {
return err
}
if !w.compact {
if err := w.WriteByte(' '); err != nil {
return err
}
}
if err := writeAny(w, val, props.mvalprop); err != nil {
return err
}
if err := w.WriteByte('\n'); err != nil {
return err
}
// close struct
w.unindent()
if err := w.WriteByte('>'); err != nil {
return err
}
if err := w.WriteByte('\n'); err != nil {
return err
}
}
continue
}
if props.proto3 && fv.Kind() == reflect.Slice && fv.Len() == 0 {
// empty bytes field
continue
}
if fv.Kind() != reflect.Ptr && fv.Kind() != reflect.Slice {
// proto3 non-repeated scalar field; skip if zero value
switch fv.Kind() {
case reflect.Bool:
if !fv.Bool() {
continue
}
case reflect.Int32, reflect.Int64:
if fv.Int() == 0 {
continue
}
case reflect.Uint32, reflect.Uint64:
if fv.Uint() == 0 {
continue
}
case reflect.Float32, reflect.Float64:
if fv.Float() == 0 {
continue
}
case reflect.String:
if fv.String() == "" {
continue
}
}
}
if err := writeName(w, props); err != nil { if err := writeName(w, props); err != nil {
return err return err
@ -358,7 +446,7 @@ func writeAny(w *textWriter, v reflect.Value, props *Properties) error {
} }
} }
w.indent() w.indent()
if tm, ok := v.Interface().(textMarshaler); ok { if tm, ok := v.Interface().(encoding.TextMarshaler); ok {
text, err := tm.MarshalText() text, err := tm.MarshalText()
if err != nil { if err != nil {
return err return err
@ -653,7 +741,7 @@ func marshalText(w io.Writer, pb Message, compact bool) error {
compact: compact, compact: compact,
} }
if tm, ok := pb.(textMarshaler); ok { if tm, ok := pb.(encoding.TextMarshaler); ok {
text, err := tm.MarshalText() text, err := tm.MarshalText()
if err != nil { if err != nil {
return err return err

View File

@ -1,7 +1,7 @@
// Go support for Protocol Buffers - Google's data interchange format // Go support for Protocol Buffers - Google's data interchange format
// //
// Copyright 2010 The Go Authors. All rights reserved. // Copyright 2010 The Go Authors. All rights reserved.
// http://code.google.com/p/goprotobuf/ // https://github.com/golang/protobuf
// //
// Redistribution and use in source and binary forms, with or without // Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are // modification, are permitted provided that the following conditions are
@ -35,6 +35,7 @@ package proto
// TODO: message sets. // TODO: message sets.
import ( import (
"encoding"
"errors" "errors"
"fmt" "fmt"
"reflect" "reflect"
@ -43,13 +44,6 @@ import (
"unicode/utf8" "unicode/utf8"
) )
// textUnmarshaler is implemented by Messages that can unmarshal themsleves.
// It is identical to encoding.TextUnmarshaler, introduced in go 1.2,
// which will eventually replace it.
type textUnmarshaler interface {
UnmarshalText(text []byte) error
}
type ParseError struct { type ParseError struct {
Message string Message string
Line int // 1-based line number Line int // 1-based line number
@ -361,8 +355,20 @@ func (p *textParser) next() *token {
return &p.cur return &p.cur
} }
// Return an error indicating which required field was not set. func (p *textParser) consumeToken(s string) error {
func (p *textParser) missingRequiredFieldError(sv reflect.Value) *ParseError { tok := p.next()
if tok.err != nil {
return tok.err
}
if tok.value != s {
p.back()
return p.errorf("expected %q, found %q", s, tok.value)
}
return nil
}
// Return a RequiredNotSetError indicating which required field was not set.
func (p *textParser) missingRequiredFieldError(sv reflect.Value) *RequiredNotSetError {
st := sv.Type() st := sv.Type()
sprops := GetProperties(st) sprops := GetProperties(st)
for i := 0; i < st.NumField(); i++ { for i := 0; i < st.NumField(); i++ {
@ -372,10 +378,10 @@ func (p *textParser) missingRequiredFieldError(sv reflect.Value) *ParseError {
props := sprops.Prop[i] props := sprops.Prop[i]
if props.Required { if props.Required {
return p.errorf("message %v missing required field %q", st, props.OrigName) return &RequiredNotSetError{fmt.Sprintf("%v.%v", st, props.OrigName)}
} }
} }
return p.errorf("message %v missing required field", st) // should not happen return &RequiredNotSetError{fmt.Sprintf("%v.<unknown field name>", st)} // should not happen
} }
// Returns the index in the struct for the named field, as well as the parsed tag properties. // Returns the index in the struct for the named field, as well as the parsed tag properties.
@ -415,6 +421,10 @@ func (p *textParser) checkForColon(props *Properties, typ reflect.Type) *ParseEr
if typ.Elem().Kind() != reflect.Ptr { if typ.Elem().Kind() != reflect.Ptr {
break break
} }
} else if typ.Kind() == reflect.String {
// The proto3 exception is for a string field,
// which requires a colon.
break
} }
needColon = false needColon = false
} }
@ -426,9 +436,11 @@ func (p *textParser) checkForColon(props *Properties, typ reflect.Type) *ParseEr
return nil return nil
} }
func (p *textParser) readStruct(sv reflect.Value, terminator string) *ParseError { func (p *textParser) readStruct(sv reflect.Value, terminator string) error {
st := sv.Type() st := sv.Type()
reqCount := GetProperties(st).reqCount reqCount := GetProperties(st).reqCount
var reqFieldErr error
fieldSet := make(map[string]bool)
// A struct is a sequence of "name: value", terminated by one of // A struct is a sequence of "name: value", terminated by one of
// '>' or '}', or the end of the input. A name may also be // '>' or '}', or the end of the input. A name may also be
// "[extension]". // "[extension]".
@ -489,7 +501,10 @@ func (p *textParser) readStruct(sv reflect.Value, terminator string) *ParseError
ext = reflect.New(typ.Elem()).Elem() ext = reflect.New(typ.Elem()).Elem()
} }
if err := p.readAny(ext, props); err != nil { if err := p.readAny(ext, props); err != nil {
return err if _, ok := err.(*RequiredNotSetError); !ok {
return err
}
reqFieldErr = err
} }
ep := sv.Addr().Interface().(extendableProto) ep := sv.Addr().Interface().(extendableProto)
if !rep { if !rep {
@ -507,17 +522,71 @@ func (p *textParser) readStruct(sv reflect.Value, terminator string) *ParseError
} }
} else { } else {
// This is a normal, non-extension field. // This is a normal, non-extension field.
fi, props, ok := structFieldByName(st, tok.value) name := tok.value
fi, props, ok := structFieldByName(st, name)
if !ok { if !ok {
return p.errorf("unknown field name %q in %v", tok.value, st) return p.errorf("unknown field name %q in %v", name, st)
} }
dst := sv.Field(fi) dst := sv.Field(fi)
isDstNil := isNil(dst)
if dst.Kind() == reflect.Map {
// Consume any colon.
if err := p.checkForColon(props, dst.Type()); err != nil {
return err
}
// Construct the map if it doesn't already exist.
if dst.IsNil() {
dst.Set(reflect.MakeMap(dst.Type()))
}
key := reflect.New(dst.Type().Key()).Elem()
val := reflect.New(dst.Type().Elem()).Elem()
// The map entry should be this sequence of tokens:
// < key : KEY value : VALUE >
// Technically the "key" and "value" could come in any order,
// but in practice they won't.
tok := p.next()
var terminator string
switch tok.value {
case "<":
terminator = ">"
case "{":
terminator = "}"
default:
return p.errorf("expected '{' or '<', found %q", tok.value)
}
if err := p.consumeToken("key"); err != nil {
return err
}
if err := p.consumeToken(":"); err != nil {
return err
}
if err := p.readAny(key, props.mkeyprop); err != nil {
return err
}
if err := p.consumeToken("value"); err != nil {
return err
}
if err := p.checkForColon(props.mvalprop, dst.Type().Elem()); err != nil {
return err
}
if err := p.readAny(val, props.mvalprop); err != nil {
return err
}
if err := p.consumeToken(terminator); err != nil {
return err
}
dst.SetMapIndex(key, val)
continue
}
// Check that it's not already set if it's not a repeated field. // Check that it's not already set if it's not a repeated field.
if !props.Repeated && !isDstNil { if !props.Repeated && fieldSet[name] {
return p.errorf("non-repeated field %q was repeated", tok.value) return p.errorf("non-repeated field %q was repeated", name)
} }
if err := p.checkForColon(props, st.Field(fi).Type); err != nil { if err := p.checkForColon(props, st.Field(fi).Type); err != nil {
@ -525,11 +594,13 @@ func (p *textParser) readStruct(sv reflect.Value, terminator string) *ParseError
} }
// Parse into the field. // Parse into the field.
fieldSet[name] = true
if err := p.readAny(dst, props); err != nil { if err := p.readAny(dst, props); err != nil {
return err if _, ok := err.(*RequiredNotSetError); !ok {
} return err
}
if props.Required { reqFieldErr = err
} else if props.Required {
reqCount-- reqCount--
} }
} }
@ -547,10 +618,10 @@ func (p *textParser) readStruct(sv reflect.Value, terminator string) *ParseError
if reqCount > 0 { if reqCount > 0 {
return p.missingRequiredFieldError(sv) return p.missingRequiredFieldError(sv)
} }
return nil return reqFieldErr
} }
func (p *textParser) readAny(v reflect.Value, props *Properties) *ParseError { func (p *textParser) readAny(v reflect.Value, props *Properties) error {
tok := p.next() tok := p.next()
if tok.err != nil { if tok.err != nil {
return tok.err return tok.err
@ -652,7 +723,7 @@ func (p *textParser) readAny(v reflect.Value, props *Properties) *ParseError {
default: default:
return p.errorf("expected '{' or '<', found %q", tok.value) return p.errorf("expected '{' or '<', found %q", tok.value)
} }
// TODO: Handle nested messages which implement textUnmarshaler. // TODO: Handle nested messages which implement encoding.TextUnmarshaler.
return p.readStruct(fv, terminator) return p.readStruct(fv, terminator)
case reflect.Uint32: case reflect.Uint32:
if x, err := strconv.ParseUint(tok.value, 0, 32); err == nil { if x, err := strconv.ParseUint(tok.value, 0, 32); err == nil {
@ -670,8 +741,10 @@ func (p *textParser) readAny(v reflect.Value, props *Properties) *ParseError {
// UnmarshalText reads a protocol buffer in Text format. UnmarshalText resets pb // UnmarshalText reads a protocol buffer in Text format. UnmarshalText resets pb
// before starting to unmarshal, so any existing data in pb is always removed. // before starting to unmarshal, so any existing data in pb is always removed.
// If a required field is not set and no other error occurs,
// UnmarshalText returns *RequiredNotSetError.
func UnmarshalText(s string, pb Message) error { func UnmarshalText(s string, pb Message) error {
if um, ok := pb.(textUnmarshaler); ok { if um, ok := pb.(encoding.TextUnmarshaler); ok {
err := um.UnmarshalText([]byte(s)) err := um.UnmarshalText([]byte(s))
return err return err
} }

View File

@ -1,7 +1,7 @@
// Go support for Protocol Buffers - Google's data interchange format // Go support for Protocol Buffers - Google's data interchange format
// //
// Copyright 2010 The Go Authors. All rights reserved. // Copyright 2010 The Go Authors. All rights reserved.
// http://code.google.com/p/goprotobuf/ // https://github.com/golang/protobuf
// //
// Redistribution and use in source and binary forms, with or without // Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are // modification, are permitted provided that the following conditions are
@ -36,8 +36,9 @@ import (
"reflect" "reflect"
"testing" "testing"
proto3pb "./proto3_proto"
. "./testdata" . "./testdata"
. "code.google.com/p/goprotobuf/proto" . "github.com/golang/protobuf/proto"
) )
type UnmarshalTextTest struct { type UnmarshalTextTest struct {
@ -294,8 +295,11 @@ var unMarshalTextTests = []UnmarshalTextTest{
// Missing required field // Missing required field
{ {
in: ``, in: `name: "Pawel"`,
err: `line 1.0: message testdata.MyMessage missing required field "count"`, err: `proto: required field "testdata.MyMessage.count" not set`,
out: &MyMessage{
Name: String("Pawel"),
},
}, },
// Repeated non-repeated field // Repeated non-repeated field
@ -408,6 +412,9 @@ func TestUnmarshalText(t *testing.T) {
} else if err.Error() != test.err { } else if err.Error() != test.err {
t.Errorf("Test %d: Incorrect error.\nHave: %v\nWant: %v", t.Errorf("Test %d: Incorrect error.\nHave: %v\nWant: %v",
i, err.Error(), test.err) i, err.Error(), test.err)
} else if _, ok := err.(*RequiredNotSetError); ok && test.out != nil && !reflect.DeepEqual(pb, test.out) {
t.Errorf("Test %d: Incorrect populated \nHave: %v\nWant: %v",
i, pb, test.out)
} }
} }
} }
@ -437,6 +444,48 @@ func TestRepeatedEnum(t *testing.T) {
} }
} }
func TestProto3TextParsing(t *testing.T) {
m := new(proto3pb.Message)
const in = `name: "Wallace" true_scotsman: true`
want := &proto3pb.Message{
Name: "Wallace",
TrueScotsman: true,
}
if err := UnmarshalText(in, m); err != nil {
t.Fatal(err)
}
if !Equal(m, want) {
t.Errorf("\n got %v\nwant %v", m, want)
}
}
func TestMapParsing(t *testing.T) {
m := new(MessageWithMap)
const in = `name_mapping:<key:1234 value:"Feist"> name_mapping:<key:1 value:"Beatles">` +
`msg_mapping:<key:-4 value:<f: 2.0>>` +
`msg_mapping<key:-2 value<f: 4.0>>` + // no colon after "value"
`byte_mapping:<key:true value:"so be it">`
want := &MessageWithMap{
NameMapping: map[int32]string{
1: "Beatles",
1234: "Feist",
},
MsgMapping: map[int64]*FloatingPoint{
-4: {F: Float64(2.0)},
-2: {F: Float64(4.0)},
},
ByteMapping: map[bool][]byte{
true: []byte("so be it"),
},
}
if err := UnmarshalText(in, m); err != nil {
t.Fatal(err)
}
if !Equal(m, want) {
t.Errorf("\n got %v\nwant %v", m, want)
}
}
var benchInput string var benchInput string
func init() { func init() {

View File

@ -1,7 +1,7 @@
// Go support for Protocol Buffers - Google's data interchange format // Go support for Protocol Buffers - Google's data interchange format
// //
// Copyright 2010 The Go Authors. All rights reserved. // Copyright 2010 The Go Authors. All rights reserved.
// http://code.google.com/p/goprotobuf/ // https://github.com/golang/protobuf
// //
// Redistribution and use in source and binary forms, with or without // Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are // modification, are permitted provided that the following conditions are
@ -39,8 +39,9 @@ import (
"strings" "strings"
"testing" "testing"
"code.google.com/p/goprotobuf/proto" "github.com/golang/protobuf/proto"
proto3pb "./proto3_proto"
pb "./testdata" pb "./testdata"
) )
@ -406,3 +407,30 @@ Message <nil>
t.Errorf(" got: %s\nwant: %s", s, want) t.Errorf(" got: %s\nwant: %s", s, want)
} }
} }
func TestProto3Text(t *testing.T) {
tests := []struct {
m proto.Message
want string
}{
// zero message
{&proto3pb.Message{}, ``},
// zero message except for an empty byte slice
{&proto3pb.Message{Data: []byte{}}, ``},
// trivial case
{&proto3pb.Message{Name: "Rob", HeightInCm: 175}, `name:"Rob" height_in_cm:175`},
// empty map
{&pb.MessageWithMap{}, ``},
// non-empty map; current map format is the same as a repeated struct
{
&pb.MessageWithMap{NameMapping: map[int32]string{1234: "Feist"}},
`name_mapping:<key:1234 value:"Feist" >`,
},
}
for _, test := range tests {
got := strings.TrimSpace(test.m.String())
if got != test.want {
t.Errorf("\n got %s\nwant %s", got, test.want)
}
}
}

View File

@ -21,8 +21,8 @@ import (
"testing" "testing"
"testing/quick" "testing/quick"
. "code.google.com/p/goprotobuf/proto" . "github.com/golang/protobuf/proto"
. "code.google.com/p/goprotobuf/proto/testdata" . "github.com/golang/protobuf/proto/testdata"
) )
func TestWriteDelimited(t *testing.T) { func TestWriteDelimited(t *testing.T) {

View File

@ -19,7 +19,7 @@ import (
"errors" "errors"
"io" "io"
"code.google.com/p/goprotobuf/proto" "github.com/golang/protobuf/proto"
) )
var errInvalidVarint = errors.New("invalid varint32 encountered") var errInvalidVarint = errors.New("invalid varint32 encountered")

View File

@ -18,7 +18,7 @@ import (
"encoding/binary" "encoding/binary"
"io" "io"
"code.google.com/p/goprotobuf/proto" "github.com/golang/protobuf/proto"
) )
// WriteDelimited encodes and dumps a message to the provided writer prefixed // WriteDelimited encodes and dumps a message to the provided writer prefixed

View File

@ -1,5 +1,5 @@
// Copyright 2010 The Go Authors. All rights reserved. // Copyright 2010 The Go Authors. All rights reserved.
// http://code.google.com/p/goprotobuf/ // http://github.com/golang/protobuf/
// //
// Redistribution and use in source and binary forms, with or without // Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are // modification, are permitted provided that the following conditions are
@ -30,11 +30,11 @@
package ext package ext
import ( import (
. "code.google.com/p/goprotobuf/proto" . "github.com/golang/protobuf/proto"
. "code.google.com/p/goprotobuf/proto/testdata" . "github.com/golang/protobuf/proto/testdata"
) )
// FROM https://code.google.com/p/goprotobuf/source/browse/proto/all_test.go. // FROM https://github.com/golang/protobuf/blob/master/proto/all_test.go.
func initGoTestField() *GoTestField { func initGoTestField() *GoTestField {
f := new(GoTestField) f := new(GoTestField)

View File

@ -16,4 +16,6 @@ script:
- GOARCH=$BUILD_GOARCH GOOS=$BUILD_GOOS go build - GOARCH=$BUILD_GOARCH GOOS=$BUILD_GOOS go build
# only test on linux # only test on linux
- if [ $BUILD_GOOS == "linux" ]; then GOARCH=$BUILD_GOARCH GOOS=$BUILD_GOOS go test -bench=.; fi # also specify -short; the crypto tests fail in weird ways *sometimes*
# See issue #151
- if [ $BUILD_GOOS == "linux" ]; then GOARCH=$BUILD_GOARCH GOOS=$BUILD_GOOS go test -short -bench=.; fi

View File

@ -34,6 +34,7 @@ A not-so-up-to-date-list-that-may-be-actually-current:
* https://github.com/skynetservices/skydns * https://github.com/skynetservices/skydns
* https://github.com/DevelopersPL/godnsagent * https://github.com/DevelopersPL/godnsagent
* https://github.com/duedil-ltd/discodns * https://github.com/duedil-ltd/discodns
* https://github.com/StalkR/dns-reverse-proxy
Send pull request if you want to be listed here. Send pull request if you want to be listed here.
@ -49,7 +50,7 @@ Send pull request if you want to be listed here.
* DNSSEC: signing, validating and key generation for DSA, RSA and ECDSA; * DNSSEC: signing, validating and key generation for DSA, RSA and ECDSA;
* EDNS0, NSID; * EDNS0, NSID;
* AXFR/IXFR; * AXFR/IXFR;
* TSIG; * TSIG, SIG(0);
* DNS name compression; * DNS name compression;
* Depends only on the standard library. * Depends only on the standard library.
@ -95,7 +96,8 @@ Example programs can be found in the `github.com/miekg/exdns` repository.
* 3225 - DO bit (DNSSEC OK) * 3225 - DO bit (DNSSEC OK)
* 340{1,2,3} - NAPTR record * 340{1,2,3} - NAPTR record
* 3445 - Limiting the scope of (DNS)KEY * 3445 - Limiting the scope of (DNS)KEY
* 3597 - Unkown RRs * 3597 - Unknown RRs
* 4025 - IPSECKEY
* 403{3,4,5} - DNSSEC + validation functions * 403{3,4,5} - DNSSEC + validation functions
* 4255 - SSHFP record * 4255 - SSHFP record
* 4343 - Case insensitivity * 4343 - Case insensitivity
@ -112,6 +114,7 @@ Example programs can be found in the `github.com/miekg/exdns` repository.
* 5936 - AXFR * 5936 - AXFR
* 5966 - TCP implementation recommendations * 5966 - TCP implementation recommendations
* 6605 - ECDSA * 6605 - ECDSA
* 6725 - IANA Registry Update
* 6742 - ILNP DNS * 6742 - ILNP DNS
* 6891 - EDNS0 update * 6891 - EDNS0 update
* 6895 - DNS IANA considerations * 6895 - DNS IANA considerations
@ -131,10 +134,8 @@ Example programs can be found in the `github.com/miekg/exdns` repository.
## TODO ## TODO
* privatekey.Precompute() when signing? * privatekey.Precompute() when signing?
* Last remaining RRs: APL, ATMA, A6 and NXT; * Last remaining RRs: APL, ATMA, A6 and NXT and IPSECKEY;
* Missing in parsing: ISDN, UNSPEC, ATMA; * Missing in parsing: ISDN, UNSPEC, ATMA;
* CAA parsing is broken; * CAA parsing is broken;
* NSEC(3) cover/match/closest enclose; * NSEC(3) cover/match/closest enclose;
* Replies with TC bit are not parsed to the end; * Replies with TC bit are not parsed to the end;
* SIG(0);
* Create IsMsg to validate a message before fully parsing it.

View File

@ -9,7 +9,7 @@ import (
"time" "time"
) )
const dnsTimeout time.Duration = 2 * 1e9 const dnsTimeout time.Duration = 2 * time.Second
const tcpIdleTimeout time.Duration = 8 * time.Second const tcpIdleTimeout time.Duration = 8 * time.Second
// A Conn represents a connection to a DNS server. // A Conn represents a connection to a DNS server.
@ -26,9 +26,9 @@ type Conn struct {
type Client struct { type Client struct {
Net string // if "tcp" a TCP query will be initiated, otherwise an UDP one (default is "" for UDP) Net string // if "tcp" a TCP query will be initiated, otherwise an UDP one (default is "" for UDP)
UDPSize uint16 // minimum receive buffer for UDP messages UDPSize uint16 // minimum receive buffer for UDP messages
DialTimeout time.Duration // net.DialTimeout (ns), defaults to 2 * 1e9 DialTimeout time.Duration // net.DialTimeout, defaults to 2 seconds
ReadTimeout time.Duration // net.Conn.SetReadTimeout value for connections (ns), defaults to 2 * 1e9 ReadTimeout time.Duration // net.Conn.SetReadTimeout value for connections, defaults to 2 seconds
WriteTimeout time.Duration // net.Conn.SetWriteTimeout value for connections (ns), defaults to 2 * 1e9 WriteTimeout time.Duration // net.Conn.SetWriteTimeout value for connections, defaults to 2 seconds
TsigSecret map[string]string // secret(s) for Tsig map[<zonename>]<base64 secret>, zonename must be fully qualified TsigSecret map[string]string // secret(s) for Tsig map[<zonename>]<base64 secret>, zonename must be fully qualified
SingleInflight bool // if true suppress multiple outstanding queries for the same Qname, Qtype and Qclass SingleInflight bool // if true suppress multiple outstanding queries for the same Qname, Qtype and Qclass
group singleflight group singleflight

View File

@ -118,54 +118,7 @@ Loop:
} }
} }
/* // ExampleUpdateLeaseTSIG shows how to update a lease signed with TSIG.
func TestClientTsigAXFR(t *testing.T) {
m := new(Msg)
m.SetAxfr("example.nl.")
m.SetTsig("axfr.", HmacMD5, 300, time.Now().Unix())
tr := new(Transfer)
tr.TsigSecret = map[string]string{"axfr.": "so6ZGir4GPAqINNh9U5c3A=="}
if a, err := tr.In(m, "176.58.119.54:53"); err != nil {
t.Log("failed to setup axfr: " + err.Error())
t.Fatal()
} else {
for ex := range a {
if ex.Error != nil {
t.Logf("error %s\n", ex.Error.Error())
t.Fail()
break
}
for _, rr := range ex.RR {
t.Logf("%s\n", rr.String())
}
}
}
}
func TestClientAXFRMultipleEnvelopes(t *testing.T) {
m := new(Msg)
m.SetAxfr("nlnetlabs.nl.")
tr := new(Transfer)
if a, err := tr.In(m, "213.154.224.1:53"); err != nil {
t.Log("Failed to setup axfr" + err.Error())
t.Fail()
return
} else {
for ex := range a {
if ex.Error != nil {
t.Logf("Error %s\n", ex.Error.Error())
t.Fail()
break
}
}
}
}
*/
// ExapleUpdateLeaseTSIG shows how to update a lease signed with TSIG.
func ExampleUpdateLeaseTSIG(t *testing.T) { func ExampleUpdateLeaseTSIG(t *testing.T) {
m := new(Msg) m := new(Msg)
m.SetUpdate("t.local.ip6.io.") m.SetUpdate("t.local.ip6.io.")

View File

@ -26,14 +26,19 @@ func ClientConfigFromFile(resolvconf string) (*ClientConfig, error) {
} }
defer file.Close() defer file.Close()
c := new(ClientConfig) c := new(ClientConfig)
b := bufio.NewReader(file) scanner := bufio.NewScanner(file)
c.Servers = make([]string, 0) c.Servers = make([]string, 0)
c.Search = make([]string, 0) c.Search = make([]string, 0)
c.Port = "53" c.Port = "53"
c.Ndots = 1 c.Ndots = 1
c.Timeout = 5 c.Timeout = 5
c.Attempts = 2 c.Attempts = 2
for line, ok := b.ReadString('\n'); ok == nil; line, ok = b.ReadString('\n') {
for scanner.Scan() {
if err := scanner.Err(); err != nil {
return nil, err
}
line := scanner.Text()
f := strings.Fields(line) f := strings.Fields(line)
if len(f) < 1 { if len(f) < 1 {
continue continue

View File

@ -0,0 +1,55 @@
package dns
import (
"io/ioutil"
"os"
"path/filepath"
"testing"
)
const normal string = `
# Comment
domain somedomain.com
nameserver 10.28.10.2
nameserver 11.28.10.1
`
const missingNewline string = `
domain somedomain.com
nameserver 10.28.10.2
nameserver 11.28.10.1` // <- NOTE: NO newline.
func testConfig(t *testing.T, data string) {
tempDir, err := ioutil.TempDir("", "")
if err != nil {
t.Fatalf("TempDir: %v", err)
}
defer os.RemoveAll(tempDir)
path := filepath.Join(tempDir, "resolv.conf")
if err := ioutil.WriteFile(path, []byte(data), 0644); err != nil {
t.Fatalf("WriteFile: %v", err)
}
cc, err := ClientConfigFromFile(path)
if err != nil {
t.Errorf("error parsing resolv.conf: %s", err)
}
if l := len(cc.Servers); l != 2 {
t.Errorf("incorrect number of nameservers detected: %d", l)
}
if l := len(cc.Search); l != 1 {
t.Errorf("domain directive not parsed correctly: %v", cc.Search)
} else {
if cc.Search[0] != "somedomain.com" {
t.Errorf("domain is unexpected: %v", cc.Search[0])
}
}
}
func TestNameserver(t *testing.T) {
testConfig(t, normal)
}
func TestMissingFinalNewLine(t *testing.T) {
testConfig(t, missingNewline)
}

View File

@ -73,13 +73,15 @@ func (dns *Msg) SetUpdate(z string) *Msg {
} }
// SetIxfr creates message for requesting an IXFR. // SetIxfr creates message for requesting an IXFR.
func (dns *Msg) SetIxfr(z string, serial uint32) *Msg { func (dns *Msg) SetIxfr(z string, serial uint32, ns, mbox string) *Msg {
dns.Id = Id() dns.Id = Id()
dns.Question = make([]Question, 1) dns.Question = make([]Question, 1)
dns.Ns = make([]RR, 1) dns.Ns = make([]RR, 1)
s := new(SOA) s := new(SOA)
s.Hdr = RR_Header{z, TypeSOA, ClassINET, defaultTtl, 0} s.Hdr = RR_Header{z, TypeSOA, ClassINET, defaultTtl, 0}
s.Serial = serial s.Serial = serial
s.Ns = ns
s.Mbox = mbox
dns.Question[0] = Question{z, TypeIXFR, ClassINET} dns.Question[0] = Question{z, TypeIXFR, ClassINET}
dns.Ns[0] = s dns.Ns[0] = s
return dns return dns
@ -169,7 +171,7 @@ func IsMsg(buf []byte) error {
return errors.New("dns: bad message header") return errors.New("dns: bad message header")
} }
// Header: Opcode // Header: Opcode
// TODO(miek): more checks here, e.g. check all header bits.
return nil return nil
} }

View File

@ -93,9 +93,7 @@
// spaces, semicolons and the at symbol are escaped. // spaces, semicolons and the at symbol are escaped.
package dns package dns
import ( import "strconv"
"strconv"
)
const ( const (
year68 = 1 << 31 // For RFC1982 (Serial Arithmetic) calculations in 32 bits. year68 = 1 << 31 // For RFC1982 (Serial Arithmetic) calculations in 32 bits.
@ -124,7 +122,7 @@ type RR interface {
String() string String() string
// copy returns a copy of the RR // copy returns a copy of the RR
copy() RR copy() RR
// len returns the length (in octects) of the uncompressed RR in wire format. // len returns the length (in octets) of the uncompressed RR in wire format.
len() int len() int
} }

View File

@ -426,11 +426,21 @@ func BenchmarkUnpackDomainName(b *testing.B) {
} }
} }
func BenchmarkUnpackDomainNameUnprintable(b *testing.B) {
name1 := "\x02\x02\x02\x025\x02\x02\x02\x02.12345678.123."
buf := make([]byte, len(name1)+1)
_, _ = PackDomainName(name1, buf, 0, nil, false)
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, _, _ = UnpackDomainName(buf, 0)
}
}
func TestToRFC3597(t *testing.T) { func TestToRFC3597(t *testing.T) {
a, _ := NewRR("miek.nl. IN A 10.0.1.1") a, _ := NewRR("miek.nl. IN A 10.0.1.1")
x := new(RFC3597) x := new(RFC3597)
x.ToRFC3597(a) x.ToRFC3597(a)
if x.String() != `miek.nl. 3600 IN A \# 4 0a000101` { if x.String() != `miek.nl. 3600 CLASS1 TYPE1 \# 4 0a000101` {
t.Fail() t.Fail()
} }
} }
@ -499,3 +509,72 @@ func TestCopy(t *testing.T) {
t.Fatalf("Copy() failed %s != %s", rr.String(), rr1.String()) t.Fatalf("Copy() failed %s != %s", rr.String(), rr1.String())
} }
} }
func TestMsgCopy(t *testing.T) {
m := new(Msg)
m.SetQuestion("miek.nl.", TypeA)
rr, _ := NewRR("miek.nl. 2311 IN A 127.0.0.1")
m.Answer = []RR{rr}
rr, _ = NewRR("miek.nl. 2311 IN NS 127.0.0.1")
m.Ns = []RR{rr}
m1 := m.Copy()
if m.String() != m1.String() {
t.Fatalf("Msg.Copy() failed %s != %s", m.String(), m1.String())
}
m1.Answer[0], _ = NewRR("somethingelse.nl. 2311 IN A 127.0.0.1")
if m.String() == m1.String() {
t.Fatalf("Msg.Copy() failed; change to copy changed template %s", m.String())
}
rr, _ = NewRR("miek.nl. 2311 IN A 127.0.0.2")
m1.Answer = append(m1.Answer, rr)
if m1.Ns[0].String() == m1.Answer[1].String() {
t.Fatalf("Msg.Copy() failed; append changed underlying array %s", m1.Ns[0].String())
}
}
func BenchmarkCopy(b *testing.B) {
b.ReportAllocs()
m := new(Msg)
m.SetQuestion("miek.nl.", TypeA)
rr, _ := NewRR("miek.nl. 2311 IN A 127.0.0.1")
m.Answer = []RR{rr}
rr, _ = NewRR("miek.nl. 2311 IN NS 127.0.0.1")
m.Ns = []RR{rr}
rr, _ = NewRR("miek.nl. 2311 IN A 127.0.0.1")
m.Extra = []RR{rr}
b.ResetTimer()
for i := 0; i < b.N; i++ {
m.Copy()
}
}
func TestPackIPSECKEY(t *testing.T) {
tests := []string{
"38.2.0.192.in-addr.arpa. 7200 IN IPSECKEY ( 10 1 2 192.0.2.38 AQNRU3mG7TVTO2BkR47usntb102uFJtugbo6BSGvgqt4AQ== )",
"38.2.0.192.in-addr.arpa. 7200 IN IPSECKEY ( 10 0 2 . AQNRU3mG7TVTO2BkR47usntb102uFJtugbo6BSGvgqt4AQ== )",
"38.2.0.192.in-addr.arpa. 7200 IN IPSECKEY ( 10 1 2 192.0.2.3 AQNRU3mG7TVTO2BkR47usntb102uFJtugbo6BSGvgqt4AQ== )",
"38.1.0.192.in-addr.arpa. 7200 IN IPSECKEY ( 10 3 2 mygateway.example.com. AQNRU3mG7TVTO2BkR47usntb102uFJtugbo6BSGvgqt4AQ== )",
"0.d.4.0.3.0.e.f.f.f.3.f.0.1.2.0 7200 IN IPSECKEY ( 10 2 2 2001:0DB8:0:8002::2000:1 AQNRU3mG7TVTO2BkR47usntb102uFJtugbo6BSGvgqt4AQ== )",
}
buf := make([]byte, 1024)
for _, t1 := range tests {
rr, _ := NewRR(t1)
off, e := PackRR(rr, buf, 0, nil, false)
if e != nil {
t.Logf("failed to pack IPSECKEY %s: %s\n", e, t1)
t.Fail()
continue
}
rr, _, e = UnpackRR(buf[:off], 0)
if e != nil {
t.Logf("failed to unpack IPSECKEY %s: %s\n", e, t1)
t.Fail()
}
t.Logf("%s\n", rr)
}
}

View File

@ -20,7 +20,6 @@ import (
"crypto/ecdsa" "crypto/ecdsa"
"crypto/elliptic" "crypto/elliptic"
"crypto/md5" "crypto/md5"
"crypto/rand"
"crypto/rsa" "crypto/rsa"
"crypto/sha1" "crypto/sha1"
"crypto/sha256" "crypto/sha256"
@ -36,31 +35,34 @@ import (
// DNSSEC encryption algorithm codes. // DNSSEC encryption algorithm codes.
const ( const (
RSAMD5 = 1 _ uint8 = iota
DH = 2 RSAMD5
DSA = 3 DH
ECC = 4 DSA
RSASHA1 = 5 _ // Skip 4, RFC 6725, section 2.1
DSANSEC3SHA1 = 6 RSASHA1
RSASHA1NSEC3SHA1 = 7 DSANSEC3SHA1
RSASHA256 = 8 RSASHA1NSEC3SHA1
RSASHA512 = 10 RSASHA256
ECCGOST = 12 _ // Skip 9, RFC 6725, section 2.1
ECDSAP256SHA256 = 13 RSASHA512
ECDSAP384SHA384 = 14 _ // Skip 11, RFC 6725, section 2.1
INDIRECT = 252 ECCGOST
PRIVATEDNS = 253 // Private (experimental keys) ECDSAP256SHA256
PRIVATEOID = 254 ECDSAP384SHA384
INDIRECT uint8 = 252
PRIVATEDNS uint8 = 253 // Private (experimental keys)
PRIVATEOID uint8 = 254
) )
// DNSSEC hashing algorithm codes. // DNSSEC hashing algorithm codes.
const ( const (
_ = iota _ uint8 = iota
SHA1 // RFC 4034 SHA1 // RFC 4034
SHA256 // RFC 4509 SHA256 // RFC 4509
GOST94 // RFC 5933 GOST94 // RFC 5933
SHA384 // Experimental SHA384 // Experimental
SHA512 // Experimental SHA512 // Experimental
) )
// DNSKEY flag values. // DNSKEY flag values.
@ -94,6 +96,10 @@ type dnskeyWireFmt struct {
/* Nothing is left out */ /* Nothing is left out */
} }
func divRoundUp(a, b int) int {
return (a + b - 1) / b
}
// KeyTag calculates the keytag (or key-id) of the DNSKEY. // KeyTag calculates the keytag (or key-id) of the DNSKEY.
func (k *DNSKEY) KeyTag() uint16 { func (k *DNSKEY) KeyTag() uint16 {
if k == nil { if k == nil {
@ -105,7 +111,7 @@ func (k *DNSKEY) KeyTag() uint16 {
// Look at the bottom two bytes of the modules, which the last // Look at the bottom two bytes of the modules, which the last
// item in the pubkey. We could do this faster by looking directly // item in the pubkey. We could do this faster by looking directly
// at the base64 values. But I'm lazy. // at the base64 values. But I'm lazy.
modulus, _ := packBase64([]byte(k.PublicKey)) modulus, _ := fromBase64([]byte(k.PublicKey))
if len(modulus) > 1 { if len(modulus) > 1 {
x, _ := unpackUint16(modulus, len(modulus)-2) x, _ := unpackUint16(modulus, len(modulus)-2)
keytag = int(x) keytag = int(x)
@ -136,7 +142,7 @@ func (k *DNSKEY) KeyTag() uint16 {
} }
// ToDS converts a DNSKEY record to a DS record. // ToDS converts a DNSKEY record to a DS record.
func (k *DNSKEY) ToDS(h int) *DS { func (k *DNSKEY) ToDS(h uint8) *DS {
if k == nil { if k == nil {
return nil return nil
} }
@ -146,7 +152,7 @@ func (k *DNSKEY) ToDS(h int) *DS {
ds.Hdr.Rrtype = TypeDS ds.Hdr.Rrtype = TypeDS
ds.Hdr.Ttl = k.Hdr.Ttl ds.Hdr.Ttl = k.Hdr.Ttl
ds.Algorithm = k.Algorithm ds.Algorithm = k.Algorithm
ds.DigestType = uint8(h) ds.DigestType = h
ds.KeyTag = k.KeyTag() ds.KeyTag = k.KeyTag()
keywire := new(dnskeyWireFmt) keywire := new(dnskeyWireFmt)
@ -249,60 +255,36 @@ func (rr *RRSIG) Sign(k PrivateKey, rrset []RR) error {
} }
signdata = append(signdata, wire...) signdata = append(signdata, wire...)
var sighash []byte
var h hash.Hash var h hash.Hash
var ch crypto.Hash // Only need for RSA
switch rr.Algorithm { switch rr.Algorithm {
case DSA, DSANSEC3SHA1: case DSA, DSANSEC3SHA1:
// Implicit in the ParameterSizes // TODO: this seems bugged, will panic
case RSASHA1, RSASHA1NSEC3SHA1: case RSASHA1, RSASHA1NSEC3SHA1:
h = sha1.New() h = sha1.New()
ch = crypto.SHA1
case RSASHA256, ECDSAP256SHA256: case RSASHA256, ECDSAP256SHA256:
h = sha256.New() h = sha256.New()
ch = crypto.SHA256
case ECDSAP384SHA384: case ECDSAP384SHA384:
h = sha512.New384() h = sha512.New384()
case RSASHA512: case RSASHA512:
h = sha512.New() h = sha512.New()
ch = crypto.SHA512
case RSAMD5: case RSAMD5:
fallthrough // Deprecated in RFC 6725 fallthrough // Deprecated in RFC 6725
default: default:
return ErrAlg return ErrAlg
} }
io.WriteString(h, string(signdata))
sighash = h.Sum(nil)
switch p := k.(type) { _, err = h.Write(signdata)
case *dsa.PrivateKey: if err != nil {
r1, s1, err := dsa.Sign(rand.Reader, p, sighash) return err
if err != nil {
return err
}
signature := []byte{0x4D} // T value, here the ASCII M for Miek (not used in DNSSEC)
signature = append(signature, r1.Bytes()...)
signature = append(signature, s1.Bytes()...)
rr.Signature = unpackBase64(signature)
case *rsa.PrivateKey:
// We can use nil as rand.Reader here (says AGL)
signature, err := rsa.SignPKCS1v15(nil, p, ch, sighash)
if err != nil {
return err
}
rr.Signature = unpackBase64(signature)
case *ecdsa.PrivateKey:
r1, s1, err := ecdsa.Sign(rand.Reader, p, sighash)
if err != nil {
return err
}
signature := r1.Bytes()
signature = append(signature, s1.Bytes()...)
rr.Signature = unpackBase64(signature)
default:
// Not given the correct key
return ErrKeyAlg
} }
sighash := h.Sum(nil)
signature, err := k.Sign(sighash, rr.Algorithm)
if err != nil {
return err
}
rr.Signature = toBase64(signature)
return nil return nil
} }
@ -395,7 +377,7 @@ func (rr *RRSIG) Verify(k *DNSKEY, rrset []RR) error {
sighash := h.Sum(nil) sighash := h.Sum(nil)
return rsa.VerifyPKCS1v15(pubkey, ch, sighash, sigbuf) return rsa.VerifyPKCS1v15(pubkey, ch, sighash, sigbuf)
case ECDSAP256SHA256, ECDSAP384SHA384: case ECDSAP256SHA256, ECDSAP384SHA384:
pubkey := k.publicKeyCurve() pubkey := k.publicKeyECDSA()
if pubkey == nil { if pubkey == nil {
return ErrKey return ErrKey
} }
@ -441,41 +423,16 @@ func (rr *RRSIG) ValidityPeriod(t time.Time) bool {
// Return the signatures base64 encodedig sigdata as a byte slice. // Return the signatures base64 encodedig sigdata as a byte slice.
func (s *RRSIG) sigBuf() []byte { func (s *RRSIG) sigBuf() []byte {
sigbuf, err := packBase64([]byte(s.Signature)) sigbuf, err := fromBase64([]byte(s.Signature))
if err != nil { if err != nil {
return nil return nil
} }
return sigbuf return sigbuf
} }
// setPublicKeyInPrivate sets the public key in the private key.
func (k *DNSKEY) setPublicKeyInPrivate(p PrivateKey) bool {
switch t := p.(type) {
case *dsa.PrivateKey:
x := k.publicKeyDSA()
if x == nil {
return false
}
t.PublicKey = *x
case *rsa.PrivateKey:
x := k.publicKeyRSA()
if x == nil {
return false
}
t.PublicKey = *x
case *ecdsa.PrivateKey:
x := k.publicKeyCurve()
if x == nil {
return false
}
t.PublicKey = *x
}
return true
}
// publicKeyRSA returns the RSA public key from a DNSKEY record. // publicKeyRSA returns the RSA public key from a DNSKEY record.
func (k *DNSKEY) publicKeyRSA() *rsa.PublicKey { func (k *DNSKEY) publicKeyRSA() *rsa.PublicKey {
keybuf, err := packBase64([]byte(k.PublicKey)) keybuf, err := fromBase64([]byte(k.PublicKey))
if err != nil { if err != nil {
return nil return nil
} }
@ -511,9 +468,9 @@ func (k *DNSKEY) publicKeyRSA() *rsa.PublicKey {
return pubkey return pubkey
} }
// publicKeyCurve returns the Curve public key from the DNSKEY record. // publicKeyECDSA returns the Curve public key from the DNSKEY record.
func (k *DNSKEY) publicKeyCurve() *ecdsa.PublicKey { func (k *DNSKEY) publicKeyECDSA() *ecdsa.PublicKey {
keybuf, err := packBase64([]byte(k.PublicKey)) keybuf, err := fromBase64([]byte(k.PublicKey))
if err != nil { if err != nil {
return nil return nil
} }
@ -540,97 +497,29 @@ func (k *DNSKEY) publicKeyCurve() *ecdsa.PublicKey {
} }
func (k *DNSKEY) publicKeyDSA() *dsa.PublicKey { func (k *DNSKEY) publicKeyDSA() *dsa.PublicKey {
keybuf, err := packBase64([]byte(k.PublicKey)) keybuf, err := fromBase64([]byte(k.PublicKey))
if err != nil { if err != nil {
return nil return nil
} }
if len(keybuf) < 22 { // TODO: check if len(keybuf) < 22 {
return nil return nil
} }
t := int(keybuf[0]) t, keybuf := int(keybuf[0]), keybuf[1:]
size := 64 + t*8 size := 64 + t*8
q, keybuf := keybuf[:20], keybuf[20:]
if len(keybuf) != 3*size {
return nil
}
p, keybuf := keybuf[:size], keybuf[size:]
g, y := keybuf[:size], keybuf[size:]
pubkey := new(dsa.PublicKey) pubkey := new(dsa.PublicKey)
pubkey.Parameters.Q = big.NewInt(0) pubkey.Parameters.Q = big.NewInt(0).SetBytes(q)
pubkey.Parameters.Q.SetBytes(keybuf[1:21]) // +/- 1 ? pubkey.Parameters.P = big.NewInt(0).SetBytes(p)
pubkey.Parameters.P = big.NewInt(0) pubkey.Parameters.G = big.NewInt(0).SetBytes(g)
pubkey.Parameters.P.SetBytes(keybuf[22 : 22+size]) pubkey.Y = big.NewInt(0).SetBytes(y)
pubkey.Parameters.G = big.NewInt(0)
pubkey.Parameters.G.SetBytes(keybuf[22+size+1 : 22+size*2])
pubkey.Y = big.NewInt(0)
pubkey.Y.SetBytes(keybuf[22+size*2+1 : 22+size*3])
return pubkey return pubkey
} }
// Set the public key (the value E and N)
func (k *DNSKEY) setPublicKeyRSA(_E int, _N *big.Int) bool {
if _E == 0 || _N == nil {
return false
}
buf := exponentToBuf(_E)
buf = append(buf, _N.Bytes()...)
k.PublicKey = unpackBase64(buf)
return true
}
// Set the public key for Elliptic Curves
func (k *DNSKEY) setPublicKeyCurve(_X, _Y *big.Int) bool {
if _X == nil || _Y == nil {
return false
}
buf := curveToBuf(_X, _Y)
// Check the length of the buffer, either 64 or 92 bytes
k.PublicKey = unpackBase64(buf)
return true
}
// Set the public key for DSA
func (k *DNSKEY) setPublicKeyDSA(_Q, _P, _G, _Y *big.Int) bool {
if _Q == nil || _P == nil || _G == nil || _Y == nil {
return false
}
buf := dsaToBuf(_Q, _P, _G, _Y)
k.PublicKey = unpackBase64(buf)
return true
}
// Set the public key (the values E and N) for RSA
// RFC 3110: Section 2. RSA Public KEY Resource Records
func exponentToBuf(_E int) []byte {
var buf []byte
i := big.NewInt(int64(_E))
if len(i.Bytes()) < 256 {
buf = make([]byte, 1)
buf[0] = uint8(len(i.Bytes()))
} else {
buf = make([]byte, 3)
buf[0] = 0
buf[1] = uint8(len(i.Bytes()) >> 8)
buf[2] = uint8(len(i.Bytes()))
}
buf = append(buf, i.Bytes()...)
return buf
}
// Set the public key for X and Y for Curve. The two
// values are just concatenated.
func curveToBuf(_X, _Y *big.Int) []byte {
buf := _X.Bytes()
buf = append(buf, _Y.Bytes()...)
return buf
}
// Set the public key for X and Y for Curve. The two
// values are just concatenated.
func dsaToBuf(_Q, _P, _G, _Y *big.Int) []byte {
t := byte((len(_G.Bytes()) - 64) / 8)
buf := []byte{t}
buf = append(buf, _Q.Bytes()...)
buf = append(buf, _P.Bytes()...)
buf = append(buf, _G.Bytes()...)
buf = append(buf, _Y.Bytes()...)
return buf
}
type wireSlice [][]byte type wireSlice [][]byte
func (p wireSlice) Len() int { return len(p) } func (p wireSlice) Len() int { return len(p) }

View File

@ -0,0 +1,155 @@
package dns
import (
"crypto/dsa"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/rsa"
"math/big"
)
// Generate generates a DNSKEY of the given bit size.
// The public part is put inside the DNSKEY record.
// The Algorithm in the key must be set as this will define
// what kind of DNSKEY will be generated.
// The ECDSA algorithms imply a fixed keysize, in that case
// bits should be set to the size of the algorithm.
func (r *DNSKEY) Generate(bits int) (PrivateKey, error) {
switch r.Algorithm {
case DSA, DSANSEC3SHA1:
if bits != 1024 {
return nil, ErrKeySize
}
case RSAMD5, RSASHA1, RSASHA256, RSASHA1NSEC3SHA1:
if bits < 512 || bits > 4096 {
return nil, ErrKeySize
}
case RSASHA512:
if bits < 1024 || bits > 4096 {
return nil, ErrKeySize
}
case ECDSAP256SHA256:
if bits != 256 {
return nil, ErrKeySize
}
case ECDSAP384SHA384:
if bits != 384 {
return nil, ErrKeySize
}
}
switch r.Algorithm {
case DSA, DSANSEC3SHA1:
params := new(dsa.Parameters)
if err := dsa.GenerateParameters(params, rand.Reader, dsa.L1024N160); err != nil {
return nil, err
}
priv := new(dsa.PrivateKey)
priv.PublicKey.Parameters = *params
err := dsa.GenerateKey(priv, rand.Reader)
if err != nil {
return nil, err
}
r.setPublicKeyDSA(params.Q, params.P, params.G, priv.PublicKey.Y)
return (*DSAPrivateKey)(priv), nil
case RSAMD5, RSASHA1, RSASHA256, RSASHA512, RSASHA1NSEC3SHA1:
priv, err := rsa.GenerateKey(rand.Reader, bits)
if err != nil {
return nil, err
}
r.setPublicKeyRSA(priv.PublicKey.E, priv.PublicKey.N)
return (*RSAPrivateKey)(priv), nil
case ECDSAP256SHA256, ECDSAP384SHA384:
var c elliptic.Curve
switch r.Algorithm {
case ECDSAP256SHA256:
c = elliptic.P256()
case ECDSAP384SHA384:
c = elliptic.P384()
}
priv, err := ecdsa.GenerateKey(c, rand.Reader)
if err != nil {
return nil, err
}
r.setPublicKeyECDSA(priv.PublicKey.X, priv.PublicKey.Y)
return (*ECDSAPrivateKey)(priv), nil
default:
return nil, ErrAlg
}
}
// Set the public key (the value E and N)
func (k *DNSKEY) setPublicKeyRSA(_E int, _N *big.Int) bool {
if _E == 0 || _N == nil {
return false
}
buf := exponentToBuf(_E)
buf = append(buf, _N.Bytes()...)
k.PublicKey = toBase64(buf)
return true
}
// Set the public key for Elliptic Curves
func (k *DNSKEY) setPublicKeyECDSA(_X, _Y *big.Int) bool {
if _X == nil || _Y == nil {
return false
}
var intlen int
switch k.Algorithm {
case ECDSAP256SHA256:
intlen = 32
case ECDSAP384SHA384:
intlen = 48
}
k.PublicKey = toBase64(curveToBuf(_X, _Y, intlen))
return true
}
// Set the public key for DSA
func (k *DNSKEY) setPublicKeyDSA(_Q, _P, _G, _Y *big.Int) bool {
if _Q == nil || _P == nil || _G == nil || _Y == nil {
return false
}
buf := dsaToBuf(_Q, _P, _G, _Y)
k.PublicKey = toBase64(buf)
return true
}
// Set the public key (the values E and N) for RSA
// RFC 3110: Section 2. RSA Public KEY Resource Records
func exponentToBuf(_E int) []byte {
var buf []byte
i := big.NewInt(int64(_E))
if len(i.Bytes()) < 256 {
buf = make([]byte, 1)
buf[0] = uint8(len(i.Bytes()))
} else {
buf = make([]byte, 3)
buf[0] = 0
buf[1] = uint8(len(i.Bytes()) >> 8)
buf[2] = uint8(len(i.Bytes()))
}
buf = append(buf, i.Bytes()...)
return buf
}
// Set the public key for X and Y for Curve. The two
// values are just concatenated.
func curveToBuf(_X, _Y *big.Int, intlen int) []byte {
buf := intToBytes(_X, intlen)
buf = append(buf, intToBytes(_Y, intlen)...)
return buf
}
// Set the public key for X and Y for Curve. The two
// values are just concatenated.
func dsaToBuf(_Q, _P, _G, _Y *big.Int) []byte {
t := divRoundUp(divRoundUp(_G.BitLen(), 8)-64, 8)
buf := []byte{byte(t)}
buf = append(buf, intToBytes(_Q, 20)...)
buf = append(buf, intToBytes(_P, 64+t*8)...)
buf = append(buf, intToBytes(_G, 64+t*8)...)
buf = append(buf, intToBytes(_Y, 64+t*8)...)
return buf
}

View File

@ -18,8 +18,8 @@ func (k *DNSKEY) NewPrivateKey(s string) (PrivateKey, error) {
// ReadPrivateKey reads a private key from the io.Reader q. The string file is // ReadPrivateKey reads a private key from the io.Reader q. The string file is
// only used in error reporting. // only used in error reporting.
// The public key must be // The public key must be known, because some cryptographic algorithms embed
// known, because some cryptographic algorithms embed the public inside the privatekey. // the public inside the privatekey.
func (k *DNSKEY) ReadPrivateKey(q io.Reader, file string) (PrivateKey, error) { func (k *DNSKEY) ReadPrivateKey(q io.Reader, file string) (PrivateKey, error) {
m, e := parseKey(q, file) m, e := parseKey(q, file)
if m == nil { if m == nil {
@ -34,14 +34,16 @@ func (k *DNSKEY) ReadPrivateKey(q io.Reader, file string) (PrivateKey, error) {
// TODO(mg): check if the pubkey matches the private key // TODO(mg): check if the pubkey matches the private key
switch m["algorithm"] { switch m["algorithm"] {
case "3 (DSA)": case "3 (DSA)":
p, e := readPrivateKeyDSA(m) priv, e := readPrivateKeyDSA(m)
if e != nil { if e != nil {
return nil, e return nil, e
} }
if !k.setPublicKeyInPrivate(p) { pub := k.publicKeyDSA()
return nil, ErrPrivKey if pub == nil {
return nil, ErrKey
} }
return p, e priv.PublicKey = *pub
return (*DSAPrivateKey)(priv), e
case "1 (RSAMD5)": case "1 (RSAMD5)":
fallthrough fallthrough
case "5 (RSASHA1)": case "5 (RSASHA1)":
@ -51,44 +53,44 @@ func (k *DNSKEY) ReadPrivateKey(q io.Reader, file string) (PrivateKey, error) {
case "8 (RSASHA256)": case "8 (RSASHA256)":
fallthrough fallthrough
case "10 (RSASHA512)": case "10 (RSASHA512)":
p, e := readPrivateKeyRSA(m) priv, e := readPrivateKeyRSA(m)
if e != nil { if e != nil {
return nil, e return nil, e
} }
if !k.setPublicKeyInPrivate(p) { pub := k.publicKeyRSA()
return nil, ErrPrivKey if pub == nil {
return nil, ErrKey
} }
return p, e priv.PublicKey = *pub
return (*RSAPrivateKey)(priv), e
case "12 (ECC-GOST)": case "12 (ECC-GOST)":
p, e := readPrivateKeyGOST(m) return nil, ErrPrivKey
if e != nil {
return nil, e
}
// setPublicKeyInPrivate(p)
return p, e
case "13 (ECDSAP256SHA256)": case "13 (ECDSAP256SHA256)":
fallthrough fallthrough
case "14 (ECDSAP384SHA384)": case "14 (ECDSAP384SHA384)":
p, e := readPrivateKeyECDSA(m) priv, e := readPrivateKeyECDSA(m)
if e != nil { if e != nil {
return nil, e return nil, e
} }
if !k.setPublicKeyInPrivate(p) { pub := k.publicKeyECDSA()
return nil, ErrPrivKey if pub == nil {
return nil, ErrKey
} }
return p, e priv.PublicKey = *pub
return (*ECDSAPrivateKey)(priv), e
default:
return nil, ErrPrivKey
} }
return nil, ErrPrivKey
} }
// Read a private key (file) string and create a public key. Return the private key. // Read a private key (file) string and create a public key. Return the private key.
func readPrivateKeyRSA(m map[string]string) (PrivateKey, error) { func readPrivateKeyRSA(m map[string]string) (*rsa.PrivateKey, error) {
p := new(rsa.PrivateKey) p := new(rsa.PrivateKey)
p.Primes = []*big.Int{nil, nil} p.Primes = []*big.Int{nil, nil}
for k, v := range m { for k, v := range m {
switch k { switch k {
case "modulus", "publicexponent", "privateexponent", "prime1", "prime2": case "modulus", "publicexponent", "privateexponent", "prime1", "prime2":
v1, err := packBase64([]byte(v)) v1, err := fromBase64([]byte(v))
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -119,13 +121,13 @@ func readPrivateKeyRSA(m map[string]string) (PrivateKey, error) {
return p, nil return p, nil
} }
func readPrivateKeyDSA(m map[string]string) (PrivateKey, error) { func readPrivateKeyDSA(m map[string]string) (*dsa.PrivateKey, error) {
p := new(dsa.PrivateKey) p := new(dsa.PrivateKey)
p.X = big.NewInt(0) p.X = big.NewInt(0)
for k, v := range m { for k, v := range m {
switch k { switch k {
case "private_value(x)": case "private_value(x)":
v1, err := packBase64([]byte(v)) v1, err := fromBase64([]byte(v))
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -137,14 +139,14 @@ func readPrivateKeyDSA(m map[string]string) (PrivateKey, error) {
return p, nil return p, nil
} }
func readPrivateKeyECDSA(m map[string]string) (PrivateKey, error) { func readPrivateKeyECDSA(m map[string]string) (*ecdsa.PrivateKey, error) {
p := new(ecdsa.PrivateKey) p := new(ecdsa.PrivateKey)
p.D = big.NewInt(0) p.D = big.NewInt(0)
// TODO: validate that the required flags are present // TODO: validate that the required flags are present
for k, v := range m { for k, v := range m {
switch k { switch k {
case "privatekey": case "privatekey":
v1, err := packBase64([]byte(v)) v1, err := fromBase64([]byte(v))
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -156,11 +158,6 @@ func readPrivateKeyECDSA(m map[string]string) (PrivateKey, error) {
return p, nil return p, nil
} }
func readPrivateKeyGOST(m map[string]string) (PrivateKey, error) {
// TODO(miek)
return nil, nil
}
// parseKey reads a private key from r. It returns a map[string]string, // parseKey reads a private key from r. It returns a map[string]string,
// with the key-value pairs, or an error when the file is not correct. // with the key-value pairs, or an error when the file is not correct.
func parseKey(r io.Reader, file string) (map[string]string, error) { func parseKey(r io.Reader, file string) (map[string]string, error) {

View File

@ -0,0 +1,143 @@
package dns
import (
"crypto"
"crypto/dsa"
"crypto/ecdsa"
"crypto/rand"
"crypto/rsa"
"math/big"
"strconv"
)
const _FORMAT = "Private-key-format: v1.3\n"
type PrivateKey interface {
Sign([]byte, uint8) ([]byte, error)
String(uint8) string
}
// PrivateKeyString converts a PrivateKey to a string. This string has the same
// format as the private-key-file of BIND9 (Private-key-format: v1.3).
// It needs some info from the key (the algorithm), so its a method of the
// DNSKEY and calls PrivateKey.String(alg).
func (r *DNSKEY) PrivateKeyString(p PrivateKey) string {
return p.String(r.Algorithm)
}
type RSAPrivateKey rsa.PrivateKey
func (p *RSAPrivateKey) Sign(hashed []byte, alg uint8) ([]byte, error) {
var hash crypto.Hash
switch alg {
case RSASHA1, RSASHA1NSEC3SHA1:
hash = crypto.SHA1
case RSASHA256:
hash = crypto.SHA256
case RSASHA512:
hash = crypto.SHA512
default:
return nil, ErrAlg
}
return rsa.SignPKCS1v15(nil, (*rsa.PrivateKey)(p), hash, hashed)
}
func (p *RSAPrivateKey) String(alg uint8) string {
algorithm := strconv.Itoa(int(alg)) + " (" + AlgorithmToString[alg] + ")"
modulus := toBase64(p.PublicKey.N.Bytes())
e := big.NewInt(int64(p.PublicKey.E))
publicExponent := toBase64(e.Bytes())
privateExponent := toBase64(p.D.Bytes())
prime1 := toBase64(p.Primes[0].Bytes())
prime2 := toBase64(p.Primes[1].Bytes())
// Calculate Exponent1/2 and Coefficient as per: http://en.wikipedia.org/wiki/RSA#Using_the_Chinese_remainder_algorithm
// and from: http://code.google.com/p/go/issues/detail?id=987
one := big.NewInt(1)
p_1 := big.NewInt(0).Sub(p.Primes[0], one)
q_1 := big.NewInt(0).Sub(p.Primes[1], one)
exp1 := big.NewInt(0).Mod(p.D, p_1)
exp2 := big.NewInt(0).Mod(p.D, q_1)
coeff := big.NewInt(0).ModInverse(p.Primes[1], p.Primes[0])
exponent1 := toBase64(exp1.Bytes())
exponent2 := toBase64(exp2.Bytes())
coefficient := toBase64(coeff.Bytes())
return _FORMAT +
"Algorithm: " + algorithm + "\n" +
"Modulus: " + modulus + "\n" +
"PublicExponent: " + publicExponent + "\n" +
"PrivateExponent: " + privateExponent + "\n" +
"Prime1: " + prime1 + "\n" +
"Prime2: " + prime2 + "\n" +
"Exponent1: " + exponent1 + "\n" +
"Exponent2: " + exponent2 + "\n" +
"Coefficient: " + coefficient + "\n"
}
type ECDSAPrivateKey ecdsa.PrivateKey
func (p *ECDSAPrivateKey) Sign(hashed []byte, alg uint8) ([]byte, error) {
var intlen int
switch alg {
case ECDSAP256SHA256:
intlen = 32
case ECDSAP384SHA384:
intlen = 48
default:
return nil, ErrAlg
}
r1, s1, err := ecdsa.Sign(rand.Reader, (*ecdsa.PrivateKey)(p), hashed)
if err != nil {
return nil, err
}
signature := intToBytes(r1, intlen)
signature = append(signature, intToBytes(s1, intlen)...)
return signature, nil
}
func (p *ECDSAPrivateKey) String(alg uint8) string {
algorithm := strconv.Itoa(int(alg)) + " (" + AlgorithmToString[alg] + ")"
var intlen int
switch alg {
case ECDSAP256SHA256:
intlen = 32
case ECDSAP384SHA384:
intlen = 48
}
private := toBase64(intToBytes(p.D, intlen))
return _FORMAT +
"Algorithm: " + algorithm + "\n" +
"PrivateKey: " + private + "\n"
}
type DSAPrivateKey dsa.PrivateKey
func (p *DSAPrivateKey) Sign(hashed []byte, alg uint8) ([]byte, error) {
r1, s1, err := dsa.Sign(rand.Reader, (*dsa.PrivateKey)(p), hashed)
if err != nil {
return nil, err
}
t := divRoundUp(divRoundUp(p.PublicKey.Y.BitLen(), 8)-64, 8)
signature := []byte{byte(t)}
signature = append(signature, intToBytes(r1, 20)...)
signature = append(signature, intToBytes(s1, 20)...)
return signature, nil
}
func (p *DSAPrivateKey) String(alg uint8) string {
algorithm := strconv.Itoa(int(alg)) + " (" + AlgorithmToString[alg] + ")"
T := divRoundUp(divRoundUp(p.PublicKey.Parameters.G.BitLen(), 8)-64, 8)
prime := toBase64(intToBytes(p.PublicKey.Parameters.P, 64+T*8))
subprime := toBase64(intToBytes(p.PublicKey.Parameters.Q, 20))
base := toBase64(intToBytes(p.PublicKey.Parameters.G, 64+T*8))
priv := toBase64(intToBytes(p.X, 20))
pub := toBase64(intToBytes(p.PublicKey.Y, 64+T*8))
return _FORMAT +
"Algorithm: " + algorithm + "\n" +
"Prime(p): " + prime + "\n" +
"Subprime(q): " + subprime + "\n" +
"Base(g): " + base + "\n" +
"Private_value(x): " + priv + "\n" +
"Public_value(y): " + pub + "\n"
}

View File

@ -1,7 +1,6 @@
package dns package dns
import ( import (
"crypto/rsa"
"reflect" "reflect"
"strings" "strings"
"testing" "testing"
@ -34,6 +33,9 @@ func getSoa() *SOA {
} }
func TestGenerateEC(t *testing.T) { func TestGenerateEC(t *testing.T) {
if testing.Short() {
t.Skip("skipping test in short mode.")
}
key := new(DNSKEY) key := new(DNSKEY)
key.Hdr.Rrtype = TypeDNSKEY key.Hdr.Rrtype = TypeDNSKEY
key.Hdr.Name = "miek.nl." key.Hdr.Name = "miek.nl."
@ -48,6 +50,9 @@ func TestGenerateEC(t *testing.T) {
} }
func TestGenerateDSA(t *testing.T) { func TestGenerateDSA(t *testing.T) {
if testing.Short() {
t.Skip("skipping test in short mode.")
}
key := new(DNSKEY) key := new(DNSKEY)
key.Hdr.Rrtype = TypeDNSKEY key.Hdr.Rrtype = TypeDNSKEY
key.Hdr.Name = "miek.nl." key.Hdr.Name = "miek.nl."
@ -62,6 +67,9 @@ func TestGenerateDSA(t *testing.T) {
} }
func TestGenerateRSA(t *testing.T) { func TestGenerateRSA(t *testing.T) {
if testing.Short() {
t.Skip("skipping test in short mode.")
}
key := new(DNSKEY) key := new(DNSKEY)
key.Hdr.Rrtype = TypeDNSKEY key.Hdr.Rrtype = TypeDNSKEY
key.Hdr.Name = "miek.nl." key.Hdr.Name = "miek.nl."
@ -240,12 +248,13 @@ func Test65534(t *testing.T) {
} }
func TestDnskey(t *testing.T) { func TestDnskey(t *testing.T) {
// f, _ := os.Open("t/Kmiek.nl.+010+05240.key") pubkey, err := ReadRR(strings.NewReader(`
pubkey, _ := ReadRR(strings.NewReader(`
miek.nl. IN DNSKEY 256 3 10 AwEAAZuMCu2FdugHkTrXYgl5qixvcDw1aDDlvL46/xJKbHBAHY16fNUb2b65cwko2Js/aJxUYJbZk5dwCDZxYfrfbZVtDPQuc3o8QaChVxC7/JYz2AHc9qHvqQ1j4VrH71RWINlQo6VYjzN/BGpMhOZoZOEwzp1HfsOE3lNYcoWU1smL ;{id = 5240 (zsk), size = 1024b} miek.nl. IN DNSKEY 256 3 10 AwEAAZuMCu2FdugHkTrXYgl5qixvcDw1aDDlvL46/xJKbHBAHY16fNUb2b65cwko2Js/aJxUYJbZk5dwCDZxYfrfbZVtDPQuc3o8QaChVxC7/JYz2AHc9qHvqQ1j4VrH71RWINlQo6VYjzN/BGpMhOZoZOEwzp1HfsOE3lNYcoWU1smL ;{id = 5240 (zsk), size = 1024b}
`), "Kmiek.nl.+010+05240.key") `), "Kmiek.nl.+010+05240.key")
privkey, _ := pubkey.(*DNSKEY).ReadPrivateKey(strings.NewReader(` if err != nil {
Private-key-format: v1.2 t.Fatal(err)
}
privStr := `Private-key-format: v1.3
Algorithm: 10 (RSASHA512) Algorithm: 10 (RSASHA512)
Modulus: m4wK7YV26AeROtdiCXmqLG9wPDVoMOW8vjr/EkpscEAdjXp81RvZvrlzCSjYmz9onFRgltmTl3AINnFh+t9tlW0M9C5zejxBoKFXELv8ljPYAdz2oe+pDWPhWsfvVFYg2VCjpViPM38EakyE5mhk4TDOnUd+w4TeU1hyhZTWyYs= Modulus: m4wK7YV26AeROtdiCXmqLG9wPDVoMOW8vjr/EkpscEAdjXp81RvZvrlzCSjYmz9onFRgltmTl3AINnFh+t9tlW0M9C5zejxBoKFXELv8ljPYAdz2oe+pDWPhWsfvVFYg2VCjpViPM38EakyE5mhk4TDOnUd+w4TeU1hyhZTWyYs=
PublicExponent: AQAB PublicExponent: AQAB
@ -255,13 +264,21 @@ Prime2: xA1bF8M0RTIQ6+A11AoVG6GIR/aPGg5sogRkIZ7ID/sF6g9HMVU/CM2TqVEBJLRPp73cv6Ze
Exponent1: xzkblyZ96bGYxTVZm2/vHMOXswod4KWIyMoOepK6B/ZPcZoIT6omLCgtypWtwHLfqyCz3MK51Nc0G2EGzg8rFQ== Exponent1: xzkblyZ96bGYxTVZm2/vHMOXswod4KWIyMoOepK6B/ZPcZoIT6omLCgtypWtwHLfqyCz3MK51Nc0G2EGzg8rFQ==
Exponent2: Pu5+mCEb7T5F+kFNZhQadHUklt0JUHbi3hsEvVoHpEGSw3BGDQrtIflDde0/rbWHgDPM4WQY+hscd8UuTXrvLw== Exponent2: Pu5+mCEb7T5F+kFNZhQadHUklt0JUHbi3hsEvVoHpEGSw3BGDQrtIflDde0/rbWHgDPM4WQY+hscd8UuTXrvLw==
Coefficient: UuRoNqe7YHnKmQzE6iDWKTMIWTuoqqrFAmXPmKQnC+Y+BQzOVEHUo9bXdDnoI9hzXP1gf8zENMYwYLeWpuYlFQ== Coefficient: UuRoNqe7YHnKmQzE6iDWKTMIWTuoqqrFAmXPmKQnC+Y+BQzOVEHUo9bXdDnoI9hzXP1gf8zENMYwYLeWpuYlFQ==
`), "Kmiek.nl.+010+05240.private") `
privkey, err := pubkey.(*DNSKEY).ReadPrivateKey(strings.NewReader(privStr),
"Kmiek.nl.+010+05240.private")
if err != nil {
t.Fatal(err)
}
if pubkey.(*DNSKEY).PublicKey != "AwEAAZuMCu2FdugHkTrXYgl5qixvcDw1aDDlvL46/xJKbHBAHY16fNUb2b65cwko2Js/aJxUYJbZk5dwCDZxYfrfbZVtDPQuc3o8QaChVxC7/JYz2AHc9qHvqQ1j4VrH71RWINlQo6VYjzN/BGpMhOZoZOEwzp1HfsOE3lNYcoWU1smL" { if pubkey.(*DNSKEY).PublicKey != "AwEAAZuMCu2FdugHkTrXYgl5qixvcDw1aDDlvL46/xJKbHBAHY16fNUb2b65cwko2Js/aJxUYJbZk5dwCDZxYfrfbZVtDPQuc3o8QaChVxC7/JYz2AHc9qHvqQ1j4VrH71RWINlQo6VYjzN/BGpMhOZoZOEwzp1HfsOE3lNYcoWU1smL" {
t.Log("pubkey is not what we've read") t.Log("pubkey is not what we've read")
t.Fail() t.Fail()
} }
// Coefficient looks fishy... if pubkey.(*DNSKEY).PrivateKeyString(privkey) != privStr {
t.Logf("%s", pubkey.(*DNSKEY).PrivateKeyString(privkey)) t.Log("privkey is not what we've read")
t.Logf("%v", pubkey.(*DNSKEY).PrivateKeyString(privkey))
t.Fail()
}
} }
func TestTag(t *testing.T) { func TestTag(t *testing.T) {
@ -283,6 +300,9 @@ func TestTag(t *testing.T) {
} }
func TestKeyRSA(t *testing.T) { func TestKeyRSA(t *testing.T) {
if testing.Short() {
t.Skip("skipping test in short mode.")
}
key := new(DNSKEY) key := new(DNSKEY)
key.Hdr.Name = "miek.nl." key.Hdr.Name = "miek.nl."
key.Hdr.Rrtype = TypeDNSKEY key.Hdr.Rrtype = TypeDNSKEY
@ -368,7 +388,7 @@ Activate: 20110302104537`
t.Fail() t.Fail()
} }
switch priv := p.(type) { switch priv := p.(type) {
case *rsa.PrivateKey: case *RSAPrivateKey:
if 65537 != priv.PublicKey.E { if 65537 != priv.PublicKey.E {
t.Log("exponenent should be 65537") t.Log("exponenent should be 65537")
t.Fail() t.Fail()
@ -443,12 +463,19 @@ PrivateKey: WURgWHCcYIYUPWgeLmiPY2DJJk02vgrmTfitxgqcL4vwW7BOrbawVmVe0d9V94SR`
sig.SignerName = eckey.(*DNSKEY).Hdr.Name sig.SignerName = eckey.(*DNSKEY).Hdr.Name
sig.Algorithm = eckey.(*DNSKEY).Algorithm sig.Algorithm = eckey.(*DNSKEY).Algorithm
sig.Sign(privkey, []RR{a}) if sig.Sign(privkey, []RR{a}) != nil {
t.Fatal("failure to sign the record")
}
t.Logf("%s", sig.String())
if e := sig.Verify(eckey.(*DNSKEY), []RR{a}); e != nil { if e := sig.Verify(eckey.(*DNSKEY), []RR{a}); e != nil {
t.Logf("failure to validate: %s", e.Error()) t.Logf("\n%s\n%s\n%s\n\n%s\n\n",
t.Fail() eckey.(*DNSKEY).String(),
a.String(),
sig.String(),
eckey.(*DNSKEY).PrivateKeyString(privkey),
)
t.Fatalf("failure to validate: %s", e.Error())
} }
} }
@ -491,6 +518,13 @@ func TestSignVerifyECDSA2(t *testing.T) {
err = sig.Verify(key, []RR{srv}) err = sig.Verify(key, []RR{srv})
if err != nil { if err != nil {
t.Logf("\n%s\n%s\n%s\n\n%s\n\n",
key.String(),
srv.String(),
sig.String(),
key.PrivateKeyString(privkey),
)
t.Fatal("Failure to validate:", err) t.Fatal("Failure to validate:", err)
} }
} }

View File

@ -287,7 +287,7 @@ func (e *EDNS0_SUBNET) unpack(b []byte) error {
case 1: case 1:
addr := make([]byte, 4) addr := make([]byte, 4)
for i := 0; i < int(e.SourceNetmask/8); i++ { for i := 0; i < int(e.SourceNetmask/8); i++ {
if 4+i > len(b) { if i >= len(addr) || 4+i >= len(b) {
return ErrBuf return ErrBuf
} }
addr[i] = b[4+i] addr[i] = b[4+i]
@ -296,7 +296,7 @@ func (e *EDNS0_SUBNET) unpack(b []byte) error {
case 2: case 2:
addr := make([]byte, 16) addr := make([]byte, 16)
for i := 0; i < int(e.SourceNetmask/8); i++ { for i := 0; i < int(e.SourceNetmask/8); i++ {
if 4+i > len(b) { if i >= len(addr) || 4+i >= len(b) {
return ErrBuf return ErrBuf
} }
addr[i] = b[4+i] addr[i] = b[4+i]

View File

@ -49,7 +49,7 @@ func ExampleDS(zone string) {
} }
for _, k := range r.Answer { for _, k := range r.Answer {
if key, ok := k.(*dns.DNSKEY); ok { if key, ok := k.(*dns.DNSKEY); ok {
for _, alg := range []int{dns.SHA1, dns.SHA256, dns.SHA384} { for _, alg := range []uint8{dns.SHA1, dns.SHA256, dns.SHA384} {
fmt.Printf("%s; %d\n", key.ToDS(alg).String(), key.Flags) fmt.Printf("%s; %d\n", key.ToDS(alg).String(), key.Flags)
} }
} }

96
Godeps/_workspace/src/github.com/miekg/dns/format.go generated vendored Normal file
View File

@ -0,0 +1,96 @@
package dns
import (
"net"
"reflect"
"strconv"
)
// NumField returns the number of rdata fields r has.
func NumField(r RR) int {
return reflect.ValueOf(r).Elem().NumField() - 1 // Remove RR_Header
}
// Field returns the rdata field i as a string. Fields are indexed starting from 1.
// RR types that holds slice data, for instance the NSEC type bitmap will return a single
// string where the types are concatenated using a space.
// Accessing non existing fields will cause a panic.
func Field(r RR, i int) string {
if i == 0 {
return ""
}
d := reflect.ValueOf(r).Elem().Field(i)
switch k := d.Kind(); k {
case reflect.String:
return d.String()
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return strconv.FormatInt(d.Int(), 10)
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
return strconv.FormatUint(d.Uint(), 10)
case reflect.Slice:
switch reflect.ValueOf(r).Elem().Type().Field(i).Tag {
case `dns:"a"`:
// TODO(miek): Hmm store this as 16 bytes
if d.Len() < net.IPv6len {
return net.IPv4(byte(d.Index(0).Uint()),
byte(d.Index(1).Uint()),
byte(d.Index(2).Uint()),
byte(d.Index(3).Uint())).String()
}
return net.IPv4(byte(d.Index(12).Uint()),
byte(d.Index(13).Uint()),
byte(d.Index(14).Uint()),
byte(d.Index(15).Uint())).String()
case `dns:"aaaa"`:
return net.IP{
byte(d.Index(0).Uint()),
byte(d.Index(1).Uint()),
byte(d.Index(2).Uint()),
byte(d.Index(3).Uint()),
byte(d.Index(4).Uint()),
byte(d.Index(5).Uint()),
byte(d.Index(6).Uint()),
byte(d.Index(7).Uint()),
byte(d.Index(8).Uint()),
byte(d.Index(9).Uint()),
byte(d.Index(10).Uint()),
byte(d.Index(11).Uint()),
byte(d.Index(12).Uint()),
byte(d.Index(13).Uint()),
byte(d.Index(14).Uint()),
byte(d.Index(15).Uint()),
}.String()
case `dns:"nsec"`:
if d.Len() == 0 {
return ""
}
s := Type(d.Index(0).Uint()).String()
for i := 1; i < d.Len(); i++ {
s += " " + Type(d.Index(i).Uint()).String()
}
return s
case `dns:"wks"`:
if d.Len() == 0 {
return ""
}
s := strconv.Itoa(int(d.Index(0).Uint()))
for i := 0; i < d.Len(); i++ {
s += " " + strconv.Itoa(int(d.Index(i).Uint()))
}
return s
default:
// if it does not have a tag its a string slice
fallthrough
case `dns:"txt"`:
if d.Len() == 0 {
return ""
}
s := d.Index(0).String()
for i := 1; i < d.Len(); i++ {
s += " " + d.Index(i).String()
}
return s
}
}
return ""
}

View File

@ -1,149 +0,0 @@
package dns
import (
"crypto/dsa"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/rsa"
"math/big"
"strconv"
)
const _FORMAT = "Private-key-format: v1.3\n"
// Empty interface that is used as a wrapper around all possible
// private key implementations from the crypto package.
type PrivateKey interface{}
// Generate generates a DNSKEY of the given bit size.
// The public part is put inside the DNSKEY record.
// The Algorithm in the key must be set as this will define
// what kind of DNSKEY will be generated.
// The ECDSA algorithms imply a fixed keysize, in that case
// bits should be set to the size of the algorithm.
func (r *DNSKEY) Generate(bits int) (PrivateKey, error) {
switch r.Algorithm {
case DSA, DSANSEC3SHA1:
if bits != 1024 {
return nil, ErrKeySize
}
case RSAMD5, RSASHA1, RSASHA256, RSASHA1NSEC3SHA1:
if bits < 512 || bits > 4096 {
return nil, ErrKeySize
}
case RSASHA512:
if bits < 1024 || bits > 4096 {
return nil, ErrKeySize
}
case ECDSAP256SHA256:
if bits != 256 {
return nil, ErrKeySize
}
case ECDSAP384SHA384:
if bits != 384 {
return nil, ErrKeySize
}
}
switch r.Algorithm {
case DSA, DSANSEC3SHA1:
params := new(dsa.Parameters)
if err := dsa.GenerateParameters(params, rand.Reader, dsa.L1024N160); err != nil {
return nil, err
}
priv := new(dsa.PrivateKey)
priv.PublicKey.Parameters = *params
err := dsa.GenerateKey(priv, rand.Reader)
if err != nil {
return nil, err
}
r.setPublicKeyDSA(params.Q, params.P, params.G, priv.PublicKey.Y)
return priv, nil
case RSAMD5, RSASHA1, RSASHA256, RSASHA512, RSASHA1NSEC3SHA1:
priv, err := rsa.GenerateKey(rand.Reader, bits)
if err != nil {
return nil, err
}
r.setPublicKeyRSA(priv.PublicKey.E, priv.PublicKey.N)
return priv, nil
case ECDSAP256SHA256, ECDSAP384SHA384:
var c elliptic.Curve
switch r.Algorithm {
case ECDSAP256SHA256:
c = elliptic.P256()
case ECDSAP384SHA384:
c = elliptic.P384()
}
priv, err := ecdsa.GenerateKey(c, rand.Reader)
if err != nil {
return nil, err
}
r.setPublicKeyCurve(priv.PublicKey.X, priv.PublicKey.Y)
return priv, nil
default:
return nil, ErrAlg
}
return nil, nil // Dummy return
}
// PrivateKeyString converts a PrivateKey to a string. This
// string has the same format as the private-key-file of BIND9 (Private-key-format: v1.3).
// It needs some info from the key (hashing, keytag), so its a method of the DNSKEY.
func (r *DNSKEY) PrivateKeyString(p PrivateKey) (s string) {
switch t := p.(type) {
case *rsa.PrivateKey:
algorithm := strconv.Itoa(int(r.Algorithm)) + " (" + AlgorithmToString[r.Algorithm] + ")"
modulus := unpackBase64(t.PublicKey.N.Bytes())
e := big.NewInt(int64(t.PublicKey.E))
publicExponent := unpackBase64(e.Bytes())
privateExponent := unpackBase64(t.D.Bytes())
prime1 := unpackBase64(t.Primes[0].Bytes())
prime2 := unpackBase64(t.Primes[1].Bytes())
// Calculate Exponent1/2 and Coefficient as per: http://en.wikipedia.org/wiki/RSA#Using_the_Chinese_remainder_algorithm
// and from: http://code.google.com/p/go/issues/detail?id=987
one := big.NewInt(1)
minusone := big.NewInt(-1)
p_1 := big.NewInt(0).Sub(t.Primes[0], one)
q_1 := big.NewInt(0).Sub(t.Primes[1], one)
exp1 := big.NewInt(0).Mod(t.D, p_1)
exp2 := big.NewInt(0).Mod(t.D, q_1)
coeff := big.NewInt(0).Exp(t.Primes[1], minusone, t.Primes[0])
exponent1 := unpackBase64(exp1.Bytes())
exponent2 := unpackBase64(exp2.Bytes())
coefficient := unpackBase64(coeff.Bytes())
s = _FORMAT +
"Algorithm: " + algorithm + "\n" +
"Modules: " + modulus + "\n" +
"PublicExponent: " + publicExponent + "\n" +
"PrivateExponent: " + privateExponent + "\n" +
"Prime1: " + prime1 + "\n" +
"Prime2: " + prime2 + "\n" +
"Exponent1: " + exponent1 + "\n" +
"Exponent2: " + exponent2 + "\n" +
"Coefficient: " + coefficient + "\n"
case *ecdsa.PrivateKey:
algorithm := strconv.Itoa(int(r.Algorithm)) + " (" + AlgorithmToString[r.Algorithm] + ")"
private := unpackBase64(t.D.Bytes())
s = _FORMAT +
"Algorithm: " + algorithm + "\n" +
"PrivateKey: " + private + "\n"
case *dsa.PrivateKey:
algorithm := strconv.Itoa(int(r.Algorithm)) + " (" + AlgorithmToString[r.Algorithm] + ")"
prime := unpackBase64(t.PublicKey.Parameters.P.Bytes())
subprime := unpackBase64(t.PublicKey.Parameters.Q.Bytes())
base := unpackBase64(t.PublicKey.Parameters.G.Bytes())
priv := unpackBase64(t.X.Bytes())
pub := unpackBase64(t.PublicKey.Y.Bytes())
s = _FORMAT +
"Algorithm: " + algorithm + "\n" +
"Prime(p): " + prime + "\n" +
"Subprime(q): " + subprime + "\n" +
"Base(g): " + base + "\n" +
"Private_value(x): " + priv + "\n" +
"Public_value(y): " + pub + "\n"
}
return
}

View File

@ -12,7 +12,7 @@ import (
"encoding/base32" "encoding/base32"
"encoding/base64" "encoding/base64"
"encoding/hex" "encoding/hex"
"fmt" "math/big"
"math/rand" "math/rand"
"net" "net"
"reflect" "reflect"
@ -91,6 +91,7 @@ var TypeToString = map[uint16]string{
TypeATMA: "ATMA", TypeATMA: "ATMA",
TypeAXFR: "AXFR", // Meta RR TypeAXFR: "AXFR", // Meta RR
TypeCAA: "CAA", TypeCAA: "CAA",
TypeCDNSKEY: "CDNSKEY",
TypeCDS: "CDS", TypeCDS: "CDS",
TypeCERT: "CERT", TypeCERT: "CERT",
TypeCNAME: "CNAME", TypeCNAME: "CNAME",
@ -109,6 +110,7 @@ var TypeToString = map[uint16]string{
TypeIPSECKEY: "IPSECKEY", TypeIPSECKEY: "IPSECKEY",
TypeISDN: "ISDN", TypeISDN: "ISDN",
TypeIXFR: "IXFR", // Meta RR TypeIXFR: "IXFR", // Meta RR
TypeKEY: "KEY",
TypeKX: "KX", TypeKX: "KX",
TypeL32: "L32", TypeL32: "L32",
TypeL64: "L64", TypeL64: "L64",
@ -139,6 +141,7 @@ var TypeToString = map[uint16]string{
TypeRP: "RP", TypeRP: "RP",
TypeRRSIG: "RRSIG", TypeRRSIG: "RRSIG",
TypeRT: "RT", TypeRT: "RT",
TypeSIG: "SIG",
TypeSOA: "SOA", TypeSOA: "SOA",
TypeSPF: "SPF", TypeSPF: "SPF",
TypeSRV: "SRV", TypeSRV: "SRV",
@ -419,7 +422,15 @@ Loop:
s = append(s, '\\', 'r') s = append(s, '\\', 'r')
default: default:
if b < 32 || b >= 127 { // unprintable use \DDD if b < 32 || b >= 127 { // unprintable use \DDD
s = append(s, fmt.Sprintf("\\%03d", b)...) var buf [3]byte
bufs := strconv.AppendInt(buf[:0], int64(b), 10)
s = append(s, '\\')
for i := 0; i < 3-len(bufs); i++ {
s = append(s, '0')
}
for _, r := range bufs {
s = append(s, r)
}
} else { } else {
s = append(s, b) s = append(s, b)
} }
@ -554,7 +565,15 @@ func unpackTxtString(msg []byte, offset int) (string, int, error) {
s = append(s, `\n`...) s = append(s, `\n`...)
default: default:
if b < 32 || b > 127 { // unprintable if b < 32 || b > 127 { // unprintable
s = append(s, fmt.Sprintf("\\%03d", b)...) var buf [3]byte
bufs := strconv.AppendInt(buf[:0], int64(b), 10)
s = append(s, '\\')
for i := 0; i < 3-len(bufs); i++ {
s = append(s, '0')
}
for _, r := range bufs {
s = append(s, r)
}
} else { } else {
s = append(s, b) s = append(s, b)
} }
@ -633,6 +652,12 @@ func packStructValue(val reflect.Value, msg []byte, off int, compression map[str
off += len(b) off += len(b)
} }
case `dns:"a"`: case `dns:"a"`:
if val.Type().String() == "dns.IPSECKEY" {
// Field(2) is GatewayType, must be 1
if val.Field(2).Uint() != 1 {
continue
}
}
// It must be a slice of 4, even if it is 16, we encode // It must be a slice of 4, even if it is 16, we encode
// only the first 4 // only the first 4
if off+net.IPv4len > lenmsg { if off+net.IPv4len > lenmsg {
@ -657,6 +682,12 @@ func packStructValue(val reflect.Value, msg []byte, off int, compression map[str
return lenmsg, &Error{err: "overflow packing a"} return lenmsg, &Error{err: "overflow packing a"}
} }
case `dns:"aaaa"`: case `dns:"aaaa"`:
if val.Type().String() == "dns.IPSECKEY" {
// Field(2) is GatewayType, must be 2
if val.Field(2).Uint() != 2 {
continue
}
}
if fv.Len() == 0 { if fv.Len() == 0 {
break break
} }
@ -668,6 +699,7 @@ func packStructValue(val reflect.Value, msg []byte, off int, compression map[str
off++ off++
} }
case `dns:"wks"`: case `dns:"wks"`:
// TODO(miek): this is wrong should be lenrd
if off == lenmsg { if off == lenmsg {
break // dyn. updates break // dyn. updates
} }
@ -794,13 +826,20 @@ func packStructValue(val reflect.Value, msg []byte, off int, compression map[str
default: default:
return lenmsg, &Error{"bad tag packing string: " + typefield.Tag.Get("dns")} return lenmsg, &Error{"bad tag packing string: " + typefield.Tag.Get("dns")}
case `dns:"base64"`: case `dns:"base64"`:
b64, e := packBase64([]byte(s)) b64, e := fromBase64([]byte(s))
if e != nil { if e != nil {
return lenmsg, e return lenmsg, e
} }
copy(msg[off:off+len(b64)], b64) copy(msg[off:off+len(b64)], b64)
off += len(b64) off += len(b64)
case `dns:"domain-name"`: case `dns:"domain-name"`:
if val.Type().String() == "dns.IPSECKEY" {
// Field(2) is GatewayType, 1 and 2 or used for addresses
x := val.Field(2).Uint()
if x == 1 || x == 2 {
continue
}
}
if off, err = PackDomainName(s, msg, off, compression, false && compress); err != nil { if off, err = PackDomainName(s, msg, off, compression, false && compress); err != nil {
return lenmsg, err return lenmsg, err
} }
@ -815,7 +854,7 @@ func packStructValue(val reflect.Value, msg []byte, off int, compression map[str
msg[off-1] = 20 msg[off-1] = 20
fallthrough fallthrough
case `dns:"base32"`: case `dns:"base32"`:
b32, e := packBase32([]byte(s)) b32, e := fromBase32([]byte(s))
if e != nil { if e != nil {
return lenmsg, e return lenmsg, e
} }
@ -875,9 +914,12 @@ func packStructCompress(any interface{}, msg []byte, off int, compression map[st
// Unpack a reflect.StructValue from msg. // Unpack a reflect.StructValue from msg.
// Same restrictions as packStructValue. // Same restrictions as packStructValue.
func unpackStructValue(val reflect.Value, msg []byte, off int) (off1 int, err error) { func unpackStructValue(val reflect.Value, msg []byte, off int) (off1 int, err error) {
var rdend int var lenrd int
lenmsg := len(msg) lenmsg := len(msg)
for i := 0; i < val.NumField(); i++ { for i := 0; i < val.NumField(); i++ {
if lenrd != 0 && lenrd == off {
break
}
if off > lenmsg { if off > lenmsg {
return lenmsg, &Error{"bad offset unpacking"} return lenmsg, &Error{"bad offset unpacking"}
} }
@ -889,7 +931,7 @@ func unpackStructValue(val reflect.Value, msg []byte, off int) (off1 int, err er
// therefore it's expected that this interface would be PrivateRdata // therefore it's expected that this interface would be PrivateRdata
switch data := fv.Interface().(type) { switch data := fv.Interface().(type) {
case PrivateRdata: case PrivateRdata:
n, err := data.Unpack(msg[off:rdend]) n, err := data.Unpack(msg[off:lenrd])
if err != nil { if err != nil {
return lenmsg, err return lenmsg, err
} }
@ -905,7 +947,7 @@ func unpackStructValue(val reflect.Value, msg []byte, off int) (off1 int, err er
// HIP record slice of name (or none) // HIP record slice of name (or none)
servers := make([]string, 0) servers := make([]string, 0)
var s string var s string
for off < rdend { for off < lenrd {
s, off, err = UnpackDomainName(msg, off) s, off, err = UnpackDomainName(msg, off)
if err != nil { if err != nil {
return lenmsg, err return lenmsg, err
@ -914,17 +956,17 @@ func unpackStructValue(val reflect.Value, msg []byte, off int) (off1 int, err er
} }
fv.Set(reflect.ValueOf(servers)) fv.Set(reflect.ValueOf(servers))
case `dns:"txt"`: case `dns:"txt"`:
if off == lenmsg || rdend == off { if off == lenmsg || lenrd == off {
break break
} }
var txt []string var txt []string
txt, off, err = unpackTxt(msg, off, rdend) txt, off, err = unpackTxt(msg, off, lenrd)
if err != nil { if err != nil {
return lenmsg, err return lenmsg, err
} }
fv.Set(reflect.ValueOf(txt)) fv.Set(reflect.ValueOf(txt))
case `dns:"opt"`: // edns0 case `dns:"opt"`: // edns0
if off == rdend { if off == lenrd {
// This is an EDNS0 (OPT Record) with no rdata // This is an EDNS0 (OPT Record) with no rdata
// We can safely return here. // We can safely return here.
break break
@ -937,7 +979,7 @@ func unpackStructValue(val reflect.Value, msg []byte, off int) (off1 int, err er
} }
code, off = unpackUint16(msg, off) code, off = unpackUint16(msg, off)
optlen, off1 := unpackUint16(msg, off) optlen, off1 := unpackUint16(msg, off)
if off1+int(optlen) > rdend { if off1+int(optlen) > lenrd {
return lenmsg, &Error{err: "overflow unpacking opt"} return lenmsg, &Error{err: "overflow unpacking opt"}
} }
switch code { switch code {
@ -997,24 +1039,36 @@ func unpackStructValue(val reflect.Value, msg []byte, off int) (off1 int, err er
// do nothing? // do nothing?
off = off1 + int(optlen) off = off1 + int(optlen)
} }
if off < rdend { if off < lenrd {
goto Option goto Option
} }
fv.Set(reflect.ValueOf(edns)) fv.Set(reflect.ValueOf(edns))
case `dns:"a"`: case `dns:"a"`:
if off == lenmsg { if val.Type().String() == "dns.IPSECKEY" {
// Field(2) is GatewayType, must be 1
if val.Field(2).Uint() != 1 {
continue
}
}
if off == lenrd {
break // dyn. update break // dyn. update
} }
if off+net.IPv4len > rdend { if off+net.IPv4len > lenrd || off+net.IPv4len > lenmsg {
return lenmsg, &Error{err: "overflow unpacking a"} return lenmsg, &Error{err: "overflow unpacking a"}
} }
fv.Set(reflect.ValueOf(net.IPv4(msg[off], msg[off+1], msg[off+2], msg[off+3]))) fv.Set(reflect.ValueOf(net.IPv4(msg[off], msg[off+1], msg[off+2], msg[off+3])))
off += net.IPv4len off += net.IPv4len
case `dns:"aaaa"`: case `dns:"aaaa"`:
if off == rdend { if val.Type().String() == "dns.IPSECKEY" {
// Field(2) is GatewayType, must be 2
if val.Field(2).Uint() != 2 {
continue
}
}
if off == lenrd {
break break
} }
if off+net.IPv6len > rdend || off+net.IPv6len > lenmsg { if off+net.IPv6len > lenrd || off+net.IPv6len > lenmsg {
return lenmsg, &Error{err: "overflow unpacking aaaa"} return lenmsg, &Error{err: "overflow unpacking aaaa"}
} }
fv.Set(reflect.ValueOf(net.IP{msg[off], msg[off+1], msg[off+2], msg[off+3], msg[off+4], fv.Set(reflect.ValueOf(net.IP{msg[off], msg[off+1], msg[off+2], msg[off+3], msg[off+4],
@ -1025,7 +1079,7 @@ func unpackStructValue(val reflect.Value, msg []byte, off int) (off1 int, err er
// Rest of the record is the bitmap // Rest of the record is the bitmap
serv := make([]uint16, 0) serv := make([]uint16, 0)
j := 0 j := 0
for off < rdend { for off < lenrd {
if off+1 > lenmsg { if off+1 > lenmsg {
return lenmsg, &Error{err: "overflow unpacking wks"} return lenmsg, &Error{err: "overflow unpacking wks"}
} }
@ -1060,17 +1114,17 @@ func unpackStructValue(val reflect.Value, msg []byte, off int) (off1 int, err er
} }
fv.Set(reflect.ValueOf(serv)) fv.Set(reflect.ValueOf(serv))
case `dns:"nsec"`: // NSEC/NSEC3 case `dns:"nsec"`: // NSEC/NSEC3
if off == rdend { if off == lenrd {
break break
} }
// Rest of the record is the type bitmap // Rest of the record is the type bitmap
if off+2 > rdend || off+2 > lenmsg { if off+2 > lenrd || off+2 > lenmsg {
return lenmsg, &Error{err: "overflow unpacking nsecx"} return lenmsg, &Error{err: "overflow unpacking nsecx"}
} }
nsec := make([]uint16, 0) nsec := make([]uint16, 0)
length := 0 length := 0
window := 0 window := 0
for off+2 < rdend { for off+2 < lenrd {
window = int(msg[off]) window = int(msg[off])
length = int(msg[off+1]) length = int(msg[off+1])
//println("off, windows, length, end", off, window, length, endrr) //println("off, windows, length, end", off, window, length, endrr)
@ -1127,7 +1181,7 @@ func unpackStructValue(val reflect.Value, msg []byte, off int) (off1 int, err er
return lenmsg, err return lenmsg, err
} }
if val.Type().Field(i).Name == "Hdr" { if val.Type().Field(i).Name == "Hdr" {
rdend = off + int(val.FieldByName("Hdr").FieldByName("Rdlength").Uint()) lenrd = off + int(val.FieldByName("Hdr").FieldByName("Rdlength").Uint())
} }
case reflect.Uint8: case reflect.Uint8:
if off == lenmsg { if off == lenmsg {
@ -1184,29 +1238,36 @@ func unpackStructValue(val reflect.Value, msg []byte, off int) (off1 int, err er
default: default:
return lenmsg, &Error{"bad tag unpacking string: " + val.Type().Field(i).Tag.Get("dns")} return lenmsg, &Error{"bad tag unpacking string: " + val.Type().Field(i).Tag.Get("dns")}
case `dns:"hex"`: case `dns:"hex"`:
hexend := rdend hexend := lenrd
if val.FieldByName("Hdr").FieldByName("Rrtype").Uint() == uint64(TypeHIP) { if val.FieldByName("Hdr").FieldByName("Rrtype").Uint() == uint64(TypeHIP) {
hexend = off + int(val.FieldByName("HitLength").Uint()) hexend = off + int(val.FieldByName("HitLength").Uint())
} }
if hexend > rdend || hexend > lenmsg { if hexend > lenrd || hexend > lenmsg {
return lenmsg, &Error{err: "overflow unpacking hex"} return lenmsg, &Error{err: "overflow unpacking hex"}
} }
s = hex.EncodeToString(msg[off:hexend]) s = hex.EncodeToString(msg[off:hexend])
off = hexend off = hexend
case `dns:"base64"`: case `dns:"base64"`:
// Rest of the RR is base64 encoded value // Rest of the RR is base64 encoded value
b64end := rdend b64end := lenrd
if val.FieldByName("Hdr").FieldByName("Rrtype").Uint() == uint64(TypeHIP) { if val.FieldByName("Hdr").FieldByName("Rrtype").Uint() == uint64(TypeHIP) {
b64end = off + int(val.FieldByName("PublicKeyLength").Uint()) b64end = off + int(val.FieldByName("PublicKeyLength").Uint())
} }
if b64end > rdend || b64end > lenmsg { if b64end > lenrd || b64end > lenmsg {
return lenmsg, &Error{err: "overflow unpacking base64"} return lenmsg, &Error{err: "overflow unpacking base64"}
} }
s = unpackBase64(msg[off:b64end]) s = toBase64(msg[off:b64end])
off = b64end off = b64end
case `dns:"cdomain-name"`: case `dns:"cdomain-name"`:
fallthrough fallthrough
case `dns:"domain-name"`: case `dns:"domain-name"`:
if val.Type().String() == "dns.IPSECKEY" {
// Field(2) is GatewayType, 1 and 2 or used for addresses
x := val.Field(2).Uint()
if x == 1 || x == 2 {
continue
}
}
if off == lenmsg { if off == lenmsg {
// zero rdata foo, OK for dyn. updates // zero rdata foo, OK for dyn. updates
break break
@ -1228,7 +1289,7 @@ func unpackStructValue(val reflect.Value, msg []byte, off int) (off1 int, err er
if off+size > lenmsg { if off+size > lenmsg {
return lenmsg, &Error{err: "overflow unpacking base32"} return lenmsg, &Error{err: "overflow unpacking base32"}
} }
s = unpackBase32(msg[off : off+size]) s = toBase32(msg[off : off+size])
off += size off += size
case `dns:"size-hex"`: case `dns:"size-hex"`:
// a "size" string, but it must be encoded in hex in the string // a "size" string, but it must be encoded in hex in the string
@ -1276,58 +1337,53 @@ func dddToByte(s []byte) byte {
return byte((s[0]-'0')*100 + (s[1]-'0')*10 + (s[2] - '0')) return byte((s[0]-'0')*100 + (s[1]-'0')*10 + (s[2] - '0'))
} }
// Helper function for unpacking
func unpackUint16(msg []byte, off int) (v uint16, off1 int) {
v = uint16(msg[off])<<8 | uint16(msg[off+1])
off1 = off + 2
return
}
// UnpackStruct unpacks a binary message from offset off to the interface // UnpackStruct unpacks a binary message from offset off to the interface
// value given. // value given.
func UnpackStruct(any interface{}, msg []byte, off int) (off1 int, err error) { func UnpackStruct(any interface{}, msg []byte, off int) (int, error) {
off, err = unpackStructValue(structValue(any), msg, off) return unpackStructValue(structValue(any), msg, off)
return off, err
} }
func unpackBase32(b []byte) string { // Helper function for packing and unpacking
b32 := make([]byte, base32.HexEncoding.EncodedLen(len(b))) func intToBytes(i *big.Int, length int) []byte {
base32.HexEncoding.Encode(b32, b) buf := i.Bytes()
return string(b32) if len(buf) < length {
b := make([]byte, length)
copy(b[length-len(buf):], buf)
return b
}
return buf
} }
func unpackBase64(b []byte) string { func unpackUint16(msg []byte, off int) (uint16, int) {
b64 := make([]byte, base64.StdEncoding.EncodedLen(len(b))) return uint16(msg[off])<<8 | uint16(msg[off+1]), off + 2
base64.StdEncoding.Encode(b64, b)
return string(b64)
} }
// Helper function for packing
func packUint16(i uint16) (byte, byte) { func packUint16(i uint16) (byte, byte) {
return byte(i >> 8), byte(i) return byte(i >> 8), byte(i)
} }
func packBase64(s []byte) ([]byte, error) { func toBase32(b []byte) string {
b64len := base64.StdEncoding.DecodedLen(len(s)) return base32.HexEncoding.EncodeToString(b)
buf := make([]byte, b64len)
n, err := base64.StdEncoding.Decode(buf, []byte(s))
if err != nil {
return nil, err
}
buf = buf[:n]
return buf, nil
} }
// Helper function for packing, mostly used in dnssec.go func fromBase32(s []byte) (buf []byte, err error) {
func packBase32(s []byte) ([]byte, error) { buflen := base32.HexEncoding.DecodedLen(len(s))
b32len := base32.HexEncoding.DecodedLen(len(s)) buf = make([]byte, buflen)
buf := make([]byte, b32len) n, err := base32.HexEncoding.Decode(buf, s)
n, err := base32.HexEncoding.Decode(buf, []byte(s))
if err != nil {
return nil, err
}
buf = buf[:n] buf = buf[:n]
return buf, nil return
}
func toBase64(b []byte) string {
return base64.StdEncoding.EncodeToString(b)
}
func fromBase64(s []byte) (buf []byte, err error) {
buflen := base64.StdEncoding.DecodedLen(len(s))
buf = make([]byte, buflen)
n, err := base64.StdEncoding.Decode(buf, s)
buf = buf[:n]
return
} }
// PackRR packs a resource record rr into msg[off:]. // PackRR packs a resource record rr into msg[off:].
@ -1856,25 +1912,34 @@ func (dns *Msg) Copy() *Msg {
copy(r1.Question, dns.Question) // TODO(miek): Question is an immutable value, ok to do a shallow-copy copy(r1.Question, dns.Question) // TODO(miek): Question is an immutable value, ok to do a shallow-copy
} }
rrArr := make([]RR, len(dns.Answer)+len(dns.Ns)+len(dns.Extra))
var rri int
if len(dns.Answer) > 0 { if len(dns.Answer) > 0 {
r1.Answer = make([]RR, len(dns.Answer)) rrbegin := rri
for i := 0; i < len(dns.Answer); i++ { for i := 0; i < len(dns.Answer); i++ {
r1.Answer[i] = dns.Answer[i].copy() rrArr[rri] = dns.Answer[i].copy()
rri++
} }
r1.Answer = rrArr[rrbegin:rri:rri]
} }
if len(dns.Ns) > 0 { if len(dns.Ns) > 0 {
r1.Ns = make([]RR, len(dns.Ns)) rrbegin := rri
for i := 0; i < len(dns.Ns); i++ { for i := 0; i < len(dns.Ns); i++ {
r1.Ns[i] = dns.Ns[i].copy() rrArr[rri] = dns.Ns[i].copy()
rri++
} }
r1.Ns = rrArr[rrbegin:rri:rri]
} }
if len(dns.Extra) > 0 { if len(dns.Extra) > 0 {
r1.Extra = make([]RR, len(dns.Extra)) rrbegin := rri
for i := 0; i < len(dns.Extra); i++ { for i := 0; i < len(dns.Extra); i++ {
r1.Extra[i] = dns.Extra[i].copy() rrArr[rri] = dns.Extra[i].copy()
rri++
} }
r1.Extra = rrArr[rrbegin:rri:rri]
} }
return r1 return r1

View File

@ -47,7 +47,7 @@ func HashName(label string, ha uint8, iter uint16, salt string) string {
io.WriteString(s, string(nsec3)) io.WriteString(s, string(nsec3))
nsec3 = s.Sum(nil) nsec3 = s.Sum(nil)
} }
return unpackBase32(nsec3) return toBase32(nsec3)
} }
type Denialer interface { type Denialer interface {

View File

@ -386,8 +386,8 @@ func TestNSEC(t *testing.T) {
func TestParseLOC(t *testing.T) { func TestParseLOC(t *testing.T) {
lt := map[string]string{ lt := map[string]string{
"SW1A2AA.find.me.uk. LOC 51 30 12.748 N 00 07 39.611 W 0.00m 0.00m 0.00m 0.00m": "SW1A2AA.find.me.uk.\t3600\tIN\tLOC\t51 30 12.748 N 00 07 39.611 W 0.00m 0.00m 0.00m 0.00m", "SW1A2AA.find.me.uk. LOC 51 30 12.748 N 00 07 39.611 W 0.00m 0.00m 0.00m 0.00m": "SW1A2AA.find.me.uk.\t3600\tIN\tLOC\t51 30 12.748 N 00 07 39.611 W 0m 0.00m 0.00m 0.00m",
"SW1A2AA.find.me.uk. LOC 51 0 0.0 N 00 07 39.611 W 0.00m 0.00m 0.00m 0.00m": "SW1A2AA.find.me.uk.\t3600\tIN\tLOC\t51 00 0.000 N 00 07 39.611 W 0.00m 0.00m 0.00m 0.00m", "SW1A2AA.find.me.uk. LOC 51 0 0.0 N 00 07 39.611 W 0.00m 0.00m 0.00m 0.00m": "SW1A2AA.find.me.uk.\t3600\tIN\tLOC\t51 00 0.000 N 00 07 39.611 W 0m 0.00m 0.00m 0.00m",
} }
for i, o := range lt { for i, o := range lt {
rr, e := NewRR(i) rr, e := NewRR(i)
@ -652,7 +652,7 @@ moutamassey NS ns01.yahoodomains.jp.
} }
func ExampleHIP() { func ExampleHIP() {
h := `www.example.com IN HIP ( 2 200100107B1A74DF365639CC39F1D578 h := `www.example.com IN HIP ( 2 200100107B1A74DF365639CC39F1D578
AwEAAbdxyhNuSutc5EMzxTs9LBPCIkOFH8cIvM4p AwEAAbdxyhNuSutc5EMzxTs9LBPCIkOFH8cIvM4p
9+LrV4e19WzK00+CI6zBCQTdtWsuxKbWIy87UOoJTwkUs7lBu+Upr1gsNrut79ryra+bSRGQ 9+LrV4e19WzK00+CI6zBCQTdtWsuxKbWIy87UOoJTwkUs7lBu+Upr1gsNrut79ryra+bSRGQ
b1slImA8YVJyuIDsj7kwzG7jnERNqnWxZ48AWkskmdHaVDP4BcelrTI3rMXdXF5D b1slImA8YVJyuIDsj7kwzG7jnERNqnWxZ48AWkskmdHaVDP4BcelrTI3rMXdXF5D
@ -661,7 +661,7 @@ b1slImA8YVJyuIDsj7kwzG7jnERNqnWxZ48AWkskmdHaVDP4BcelrTI3rMXdXF5D
fmt.Printf("%s\n", hip.String()) fmt.Printf("%s\n", hip.String())
} }
// Output: // Output:
// www.example.com. 3600 IN HIP 2 200100107B1A74DF365639CC39F1D578 AwEAAbdxyhNuSutc5EMzxTs9LBPCIkOFH8cIvM4p9+LrV4e19WzK00+CI6zBCQTdtWsuxKbWIy87UOoJTwkUs7lBu+Upr1gsNrut79ryra+bSRGQb1slImA8YVJyuIDsj7kwzG7jnERNqnWxZ48AWkskmdHaVDP4BcelrTI3rMXdXF5D rvs.example.com. // www.example.com. 3600 IN HIP 2 200100107B1A74DF365639CC39F1D578 AwEAAbdxyhNuSutc5EMzxTs9LBPCIkOFH8cIvM4p9+LrV4e19WzK00+CI6zBCQTdtWsuxKbWIy87UOoJTwkUs7lBu+Upr1gsNrut79ryra+bSRGQb1slImA8YVJyuIDsj7kwzG7jnERNqnWxZ48AWkskmdHaVDP4BcelrTI3rMXdXF5D rvs.example.com.
} }
func TestHIP(t *testing.T) { func TestHIP(t *testing.T) {
@ -1225,35 +1225,21 @@ func TestMalformedPackets(t *testing.T) {
} }
} }
func TestDynamicUpdateParsing(t *testing.T) {
prefix := "example.com. IN "
for _, typ := range TypeToString {
if typ == "CAA" || typ == "OPT" || typ == "AXFR" || typ == "IXFR" || typ == "ANY" || typ == "TKEY" ||
typ == "TSIG" || typ == "ISDN" || typ == "UNSPEC" || typ == "NULL" || typ == "ATMA" {
continue
}
r, e := NewRR(prefix + typ)
if e != nil {
t.Log("failure to parse: " + prefix + typ)
t.Fail()
} else {
t.Logf("parsed: %s", r.String())
}
}
}
type algorithm struct { type algorithm struct {
name uint8 name uint8
bits int bits int
} }
func TestNewPrivateKeyECDSA(t *testing.T) { func TestNewPrivateKey(t *testing.T) {
if testing.Short() {
t.Skip("skipping test in short mode.")
}
algorithms := []algorithm{ algorithms := []algorithm{
algorithm{ECDSAP256SHA256, 256}, algorithm{ECDSAP256SHA256, 256},
algorithm{ECDSAP384SHA384, 384}, algorithm{ECDSAP384SHA384, 384},
algorithm{RSASHA1, 1024}, algorithm{RSASHA1, 1024},
algorithm{RSASHA256, 2048}, algorithm{RSASHA256, 2048},
// algorithm{DSA, 1024}, // TODO: STILL BROKEN! algorithm{DSA, 1024},
} }
for _, algo := range algorithms { for _, algo := range algorithms {
@ -1272,12 +1258,15 @@ func TestNewPrivateKeyECDSA(t *testing.T) {
newPrivKey, err := key.NewPrivateKey(key.PrivateKeyString(privkey)) newPrivKey, err := key.NewPrivateKey(key.PrivateKeyString(privkey))
if err != nil { if err != nil {
t.Log(key.String())
t.Log(key.PrivateKeyString(privkey))
t.Fatal(err.Error()) t.Fatal(err.Error())
} }
switch newPrivKey := newPrivKey.(type) { switch newPrivKey := newPrivKey.(type) {
case *rsa.PrivateKey: case *RSAPrivateKey:
newPrivKey.Precompute() (*rsa.PrivateKey)(newPrivKey).Precompute()
} }
if !reflect.DeepEqual(privkey, newPrivKey) { if !reflect.DeepEqual(privkey, newPrivKey) {
@ -1285,3 +1274,136 @@ func TestNewPrivateKeyECDSA(t *testing.T) {
} }
} }
} }
// special input test
func TestNewRRSpecial(t *testing.T) {
var (
rr RR
err error
expect string
)
rr, err = NewRR("; comment")
expect = ""
if err != nil {
t.Errorf("unexpected err: %s", err)
}
if rr != nil {
t.Errorf("unexpected result: [%s] != [%s]", rr, expect)
}
rr, err = NewRR("")
expect = ""
if err != nil {
t.Errorf("unexpected err: %s", err)
}
if rr != nil {
t.Errorf("unexpected result: [%s] != [%s]", rr, expect)
}
rr, err = NewRR("$ORIGIN foo.")
expect = ""
if err != nil {
t.Errorf("unexpected err: %s", err)
}
if rr != nil {
t.Errorf("unexpected result: [%s] != [%s]", rr, expect)
}
rr, err = NewRR(" ")
expect = ""
if err != nil {
t.Errorf("unexpected err: %s", err)
}
if rr != nil {
t.Errorf("unexpected result: [%s] != [%s]", rr, expect)
}
rr, err = NewRR("\n")
expect = ""
if err != nil {
t.Errorf("unexpected err: %s", err)
}
if rr != nil {
t.Errorf("unexpected result: [%s] != [%s]", rr, expect)
}
rr, err = NewRR("foo. A 1.1.1.1\nbar. A 2.2.2.2")
expect = "foo.\t3600\tIN\tA\t1.1.1.1"
if err != nil {
t.Errorf("unexpected err: %s", err)
}
if rr == nil || rr.String() != expect {
t.Errorf("unexpected result: [%s] != [%s]", rr, expect)
}
}
func TestPrintfVerbsRdata(t *testing.T) {
x, _ := NewRR("www.miek.nl. IN MX 20 mx.miek.nl.")
if Field(x, 1) != "20" {
t.Errorf("should be 20")
}
if Field(x, 2) != "mx.miek.nl." {
t.Errorf("should be mx.miek.nl.")
}
x, _ = NewRR("www.miek.nl. IN A 127.0.0.1")
if Field(x, 1) != "127.0.0.1" {
t.Errorf("should be 127.0.0.1")
}
x, _ = NewRR("www.miek.nl. IN AAAA ::1")
if Field(x, 1) != "::1" {
t.Errorf("should be ::1")
}
x, _ = NewRR("www.miek.nl. IN NSEC a.miek.nl. A NS SOA MX AAAA")
if Field(x, 1) != "a.miek.nl." {
t.Errorf("should be a.miek.nl.")
}
if Field(x, 2) != "A NS SOA MX AAAA" {
t.Errorf("should be A NS SOA MX AAAA")
}
x, _ = NewRR("www.miek.nl. IN TXT \"first\" \"second\"")
if Field(x, 1) != "first second" {
t.Errorf("should be first second")
}
if Field(x, 0) != "" {
t.Errorf("should be empty")
}
}
func TestParseIPSECKEY(t *testing.T) {
tests := []string{
"38.2.0.192.in-addr.arpa. 7200 IN IPSECKEY ( 10 1 2 192.0.2.38 AQNRU3mG7TVTO2BkR47usntb102uFJtugbo6BSGvgqt4AQ== )",
"38.2.0.192.in-addr.arpa.\t7200\tIN\tIPSECKEY\t10 1 2 192.0.2.38 AQNRU3mG7TVTO2BkR47usntb102uFJtugbo6BSGvgqt4AQ==",
"38.2.0.192.in-addr.arpa. 7200 IN IPSECKEY ( 10 0 2 . AQNRU3mG7TVTO2BkR47usntb102uFJtugbo6BSGvgqt4AQ== )",
"38.2.0.192.in-addr.arpa.\t7200\tIN\tIPSECKEY\t10 0 2 . AQNRU3mG7TVTO2BkR47usntb102uFJtugbo6BSGvgqt4AQ==",
"38.2.0.192.in-addr.arpa. 7200 IN IPSECKEY ( 10 1 2 192.0.2.3 AQNRU3mG7TVTO2BkR47usntb102uFJtugbo6BSGvgqt4AQ== )",
"38.2.0.192.in-addr.arpa.\t7200\tIN\tIPSECKEY\t10 1 2 192.0.2.3 AQNRU3mG7TVTO2BkR47usntb102uFJtugbo6BSGvgqt4AQ==",
"38.1.0.192.in-addr.arpa. 7200 IN IPSECKEY ( 10 3 2 mygateway.example.com. AQNRU3mG7TVTO2BkR47usntb102uFJtugbo6BSGvgqt4AQ== )",
"38.1.0.192.in-addr.arpa.\t7200\tIN\tIPSECKEY\t10 3 2 mygateway.example.com. AQNRU3mG7TVTO2BkR47usntb102uFJtugbo6BSGvgqt4AQ==",
"0.d.4.0.3.0.e.f.f.f.3.f.0.1.2.0 7200 IN IPSECKEY ( 10 2 2 2001:0DB8:0:8002::2000:1 AQNRU3mG7TVTO2BkR47usntb102uFJtugbo6BSGvgqt4AQ== )",
"0.d.4.0.3.0.e.f.f.f.3.f.0.1.2.0.\t7200\tIN\tIPSECKEY\t10 2 2 2001:db8:0:8002::2000:1 AQNRU3mG7TVTO2BkR47usntb102uFJtugbo6BSGvgqt4AQ==",
}
for i := 0; i < len(tests)-1; i++ {
t1 := tests[i]
e1 := tests[i+1]
r, e := NewRR(t1)
if e != nil {
t.Logf("failed to parse IPSECKEY %s", e)
continue
}
if r.String() != e1 {
t.Logf("these two IPSECKEY records should match")
t.Logf("\n%s\n%s\n", r.String(), e1)
t.Fail()
}
i++
}
}

View File

@ -77,8 +77,7 @@ func (f HandlerFunc) ServeDNS(w ResponseWriter, r *Msg) {
f(w, r) f(w, r)
} }
// FailedHandler returns a HandlerFunc // FailedHandler returns a HandlerFunc that returns SERVFAIL for every request it gets.
// returns SERVFAIL for every request it gets.
func HandleFailed(w ResponseWriter, r *Msg) { func HandleFailed(w ResponseWriter, r *Msg) {
m := new(Msg) m := new(Msg)
m.SetRcode(r, RcodeServerFailure) m.SetRcode(r, RcodeServerFailure)
@ -170,10 +169,10 @@ func (mux *ServeMux) HandleRemove(pattern string) {
// is sought. // is sought.
// If no handler is found a standard SERVFAIL message is returned // If no handler is found a standard SERVFAIL message is returned
// If the request message does not have exactly one question in the // If the request message does not have exactly one question in the
// question section a SERVFAIL is returned. // question section a SERVFAIL is returned, unlesss Unsafe is true.
func (mux *ServeMux) ServeDNS(w ResponseWriter, request *Msg) { func (mux *ServeMux) ServeDNS(w ResponseWriter, request *Msg) {
var h Handler var h Handler
if len(request.Question) != 1 { if len(request.Question) < 1 { // allow more than one question
h = failedHandler() h = failedHandler()
} else { } else {
if h = mux.match(request.Question[0].Name, request.Question[0].Qtype); h == nil { if h = mux.match(request.Question[0].Name, request.Question[0].Qtype); h == nil {
@ -221,6 +220,11 @@ type Server struct {
IdleTimeout func() time.Duration IdleTimeout func() time.Duration
// Secret(s) for Tsig map[<zonename>]<base64 secret>. // Secret(s) for Tsig map[<zonename>]<base64 secret>.
TsigSecret map[string]string TsigSecret map[string]string
// Unsafe instructs the server to disregard any sanity checks and directly hand the message to
// the handler. It will specfically not check if the query has the QR bit not set.
Unsafe bool
// If NotifyStartedFunc is set is is called, once the server has started listening.
NotifyStartedFunc func()
// For graceful shutdown. // For graceful shutdown.
stopUDP chan bool stopUDP chan bool
@ -237,6 +241,7 @@ type Server struct {
func (srv *Server) ListenAndServe() error { func (srv *Server) ListenAndServe() error {
srv.lock.Lock() srv.lock.Lock()
if srv.started { if srv.started {
srv.lock.Unlock()
return &Error{err: "server already started"} return &Error{err: "server already started"}
} }
srv.stopUDP, srv.stopTCP = make(chan bool), make(chan bool) srv.stopUDP, srv.stopTCP = make(chan bool), make(chan bool)
@ -282,14 +287,12 @@ func (srv *Server) ListenAndServe() error {
func (srv *Server) ActivateAndServe() error { func (srv *Server) ActivateAndServe() error {
srv.lock.Lock() srv.lock.Lock()
if srv.started { if srv.started {
srv.lock.Unlock()
return &Error{err: "server already started"} return &Error{err: "server already started"}
} }
srv.stopUDP, srv.stopTCP = make(chan bool), make(chan bool) srv.stopUDP, srv.stopTCP = make(chan bool), make(chan bool)
srv.started = true srv.started = true
srv.lock.Unlock() srv.lock.Unlock()
if srv.UDPSize == 0 {
srv.UDPSize = MinMsgSize
}
if srv.PacketConn != nil { if srv.PacketConn != nil {
if srv.UDPSize == 0 { if srv.UDPSize == 0 {
srv.UDPSize = MinMsgSize srv.UDPSize = MinMsgSize
@ -316,6 +319,7 @@ func (srv *Server) ActivateAndServe() error {
func (srv *Server) Shutdown() error { func (srv *Server) Shutdown() error {
srv.lock.Lock() srv.lock.Lock()
if !srv.started { if !srv.started {
srv.lock.Unlock()
return &Error{err: "server not started"} return &Error{err: "server not started"}
} }
srv.started = false srv.started = false
@ -371,6 +375,11 @@ func (srv *Server) getReadTimeout() time.Duration {
// Each request is handled in a seperate goroutine. // Each request is handled in a seperate goroutine.
func (srv *Server) serveTCP(l *net.TCPListener) error { func (srv *Server) serveTCP(l *net.TCPListener) error {
defer l.Close() defer l.Close()
if srv.NotifyStartedFunc != nil {
srv.NotifyStartedFunc()
}
handler := srv.Handler handler := srv.Handler
if handler == nil { if handler == nil {
handler = DefaultServeMux handler = DefaultServeMux
@ -402,6 +411,10 @@ func (srv *Server) serveTCP(l *net.TCPListener) error {
func (srv *Server) serveUDP(l *net.UDPConn) error { func (srv *Server) serveUDP(l *net.UDPConn) error {
defer l.Close() defer l.Close()
if srv.NotifyStartedFunc != nil {
srv.NotifyStartedFunc()
}
handler := srv.Handler handler := srv.Handler
if handler == nil { if handler == nil {
handler = DefaultServeMux handler = DefaultServeMux
@ -445,6 +458,9 @@ Redo:
w.WriteMsg(x) w.WriteMsg(x)
goto Exit goto Exit
} }
if !srv.Unsafe && req.Response {
goto Exit
}
w.tsigStatus = nil w.tsigStatus = nil
if w.tsigSecret != nil { if w.tsigSecret != nil {

View File

@ -32,10 +32,37 @@ func RunLocalUDPServer(laddr string) (*Server, string, error) {
return nil, "", err return nil, "", err
} }
server := &Server{PacketConn: pc} server := &Server{PacketConn: pc}
waitLock := sync.Mutex{}
waitLock.Lock()
server.NotifyStartedFunc = waitLock.Unlock
go func() { go func() {
server.ActivateAndServe() server.ActivateAndServe()
pc.Close() pc.Close()
}() }()
waitLock.Lock()
return server, pc.LocalAddr().String(), nil
}
func RunLocalUDPServerUnsafe(laddr string) (*Server, string, error) {
pc, err := net.ListenPacket("udp", laddr)
if err != nil {
return nil, "", err
}
server := &Server{PacketConn: pc, Unsafe: true}
waitLock := sync.Mutex{}
waitLock.Lock()
server.NotifyStartedFunc = waitLock.Unlock
go func() {
server.ActivateAndServe()
pc.Close()
}()
waitLock.Lock()
return server, pc.LocalAddr().String(), nil return server, pc.LocalAddr().String(), nil
} }
@ -44,11 +71,19 @@ func RunLocalTCPServer(laddr string) (*Server, string, error) {
if err != nil { if err != nil {
return nil, "", err return nil, "", err
} }
server := &Server{Listener: l} server := &Server{Listener: l}
waitLock := sync.Mutex{}
waitLock.Lock()
server.NotifyStartedFunc = waitLock.Unlock
go func() { go func() {
server.ActivateAndServe() server.ActivateAndServe()
l.Close() l.Close()
}() }()
waitLock.Lock()
return server, l.Addr().String(), nil return server, l.Addr().String(), nil
} }
@ -68,7 +103,7 @@ func TestServing(t *testing.T) {
m := new(Msg) m := new(Msg)
m.SetQuestion("miek.nl.", TypeTXT) m.SetQuestion("miek.nl.", TypeTXT)
r, _, err := c.Exchange(m, addrstr) r, _, err := c.Exchange(m, addrstr)
if err != nil { if err != nil || len(r.Extra) == 0 {
t.Log("failed to exchange miek.nl", err) t.Log("failed to exchange miek.nl", err)
t.Fatal() t.Fatal()
} }
@ -302,14 +337,52 @@ func TestServingLargeResponses(t *testing.T) {
} }
} }
func TestServingResponse(t *testing.T) {
if testing.Short() {
t.Skip("skipping test in short mode.")
}
HandleFunc("miek.nl.", HelloServer)
s, addrstr, err := RunLocalUDPServer("127.0.0.1:0")
if err != nil {
t.Fatalf("Unable to run test server: %s", err)
}
c := new(Client)
m := new(Msg)
m.SetQuestion("miek.nl.", TypeTXT)
m.Response = false
_, _, err = c.Exchange(m, addrstr)
if err != nil {
t.Log("failed to exchange", err)
t.Fatal()
}
m.Response = true
_, _, err = c.Exchange(m, addrstr)
if err == nil {
t.Log("exchanged response message")
t.Fatal()
}
s.Shutdown()
s, addrstr, err = RunLocalUDPServerUnsafe("127.0.0.1:0")
if err != nil {
t.Fatalf("Unable to run test server: %s", err)
}
defer s.Shutdown()
m.Response = true
_, _, err = c.Exchange(m, addrstr)
if err != nil {
t.Log("could exchanged response message in Unsafe mode")
t.Fatal()
}
}
func TestShutdownTCP(t *testing.T) { func TestShutdownTCP(t *testing.T) {
s, _, err := RunLocalTCPServer("127.0.0.1:0") s, _, err := RunLocalTCPServer("127.0.0.1:0")
if err != nil { if err != nil {
t.Fatalf("Unable to run test server: %s", err) t.Fatalf("Unable to run test server: %s", err)
} }
// it normally is too early to shutting down because server
// activates in goroutine.
runtime.Gosched()
err = s.Shutdown() err = s.Shutdown()
if err != nil { if err != nil {
t.Errorf("Could not shutdown test TCP server, %s", err) t.Errorf("Could not shutdown test TCP server, %s", err)
@ -321,9 +394,6 @@ func TestShutdownUDP(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("Unable to run test server: %s", err) t.Fatalf("Unable to run test server: %s", err)
} }
// it normally is too early to shutting down because server
// activates in goroutine.
runtime.Gosched()
err = s.Shutdown() err = s.Shutdown()
if err != nil { if err != nil {
t.Errorf("Could not shutdown test UDP server, %s", err) t.Errorf("Could not shutdown test UDP server, %s", err)

236
Godeps/_workspace/src/github.com/miekg/dns/sig0.go generated vendored Normal file
View File

@ -0,0 +1,236 @@
// SIG(0)
//
// From RFC 2931:
//
// SIG(0) provides protection for DNS transactions and requests ....
// ... protection for glue records, DNS requests, protection for message headers
// on requests and responses, and protection of the overall integrity of a response.
//
// It works like TSIG, except that SIG(0) uses public key cryptography, instead of the shared
// secret approach in TSIG.
// Supported algorithms: DSA, ECDSAP256SHA256, ECDSAP384SHA384, RSASHA1, RSASHA256 and
// RSASHA512.
//
// Signing subsequent messages in multi-message sessions is not implemented.
//
package dns
import (
"crypto"
"crypto/dsa"
"crypto/ecdsa"
"crypto/rsa"
"math/big"
"strings"
"time"
)
// Sign signs a dns.Msg. It fills the signature with the appropriate data.
// The SIG record should have the SignerName, KeyTag, Algorithm, Inception
// and Expiration set.
func (rr *SIG) Sign(k PrivateKey, m *Msg) ([]byte, error) {
if k == nil {
return nil, ErrPrivKey
}
if rr.KeyTag == 0 || len(rr.SignerName) == 0 || rr.Algorithm == 0 {
return nil, ErrKey
}
rr.Header().Rrtype = TypeSIG
rr.Header().Class = ClassANY
rr.Header().Ttl = 0
rr.Header().Name = "."
rr.OrigTtl = 0
rr.TypeCovered = 0
rr.Labels = 0
buf := make([]byte, m.Len()+rr.len())
mbuf, err := m.PackBuffer(buf)
if err != nil {
return nil, err
}
if &buf[0] != &mbuf[0] {
return nil, ErrBuf
}
off, err := PackRR(rr, buf, len(mbuf), nil, false)
if err != nil {
return nil, err
}
buf = buf[:off:cap(buf)]
var hash crypto.Hash
switch rr.Algorithm {
case DSA, RSASHA1:
hash = crypto.SHA1
case RSASHA256, ECDSAP256SHA256:
hash = crypto.SHA256
case ECDSAP384SHA384:
hash = crypto.SHA384
case RSASHA512:
hash = crypto.SHA512
default:
return nil, ErrAlg
}
hasher := hash.New()
// Write SIG rdata
hasher.Write(buf[len(mbuf)+1+2+2+4+2:])
// Write message
hasher.Write(buf[:len(mbuf)])
hashed := hasher.Sum(nil)
sig, err := k.Sign(hashed, rr.Algorithm)
if err != nil {
return nil, err
}
rr.Signature = toBase64(sig)
buf = append(buf, sig...)
if len(buf) > int(^uint16(0)) {
return nil, ErrBuf
}
// Adjust sig data length
rdoff := len(mbuf) + 1 + 2 + 2 + 4
rdlen, _ := unpackUint16(buf, rdoff)
rdlen += uint16(len(sig))
buf[rdoff], buf[rdoff+1] = packUint16(rdlen)
// Adjust additional count
adc, _ := unpackUint16(buf, 10)
adc += 1
buf[10], buf[11] = packUint16(adc)
return buf, nil
}
// Verify validates the message buf using the key k.
// It's assumed that buf is a valid message from which rr was unpacked.
func (rr *SIG) Verify(k *KEY, buf []byte) error {
if k == nil {
return ErrKey
}
if rr.KeyTag == 0 || len(rr.SignerName) == 0 || rr.Algorithm == 0 {
return ErrKey
}
var hash crypto.Hash
switch rr.Algorithm {
case DSA, RSASHA1:
hash = crypto.SHA1
case RSASHA256, ECDSAP256SHA256:
hash = crypto.SHA256
case ECDSAP384SHA384:
hash = crypto.SHA384
case RSASHA512:
hash = crypto.SHA512
default:
return ErrAlg
}
hasher := hash.New()
buflen := len(buf)
qdc, _ := unpackUint16(buf, 4)
anc, _ := unpackUint16(buf, 6)
auc, _ := unpackUint16(buf, 8)
adc, offset := unpackUint16(buf, 10)
var err error
for i := uint16(0); i < qdc && offset < buflen; i++ {
_, offset, err = UnpackDomainName(buf, offset)
if err != nil {
return err
}
// Skip past Type and Class
offset += 2 + 2
}
for i := uint16(1); i < anc+auc+adc && offset < buflen; i++ {
_, offset, err = UnpackDomainName(buf, offset)
if err != nil {
return err
}
// Skip past Type, Class and TTL
offset += 2 + 2 + 4
if offset+1 >= buflen {
continue
}
var rdlen uint16
rdlen, offset = unpackUint16(buf, offset)
offset += int(rdlen)
}
if offset >= buflen {
return &Error{err: "overflowing unpacking signed message"}
}
// offset should be just prior to SIG
bodyend := offset
// owner name SHOULD be root
_, offset, err = UnpackDomainName(buf, offset)
if err != nil {
return err
}
// Skip Type, Class, TTL, RDLen
offset += 2 + 2 + 4 + 2
sigstart := offset
// Skip Type Covered, Algorithm, Labels, Original TTL
offset += 2 + 1 + 1 + 4
if offset+4+4 >= buflen {
return &Error{err: "overflow unpacking signed message"}
}
expire := uint32(buf[offset])<<24 | uint32(buf[offset+1])<<16 | uint32(buf[offset+2])<<8 | uint32(buf[offset+3])
offset += 4
incept := uint32(buf[offset])<<24 | uint32(buf[offset+1])<<16 | uint32(buf[offset+2])<<8 | uint32(buf[offset+3])
offset += 4
now := uint32(time.Now().Unix())
if now < incept || now > expire {
return ErrTime
}
// Skip key tag
offset += 2
var signername string
signername, offset, err = UnpackDomainName(buf, offset)
if err != nil {
return err
}
// If key has come from the DNS name compression might
// have mangled the case of the name
if strings.ToLower(signername) != strings.ToLower(k.Header().Name) {
return &Error{err: "signer name doesn't match key name"}
}
sigend := offset
hasher.Write(buf[sigstart:sigend])
hasher.Write(buf[:10])
hasher.Write([]byte{
byte((adc - 1) << 8),
byte(adc - 1),
})
hasher.Write(buf[12:bodyend])
hashed := hasher.Sum(nil)
sig := buf[sigend:]
switch k.Algorithm {
case DSA:
pk := k.publicKeyDSA()
sig = sig[1:]
r := big.NewInt(0)
r.SetBytes(sig[:len(sig)/2])
s := big.NewInt(0)
s.SetBytes(sig[len(sig)/2:])
if pk != nil {
if dsa.Verify(pk, hashed, r, s) {
return nil
}
return ErrSig
}
case RSASHA1, RSASHA256, RSASHA512:
pk := k.publicKeyRSA()
if pk != nil {
return rsa.VerifyPKCS1v15(pk, hash, hashed, sig)
}
case ECDSAP256SHA256, ECDSAP384SHA384:
pk := k.publicKeyECDSA()
r := big.NewInt(0)
r.SetBytes(sig[:len(sig)/2])
s := big.NewInt(0)
s.SetBytes(sig[len(sig)/2:])
if pk != nil {
if ecdsa.Verify(pk, hashed, r, s) {
return nil
}
return ErrSig
}
}
return ErrKeyAlg
}

View File

@ -0,0 +1,96 @@
package dns
import (
"testing"
"time"
)
func TestSIG0(t *testing.T) {
if testing.Short() {
t.Skip("skipping test in short mode.")
}
m := new(Msg)
m.SetQuestion("example.org.", TypeSOA)
for _, alg := range []uint8{DSA, ECDSAP256SHA256, ECDSAP384SHA384, RSASHA1, RSASHA256, RSASHA512} {
algstr := AlgorithmToString[alg]
keyrr := new(KEY)
keyrr.Hdr.Name = algstr + "."
keyrr.Hdr.Rrtype = TypeKEY
keyrr.Hdr.Class = ClassINET
keyrr.Algorithm = alg
keysize := 1024
switch alg {
case ECDSAP256SHA256:
keysize = 256
case ECDSAP384SHA384:
keysize = 384
}
pk, err := keyrr.Generate(keysize)
if err != nil {
t.Logf("Failed to generate key for “%s”: %v", algstr, err)
t.Fail()
continue
}
now := uint32(time.Now().Unix())
sigrr := new(SIG)
sigrr.Hdr.Name = "."
sigrr.Hdr.Rrtype = TypeSIG
sigrr.Hdr.Class = ClassANY
sigrr.Algorithm = alg
sigrr.Expiration = now + 300
sigrr.Inception = now - 300
sigrr.KeyTag = keyrr.KeyTag()
sigrr.SignerName = keyrr.Hdr.Name
mb, err := sigrr.Sign(pk, m)
if err != nil {
t.Logf("Failed to sign message using “%s”: %v", algstr, err)
t.Fail()
continue
}
m := new(Msg)
if err := m.Unpack(mb); err != nil {
t.Logf("Failed to unpack message signed using “%s”: %v", algstr, err)
t.Fail()
continue
}
if len(m.Extra) != 1 {
t.Logf("Missing SIG for message signed using “%s”", algstr)
t.Fail()
continue
}
var sigrrwire *SIG
switch rr := m.Extra[0].(type) {
case *SIG:
sigrrwire = rr
default:
t.Logf("Expected SIG RR, instead: %v", rr)
t.Fail()
continue
}
for _, rr := range []*SIG{sigrr, sigrrwire} {
id := "sigrr"
if rr == sigrrwire {
id = "sigrrwire"
}
if err := rr.Verify(keyrr, mb); err != nil {
t.Logf("Failed to verify “%s” signed SIG(%s): %v", algstr, id, err)
t.Fail()
continue
}
}
mb[13]++
if err := sigrr.Verify(keyrr, mb); err == nil {
t.Logf("Verify succeeded on an altered message using “%s”", algstr)
t.Fail()
continue
}
sigrr.Expiration = 2
sigrr.Inception = 1
mb, _ = sigrr.Sign(pk, m)
if err := sigrr.Verify(keyrr, mb); err == nil {
t.Logf("Verify succeeded on an expired message using “%s”", algstr)
t.Fail()
continue
}
}
}

View File

@ -1,7 +1,7 @@
// TRANSACTION SIGNATURE // TRANSACTION SIGNATURE
// //
// An TSIG or transaction signature adds a HMAC TSIG record to each message sent. // An TSIG or transaction signature adds a HMAC TSIG record to each message sent.
// The supported algorithms include: HmacMD5, HmacSHA1 and HmacSHA256. // The supported algorithms include: HmacMD5, HmacSHA1, HmacSHA256 and HmacSHA512.
// //
// Basic use pattern when querying with a TSIG name "axfr." (note that these key names // Basic use pattern when querying with a TSIG name "axfr." (note that these key names
// must be fully qualified - as they are domain names) and the base64 secret // must be fully qualified - as they are domain names) and the base64 secret
@ -58,6 +58,7 @@ import (
"crypto/md5" "crypto/md5"
"crypto/sha1" "crypto/sha1"
"crypto/sha256" "crypto/sha256"
"crypto/sha512"
"encoding/hex" "encoding/hex"
"hash" "hash"
"io" "io"
@ -71,6 +72,7 @@ const (
HmacMD5 = "hmac-md5.sig-alg.reg.int." HmacMD5 = "hmac-md5.sig-alg.reg.int."
HmacSHA1 = "hmac-sha1." HmacSHA1 = "hmac-sha1."
HmacSHA256 = "hmac-sha256." HmacSHA256 = "hmac-sha256."
HmacSHA512 = "hmac-sha512."
) )
type TSIG struct { type TSIG struct {
@ -159,7 +161,7 @@ func TsigGenerate(m *Msg, secret, requestMAC string, timersOnly bool) ([]byte, s
panic("dns: TSIG not last RR in additional") panic("dns: TSIG not last RR in additional")
} }
// If we barf here, the caller is to blame // If we barf here, the caller is to blame
rawsecret, err := packBase64([]byte(secret)) rawsecret, err := fromBase64([]byte(secret))
if err != nil { if err != nil {
return nil, "", err return nil, "", err
} }
@ -181,6 +183,8 @@ func TsigGenerate(m *Msg, secret, requestMAC string, timersOnly bool) ([]byte, s
h = hmac.New(sha1.New, []byte(rawsecret)) h = hmac.New(sha1.New, []byte(rawsecret))
case HmacSHA256: case HmacSHA256:
h = hmac.New(sha256.New, []byte(rawsecret)) h = hmac.New(sha256.New, []byte(rawsecret))
case HmacSHA512:
h = hmac.New(sha512.New, []byte(rawsecret))
default: default:
return nil, "", ErrKeyAlg return nil, "", ErrKeyAlg
} }
@ -209,7 +213,7 @@ func TsigGenerate(m *Msg, secret, requestMAC string, timersOnly bool) ([]byte, s
// If the signature does not validate err contains the // If the signature does not validate err contains the
// error, otherwise it is nil. // error, otherwise it is nil.
func TsigVerify(msg []byte, secret, requestMAC string, timersOnly bool) error { func TsigVerify(msg []byte, secret, requestMAC string, timersOnly bool) error {
rawsecret, err := packBase64([]byte(secret)) rawsecret, err := fromBase64([]byte(secret))
if err != nil { if err != nil {
return err return err
} }
@ -225,7 +229,14 @@ func TsigVerify(msg []byte, secret, requestMAC string, timersOnly bool) error {
} }
buf := tsigBuffer(stripped, tsig, requestMAC, timersOnly) buf := tsigBuffer(stripped, tsig, requestMAC, timersOnly)
ti := uint64(time.Now().Unix()) - tsig.TimeSigned
// Fudge factor works both ways. A message can arrive before it was signed because
// of clock skew.
now := uint64(time.Now().Unix())
ti := now - tsig.TimeSigned
if now < tsig.TimeSigned {
ti = tsig.TimeSigned - now
}
if uint64(tsig.Fudge) < ti { if uint64(tsig.Fudge) < ti {
return ErrTime return ErrTime
} }
@ -238,6 +249,8 @@ func TsigVerify(msg []byte, secret, requestMAC string, timersOnly bool) error {
h = hmac.New(sha1.New, rawsecret) h = hmac.New(sha1.New, rawsecret)
case HmacSHA256: case HmacSHA256:
h = hmac.New(sha256.New, rawsecret) h = hmac.New(sha256.New, rawsecret)
case HmacSHA512:
h = hmac.New(sha512.New, rawsecret)
default: default:
return ErrKeyAlg return ErrKeyAlg
} }

View File

@ -75,6 +75,7 @@ const (
TypeRKEY uint16 = 57 TypeRKEY uint16 = 57
TypeTALINK uint16 = 58 TypeTALINK uint16 = 58
TypeCDS uint16 = 59 TypeCDS uint16 = 59
TypeCDNSKEY uint16 = 60
TypeOPENPGPKEY uint16 = 61 TypeOPENPGPKEY uint16 = 61
TypeSPF uint16 = 99 TypeSPF uint16 = 99
TypeUINFO uint16 = 100 TypeUINFO uint16 = 100
@ -159,10 +160,14 @@ const (
_AD = 1 << 5 // authticated data _AD = 1 << 5 // authticated data
_CD = 1 << 4 // checking disabled _CD = 1 << 4 // checking disabled
) LOC_EQUATOR = 1 << 31 // RFC 1876, Section 2.
LOC_PRIMEMERIDIAN = 1 << 31 // RFC 1876, Section 2.
// RFC 1876, Section 2 LOC_HOURS = 60 * 1000
const _LOC_EQUATOR = 1 << 31 LOC_DEGREES = 60 * LOC_HOURS
LOC_ALTITUDEBASE = 100000
)
// RFC 4398, Section 2.1 // RFC 4398, Section 2.1
const ( const (
@ -307,7 +312,7 @@ func (rr *MF) copy() RR { return &MF{*rr.Hdr.copyHeader(), rr.Mf} }
func (rr *MF) len() int { return rr.Hdr.len() + len(rr.Mf) + 1 } func (rr *MF) len() int { return rr.Hdr.len() + len(rr.Mf) + 1 }
func (rr *MF) String() string { func (rr *MF) String() string {
return rr.Hdr.String() + " " + sprintName(rr.Mf) return rr.Hdr.String() + sprintName(rr.Mf)
} }
type MD struct { type MD struct {
@ -320,7 +325,7 @@ func (rr *MD) copy() RR { return &MD{*rr.Hdr.copyHeader(), rr.Md} }
func (rr *MD) len() int { return rr.Hdr.len() + len(rr.Md) + 1 } func (rr *MD) len() int { return rr.Hdr.len() + len(rr.Md) + 1 }
func (rr *MD) String() string { func (rr *MD) String() string {
return rr.Hdr.String() + " " + sprintName(rr.Md) return rr.Hdr.String() + sprintName(rr.Md)
} }
type MX struct { type MX struct {
@ -527,7 +532,17 @@ func appendTXTStringByte(s []byte, b byte) []byte {
return append(s, '\\', b) return append(s, '\\', b)
} }
if b < ' ' || b > '~' { if b < ' ' || b > '~' {
return append(s, fmt.Sprintf("\\%03d", b)...) var buf [3]byte
bufs := strconv.AppendInt(buf[:0], int64(b), 10)
s = append(s, '\\')
for i := 0; i < 3-len(bufs); i++ {
s = append(s, '0')
}
for _, r := range bufs {
s = append(s, r)
}
return s
} }
return append(s, b) return append(s, b)
} }
@ -772,51 +787,77 @@ func (rr *LOC) copy() RR {
return &LOC{*rr.Hdr.copyHeader(), rr.Version, rr.Size, rr.HorizPre, rr.VertPre, rr.Latitude, rr.Longitude, rr.Altitude} return &LOC{*rr.Hdr.copyHeader(), rr.Version, rr.Size, rr.HorizPre, rr.VertPre, rr.Latitude, rr.Longitude, rr.Altitude}
} }
// cmToM takes a cm value expressed in RFC1876 SIZE mantissa/exponent
// format and returns a string in m (two decimals for the cm)
func cmToM(m, e uint8) string {
if e < 2 {
if e == 1 {
m *= 10
}
return fmt.Sprintf("0.%02d", m)
}
s := fmt.Sprintf("%d", m)
for e > 2 {
s += "0"
e -= 1
}
return s
}
// String returns a string version of a LOC
func (rr *LOC) String() string { func (rr *LOC) String() string {
s := rr.Hdr.String() s := rr.Hdr.String()
// Copied from ldns
// Latitude
lat := rr.Latitude
north := "N"
if lat > _LOC_EQUATOR {
lat = lat - _LOC_EQUATOR
} else {
north = "S"
lat = _LOC_EQUATOR - lat
}
h := lat / (1000 * 60 * 60)
lat = lat % (1000 * 60 * 60)
m := lat / (1000 * 60)
lat = lat % (1000 * 60)
s += fmt.Sprintf("%02d %02d %0.3f %s ", h, m, (float32(lat) / 1000), north)
// Longitude
lon := rr.Longitude
east := "E"
if lon > _LOC_EQUATOR {
lon = lon - _LOC_EQUATOR
} else {
east = "W"
lon = _LOC_EQUATOR - lon
}
h = lon / (1000 * 60 * 60)
lon = lon % (1000 * 60 * 60)
m = lon / (1000 * 60)
lon = lon % (1000 * 60)
s += fmt.Sprintf("%02d %02d %0.3f %s ", h, m, (float32(lon) / 1000), east)
s1 := rr.Altitude / 100.00 lat := rr.Latitude
s1 -= 100000 ns := "N"
if rr.Altitude%100 == 0 { if lat > LOC_EQUATOR {
s += fmt.Sprintf("%.2fm ", float32(s1)) lat = lat - LOC_EQUATOR
} else { } else {
s += fmt.Sprintf("%.0fm ", float32(s1)) ns = "S"
lat = LOC_EQUATOR - lat
} }
s += cmToString((rr.Size&0xf0)>>4, rr.Size&0x0f) + "m " h := lat / LOC_DEGREES
s += cmToString((rr.HorizPre&0xf0)>>4, rr.HorizPre&0x0f) + "m " lat = lat % LOC_DEGREES
s += cmToString((rr.VertPre&0xf0)>>4, rr.VertPre&0x0f) + "m" m := lat / LOC_HOURS
lat = lat % LOC_HOURS
s += fmt.Sprintf("%02d %02d %0.3f %s ", h, m, (float64(lat) / 1000), ns)
lon := rr.Longitude
ew := "E"
if lon > LOC_PRIMEMERIDIAN {
lon = lon - LOC_PRIMEMERIDIAN
} else {
ew = "W"
lon = LOC_PRIMEMERIDIAN - lon
}
h = lon / LOC_DEGREES
lon = lon % LOC_DEGREES
m = lon / LOC_HOURS
lon = lon % LOC_HOURS
s += fmt.Sprintf("%02d %02d %0.3f %s ", h, m, (float64(lon) / 1000), ew)
var alt float64 = float64(rr.Altitude) / 100
alt -= LOC_ALTITUDEBASE
if rr.Altitude%100 != 0 {
s += fmt.Sprintf("%.2fm ", alt)
} else {
s += fmt.Sprintf("%.0fm ", alt)
}
s += cmToM((rr.Size&0xf0)>>4, rr.Size&0x0f) + "m "
s += cmToM((rr.HorizPre&0xf0)>>4, rr.HorizPre&0x0f) + "m "
s += cmToM((rr.VertPre&0xf0)>>4, rr.VertPre&0x0f) + "m"
return s return s
} }
// SIG is identical to RRSIG and nowadays only used for SIG(0), RFC2931.
type SIG struct {
RRSIG
}
type RRSIG struct { type RRSIG struct {
Hdr RR_Header Hdr RR_Header
TypeCovered uint16 TypeCovered uint16
@ -888,6 +929,14 @@ func (rr *NSEC) len() int {
return l return l
} }
type DLV struct {
DS
}
type CDS struct {
DS
}
type DS struct { type DS struct {
Hdr RR_Header Hdr RR_Header
KeyTag uint16 KeyTag uint16
@ -909,48 +958,6 @@ func (rr *DS) String() string {
" " + strings.ToUpper(rr.Digest) " " + strings.ToUpper(rr.Digest)
} }
type CDS struct {
Hdr RR_Header
KeyTag uint16
Algorithm uint8
DigestType uint8
Digest string `dns:"hex"`
}
func (rr *CDS) Header() *RR_Header { return &rr.Hdr }
func (rr *CDS) len() int { return rr.Hdr.len() + 4 + len(rr.Digest)/2 }
func (rr *CDS) copy() RR {
return &CDS{*rr.Hdr.copyHeader(), rr.KeyTag, rr.Algorithm, rr.DigestType, rr.Digest}
}
func (rr *CDS) String() string {
return rr.Hdr.String() + strconv.Itoa(int(rr.KeyTag)) +
" " + strconv.Itoa(int(rr.Algorithm)) +
" " + strconv.Itoa(int(rr.DigestType)) +
" " + strings.ToUpper(rr.Digest)
}
type DLV struct {
Hdr RR_Header
KeyTag uint16
Algorithm uint8
DigestType uint8
Digest string `dns:"hex"`
}
func (rr *DLV) Header() *RR_Header { return &rr.Hdr }
func (rr *DLV) len() int { return rr.Hdr.len() + 4 + len(rr.Digest)/2 }
func (rr *DLV) copy() RR {
return &DLV{*rr.Hdr.copyHeader(), rr.KeyTag, rr.Algorithm, rr.DigestType, rr.Digest}
}
func (rr *DLV) String() string {
return rr.Hdr.String() + strconv.Itoa(int(rr.KeyTag)) +
" " + strconv.Itoa(int(rr.Algorithm)) +
" " + strconv.Itoa(int(rr.DigestType)) +
" " + strings.ToUpper(rr.Digest)
}
type KX struct { type KX struct {
Hdr RR_Header Hdr RR_Header
Preference uint16 Preference uint16
@ -999,7 +1006,7 @@ func (rr *TALINK) len() int { return rr.Hdr.len() + len(rr.PreviousNam
func (rr *TALINK) String() string { func (rr *TALINK) String() string {
return rr.Hdr.String() + return rr.Hdr.String() +
" " + sprintName(rr.PreviousName) + " " + sprintName(rr.NextName) sprintName(rr.PreviousName) + " " + sprintName(rr.NextName)
} }
type SSHFP struct { type SSHFP struct {
@ -1022,30 +1029,67 @@ func (rr *SSHFP) String() string {
} }
type IPSECKEY struct { type IPSECKEY struct {
Hdr RR_Header Hdr RR_Header
Precedence uint8 Precedence uint8
// GatewayType: 1: A record, 2: AAAA record, 3: domainname.
// 0 is use for no type and GatewayName should be "." then.
GatewayType uint8 GatewayType uint8
Algorithm uint8 Algorithm uint8
Gateway string `dns:"ipseckey"` // Gateway can be an A record, AAAA record or a domain name.
GatewayA net.IP `dns:"a"`
GatewayAAAA net.IP `dns:"aaaa"`
GatewayName string `dns:"domain-name"`
PublicKey string `dns:"base64"` PublicKey string `dns:"base64"`
} }
func (rr *IPSECKEY) Header() *RR_Header { return &rr.Hdr } func (rr *IPSECKEY) Header() *RR_Header { return &rr.Hdr }
func (rr *IPSECKEY) copy() RR { func (rr *IPSECKEY) copy() RR {
return &IPSECKEY{*rr.Hdr.copyHeader(), rr.Precedence, rr.GatewayType, rr.Algorithm, rr.Gateway, rr.PublicKey} return &IPSECKEY{*rr.Hdr.copyHeader(), rr.Precedence, rr.GatewayType, rr.Algorithm, rr.GatewayA, rr.GatewayAAAA, rr.GatewayName, rr.PublicKey}
} }
func (rr *IPSECKEY) String() string { func (rr *IPSECKEY) String() string {
return rr.Hdr.String() + strconv.Itoa(int(rr.Precedence)) + s := rr.Hdr.String() + strconv.Itoa(int(rr.Precedence)) +
" " + strconv.Itoa(int(rr.GatewayType)) + " " + strconv.Itoa(int(rr.GatewayType)) +
" " + strconv.Itoa(int(rr.Algorithm)) + " " + strconv.Itoa(int(rr.Algorithm))
" " + rr.Gateway + switch rr.GatewayType {
" " + rr.PublicKey case 0:
fallthrough
case 3:
s += " " + rr.GatewayName
case 1:
s += " " + rr.GatewayA.String()
case 2:
s += " " + rr.GatewayAAAA.String()
default:
s += " ."
}
s += " " + rr.PublicKey
return s
} }
func (rr *IPSECKEY) len() int { func (rr *IPSECKEY) len() int {
return rr.Hdr.len() + 3 + len(rr.Gateway) + 1 + l := rr.Hdr.len() + 3 + 1
base64.StdEncoding.DecodedLen(len(rr.PublicKey)) switch rr.GatewayType {
default:
fallthrough
case 0:
fallthrough
case 3:
l += len(rr.GatewayName)
case 1:
l += 4
case 2:
l += 16
}
return l + base64.StdEncoding.DecodedLen(len(rr.PublicKey))
}
type KEY struct {
DNSKEY
}
type CDNSKEY struct {
DNSKEY
} }
type DNSKEY struct { type DNSKEY struct {
@ -1221,11 +1265,23 @@ func (rr *RFC3597) copy() RR { return &RFC3597{*rr.Hdr.copyHeader(), r
func (rr *RFC3597) len() int { return rr.Hdr.len() + len(rr.Rdata)/2 + 2 } func (rr *RFC3597) len() int { return rr.Hdr.len() + len(rr.Rdata)/2 + 2 }
func (rr *RFC3597) String() string { func (rr *RFC3597) String() string {
s := rr.Hdr.String() // Let's call it a hack
s := rfc3597Header(rr.Hdr)
s += "\\# " + strconv.Itoa(len(rr.Rdata)/2) + " " + rr.Rdata s += "\\# " + strconv.Itoa(len(rr.Rdata)/2) + " " + rr.Rdata
return s return s
} }
func rfc3597Header(h RR_Header) string {
var s string
s += sprintName(h.Name) + "\t"
s += strconv.FormatInt(int64(h.Ttl), 10) + "\t"
s += "CLASS" + strconv.Itoa(int(h.Class)) + "\t"
s += "TYPE" + strconv.Itoa(int(h.Rrtype)) + "\t"
return s
}
type URI struct { type URI struct {
Hdr RR_Header Hdr RR_Header
Priority uint16 Priority uint16
@ -1280,7 +1336,7 @@ func (rr *TLSA) copy() RR {
func (rr *TLSA) String() string { func (rr *TLSA) String() string {
return rr.Hdr.String() + return rr.Hdr.String() +
" " + strconv.Itoa(int(rr.Usage)) + strconv.Itoa(int(rr.Usage)) +
" " + strconv.Itoa(int(rr.Selector)) + " " + strconv.Itoa(int(rr.Selector)) +
" " + strconv.Itoa(int(rr.MatchingType)) + " " + strconv.Itoa(int(rr.MatchingType)) +
" " + rr.Certificate " " + rr.Certificate
@ -1305,7 +1361,7 @@ func (rr *HIP) copy() RR {
func (rr *HIP) String() string { func (rr *HIP) String() string {
s := rr.Hdr.String() + s := rr.Hdr.String() +
" " + strconv.Itoa(int(rr.PublicKeyAlgorithm)) + strconv.Itoa(int(rr.PublicKeyAlgorithm)) +
" " + rr.Hit + " " + rr.Hit +
" " + rr.PublicKey " " + rr.PublicKey
for _, d := range rr.RendezvousServers { for _, d := range rr.RendezvousServers {
@ -1367,6 +1423,7 @@ func (rr *WKS) String() (s string) {
if rr.Address != nil { if rr.Address != nil {
s += rr.Address.String() s += rr.Address.String()
} }
// TODO(miek): missing protocol here, see /etc/protocols
for i := 0; i < len(rr.BitMap); i++ { for i := 0; i < len(rr.BitMap); i++ {
// should lookup the port // should lookup the port
s += " " + strconv.Itoa(int(rr.BitMap[i])) s += " " + strconv.Itoa(int(rr.BitMap[i]))
@ -1578,23 +1635,6 @@ func saltToString(s string) string {
return strings.ToUpper(s) return strings.ToUpper(s)
} }
func cmToString(mantissa, exponent uint8) string {
switch exponent {
case 0, 1:
if exponent == 1 {
mantissa *= 10
}
return fmt.Sprintf("%.02f", float32(mantissa))
default:
s := fmt.Sprintf("%d", mantissa)
for i := uint8(0); i < exponent-2; i++ {
s += "0"
}
return s
}
panic("dns: not reached")
}
func euiToString(eui uint64, bits int) (hex string) { func euiToString(eui uint64, bits int) (hex string) {
switch bits { switch bits {
case 64: case 64:
@ -1628,6 +1668,7 @@ var typeToRR = map[uint16]func() RR{
TypeDHCID: func() RR { return new(DHCID) }, TypeDHCID: func() RR { return new(DHCID) },
TypeDLV: func() RR { return new(DLV) }, TypeDLV: func() RR { return new(DLV) },
TypeDNAME: func() RR { return new(DNAME) }, TypeDNAME: func() RR { return new(DNAME) },
TypeKEY: func() RR { return new(KEY) },
TypeDNSKEY: func() RR { return new(DNSKEY) }, TypeDNSKEY: func() RR { return new(DNSKEY) },
TypeDS: func() RR { return new(DS) }, TypeDS: func() RR { return new(DS) },
TypeEUI48: func() RR { return new(EUI48) }, TypeEUI48: func() RR { return new(EUI48) },
@ -1637,6 +1678,7 @@ var typeToRR = map[uint16]func() RR{
TypeEID: func() RR { return new(EID) }, TypeEID: func() RR { return new(EID) },
TypeHINFO: func() RR { return new(HINFO) }, TypeHINFO: func() RR { return new(HINFO) },
TypeHIP: func() RR { return new(HIP) }, TypeHIP: func() RR { return new(HIP) },
TypeIPSECKEY: func() RR { return new(IPSECKEY) },
TypeKX: func() RR { return new(KX) }, TypeKX: func() RR { return new(KX) },
TypeL32: func() RR { return new(L32) }, TypeL32: func() RR { return new(L32) },
TypeL64: func() RR { return new(L64) }, TypeL64: func() RR { return new(L64) },
@ -1665,6 +1707,7 @@ var typeToRR = map[uint16]func() RR{
TypeRKEY: func() RR { return new(RKEY) }, TypeRKEY: func() RR { return new(RKEY) },
TypeRP: func() RR { return new(RP) }, TypeRP: func() RR { return new(RP) },
TypePX: func() RR { return new(PX) }, TypePX: func() RR { return new(PX) },
TypeSIG: func() RR { return new(SIG) },
TypeRRSIG: func() RR { return new(RRSIG) }, TypeRRSIG: func() RR { return new(RRSIG) },
TypeRT: func() RR { return new(RT) }, TypeRT: func() RR { return new(RT) },
TypeSOA: func() RR { return new(SOA) }, TypeSOA: func() RR { return new(SOA) },

View File

@ -0,0 +1,42 @@
package dns
import (
"testing"
)
func TestCmToM(t *testing.T) {
s := cmToM(0, 0)
if s != "0.00" {
t.Error("0, 0")
}
s = cmToM(1, 0)
if s != "0.01" {
t.Error("1, 0")
}
s = cmToM(3, 1)
if s != "0.30" {
t.Error("3, 1")
}
s = cmToM(4, 2)
if s != "4" {
t.Error("4, 2")
}
s = cmToM(5, 3)
if s != "50" {
t.Error("5, 3")
}
s = cmToM(7, 5)
if s != "7000" {
t.Error("7, 5")
}
s = cmToM(9, 9)
if s != "90000000" {
t.Error("9, 9")
}
}

View File

@ -106,10 +106,7 @@ func (u *Msg) Insert(rr []RR) {
func (u *Msg) RemoveRRset(rr []RR) { func (u *Msg) RemoveRRset(rr []RR) {
u.Ns = make([]RR, len(rr)) u.Ns = make([]RR, len(rr))
for i, r := range rr { for i, r := range rr {
u.Ns[i] = r u.Ns[i] = &ANY{Hdr: RR_Header{Name: r.Header().Name, Ttl: 0, Rrtype: r.Header().Rrtype, Class: ClassANY}}
u.Ns[i].Header().Class = ClassANY
u.Ns[i].Header().Rdlength = 0
u.Ns[i].Header().Ttl = 0
} }
} }

View File

@ -0,0 +1,89 @@
package dns
import (
"bytes"
"testing"
)
func TestDynamicUpdateParsing(t *testing.T) {
prefix := "example.com. IN "
for _, typ := range TypeToString {
if typ == "CAA" || typ == "OPT" || typ == "AXFR" || typ == "IXFR" || typ == "ANY" || typ == "TKEY" ||
typ == "TSIG" || typ == "ISDN" || typ == "UNSPEC" || typ == "NULL" || typ == "ATMA" {
continue
}
r, e := NewRR(prefix + typ)
if e != nil {
t.Log("failure to parse: " + prefix + typ)
t.Fail()
} else {
t.Logf("parsed: %s", r.String())
}
}
}
func TestDynamicUpdateUnpack(t *testing.T) {
// From https://github.com/miekg/dns/issues/150#issuecomment-62296803
// It should be an update message for the zone "example.",
// deleting the A RRset "example." and then adding an A record at "example.".
// class ANY, TYPE A
buf := []byte{171, 68, 40, 0, 0, 1, 0, 0, 0, 2, 0, 0, 7, 101, 120, 97, 109, 112, 108, 101, 0, 0, 6, 0, 1, 192, 12, 0, 1, 0, 255, 0, 0, 0, 0, 0, 0, 192, 12, 0, 1, 0, 1, 0, 0, 0, 0, 0, 4, 127, 0, 0, 1}
msg := new(Msg)
err := msg.Unpack(buf)
if err != nil {
t.Log("failed to unpack: " + err.Error() + "\n" + msg.String())
t.Fail()
}
}
func TestDynamicUpdateZeroRdataUnpack(t *testing.T) {
m := new(Msg)
rr := &RR_Header{Name: ".", Rrtype: 0, Class: 1, Ttl: ^uint32(0), Rdlength: 0}
m.Answer = []RR{rr, rr, rr, rr, rr}
m.Ns = m.Answer
for n, s := range TypeToString {
rr.Rrtype = n
bytes, err := m.Pack()
if err != nil {
t.Logf("failed to pack %s: %v", s, err)
t.Fail()
continue
}
if err := new(Msg).Unpack(bytes); err != nil {
t.Logf("failed to unpack %s: %v", s, err)
t.Fail()
}
}
}
func TestRemoveRRset(t *testing.T) {
// Should add a zero data RR in Class ANY with a TTL of 0
// for each set mentioned in the RRs provided to it.
rr, err := NewRR(". 100 IN A 127.0.0.1")
if err != nil {
t.Fatalf("Error constructing RR: %v", err)
}
m := new(Msg)
m.Ns = []RR{&RR_Header{Name: ".", Rrtype: TypeA, Class: ClassANY, Ttl: 0, Rdlength: 0}}
expectstr := m.String()
expect, err := m.Pack()
if err != nil {
t.Fatalf("Error packing expected msg: %v", err)
}
m.Ns = nil
m.RemoveRRset([]RR{rr})
actual, err := m.Pack()
if err != nil {
t.Fatalf("Error packing actual msg: %v", err)
}
if !bytes.Equal(actual, expect) {
tmp := new(Msg)
if err := tmp.Unpack(actual); err != nil {
t.Fatalf("Error unpacking actual msg: %v", err)
}
t.Logf("Expected msg:\n%s", expectstr)
t.Logf("Actual msg:\n%v", tmp)
t.Fail()
}
}

View File

@ -13,9 +13,9 @@ type Envelope struct {
// A Transfer defines parameters that are used during a zone transfer. // A Transfer defines parameters that are used during a zone transfer.
type Transfer struct { type Transfer struct {
*Conn *Conn
DialTimeout time.Duration // net.DialTimeout (ns), defaults to 2 * 1e9 DialTimeout time.Duration // net.DialTimeout, defaults to 2 seconds
ReadTimeout time.Duration // net.Conn.SetReadTimeout value for connections (ns), defaults to 2 * 1e9 ReadTimeout time.Duration // net.Conn.SetReadTimeout value for connections, defaults to 2 seconds
WriteTimeout time.Duration // net.Conn.SetWriteTimeout value for connections (ns), defaults to 2 * 1e9 WriteTimeout time.Duration // net.Conn.SetWriteTimeout value for connections, defaults to 2 seconds
TsigSecret map[string]string // Secret(s) for Tsig map[<zonename>]<base64 secret>, zonename must be fully qualified TsigSecret map[string]string // Secret(s) for Tsig map[<zonename>]<base64 secret>, zonename must be fully qualified
tsigTimersOnly bool tsigTimersOnly bool
} }
@ -160,22 +160,18 @@ func (t *Transfer) inIxfr(id uint16, c chan *Envelope) {
// The server is responsible for sending the correct sequence of RRs through the // The server is responsible for sending the correct sequence of RRs through the
// channel ch. // channel ch.
func (t *Transfer) Out(w ResponseWriter, q *Msg, ch chan *Envelope) error { func (t *Transfer) Out(w ResponseWriter, q *Msg, ch chan *Envelope) error {
r := new(Msg) for x := range ch {
// Compress? r := new(Msg)
r.SetReply(q) // Compress?
r.Authoritative = true r.SetReply(q)
r.Authoritative = true
go func() { // assume it fits TODO(miek): fix
for x := range ch { r.Answer = append(r.Answer, x.RR...)
// assume it fits TODO(miek): fix if err := w.WriteMsg(r); err != nil {
r.Answer = append(r.Answer, x.RR...) return err
if err := w.WriteMsg(r); err != nil {
return
}
} }
w.TsigTimersOnly(true) }
r.Answer = nil w.TsigTimersOnly(true)
}()
return nil return nil
} }

99
Godeps/_workspace/src/github.com/miekg/dns/xfr_test.go generated vendored Normal file
View File

@ -0,0 +1,99 @@
package dns
import (
"net"
"testing"
"time"
)
func getIP(s string) string {
a, e := net.LookupAddr(s)
if e != nil {
return ""
}
return a[0]
}
// flaky, need to setup local server and test from
// that.
func testClientAXFR(t *testing.T) {
if testing.Short() {
return
}
m := new(Msg)
m.SetAxfr("miek.nl.")
server := getIP("linode.atoom.net")
tr := new(Transfer)
if a, err := tr.In(m, net.JoinHostPort(server, "53")); err != nil {
t.Log("failed to setup axfr: " + err.Error())
t.Fatal()
} else {
for ex := range a {
if ex.Error != nil {
t.Logf("error %s\n", ex.Error.Error())
t.Fail()
break
}
for _, rr := range ex.RR {
t.Logf("%s\n", rr.String())
}
}
}
}
// fails.
func testClientAXFRMultipleEnvelopes(t *testing.T) {
if testing.Short() {
return
}
m := new(Msg)
m.SetAxfr("nlnetlabs.nl.")
server := getIP("open.nlnetlabs.nl.")
tr := new(Transfer)
if a, err := tr.In(m, net.JoinHostPort(server, "53")); err != nil {
t.Log("Failed to setup axfr" + err.Error() + "for server: " + server)
t.Fail()
return
} else {
for ex := range a {
if ex.Error != nil {
t.Logf("Error %s\n", ex.Error.Error())
t.Fail()
break
}
}
}
}
func testClientTsigAXFR(t *testing.T) {
if testing.Short() {
return
}
m := new(Msg)
m.SetAxfr("example.nl.")
m.SetTsig("axfr.", HmacMD5, 300, time.Now().Unix())
tr := new(Transfer)
tr.TsigSecret = map[string]string{"axfr.": "so6ZGir4GPAqINNh9U5c3A=="}
if a, err := tr.In(m, "176.58.119.54:53"); err != nil {
t.Log("failed to setup axfr: " + err.Error())
t.Fatal()
} else {
for ex := range a {
if ex.Error != nil {
t.Logf("error %s\n", ex.Error.Error())
t.Fail()
break
}
for _, rr := range ex.RR {
t.Logf("%s\n", rr.String())
}
}
}
}

View File

@ -102,12 +102,13 @@ type Token struct {
Comment string // a potential comment positioned after the RR and on the same line Comment string // a potential comment positioned after the RR and on the same line
} }
// NewRR reads the RR contained in the string s. Only the first RR is returned. // NewRR reads the RR contained in the string s. Only the first RR is
// The class defaults to IN and TTL defaults to 3600. The full zone file // returned. If s contains no RR, return nil with no error. The class
// syntax like $TTL, $ORIGIN, etc. is supported. // defaults to IN and TTL defaults to 3600. The full zone file syntax
// All fields of the returned RR are set, except RR.Header().Rdlength which is set to 0. // like $TTL, $ORIGIN, etc. is supported. All fields of the returned
// RR are set, except RR.Header().Rdlength which is set to 0.
func NewRR(s string) (RR, error) { func NewRR(s string) (RR, error) {
if s[len(s)-1] != '\n' { // We need a closing newline if len(s) > 0 && s[len(s)-1] != '\n' { // We need a closing newline
return ReadRR(strings.NewReader(s+"\n"), "") return ReadRR(strings.NewReader(s+"\n"), "")
} }
return ReadRR(strings.NewReader(s), "") return ReadRR(strings.NewReader(s), "")
@ -117,6 +118,10 @@ func NewRR(s string) (RR, error) {
// See NewRR for more documentation. // See NewRR for more documentation.
func ReadRR(q io.Reader, filename string) (RR, error) { func ReadRR(q io.Reader, filename string) (RR, error) {
r := <-parseZoneHelper(q, ".", filename, 1) r := <-parseZoneHelper(q, ".", filename, 1)
if r == nil {
return nil, nil
}
if r.Error != nil { if r.Error != nil {
return nil, r.Error return nil, r.Error
} }
@ -899,9 +904,9 @@ func appendOrigin(name, origin string) string {
func locCheckNorth(token string, latitude uint32) (uint32, bool) { func locCheckNorth(token string, latitude uint32) (uint32, bool) {
switch token { switch token {
case "n", "N": case "n", "N":
return _LOC_EQUATOR + latitude, true return LOC_EQUATOR + latitude, true
case "s", "S": case "s", "S":
return _LOC_EQUATOR - latitude, true return LOC_EQUATOR - latitude, true
} }
return latitude, false return latitude, false
} }
@ -910,9 +915,9 @@ func locCheckNorth(token string, latitude uint32) (uint32, bool) {
func locCheckEast(token string, longitude uint32) (uint32, bool) { func locCheckEast(token string, longitude uint32) (uint32, bool) {
switch token { switch token {
case "e", "E": case "e", "E":
return _LOC_EQUATOR + longitude, true return LOC_EQUATOR + longitude, true
case "w", "W": case "w", "W":
return _LOC_EQUATOR - longitude, true return LOC_EQUATOR - longitude, true
} }
return longitude, false return longitude, false
} }

View File

@ -1065,6 +1065,14 @@ func setOPENPGPKEY(h RR_Header, c chan lex, o, f string) (RR, *ParseError, strin
return rr, nil, c1 return rr, nil, c1
} }
func setSIG(h RR_Header, c chan lex, o, f string) (RR, *ParseError, string) {
r, e, s := setRRSIG(h, c, o, f)
if r != nil {
return &SIG{*r.(*RRSIG)}, e, s
}
return nil, e, s
}
func setRRSIG(h RR_Header, c chan lex, o, f string) (RR, *ParseError, string) { func setRRSIG(h RR_Header, c chan lex, o, f string) (RR, *ParseError, string) {
rr := new(RRSIG) rr := new(RRSIG)
rr.Hdr = h rr.Hdr = h
@ -1452,7 +1460,7 @@ func setSSHFP(h RR_Header, c chan lex, o, f string) (RR, *ParseError, string) {
return rr, nil, "" return rr, nil, ""
} }
func setDNSKEY(h RR_Header, c chan lex, o, f string) (RR, *ParseError, string) { func setDNSKEYs(h RR_Header, c chan lex, o, f, typ string) (RR, *ParseError, string) {
rr := new(DNSKEY) rr := new(DNSKEY)
rr.Hdr = h rr.Hdr = h
@ -1461,25 +1469,25 @@ func setDNSKEY(h RR_Header, c chan lex, o, f string) (RR, *ParseError, string) {
return rr, nil, l.comment return rr, nil, l.comment
} }
if i, e := strconv.Atoi(l.token); e != nil { if i, e := strconv.Atoi(l.token); e != nil {
return nil, &ParseError{f, "bad DNSKEY Flags", l}, "" return nil, &ParseError{f, "bad " + typ + " Flags", l}, ""
} else { } else {
rr.Flags = uint16(i) rr.Flags = uint16(i)
} }
<-c // _BLANK <-c // _BLANK
l = <-c // _STRING l = <-c // _STRING
if i, e := strconv.Atoi(l.token); e != nil { if i, e := strconv.Atoi(l.token); e != nil {
return nil, &ParseError{f, "bad DNSKEY Protocol", l}, "" return nil, &ParseError{f, "bad " + typ + " Protocol", l}, ""
} else { } else {
rr.Protocol = uint8(i) rr.Protocol = uint8(i)
} }
<-c // _BLANK <-c // _BLANK
l = <-c // _STRING l = <-c // _STRING
if i, e := strconv.Atoi(l.token); e != nil { if i, e := strconv.Atoi(l.token); e != nil {
return nil, &ParseError{f, "bad DNSKEY Algorithm", l}, "" return nil, &ParseError{f, "bad " + typ + " Algorithm", l}, ""
} else { } else {
rr.Algorithm = uint8(i) rr.Algorithm = uint8(i)
} }
s, e, c1 := endingToString(c, "bad DNSKEY PublicKey", f) s, e, c1 := endingToString(c, "bad "+typ+" PublicKey", f)
if e != nil { if e != nil {
return nil, e, c1 return nil, e, c1
} }
@ -1487,6 +1495,27 @@ func setDNSKEY(h RR_Header, c chan lex, o, f string) (RR, *ParseError, string) {
return rr, nil, c1 return rr, nil, c1
} }
func setKEY(h RR_Header, c chan lex, o, f string) (RR, *ParseError, string) {
r, e, s := setDNSKEYs(h, c, o, f, "KEY")
if r != nil {
return &KEY{*r.(*DNSKEY)}, e, s
}
return nil, e, s
}
func setDNSKEY(h RR_Header, c chan lex, o, f string) (RR, *ParseError, string) {
r, e, s := setDNSKEYs(h, c, o, f, "DNSKEY")
return r, e, s
}
func setCDNSKEY(h RR_Header, c chan lex, o, f string) (RR, *ParseError, string) {
r, e, s := setDNSKEYs(h, c, o, f, "CDNSKEY")
if r != nil {
return &CDNSKEY{*r.(*DNSKEY)}, e, s
}
return nil, e, s
}
func setRKEY(h RR_Header, c chan lex, o, f string) (RR, *ParseError, string) { func setRKEY(h RR_Header, c chan lex, o, f string) (RR, *ParseError, string) {
rr := new(RKEY) rr := new(RKEY)
rr.Hdr = h rr.Hdr = h
@ -1522,44 +1551,6 @@ func setRKEY(h RR_Header, c chan lex, o, f string) (RR, *ParseError, string) {
return rr, nil, c1 return rr, nil, c1
} }
func setDS(h RR_Header, c chan lex, o, f string) (RR, *ParseError, string) {
rr := new(DS)
rr.Hdr = h
l := <-c
if l.length == 0 {
return rr, nil, l.comment
}
if i, e := strconv.Atoi(l.token); e != nil {
return nil, &ParseError{f, "bad DS KeyTag", l}, ""
} else {
rr.KeyTag = uint16(i)
}
<-c // _BLANK
l = <-c
if i, e := strconv.Atoi(l.token); e != nil {
if i, ok := StringToAlgorithm[l.tokenUpper]; !ok {
return nil, &ParseError{f, "bad DS Algorithm", l}, ""
} else {
rr.Algorithm = i
}
} else {
rr.Algorithm = uint8(i)
}
<-c // _BLANK
l = <-c
if i, e := strconv.Atoi(l.token); e != nil {
return nil, &ParseError{f, "bad DS DigestType", l}, ""
} else {
rr.DigestType = uint8(i)
}
s, e, c1 := endingToString(c, "bad DS Digest", f)
if e != nil {
return nil, e, c1
}
rr.Digest = s
return rr, nil, c1
}
func setEID(h RR_Header, c chan lex, o, f string) (RR, *ParseError, string) { func setEID(h RR_Header, c chan lex, o, f string) (RR, *ParseError, string) {
rr := new(EID) rr := new(EID)
rr.Hdr = h rr.Hdr = h
@ -1632,15 +1623,15 @@ func setGPOS(h RR_Header, c chan lex, o, f string) (RR, *ParseError, string) {
return rr, nil, "" return rr, nil, ""
} }
func setCDS(h RR_Header, c chan lex, o, f string) (RR, *ParseError, string) { func setDSs(h RR_Header, c chan lex, o, f, typ string) (RR, *ParseError, string) {
rr := new(CDS) rr := new(DS)
rr.Hdr = h rr.Hdr = h
l := <-c l := <-c
if l.length == 0 { if l.length == 0 {
return rr, nil, l.comment return rr, nil, l.comment
} }
if i, e := strconv.Atoi(l.token); e != nil { if i, e := strconv.Atoi(l.token); e != nil {
return nil, &ParseError{f, "bad CDS KeyTag", l}, "" return nil, &ParseError{f, "bad " + typ + " KeyTag", l}, ""
} else { } else {
rr.KeyTag = uint16(i) rr.KeyTag = uint16(i)
} }
@ -1648,7 +1639,7 @@ func setCDS(h RR_Header, c chan lex, o, f string) (RR, *ParseError, string) {
l = <-c l = <-c
if i, e := strconv.Atoi(l.token); e != nil { if i, e := strconv.Atoi(l.token); e != nil {
if i, ok := StringToAlgorithm[l.tokenUpper]; !ok { if i, ok := StringToAlgorithm[l.tokenUpper]; !ok {
return nil, &ParseError{f, "bad CDS Algorithm", l}, "" return nil, &ParseError{f, "bad " + typ + " Algorithm", l}, ""
} else { } else {
rr.Algorithm = i rr.Algorithm = i
} }
@ -1658,11 +1649,11 @@ func setCDS(h RR_Header, c chan lex, o, f string) (RR, *ParseError, string) {
<-c // _BLANK <-c // _BLANK
l = <-c l = <-c
if i, e := strconv.Atoi(l.token); e != nil { if i, e := strconv.Atoi(l.token); e != nil {
return nil, &ParseError{f, "bad CDS DigestType", l}, "" return nil, &ParseError{f, "bad " + typ + " DigestType", l}, ""
} else { } else {
rr.DigestType = uint8(i) rr.DigestType = uint8(i)
} }
s, e, c1 := endingToString(c, "bad CDS Digest", f) s, e, c1 := endingToString(c, "bad "+typ+" Digest", f)
if e != nil { if e != nil {
return nil, e, c1 return nil, e, c1
} }
@ -1670,42 +1661,25 @@ func setCDS(h RR_Header, c chan lex, o, f string) (RR, *ParseError, string) {
return rr, nil, c1 return rr, nil, c1
} }
func setDS(h RR_Header, c chan lex, o, f string) (RR, *ParseError, string) {
r, e, s := setDSs(h, c, o, f, "DS")
return r, e, s
}
func setDLV(h RR_Header, c chan lex, o, f string) (RR, *ParseError, string) { func setDLV(h RR_Header, c chan lex, o, f string) (RR, *ParseError, string) {
rr := new(DLV) r, e, s := setDSs(h, c, o, f, "DLV")
rr.Hdr = h if r != nil {
l := <-c return &DLV{*r.(*DS)}, e, s
if l.length == 0 {
return rr, nil, l.comment
} }
if i, e := strconv.Atoi(l.token); e != nil { return nil, e, s
return nil, &ParseError{f, "bad DLV KeyTag", l}, "" }
} else {
rr.KeyTag = uint16(i) func setCDS(h RR_Header, c chan lex, o, f string) (RR, *ParseError, string) {
r, e, s := setDSs(h, c, o, f, "DLV")
if r != nil {
return &CDS{*r.(*DS)}, e, s
} }
<-c // _BLANK return nil, e, s
l = <-c
if i, e := strconv.Atoi(l.token); e != nil {
if i, ok := StringToAlgorithm[l.tokenUpper]; !ok {
return nil, &ParseError{f, "bad DLV Algorithm", l}, ""
} else {
rr.Algorithm = i
}
} else {
rr.Algorithm = uint8(i)
}
<-c // _BLANK
l = <-c
if i, e := strconv.Atoi(l.token); e != nil {
return nil, &ParseError{f, "bad DLV DigestType", l}, ""
} else {
rr.DigestType = uint8(i)
}
s, e, c1 := endingToString(c, "bad DLV Digest", f)
if e != nil {
return nil, e, c1
}
rr.Digest = s
return rr, nil, c1
} }
func setTA(h RR_Header, c chan lex, o, f string) (RR, *ParseError, string) { func setTA(h RR_Header, c chan lex, o, f string) (RR, *ParseError, string) {
@ -1873,44 +1847,6 @@ func setURI(h RR_Header, c chan lex, o, f string) (RR, *ParseError, string) {
return rr, nil, c1 return rr, nil, c1
} }
func setIPSECKEY(h RR_Header, c chan lex, o, f string) (RR, *ParseError, string) {
rr := new(IPSECKEY)
rr.Hdr = h
l := <-c
if l.length == 0 {
return rr, nil, l.comment
}
if i, e := strconv.Atoi(l.token); e != nil {
return nil, &ParseError{f, "bad IPSECKEY Precedence", l}, ""
} else {
rr.Precedence = uint8(i)
}
<-c // _BLANK
l = <-c
if i, e := strconv.Atoi(l.token); e != nil {
return nil, &ParseError{f, "bad IPSECKEY GatewayType", l}, ""
} else {
rr.GatewayType = uint8(i)
}
<-c // _BLANK
l = <-c
if i, e := strconv.Atoi(l.token); e != nil {
return nil, &ParseError{f, "bad IPSECKEY Algorithm", l}, ""
} else {
rr.Algorithm = uint8(i)
}
<-c
l = <-c
rr.Gateway = l.token
s, e, c1 := endingToString(c, "bad IPSECKEY PublicKey", f)
if e != nil {
return nil, e, c1
}
rr.PublicKey = s
return rr, nil, c1
}
func setDHCID(h RR_Header, c chan lex, o, f string) (RR, *ParseError, string) { func setDHCID(h RR_Header, c chan lex, o, f string) (RR, *ParseError, string) {
// awesome record to parse! // awesome record to parse!
rr := new(DHCID) rr := new(DHCID)
@ -2113,16 +2049,85 @@ func setPX(h RR_Header, c chan lex, o, f string) (RR, *ParseError, string) {
return rr, nil, "" return rr, nil, ""
} }
func setIPSECKEY(h RR_Header, c chan lex, o, f string) (RR, *ParseError, string) {
rr := new(IPSECKEY)
rr.Hdr = h
l := <-c
if l.length == 0 {
return rr, nil, l.comment
}
if i, err := strconv.Atoi(l.token); err != nil {
return nil, &ParseError{f, "bad IPSECKEY Precedence", l}, ""
} else {
rr.Precedence = uint8(i)
}
<-c // _BLANK
l = <-c
if i, err := strconv.Atoi(l.token); err != nil {
return nil, &ParseError{f, "bad IPSECKEY GatewayType", l}, ""
} else {
rr.GatewayType = uint8(i)
}
<-c // _BLANK
l = <-c
if i, err := strconv.Atoi(l.token); err != nil {
return nil, &ParseError{f, "bad IPSECKEY Algorithm", l}, ""
} else {
rr.Algorithm = uint8(i)
}
// Now according to GatewayType we can have different elements here
<-c // _BLANK
l = <-c
switch rr.GatewayType {
case 0:
fallthrough
case 3:
rr.GatewayName = l.token
if l.token == "@" {
rr.GatewayName = o
}
_, ok := IsDomainName(l.token)
if !ok {
return nil, &ParseError{f, "bad IPSECKEY GatewayName", l}, ""
}
if rr.GatewayName[l.length-1] != '.' {
rr.GatewayName = appendOrigin(rr.GatewayName, o)
}
case 1:
rr.GatewayA = net.ParseIP(l.token)
if rr.GatewayA == nil {
return nil, &ParseError{f, "bad IPSECKEY GatewayA", l}, ""
}
case 2:
rr.GatewayAAAA = net.ParseIP(l.token)
if rr.GatewayAAAA == nil {
return nil, &ParseError{f, "bad IPSECKEY GatewayAAAA", l}, ""
}
default:
return nil, &ParseError{f, "bad IPSECKEY GatewayType", l}, ""
}
s, e, c1 := endingToString(c, "bad IPSECKEY PublicKey", f)
if e != nil {
return nil, e, c1
}
rr.PublicKey = s
return rr, nil, c1
}
var typeToparserFunc = map[uint16]parserFunc{ var typeToparserFunc = map[uint16]parserFunc{
TypeAAAA: parserFunc{setAAAA, false}, TypeAAAA: parserFunc{setAAAA, false},
TypeAFSDB: parserFunc{setAFSDB, false}, TypeAFSDB: parserFunc{setAFSDB, false},
TypeA: parserFunc{setA, false}, TypeA: parserFunc{setA, false},
TypeCDS: parserFunc{setCDS, true}, TypeCDS: parserFunc{setCDS, true},
TypeCDNSKEY: parserFunc{setCDNSKEY, true},
TypeCERT: parserFunc{setCERT, true}, TypeCERT: parserFunc{setCERT, true},
TypeCNAME: parserFunc{setCNAME, false}, TypeCNAME: parserFunc{setCNAME, false},
TypeDHCID: parserFunc{setDHCID, true}, TypeDHCID: parserFunc{setDHCID, true},
TypeDLV: parserFunc{setDLV, true}, TypeDLV: parserFunc{setDLV, true},
TypeDNAME: parserFunc{setDNAME, false}, TypeDNAME: parserFunc{setDNAME, false},
TypeKEY: parserFunc{setKEY, true},
TypeDNSKEY: parserFunc{setDNSKEY, true}, TypeDNSKEY: parserFunc{setDNSKEY, true},
TypeDS: parserFunc{setDS, true}, TypeDS: parserFunc{setDS, true},
TypeEID: parserFunc{setEID, true}, TypeEID: parserFunc{setEID, true},
@ -2158,6 +2163,7 @@ var typeToparserFunc = map[uint16]parserFunc{
TypeOPENPGPKEY: parserFunc{setOPENPGPKEY, true}, TypeOPENPGPKEY: parserFunc{setOPENPGPKEY, true},
TypePTR: parserFunc{setPTR, false}, TypePTR: parserFunc{setPTR, false},
TypePX: parserFunc{setPX, false}, TypePX: parserFunc{setPX, false},
TypeSIG: parserFunc{setSIG, true},
TypeRKEY: parserFunc{setRKEY, true}, TypeRKEY: parserFunc{setRKEY, true},
TypeRP: parserFunc{setRP, false}, TypeRP: parserFunc{setRP, false},
TypeRRSIG: parserFunc{setRRSIG, true}, TypeRRSIG: parserFunc{setRRSIG, true},

View File

@ -0,0 +1 @@
Imported at 75cd24fc2f2c from https://bitbucket.org/ww/goautoneg.

View File

@ -0,0 +1,13 @@
include $(GOROOT)/src/Make.inc
TARG=bitbucket.org/ww/goautoneg
GOFILES=autoneg.go
include $(GOROOT)/src/Make.pkg
format:
gofmt -w *.go
docs:
gomake clean
godoc ${TARG} > README.txt

View File

@ -0,0 +1,67 @@
PACKAGE
package goautoneg
import "bitbucket.org/ww/goautoneg"
HTTP Content-Type Autonegotiation.
The functions in this package implement the behaviour specified in
http://www.w3.org/Protocols/rfc2616/rfc2616-sec14.html
Copyright (c) 2011, Open Knowledge Foundation Ltd.
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:
Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in
the documentation and/or other materials provided with the
distribution.
Neither the name of the Open Knowledge Foundation Ltd. nor the
names of its contributors may be used to endorse or promote
products derived from this software without specific prior written
permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
FUNCTIONS
func Negotiate(header string, alternatives []string) (content_type string)
Negotiate the most appropriate content_type given the accept header
and a list of alternatives.
func ParseAccept(header string) (accept []Accept)
Parse an Accept Header string returning a sorted list
of clauses
TYPES
type Accept struct {
Type, SubType string
Q float32
Params map[string]string
}
Structure to represent a clause in an HTTP Accept Header
SUBDIRECTORIES
.hg

View File

@ -0,0 +1,162 @@
/*
HTTP Content-Type Autonegotiation.
The functions in this package implement the behaviour specified in
http://www.w3.org/Protocols/rfc2616/rfc2616-sec14.html
Copyright (c) 2011, Open Knowledge Foundation Ltd.
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:
Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in
the documentation and/or other materials provided with the
distribution.
Neither the name of the Open Knowledge Foundation Ltd. nor the
names of its contributors may be used to endorse or promote
products derived from this software without specific prior written
permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
package goautoneg
import (
"sort"
"strconv"
"strings"
)
// Structure to represent a clause in an HTTP Accept Header
type Accept struct {
Type, SubType string
Q float64
Params map[string]string
}
// For internal use, so that we can use the sort interface
type accept_slice []Accept
func (accept accept_slice) Len() int {
slice := []Accept(accept)
return len(slice)
}
func (accept accept_slice) Less(i, j int) bool {
slice := []Accept(accept)
ai, aj := slice[i], slice[j]
if ai.Q > aj.Q {
return true
}
if ai.Type != "*" && aj.Type == "*" {
return true
}
if ai.SubType != "*" && aj.SubType == "*" {
return true
}
return false
}
func (accept accept_slice) Swap(i, j int) {
slice := []Accept(accept)
slice[i], slice[j] = slice[j], slice[i]
}
// Parse an Accept Header string returning a sorted list
// of clauses
func ParseAccept(header string) (accept []Accept) {
parts := strings.Split(header, ",")
accept = make([]Accept, 0, len(parts))
for _, part := range parts {
part := strings.Trim(part, " ")
a := Accept{}
a.Params = make(map[string]string)
a.Q = 1.0
mrp := strings.Split(part, ";")
media_range := mrp[0]
sp := strings.Split(media_range, "/")
a.Type = strings.Trim(sp[0], " ")
switch {
case len(sp) == 1 && a.Type == "*":
a.SubType = "*"
case len(sp) == 2:
a.SubType = strings.Trim(sp[1], " ")
default:
continue
}
if len(mrp) == 1 {
accept = append(accept, a)
continue
}
for _, param := range mrp[1:] {
sp := strings.SplitN(param, "=", 2)
if len(sp) != 2 {
continue
}
token := strings.Trim(sp[0], " ")
if token == "q" {
a.Q, _ = strconv.ParseFloat(sp[1], 32)
} else {
a.Params[token] = strings.Trim(sp[1], " ")
}
}
accept = append(accept, a)
}
slice := accept_slice(accept)
sort.Sort(slice)
return
}
// Negotiate the most appropriate content_type given the accept header
// and a list of alternatives.
func Negotiate(header string, alternatives []string) (content_type string) {
asp := make([][]string, 0, len(alternatives))
for _, ctype := range alternatives {
asp = append(asp, strings.SplitN(ctype, "/", 2))
}
for _, clause := range ParseAccept(header) {
for i, ctsp := range asp {
if clause.Type == ctsp[0] && clause.SubType == ctsp[1] {
content_type = alternatives[i]
return
}
if clause.Type == ctsp[0] && clause.SubType == "*" {
content_type = alternatives[i]
return
}
if clause.Type == "*" && clause.SubType == "*" {
content_type = alternatives[i]
return
}
}
}
return
}

View File

@ -0,0 +1,33 @@
package goautoneg
import (
"testing"
)
var chrome = "application/xml,application/xhtml+xml,text/html;q=0.9,text/plain;q=0.8,image/png,*/*;q=0.5"
func TestParseAccept(t *testing.T) {
alternatives := []string{"text/html", "image/png"}
content_type := Negotiate(chrome, alternatives)
if content_type != "image/png" {
t.Errorf("got %s expected image/png", content_type)
}
alternatives = []string{"text/html", "text/plain", "text/n3"}
content_type = Negotiate(chrome, alternatives)
if content_type != "text/html" {
t.Errorf("got %s expected text/html", content_type)
}
alternatives = []string{"text/n3", "text/plain"}
content_type = Negotiate(chrome, alternatives)
if content_type != "text/plain" {
t.Errorf("got %s expected text/plain", content_type)
}
alternatives = []string{"text/n3", "application/rdf+xml"}
content_type = Negotiate(chrome, alternatives)
if content_type != "text/n3" {
t.Errorf("got %s expected text/n3", content_type)
}
}

View File

@ -0,0 +1,63 @@
package quantile
import (
"testing"
)
func BenchmarkInsertTargeted(b *testing.B) {
b.ReportAllocs()
s := NewTargeted(Targets)
b.ResetTimer()
for i := float64(0); i < float64(b.N); i++ {
s.Insert(i)
}
}
func BenchmarkInsertTargetedSmallEpsilon(b *testing.B) {
s := NewTargeted(TargetsSmallEpsilon)
b.ResetTimer()
for i := float64(0); i < float64(b.N); i++ {
s.Insert(i)
}
}
func BenchmarkInsertBiased(b *testing.B) {
s := NewLowBiased(0.01)
b.ResetTimer()
for i := float64(0); i < float64(b.N); i++ {
s.Insert(i)
}
}
func BenchmarkInsertBiasedSmallEpsilon(b *testing.B) {
s := NewLowBiased(0.0001)
b.ResetTimer()
for i := float64(0); i < float64(b.N); i++ {
s.Insert(i)
}
}
func BenchmarkQuery(b *testing.B) {
s := NewTargeted(Targets)
for i := float64(0); i < 1e6; i++ {
s.Insert(i)
}
b.ResetTimer()
n := float64(b.N)
for i := float64(0); i < n; i++ {
s.Query(i / n)
}
}
func BenchmarkQuerySmallEpsilon(b *testing.B) {
s := NewTargeted(TargetsSmallEpsilon)
for i := float64(0); i < 1e6; i++ {
s.Insert(i)
}
b.ResetTimer()
n := float64(b.N)
for i := float64(0); i < n; i++ {
s.Query(i / n)
}
}

View File

@ -0,0 +1,112 @@
// +build go1.1
package quantile_test
import (
"bufio"
"fmt"
"github.com/bmizerany/perks/quantile"
"log"
"os"
"strconv"
"time"
)
func Example_simple() {
ch := make(chan float64)
go sendFloats(ch)
// Compute the 50th, 90th, and 99th percentile.
q := quantile.NewTargeted(0.50, 0.90, 0.99)
for v := range ch {
q.Insert(v)
}
fmt.Println("perc50:", q.Query(0.50))
fmt.Println("perc90:", q.Query(0.90))
fmt.Println("perc99:", q.Query(0.99))
fmt.Println("count:", q.Count())
// Output:
// perc50: 5
// perc90: 14
// perc99: 40
// count: 2388
}
func Example_mergeMultipleStreams() {
// Scenario:
// We have multiple database shards. On each shard, there is a process
// collecting query response times from the database logs and inserting
// them into a Stream (created via NewTargeted(0.90)), much like the
// Simple example. These processes expose a network interface for us to
// ask them to serialize and send us the results of their
// Stream.Samples so we may Merge and Query them.
//
// NOTES:
// * These sample sets are small, allowing us to get them
// across the network much faster than sending the entire list of data
// points.
//
// * For this to work correctly, we must supply the same quantiles
// a priori the process collecting the samples supplied to NewTargeted,
// even if we do not plan to query them all here.
ch := make(chan quantile.Samples)
getDBQuerySamples(ch)
q := quantile.NewTargeted(0.90)
for samples := range ch {
q.Merge(samples)
}
fmt.Println("perc90:", q.Query(0.90))
}
func Example_window() {
// Scenario: We want the 90th, 95th, and 99th percentiles for each
// minute.
ch := make(chan float64)
go sendStreamValues(ch)
tick := time.NewTicker(1 * time.Minute)
q := quantile.NewTargeted(0.90, 0.95, 0.99)
for {
select {
case t := <-tick.C:
flushToDB(t, q.Samples())
q.Reset()
case v := <-ch:
q.Insert(v)
}
}
}
func sendStreamValues(ch chan float64) {
// Use your imagination
}
func flushToDB(t time.Time, samples quantile.Samples) {
// Use your imagination
}
// This is a stub for the above example. In reality this would hit the remote
// servers via http or something like it.
func getDBQuerySamples(ch chan quantile.Samples) {}
func sendFloats(ch chan<- float64) {
f, err := os.Open("exampledata.txt")
if err != nil {
log.Fatal(err)
}
sc := bufio.NewScanner(f)
for sc.Scan() {
b := sc.Bytes()
v, err := strconv.ParseFloat(string(b), 64)
if err != nil {
log.Fatal(err)
}
ch <- v
}
if sc.Err() != nil {
log.Fatal(sc.Err())
}
close(ch)
}

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,292 @@
// Package quantile computes approximate quantiles over an unbounded data
// stream within low memory and CPU bounds.
//
// A small amount of accuracy is traded to achieve the above properties.
//
// Multiple streams can be merged before calling Query to generate a single set
// of results. This is meaningful when the streams represent the same type of
// data. See Merge and Samples.
//
// For more detailed information about the algorithm used, see:
//
// Effective Computation of Biased Quantiles over Data Streams
//
// http://www.cs.rutgers.edu/~muthu/bquant.pdf
package quantile
import (
"math"
"sort"
)
// Sample holds an observed value and meta information for compression. JSON
// tags have been added for convenience.
type Sample struct {
Value float64 `json:",string"`
Width float64 `json:",string"`
Delta float64 `json:",string"`
}
// Samples represents a slice of samples. It implements sort.Interface.
type Samples []Sample
func (a Samples) Len() int { return len(a) }
func (a Samples) Less(i, j int) bool { return a[i].Value < a[j].Value }
func (a Samples) Swap(i, j int) { a[i], a[j] = a[j], a[i] }
type invariant func(s *stream, r float64) float64
// NewLowBiased returns an initialized Stream for low-biased quantiles
// (e.g. 0.01, 0.1, 0.5) where the needed quantiles are not known a priori, but
// error guarantees can still be given even for the lower ranks of the data
// distribution.
//
// The provided epsilon is a relative error, i.e. the true quantile of a value
// returned by a query is guaranteed to be within (1±Epsilon)*Quantile.
//
// See http://www.cs.rutgers.edu/~muthu/bquant.pdf for time, space, and error
// properties.
func NewLowBiased(epsilon float64) *Stream {
ƒ := func(s *stream, r float64) float64 {
return 2 * epsilon * r
}
return newStream(ƒ)
}
// NewHighBiased returns an initialized Stream for high-biased quantiles
// (e.g. 0.01, 0.1, 0.5) where the needed quantiles are not known a priori, but
// error guarantees can still be given even for the higher ranks of the data
// distribution.
//
// The provided epsilon is a relative error, i.e. the true quantile of a value
// returned by a query is guaranteed to be within 1-(1±Epsilon)*(1-Quantile).
//
// See http://www.cs.rutgers.edu/~muthu/bquant.pdf for time, space, and error
// properties.
func NewHighBiased(epsilon float64) *Stream {
ƒ := func(s *stream, r float64) float64 {
return 2 * epsilon * (s.n - r)
}
return newStream(ƒ)
}
// NewTargeted returns an initialized Stream concerned with a particular set of
// quantile values that are supplied a priori. Knowing these a priori reduces
// space and computation time. The targets map maps the desired quantiles to
// their absolute errors, i.e. the true quantile of a value returned by a query
// is guaranteed to be within (Quantile±Epsilon).
//
// See http://www.cs.rutgers.edu/~muthu/bquant.pdf for time, space, and error properties.
func NewTargeted(targets map[float64]float64) *Stream {
ƒ := func(s *stream, r float64) float64 {
var m = math.MaxFloat64
var f float64
for quantile, epsilon := range targets {
if quantile*s.n <= r {
f = (2 * epsilon * r) / quantile
} else {
f = (2 * epsilon * (s.n - r)) / (1 - quantile)
}
if f < m {
m = f
}
}
return m
}
return newStream(ƒ)
}
// Stream computes quantiles for a stream of float64s. It is not thread-safe by
// design. Take care when using across multiple goroutines.
type Stream struct {
*stream
b Samples
sorted bool
}
func newStream(ƒ invariant) *Stream {
x := &stream{ƒ: ƒ}
return &Stream{x, make(Samples, 0, 500), true}
}
// Insert inserts v into the stream.
func (s *Stream) Insert(v float64) {
s.insert(Sample{Value: v, Width: 1})
}
func (s *Stream) insert(sample Sample) {
s.b = append(s.b, sample)
s.sorted = false
if len(s.b) == cap(s.b) {
s.flush()
}
}
// Query returns the computed qth percentiles value. If s was created with
// NewTargeted, and q is not in the set of quantiles provided a priori, Query
// will return an unspecified result.
func (s *Stream) Query(q float64) float64 {
if !s.flushed() {
// Fast path when there hasn't been enough data for a flush;
// this also yields better accuracy for small sets of data.
l := len(s.b)
if l == 0 {
return 0
}
i := int(float64(l) * q)
if i > 0 {
i -= 1
}
s.maybeSort()
return s.b[i].Value
}
s.flush()
return s.stream.query(q)
}
// Merge merges samples into the underlying streams samples. This is handy when
// merging multiple streams from separate threads, database shards, etc.
//
// ATTENTION: This method is broken and does not yield correct results. The
// underlying algorithm is not capable of merging streams correctly.
func (s *Stream) Merge(samples Samples) {
sort.Sort(samples)
s.stream.merge(samples)
}
// Reset reinitializes and clears the list reusing the samples buffer memory.
func (s *Stream) Reset() {
s.stream.reset()
s.b = s.b[:0]
}
// Samples returns stream samples held by s.
func (s *Stream) Samples() Samples {
if !s.flushed() {
return s.b
}
s.flush()
return s.stream.samples()
}
// Count returns the total number of samples observed in the stream
// since initialization.
func (s *Stream) Count() int {
return len(s.b) + s.stream.count()
}
func (s *Stream) flush() {
s.maybeSort()
s.stream.merge(s.b)
s.b = s.b[:0]
}
func (s *Stream) maybeSort() {
if !s.sorted {
s.sorted = true
sort.Sort(s.b)
}
}
func (s *Stream) flushed() bool {
return len(s.stream.l) > 0
}
type stream struct {
n float64
l []Sample
ƒ invariant
}
func (s *stream) reset() {
s.l = s.l[:0]
s.n = 0
}
func (s *stream) insert(v float64) {
s.merge(Samples{{v, 1, 0}})
}
func (s *stream) merge(samples Samples) {
// TODO(beorn7): This tries to merge not only individual samples, but
// whole summaries. The paper doesn't mention merging summaries at
// all. Unittests show that the merging is inaccurate. Find out how to
// do merges properly.
var r float64
i := 0
for _, sample := range samples {
for ; i < len(s.l); i++ {
c := s.l[i]
if c.Value > sample.Value {
// Insert at position i.
s.l = append(s.l, Sample{})
copy(s.l[i+1:], s.l[i:])
s.l[i] = Sample{
sample.Value,
sample.Width,
math.Max(sample.Delta, math.Floor(s.ƒ(s, r))-1),
// TODO(beorn7): How to calculate delta correctly?
}
i++
goto inserted
}
r += c.Width
}
s.l = append(s.l, Sample{sample.Value, sample.Width, 0})
i++
inserted:
s.n += sample.Width
r += sample.Width
}
s.compress()
}
func (s *stream) count() int {
return int(s.n)
}
func (s *stream) query(q float64) float64 {
t := math.Ceil(q * s.n)
t += math.Ceil(s.ƒ(s, t) / 2)
p := s.l[0]
var r float64
for _, c := range s.l[1:] {
r += p.Width
if r+c.Width+c.Delta > t {
return p.Value
}
p = c
}
return p.Value
}
func (s *stream) compress() {
if len(s.l) < 2 {
return
}
x := s.l[len(s.l)-1]
xi := len(s.l) - 1
r := s.n - 1 - x.Width
for i := len(s.l) - 2; i >= 0; i-- {
c := s.l[i]
if c.Width+x.Width+x.Delta <= s.ƒ(s, r) {
x.Width += c.Width
s.l[xi] = x
// Remove element at i.
copy(s.l[i:], s.l[i+1:])
s.l = s.l[:len(s.l)-1]
xi -= 1
} else {
x = c
xi = i
}
r -= c.Width
}
}
func (s *stream) samples() Samples {
samples := make(Samples, len(s.l))
copy(samples, s.l)
return samples
}

View File

@ -0,0 +1,185 @@
package quantile
import (
"math"
"math/rand"
"sort"
"testing"
)
var (
Targets = map[float64]float64{
0.01: 0.001,
0.10: 0.01,
0.50: 0.05,
0.90: 0.01,
0.99: 0.001,
}
TargetsSmallEpsilon = map[float64]float64{
0.01: 0.0001,
0.10: 0.001,
0.50: 0.005,
0.90: 0.001,
0.99: 0.0001,
}
LowQuantiles = []float64{0.01, 0.1, 0.5}
HighQuantiles = []float64{0.99, 0.9, 0.5}
)
const RelativeEpsilon = 0.01
func verifyPercsWithAbsoluteEpsilon(t *testing.T, a []float64, s *Stream) {
sort.Float64s(a)
for quantile, epsilon := range Targets {
n := float64(len(a))
k := int(quantile * n)
lower := int((quantile - epsilon) * n)
if lower < 1 {
lower = 1
}
upper := int(math.Ceil((quantile + epsilon) * n))
if upper > len(a) {
upper = len(a)
}
w, min, max := a[k-1], a[lower-1], a[upper-1]
if g := s.Query(quantile); g < min || g > max {
t.Errorf("q=%f: want %v [%f,%f], got %v", quantile, w, min, max, g)
}
}
}
func verifyLowPercsWithRelativeEpsilon(t *testing.T, a []float64, s *Stream) {
sort.Float64s(a)
for _, qu := range LowQuantiles {
n := float64(len(a))
k := int(qu * n)
lowerRank := int((1 - RelativeEpsilon) * qu * n)
upperRank := int(math.Ceil((1 + RelativeEpsilon) * qu * n))
w, min, max := a[k-1], a[lowerRank-1], a[upperRank-1]
if g := s.Query(qu); g < min || g > max {
t.Errorf("q=%f: want %v [%f,%f], got %v", qu, w, min, max, g)
}
}
}
func verifyHighPercsWithRelativeEpsilon(t *testing.T, a []float64, s *Stream) {
sort.Float64s(a)
for _, qu := range HighQuantiles {
n := float64(len(a))
k := int(qu * n)
lowerRank := int((1 - (1+RelativeEpsilon)*(1-qu)) * n)
upperRank := int(math.Ceil((1 - (1-RelativeEpsilon)*(1-qu)) * n))
w, min, max := a[k-1], a[lowerRank-1], a[upperRank-1]
if g := s.Query(qu); g < min || g > max {
t.Errorf("q=%f: want %v [%f,%f], got %v", qu, w, min, max, g)
}
}
}
func populateStream(s *Stream) []float64 {
a := make([]float64, 0, 1e5+100)
for i := 0; i < cap(a); i++ {
v := rand.NormFloat64()
// Add 5% asymmetric outliers.
if i%20 == 0 {
v = v*v + 1
}
s.Insert(v)
a = append(a, v)
}
return a
}
func TestTargetedQuery(t *testing.T) {
rand.Seed(42)
s := NewTargeted(Targets)
a := populateStream(s)
verifyPercsWithAbsoluteEpsilon(t, a, s)
}
func TestLowBiasedQuery(t *testing.T) {
rand.Seed(42)
s := NewLowBiased(RelativeEpsilon)
a := populateStream(s)
verifyLowPercsWithRelativeEpsilon(t, a, s)
}
func TestHighBiasedQuery(t *testing.T) {
rand.Seed(42)
s := NewHighBiased(RelativeEpsilon)
a := populateStream(s)
verifyHighPercsWithRelativeEpsilon(t, a, s)
}
func TestTargetedMerge(t *testing.T) {
rand.Seed(42)
s1 := NewTargeted(Targets)
s2 := NewTargeted(Targets)
a := populateStream(s1)
a = append(a, populateStream(s2)...)
s1.Merge(s2.Samples())
verifyPercsWithAbsoluteEpsilon(t, a, s1)
}
func TestLowBiasedMerge(t *testing.T) {
rand.Seed(42)
s1 := NewLowBiased(RelativeEpsilon)
s2 := NewLowBiased(RelativeEpsilon)
a := populateStream(s1)
a = append(a, populateStream(s2)...)
s1.Merge(s2.Samples())
verifyLowPercsWithRelativeEpsilon(t, a, s2)
}
func TestHighBiasedMerge(t *testing.T) {
rand.Seed(42)
s1 := NewHighBiased(RelativeEpsilon)
s2 := NewHighBiased(RelativeEpsilon)
a := populateStream(s1)
a = append(a, populateStream(s2)...)
s1.Merge(s2.Samples())
verifyHighPercsWithRelativeEpsilon(t, a, s2)
}
func TestUncompressed(t *testing.T) {
q := NewTargeted(Targets)
for i := 100; i > 0; i-- {
q.Insert(float64(i))
}
if g := q.Count(); g != 100 {
t.Errorf("want count 100, got %d", g)
}
// Before compression, Query should have 100% accuracy.
for quantile := range Targets {
w := quantile * 100
if g := q.Query(quantile); g != w {
t.Errorf("want %f, got %f", w, g)
}
}
}
func TestUncompressedSamples(t *testing.T) {
q := NewTargeted(map[float64]float64{0.99: 0.001})
for i := 1; i <= 100; i++ {
q.Insert(float64(i))
}
if g := q.Samples().Len(); g != 100 {
t.Errorf("want count 100, got %d", g)
}
}
func TestUncompressedOne(t *testing.T) {
q := NewTargeted(map[float64]float64{0.99: 0.01})
q.Insert(3.14)
if g := q.Query(0.90); g != 3.14 {
t.Error("want PI, got", g)
}
}
func TestDefaults(t *testing.T) {
if g := NewTargeted(map[float64]float64{0.99: 0.001}).Query(0.99); g != 0 {
t.Errorf("want 0, got %f", g)
}
}

View File

@ -0,0 +1,74 @@
// Copyright 2013 The Prometheus Authors
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package extraction
import (
"errors"
"fmt"
"mime"
"net/http"
)
// ProcessorForRequestHeader interprets a HTTP request header to determine
// what Processor should be used for the given input. If no acceptable
// Processor can be found, an error is returned.
func ProcessorForRequestHeader(header http.Header) (Processor, error) {
if header == nil {
return nil, errors.New("received illegal and nil header")
}
mediatype, params, err := mime.ParseMediaType(header.Get("Content-Type"))
if err != nil {
return nil, fmt.Errorf("invalid Content-Type header %q: %s", header.Get("Content-Type"), err)
}
switch mediatype {
case "application/vnd.google.protobuf":
if params["proto"] != "io.prometheus.client.MetricFamily" {
return nil, fmt.Errorf("unrecognized protocol message %s", params["proto"])
}
if params["encoding"] != "delimited" {
return nil, fmt.Errorf("unsupported encoding %s", params["encoding"])
}
return MetricFamilyProcessor, nil
case "text/plain":
switch params["version"] {
case "0.0.4":
return Processor004, nil
case "":
// Fallback: most recent version.
return Processor004, nil
default:
return nil, fmt.Errorf("unrecognized API version %s", params["version"])
}
case "application/json":
var prometheusAPIVersion string
if params["schema"] == "prometheus/telemetry" && params["version"] != "" {
prometheusAPIVersion = params["version"]
} else {
prometheusAPIVersion = header.Get("X-Prometheus-API-Version")
}
switch prometheusAPIVersion {
case "0.0.2":
return Processor002, nil
case "0.0.1":
return Processor001, nil
default:
return nil, fmt.Errorf("unrecognized API version %s", prometheusAPIVersion)
}
default:
return nil, fmt.Errorf("unsupported media type %q, expected %q", mediatype, "application/json")
}
}

View File

@ -0,0 +1,126 @@
// Copyright 2013 The Prometheus Authors
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package extraction
import (
"errors"
"net/http"
"testing"
)
func testDiscriminatorHTTPHeader(t testing.TB) {
var scenarios = []struct {
input map[string]string
output Processor
err error
}{
{
output: nil,
err: errors.New("received illegal and nil header"),
},
{
input: map[string]string{"Content-Type": "application/json", "X-Prometheus-API-Version": "0.0.0"},
output: nil,
err: errors.New("unrecognized API version 0.0.0"),
},
{
input: map[string]string{"Content-Type": "application/json", "X-Prometheus-API-Version": "0.0.1"},
output: Processor001,
err: nil,
},
{
input: map[string]string{"Content-Type": `application/json; schema="prometheus/telemetry"; version=0.0.0`},
output: nil,
err: errors.New("unrecognized API version 0.0.0"),
},
{
input: map[string]string{"Content-Type": `application/json; schema="prometheus/telemetry"; version=0.0.1`},
output: Processor001,
err: nil,
},
{
input: map[string]string{"Content-Type": `application/json; schema="prometheus/telemetry"; version=0.0.2`},
output: Processor002,
err: nil,
},
{
input: map[string]string{"Content-Type": `application/vnd.google.protobuf; proto="io.prometheus.client.MetricFamily"; encoding="delimited"`},
output: MetricFamilyProcessor,
err: nil,
},
{
input: map[string]string{"Content-Type": `application/vnd.google.protobuf; proto="illegal"; encoding="delimited"`},
output: nil,
err: errors.New("unrecognized protocol message illegal"),
},
{
input: map[string]string{"Content-Type": `application/vnd.google.protobuf; proto="io.prometheus.client.MetricFamily"; encoding="illegal"`},
output: nil,
err: errors.New("unsupported encoding illegal"),
},
{
input: map[string]string{"Content-Type": `text/plain; version=0.0.4`},
output: Processor004,
err: nil,
},
{
input: map[string]string{"Content-Type": `text/plain`},
output: Processor004,
err: nil,
},
{
input: map[string]string{"Content-Type": `text/plain; version=0.0.3`},
output: nil,
err: errors.New("unrecognized API version 0.0.3"),
},
}
for i, scenario := range scenarios {
var header http.Header
if len(scenario.input) > 0 {
header = http.Header{}
}
for key, value := range scenario.input {
header.Add(key, value)
}
actual, err := ProcessorForRequestHeader(header)
if scenario.err != err {
if scenario.err != nil && err != nil {
if scenario.err.Error() != err.Error() {
t.Errorf("%d. expected %s, got %s", i, scenario.err, err)
}
} else if scenario.err != nil || err != nil {
t.Errorf("%d. expected %s, got %s", i, scenario.err, err)
}
}
if scenario.output != actual {
t.Errorf("%d. expected %s, got %s", i, scenario.output, actual)
}
}
}
func TestDiscriminatorHTTPHeader(t *testing.T) {
testDiscriminatorHTTPHeader(t)
}
func BenchmarkDiscriminatorHTTPHeader(b *testing.B) {
for i := 0; i < b.N; i++ {
testDiscriminatorHTTPHeader(b)
}
}

View File

@ -0,0 +1,15 @@
// Copyright 2013 The Prometheus Authors
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package extraction decodes Prometheus clients' data streams for consumers.
package extraction

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,79 @@
[
{
"baseLabels": {
"__name__": "rpc_calls_total",
"job": "batch_job"
},
"docstring": "RPC calls.",
"metric": {
"type": "counter",
"value": [
{
"labels": {
"service": "zed"
},
"value": 25
},
{
"labels": {
"service": "bar"
},
"value": 25
},
{
"labels": {
"service": "foo"
},
"value": 25
}
]
}
},
{
"baseLabels": {
"__name__": "rpc_latency_microseconds"
},
"docstring": "RPC latency.",
"metric": {
"type": "histogram",
"value": [
{
"labels": {
"service": "foo"
},
"value": {
"0.010000": 15.890724674774395,
"0.050000": 15.890724674774395,
"0.500000": 84.63044031436561,
"0.900000": 160.21100853053224,
"0.990000": 172.49828748957728
}
},
{
"labels": {
"service": "zed"
},
"value": {
"0.010000": 0.0459814091918713,
"0.050000": 0.0459814091918713,
"0.500000": 0.6120456642749681,
"0.900000": 1.355915069887731,
"0.990000": 1.772733213161236
}
},
{
"labels": {
"service": "bar"
},
"value": {
"0.010000": 78.48563317257356,
"0.050000": 78.48563317257356,
"0.500000": 97.31798360385088,
"0.900000": 109.89202084295582,
"0.990000": 109.99626121011262
}
}
]
}
}
]

View File

@ -0,0 +1,295 @@
// Copyright 2013 The Prometheus Authors
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package extraction
import (
"fmt"
"io"
dto "github.com/prometheus/client_model/go"
"github.com/matttproud/golang_protobuf_extensions/ext"
"github.com/prometheus/client_golang/model"
)
type metricFamilyProcessor struct{}
// MetricFamilyProcessor decodes varint encoded record length-delimited streams
// of io.prometheus.client.MetricFamily.
//
// See http://godoc.org/github.com/matttproud/golang_protobuf_extensions/ext for
// more details.
var MetricFamilyProcessor = &metricFamilyProcessor{}
func (m *metricFamilyProcessor) ProcessSingle(i io.Reader, out Ingester, o *ProcessOptions) error {
family := &dto.MetricFamily{}
for {
family.Reset()
if _, err := ext.ReadDelimited(i, family); err != nil {
if err == io.EOF {
return nil
}
return err
}
if err := extractMetricFamily(out, o, family); err != nil {
return err
}
}
}
func extractMetricFamily(out Ingester, o *ProcessOptions, family *dto.MetricFamily) error {
switch family.GetType() {
case dto.MetricType_COUNTER:
if err := extractCounter(out, o, family); err != nil {
return err
}
case dto.MetricType_GAUGE:
if err := extractGauge(out, o, family); err != nil {
return err
}
case dto.MetricType_SUMMARY:
if err := extractSummary(out, o, family); err != nil {
return err
}
case dto.MetricType_UNTYPED:
if err := extractUntyped(out, o, family); err != nil {
return err
}
case dto.MetricType_HISTOGRAM:
if err := extractHistogram(out, o, family); err != nil {
return err
}
}
return nil
}
func extractCounter(out Ingester, o *ProcessOptions, f *dto.MetricFamily) error {
samples := make(model.Samples, 0, len(f.Metric))
for _, m := range f.Metric {
if m.Counter == nil {
continue
}
sample := new(model.Sample)
samples = append(samples, sample)
if m.TimestampMs != nil {
sample.Timestamp = model.TimestampFromUnixNano(*m.TimestampMs * 1000000)
} else {
sample.Timestamp = o.Timestamp
}
sample.Metric = model.Metric{}
metric := sample.Metric
for _, p := range m.Label {
metric[model.LabelName(p.GetName())] = model.LabelValue(p.GetValue())
}
metric[model.MetricNameLabel] = model.LabelValue(f.GetName())
sample.Value = model.SampleValue(m.Counter.GetValue())
}
return out.Ingest(samples)
}
func extractGauge(out Ingester, o *ProcessOptions, f *dto.MetricFamily) error {
samples := make(model.Samples, 0, len(f.Metric))
for _, m := range f.Metric {
if m.Gauge == nil {
continue
}
sample := new(model.Sample)
samples = append(samples, sample)
if m.TimestampMs != nil {
sample.Timestamp = model.TimestampFromUnixNano(*m.TimestampMs * 1000000)
} else {
sample.Timestamp = o.Timestamp
}
sample.Metric = model.Metric{}
metric := sample.Metric
for _, p := range m.Label {
metric[model.LabelName(p.GetName())] = model.LabelValue(p.GetValue())
}
metric[model.MetricNameLabel] = model.LabelValue(f.GetName())
sample.Value = model.SampleValue(m.Gauge.GetValue())
}
return out.Ingest(samples)
}
func extractSummary(out Ingester, o *ProcessOptions, f *dto.MetricFamily) error {
samples := make(model.Samples, 0, len(f.Metric))
for _, m := range f.Metric {
if m.Summary == nil {
continue
}
timestamp := o.Timestamp
if m.TimestampMs != nil {
timestamp = model.TimestampFromUnixNano(*m.TimestampMs * 1000000)
}
for _, q := range m.Summary.Quantile {
sample := new(model.Sample)
samples = append(samples, sample)
sample.Timestamp = timestamp
sample.Metric = model.Metric{}
metric := sample.Metric
for _, p := range m.Label {
metric[model.LabelName(p.GetName())] = model.LabelValue(p.GetValue())
}
// BUG(matt): Update other names to "quantile".
metric[model.LabelName("quantile")] = model.LabelValue(fmt.Sprint(q.GetQuantile()))
metric[model.MetricNameLabel] = model.LabelValue(f.GetName())
sample.Value = model.SampleValue(q.GetValue())
}
if m.Summary.SampleSum != nil {
sum := new(model.Sample)
sum.Timestamp = timestamp
metric := model.Metric{}
for _, p := range m.Label {
metric[model.LabelName(p.GetName())] = model.LabelValue(p.GetValue())
}
metric[model.MetricNameLabel] = model.LabelValue(f.GetName() + "_sum")
sum.Metric = metric
sum.Value = model.SampleValue(m.Summary.GetSampleSum())
samples = append(samples, sum)
}
if m.Summary.SampleCount != nil {
count := new(model.Sample)
count.Timestamp = timestamp
metric := model.Metric{}
for _, p := range m.Label {
metric[model.LabelName(p.GetName())] = model.LabelValue(p.GetValue())
}
metric[model.MetricNameLabel] = model.LabelValue(f.GetName() + "_count")
count.Metric = metric
count.Value = model.SampleValue(m.Summary.GetSampleCount())
samples = append(samples, count)
}
}
return out.Ingest(samples)
}
func extractUntyped(out Ingester, o *ProcessOptions, f *dto.MetricFamily) error {
samples := make(model.Samples, 0, len(f.Metric))
for _, m := range f.Metric {
if m.Untyped == nil {
continue
}
sample := new(model.Sample)
samples = append(samples, sample)
if m.TimestampMs != nil {
sample.Timestamp = model.TimestampFromUnixNano(*m.TimestampMs * 1000000)
} else {
sample.Timestamp = o.Timestamp
}
sample.Metric = model.Metric{}
metric := sample.Metric
for _, p := range m.Label {
metric[model.LabelName(p.GetName())] = model.LabelValue(p.GetValue())
}
metric[model.MetricNameLabel] = model.LabelValue(f.GetName())
sample.Value = model.SampleValue(m.Untyped.GetValue())
}
return out.Ingest(samples)
}
func extractHistogram(out Ingester, o *ProcessOptions, f *dto.MetricFamily) error {
samples := make(model.Samples, 0, len(f.Metric))
for _, m := range f.Metric {
if m.Histogram == nil {
continue
}
timestamp := o.Timestamp
if m.TimestampMs != nil {
timestamp = model.TimestampFromUnixNano(*m.TimestampMs * 1000000)
}
for _, q := range m.Histogram.Bucket {
sample := new(model.Sample)
samples = append(samples, sample)
sample.Timestamp = timestamp
sample.Metric = model.Metric{}
metric := sample.Metric
for _, p := range m.Label {
metric[model.LabelName(p.GetName())] = model.LabelValue(p.GetValue())
}
metric[model.LabelName("le")] = model.LabelValue(fmt.Sprint(q.GetUpperBound()))
metric[model.MetricNameLabel] = model.LabelValue(f.GetName() + "_bucket")
sample.Value = model.SampleValue(q.GetCumulativeCount())
}
// TODO: If +Inf bucket is missing, add it.
if m.Histogram.SampleSum != nil {
sum := new(model.Sample)
sum.Timestamp = timestamp
metric := model.Metric{}
for _, p := range m.Label {
metric[model.LabelName(p.GetName())] = model.LabelValue(p.GetValue())
}
metric[model.MetricNameLabel] = model.LabelValue(f.GetName() + "_sum")
sum.Metric = metric
sum.Value = model.SampleValue(m.Histogram.GetSampleSum())
samples = append(samples, sum)
}
if m.Histogram.SampleCount != nil {
count := new(model.Sample)
count.Timestamp = timestamp
metric := model.Metric{}
for _, p := range m.Label {
metric[model.LabelName(p.GetName())] = model.LabelValue(p.GetValue())
}
metric[model.MetricNameLabel] = model.LabelValue(f.GetName() + "_count")
count.Metric = metric
count.Value = model.SampleValue(m.Histogram.GetSampleCount())
samples = append(samples, count)
}
}
return out.Ingest(samples)
}

View File

@ -0,0 +1,153 @@
// Copyright 2013 The Prometheus Authors
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package extraction
import (
"sort"
"strings"
"testing"
"github.com/prometheus/client_golang/model"
)
var testTime = model.Now()
type metricFamilyProcessorScenario struct {
in string
expected, actual []model.Samples
}
func (s *metricFamilyProcessorScenario) Ingest(samples model.Samples) error {
s.actual = append(s.actual, samples)
return nil
}
func (s *metricFamilyProcessorScenario) test(t *testing.T, set int) {
i := strings.NewReader(s.in)
o := &ProcessOptions{
Timestamp: testTime,
}
err := MetricFamilyProcessor.ProcessSingle(i, s, o)
if err != nil {
t.Fatalf("%d. got error: %s", set, err)
}
if len(s.expected) != len(s.actual) {
t.Fatalf("%d. expected length %d, got %d", set, len(s.expected), len(s.actual))
}
for i, expected := range s.expected {
sort.Sort(s.actual[i])
sort.Sort(expected)
if !expected.Equal(s.actual[i]) {
t.Errorf("%d.%d. expected %s, got %s", set, i, expected, s.actual[i])
}
}
}
func TestMetricFamilyProcessor(t *testing.T) {
scenarios := []metricFamilyProcessorScenario{
{
in: "",
},
{
in: "\x8f\x01\n\rrequest_count\x12\x12Number of requests\x18\x00\"0\n#\n\x0fsome_label_name\x12\x10some_label_value\x1a\t\t\x00\x00\x00\x00\x00\x00E\xc0\"6\n)\n\x12another_label_name\x12\x13another_label_value\x1a\t\t\x00\x00\x00\x00\x00\x00U@",
expected: []model.Samples{
model.Samples{
&model.Sample{
Metric: model.Metric{model.MetricNameLabel: "request_count", "some_label_name": "some_label_value"},
Value: -42,
Timestamp: testTime,
},
&model.Sample{
Metric: model.Metric{model.MetricNameLabel: "request_count", "another_label_name": "another_label_value"},
Value: 84,
Timestamp: testTime,
},
},
},
},
{
in: "\xb9\x01\n\rrequest_count\x12\x12Number of requests\x18\x02\"O\n#\n\x0fsome_label_name\x12\x10some_label_value\"(\x1a\x12\t\xaeG\xe1z\x14\xae\xef?\x11\x00\x00\x00\x00\x00\x00E\xc0\x1a\x12\t+\x87\x16\xd9\xce\xf7\xef?\x11\x00\x00\x00\x00\x00\x00U\xc0\"A\n)\n\x12another_label_name\x12\x13another_label_value\"\x14\x1a\x12\t\x00\x00\x00\x00\x00\x00\xe0?\x11\x00\x00\x00\x00\x00\x00$@",
expected: []model.Samples{
model.Samples{
&model.Sample{
Metric: model.Metric{model.MetricNameLabel: "request_count", "some_label_name": "some_label_value", "quantile": "0.99"},
Value: -42,
Timestamp: testTime,
},
&model.Sample{
Metric: model.Metric{model.MetricNameLabel: "request_count", "some_label_name": "some_label_value", "quantile": "0.999"},
Value: -84,
Timestamp: testTime,
},
&model.Sample{
Metric: model.Metric{model.MetricNameLabel: "request_count", "another_label_name": "another_label_value", "quantile": "0.5"},
Value: 10,
Timestamp: testTime,
},
},
},
},
{
in: "\x8d\x01\n\x1drequest_duration_microseconds\x12\x15The response latency.\x18\x04\"S:Q\b\x85\x15\x11\xcd\xcc\xccL\x8f\xcb:A\x1a\v\b{\x11\x00\x00\x00\x00\x00\x00Y@\x1a\f\b\x9c\x03\x11\x00\x00\x00\x00\x00\x00^@\x1a\f\b\xd0\x04\x11\x00\x00\x00\x00\x00\x00b@\x1a\f\b\xf4\v\x11\x9a\x99\x99\x99\x99\x99e@\x1a\f\b\x85\x15\x11\x00\x00\x00\x00\x00\x00\xf0\u007f",
expected: []model.Samples{
model.Samples{
&model.Sample{
Metric: model.Metric{model.MetricNameLabel: "request_duration_microseconds_bucket", "le": "100"},
Value: 123,
Timestamp: testTime,
},
&model.Sample{
Metric: model.Metric{model.MetricNameLabel: "request_duration_microseconds_bucket", "le": "120"},
Value: 412,
Timestamp: testTime,
},
&model.Sample{
Metric: model.Metric{model.MetricNameLabel: "request_duration_microseconds_bucket", "le": "144"},
Value: 592,
Timestamp: testTime,
},
&model.Sample{
Metric: model.Metric{model.MetricNameLabel: "request_duration_microseconds_bucket", "le": "172.8"},
Value: 1524,
Timestamp: testTime,
},
&model.Sample{
Metric: model.Metric{model.MetricNameLabel: "request_duration_microseconds_bucket", "le": "+Inf"},
Value: 2693,
Timestamp: testTime,
},
&model.Sample{
Metric: model.Metric{model.MetricNameLabel: "request_duration_microseconds_sum"},
Value: 1756047.3,
Timestamp: testTime,
},
&model.Sample{
Metric: model.Metric{model.MetricNameLabel: "request_duration_microseconds_count"},
Value: 2693,
Timestamp: testTime,
},
},
},
},
}
for i, scenario := range scenarios {
scenario.test(t, i)
}
}

View File

@ -0,0 +1,84 @@
// Copyright 2013 The Prometheus Authors
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package extraction
import (
"io"
"time"
"github.com/prometheus/client_golang/model"
)
// ProcessOptions dictates how the interpreted stream should be rendered for
// consumption.
type ProcessOptions struct {
// Timestamp is added to each value from the stream that has no explicit
// timestamp set.
Timestamp model.Timestamp
}
// Ingester consumes result streams in whatever way is desired by the user.
type Ingester interface {
Ingest(model.Samples) error
}
// Processor is responsible for decoding the actual message responses from
// stream into a format that can be consumed with the end result written
// to the results channel.
type Processor interface {
// ProcessSingle treats the input as a single self-contained message body and
// transforms it accordingly. It has no support for streaming.
ProcessSingle(in io.Reader, out Ingester, o *ProcessOptions) error
}
// Helper function to convert map[string]string into LabelSet.
//
// NOTE: This should be deleted when support for go 1.0.3 is removed; 1.1 is
// smart enough to unmarshal JSON objects into LabelSet directly.
func labelSet(labels map[string]string) model.LabelSet {
labelset := make(model.LabelSet, len(labels))
for k, v := range labels {
labelset[model.LabelName(k)] = model.LabelValue(v)
}
return labelset
}
// A basic interface only useful in testing contexts for dispensing the time
// in a controlled manner.
type instantProvider interface {
// The current instant.
Now() time.Time
}
// Clock is a simple means for fluently wrapping around standard Go timekeeping
// mechanisms to enhance testability without compromising code readability.
//
// It is sufficient for use on bare initialization. A provider should be
// set only for test contexts. When not provided, it emits the current
// system time.
type clock struct {
// The underlying means through which time is provided, if supplied.
Provider instantProvider
}
// Emit the current instant.
func (t *clock) Now() time.Time {
if t.Provider == nil {
return time.Now()
}
return t.Provider.Now()
}

View File

@ -0,0 +1,127 @@
// Copyright 2013 The Prometheus Authors
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package extraction
import (
"encoding/json"
"fmt"
"io"
"io/ioutil"
"github.com/prometheus/client_golang/model"
)
const (
baseLabels001 = "baseLabels"
counter001 = "counter"
docstring001 = "docstring"
gauge001 = "gauge"
histogram001 = "histogram"
labels001 = "labels"
metric001 = "metric"
type001 = "type"
value001 = "value"
percentile001 = "percentile"
)
// Processor001 is responsible for decoding payloads from protocol version
// 0.0.1.
var Processor001 = &processor001{}
// processor001 is responsible for handling API version 0.0.1.
type processor001 struct{}
// entity001 represents a the JSON structure that 0.0.1 uses.
type entity001 []struct {
BaseLabels map[string]string `json:"baseLabels"`
Docstring string `json:"docstring"`
Metric struct {
MetricType string `json:"type"`
Value []struct {
Labels map[string]string `json:"labels"`
Value interface{} `json:"value"`
} `json:"value"`
} `json:"metric"`
}
func (p *processor001) ProcessSingle(in io.Reader, out Ingester, o *ProcessOptions) error {
// TODO(matt): Replace with plain-jane JSON unmarshalling.
buffer, err := ioutil.ReadAll(in)
if err != nil {
return err
}
entities := entity001{}
if err = json.Unmarshal(buffer, &entities); err != nil {
return err
}
// TODO(matt): This outer loop is a great basis for parallelization.
pendingSamples := model.Samples{}
for _, entity := range entities {
for _, value := range entity.Metric.Value {
labels := labelSet(entity.BaseLabels).Merge(labelSet(value.Labels))
switch entity.Metric.MetricType {
case gauge001, counter001:
sampleValue, ok := value.Value.(float64)
if !ok {
return fmt.Errorf("could not convert value from %s %s to float64", entity, value)
}
pendingSamples = append(pendingSamples, &model.Sample{
Metric: model.Metric(labels),
Timestamp: o.Timestamp,
Value: model.SampleValue(sampleValue),
})
break
case histogram001:
sampleValue, ok := value.Value.(map[string]interface{})
if !ok {
return fmt.Errorf("could not convert value from %q to a map[string]interface{}", value.Value)
}
for percentile, percentileValue := range sampleValue {
individualValue, ok := percentileValue.(float64)
if !ok {
return fmt.Errorf("could not convert value from %q to a float64", percentileValue)
}
childMetric := make(map[model.LabelName]model.LabelValue, len(labels)+1)
for k, v := range labels {
childMetric[k] = v
}
childMetric[model.LabelName(percentile001)] = model.LabelValue(percentile)
pendingSamples = append(pendingSamples, &model.Sample{
Metric: model.Metric(childMetric),
Timestamp: o.Timestamp,
Value: model.SampleValue(individualValue),
})
}
break
}
}
}
if len(pendingSamples) > 0 {
return out.Ingest(pendingSamples)
}
return nil
}

View File

@ -0,0 +1,186 @@
// Copyright 2013 The Prometheus Authors
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package extraction
import (
"errors"
"os"
"path"
"sort"
"testing"
"github.com/prometheus/prometheus/utility/test"
"github.com/prometheus/client_golang/model"
)
var test001Time = model.Now()
type testProcessor001ProcessScenario struct {
in string
expected, actual []model.Samples
err error
}
func (s *testProcessor001ProcessScenario) Ingest(samples model.Samples) error {
s.actual = append(s.actual, samples)
return nil
}
func (s *testProcessor001ProcessScenario) test(t testing.TB, set int) {
reader, err := os.Open(path.Join("fixtures", s.in))
if err != nil {
t.Fatalf("%d. couldn't open scenario input file %s: %s", set, s.in, err)
}
options := &ProcessOptions{
Timestamp: test001Time,
}
if err := Processor001.ProcessSingle(reader, s, options); !test.ErrorEqual(s.err, err) {
t.Fatalf("%d. expected err of %s, got %s", set, s.err, err)
}
if len(s.actual) != len(s.expected) {
t.Fatalf("%d. expected output length of %d, got %d", set, len(s.expected), len(s.actual))
}
for i, expected := range s.expected {
sort.Sort(s.actual[i])
sort.Sort(expected)
if !expected.Equal(s.actual[i]) {
t.Errorf("%d.%d. expected %s, got %s", set, i, expected, s.actual[i])
}
}
}
func testProcessor001Process(t testing.TB) {
var scenarios = []testProcessor001ProcessScenario{
{
in: "empty.json",
err: errors.New("unexpected end of JSON input"),
},
{
in: "test0_0_1-0_0_2.json",
expected: []model.Samples{
model.Samples{
&model.Sample{
Metric: model.Metric{"service": "zed", model.MetricNameLabel: "rpc_calls_total", "job": "batch_job"},
Value: 25,
Timestamp: test001Time,
},
&model.Sample{
Metric: model.Metric{"service": "bar", model.MetricNameLabel: "rpc_calls_total", "job": "batch_job"},
Value: 25,
Timestamp: test001Time,
},
&model.Sample{
Metric: model.Metric{"service": "foo", model.MetricNameLabel: "rpc_calls_total", "job": "batch_job"},
Value: 25,
Timestamp: test001Time,
},
&model.Sample{
Metric: model.Metric{"percentile": "0.010000", model.MetricNameLabel: "rpc_latency_microseconds", "service": "zed"},
Value: 0.0459814091918713,
Timestamp: test001Time,
},
&model.Sample{
Metric: model.Metric{"percentile": "0.010000", model.MetricNameLabel: "rpc_latency_microseconds", "service": "bar"},
Value: 78.48563317257356,
Timestamp: test001Time,
},
&model.Sample{
Metric: model.Metric{"percentile": "0.010000", model.MetricNameLabel: "rpc_latency_microseconds", "service": "foo"},
Value: 15.890724674774395,
Timestamp: test001Time,
},
&model.Sample{
Metric: model.Metric{"percentile": "0.050000", model.MetricNameLabel: "rpc_latency_microseconds", "service": "zed"},
Value: 0.0459814091918713,
Timestamp: test001Time,
},
&model.Sample{
Metric: model.Metric{"percentile": "0.050000", model.MetricNameLabel: "rpc_latency_microseconds", "service": "bar"},
Value: 78.48563317257356,
Timestamp: test001Time,
},
&model.Sample{
Metric: model.Metric{"percentile": "0.050000", model.MetricNameLabel: "rpc_latency_microseconds", "service": "foo"},
Value: 15.890724674774395,
Timestamp: test001Time,
},
&model.Sample{
Metric: model.Metric{"percentile": "0.500000", model.MetricNameLabel: "rpc_latency_microseconds", "service": "zed"},
Value: 0.6120456642749681,
Timestamp: test001Time,
},
&model.Sample{
Metric: model.Metric{"percentile": "0.500000", model.MetricNameLabel: "rpc_latency_microseconds", "service": "bar"},
Value: 97.31798360385088,
Timestamp: test001Time,
},
&model.Sample{
Metric: model.Metric{"percentile": "0.500000", model.MetricNameLabel: "rpc_latency_microseconds", "service": "foo"},
Value: 84.63044031436561,
Timestamp: test001Time,
},
&model.Sample{
Metric: model.Metric{"percentile": "0.900000", model.MetricNameLabel: "rpc_latency_microseconds", "service": "zed"},
Value: 1.355915069887731,
Timestamp: test001Time,
},
&model.Sample{
Metric: model.Metric{"percentile": "0.900000", model.MetricNameLabel: "rpc_latency_microseconds", "service": "bar"},
Value: 109.89202084295582,
Timestamp: test001Time,
},
&model.Sample{
Metric: model.Metric{"percentile": "0.900000", model.MetricNameLabel: "rpc_latency_microseconds", "service": "foo"},
Value: 160.21100853053224,
Timestamp: test001Time,
},
&model.Sample{
Metric: model.Metric{"percentile": "0.990000", model.MetricNameLabel: "rpc_latency_microseconds", "service": "zed"},
Value: 1.772733213161236,
Timestamp: test001Time,
},
&model.Sample{
Metric: model.Metric{"percentile": "0.990000", model.MetricNameLabel: "rpc_latency_microseconds", "service": "bar"},
Value: 109.99626121011262,
Timestamp: test001Time,
},
&model.Sample{
Metric: model.Metric{"percentile": "0.990000", model.MetricNameLabel: "rpc_latency_microseconds", "service": "foo"},
Value: 172.49828748957728,
Timestamp: test001Time,
},
},
},
},
}
for i, scenario := range scenarios {
scenario.test(t, i)
}
}
func TestProcessor001Process(t *testing.T) {
testProcessor001Process(t)
}
func BenchmarkProcessor001Process(b *testing.B) {
for i := 0; i < b.N; i++ {
testProcessor001Process(b)
}
}

View File

@ -0,0 +1,106 @@
// Copyright 2013 The Prometheus Authors
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package extraction
import (
"encoding/json"
"fmt"
"io"
"github.com/prometheus/client_golang/model"
)
// Processor002 is responsible for decoding payloads from protocol version
// 0.0.2.
var Processor002 = &processor002{}
type histogram002 struct {
Labels map[string]string `json:"labels"`
Values map[string]model.SampleValue `json:"value"`
}
type counter002 struct {
Labels map[string]string `json:"labels"`
Value model.SampleValue `json:"value"`
}
type processor002 struct{}
func (p *processor002) ProcessSingle(in io.Reader, out Ingester, o *ProcessOptions) error {
// Processor for telemetry schema version 0.0.2.
// container for telemetry data
var entities []struct {
BaseLabels map[string]string `json:"baseLabels"`
Docstring string `json:"docstring"`
Metric struct {
Type string `json:"type"`
Values json.RawMessage `json:"value"`
} `json:"metric"`
}
if err := json.NewDecoder(in).Decode(&entities); err != nil {
return err
}
pendingSamples := model.Samples{}
for _, entity := range entities {
switch entity.Metric.Type {
case "counter", "gauge":
var values []counter002
if err := json.Unmarshal(entity.Metric.Values, &values); err != nil {
return fmt.Errorf("could not extract %s value: %s", entity.Metric.Type, err)
}
for _, counter := range values {
labels := labelSet(entity.BaseLabels).Merge(labelSet(counter.Labels))
pendingSamples = append(pendingSamples, &model.Sample{
Metric: model.Metric(labels),
Timestamp: o.Timestamp,
Value: counter.Value,
})
}
case "histogram":
var values []histogram002
if err := json.Unmarshal(entity.Metric.Values, &values); err != nil {
return fmt.Errorf("could not extract %s value: %s", entity.Metric.Type, err)
}
for _, histogram := range values {
for percentile, value := range histogram.Values {
labels := labelSet(entity.BaseLabels).Merge(labelSet(histogram.Labels))
labels[model.LabelName("percentile")] = model.LabelValue(percentile)
pendingSamples = append(pendingSamples, &model.Sample{
Metric: model.Metric(labels),
Timestamp: o.Timestamp,
Value: value,
})
}
}
default:
return fmt.Errorf("unknown metric type %q", entity.Metric.Type)
}
}
if len(pendingSamples) > 0 {
return out.Ingest(pendingSamples)
}
return nil
}

View File

@ -0,0 +1,226 @@
// Copyright 2013 The Prometheus Authors
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package extraction
import (
"bytes"
"errors"
"io/ioutil"
"os"
"path"
"runtime"
"sort"
"testing"
"github.com/prometheus/prometheus/utility/test"
"github.com/prometheus/client_golang/model"
)
var test002Time = model.Now()
type testProcessor002ProcessScenario struct {
in string
expected, actual []model.Samples
err error
}
func (s *testProcessor002ProcessScenario) Ingest(samples model.Samples) error {
s.actual = append(s.actual, samples)
return nil
}
func (s *testProcessor002ProcessScenario) test(t testing.TB, set int) {
reader, err := os.Open(path.Join("fixtures", s.in))
if err != nil {
t.Fatalf("%d. couldn't open scenario input file %s: %s", set, s.in, err)
}
options := &ProcessOptions{
Timestamp: test002Time,
}
if err := Processor002.ProcessSingle(reader, s, options); !test.ErrorEqual(s.err, err) {
t.Fatalf("%d. expected err of %s, got %s", set, s.err, err)
}
if len(s.actual) != len(s.expected) {
t.Fatalf("%d. expected output length of %d, got %d", set, len(s.expected), len(s.actual))
}
for i, expected := range s.expected {
sort.Sort(s.actual[i])
sort.Sort(expected)
if !expected.Equal(s.actual[i]) {
t.Fatalf("%d.%d. expected %s, got %s", set, i, expected, s.actual[i])
}
}
}
func testProcessor002Process(t testing.TB) {
var scenarios = []testProcessor002ProcessScenario{
{
in: "empty.json",
err: errors.New("EOF"),
},
{
in: "test0_0_1-0_0_2.json",
expected: []model.Samples{
model.Samples{
&model.Sample{
Metric: model.Metric{"service": "zed", model.MetricNameLabel: "rpc_calls_total", "job": "batch_job"},
Value: 25,
Timestamp: test002Time,
},
&model.Sample{
Metric: model.Metric{"service": "bar", model.MetricNameLabel: "rpc_calls_total", "job": "batch_job"},
Value: 25,
Timestamp: test002Time,
},
&model.Sample{
Metric: model.Metric{"service": "foo", model.MetricNameLabel: "rpc_calls_total", "job": "batch_job"},
Value: 25,
Timestamp: test002Time,
},
&model.Sample{
Metric: model.Metric{"percentile": "0.010000", model.MetricNameLabel: "rpc_latency_microseconds", "service": "zed"},
Value: 0.0459814091918713,
Timestamp: test002Time,
},
&model.Sample{
Metric: model.Metric{"percentile": "0.010000", model.MetricNameLabel: "rpc_latency_microseconds", "service": "bar"},
Value: 78.48563317257356,
Timestamp: test002Time,
},
&model.Sample{
Metric: model.Metric{"percentile": "0.010000", model.MetricNameLabel: "rpc_latency_microseconds", "service": "foo"},
Value: 15.890724674774395,
Timestamp: test002Time,
},
&model.Sample{
Metric: model.Metric{"percentile": "0.050000", model.MetricNameLabel: "rpc_latency_microseconds", "service": "zed"},
Value: 0.0459814091918713,
Timestamp: test002Time,
},
&model.Sample{
Metric: model.Metric{"percentile": "0.050000", model.MetricNameLabel: "rpc_latency_microseconds", "service": "bar"},
Value: 78.48563317257356,
Timestamp: test002Time,
},
&model.Sample{
Metric: model.Metric{"percentile": "0.050000", model.MetricNameLabel: "rpc_latency_microseconds", "service": "foo"},
Value: 15.890724674774395,
Timestamp: test002Time,
},
&model.Sample{
Metric: model.Metric{"percentile": "0.500000", model.MetricNameLabel: "rpc_latency_microseconds", "service": "zed"},
Value: 0.6120456642749681,
Timestamp: test002Time,
},
&model.Sample{
Metric: model.Metric{"percentile": "0.500000", model.MetricNameLabel: "rpc_latency_microseconds", "service": "bar"},
Value: 97.31798360385088,
Timestamp: test002Time,
},
&model.Sample{
Metric: model.Metric{"percentile": "0.500000", model.MetricNameLabel: "rpc_latency_microseconds", "service": "foo"},
Value: 84.63044031436561,
Timestamp: test002Time,
},
&model.Sample{
Metric: model.Metric{"percentile": "0.900000", model.MetricNameLabel: "rpc_latency_microseconds", "service": "zed"},
Value: 1.355915069887731,
Timestamp: test002Time,
},
&model.Sample{
Metric: model.Metric{"percentile": "0.900000", model.MetricNameLabel: "rpc_latency_microseconds", "service": "bar"},
Value: 109.89202084295582,
Timestamp: test002Time,
},
&model.Sample{
Metric: model.Metric{"percentile": "0.900000", model.MetricNameLabel: "rpc_latency_microseconds", "service": "foo"},
Value: 160.21100853053224,
Timestamp: test002Time,
},
&model.Sample{
Metric: model.Metric{"percentile": "0.990000", model.MetricNameLabel: "rpc_latency_microseconds", "service": "zed"},
Value: 1.772733213161236,
Timestamp: test002Time,
},
&model.Sample{
Metric: model.Metric{"percentile": "0.990000", model.MetricNameLabel: "rpc_latency_microseconds", "service": "bar"},
Value: 109.99626121011262,
Timestamp: test002Time,
},
&model.Sample{
Metric: model.Metric{"percentile": "0.990000", model.MetricNameLabel: "rpc_latency_microseconds", "service": "foo"},
Value: 172.49828748957728,
Timestamp: test002Time,
},
},
},
},
}
for i, scenario := range scenarios {
scenario.test(t, i)
}
}
func TestProcessor002Process(t *testing.T) {
testProcessor002Process(t)
}
func BenchmarkProcessor002Process(b *testing.B) {
b.StopTimer()
pre := runtime.MemStats{}
runtime.ReadMemStats(&pre)
b.StartTimer()
for i := 0; i < b.N; i++ {
testProcessor002Process(b)
}
post := runtime.MemStats{}
runtime.ReadMemStats(&post)
allocated := post.TotalAlloc - pre.TotalAlloc
b.Logf("Allocated %d at %f per cycle with %d cycles.", allocated, float64(allocated)/float64(b.N), b.N)
}
func BenchmarkProcessor002ParseOnly(b *testing.B) {
b.StopTimer()
data, err := ioutil.ReadFile("fixtures/test0_0_1-0_0_2-large.json")
if err != nil {
b.Fatal(err)
}
ing := fakeIngester{}
b.StartTimer()
for i := 0; i < b.N; i++ {
if err := Processor002.ProcessSingle(bytes.NewReader(data), ing, &ProcessOptions{}); err != nil {
b.Fatal(err)
}
}
}
type fakeIngester struct{}
func (i fakeIngester) Ingest(model.Samples) error {
return nil
}

View File

@ -0,0 +1,40 @@
// Copyright 2014 The Prometheus Authors
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package extraction
import (
"io"
"github.com/prometheus/client_golang/text"
)
type processor004 struct{}
// Processor004 s responsible for decoding payloads from the text based variety
// of protocol version 0.0.4.
var Processor004 = &processor004{}
func (t *processor004) ProcessSingle(i io.Reader, out Ingester, o *ProcessOptions) error {
var parser text.Parser
metricFamilies, err := parser.TextToMetricFamilies(i)
if err != nil {
return err
}
for _, metricFamily := range metricFamilies {
if err := extractMetricFamily(out, o, metricFamily); err != nil {
return err
}
}
return nil
}

View File

@ -0,0 +1,100 @@
// Copyright 2014 The Prometheus Authors
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package extraction
import (
"sort"
"strings"
"testing"
"github.com/prometheus/client_golang/model"
)
var (
ts = model.Now()
in = `
# Only a quite simple scenario with two metric families.
# More complicated tests of the parser itself can be found in the text package.
# TYPE mf2 counter
mf2 3
mf1{label="value1"} -3.14 123456
mf1{label="value2"} 42
mf2 4
`
out = map[model.LabelValue]model.Samples{
"mf1": model.Samples{
&model.Sample{
Metric: model.Metric{model.MetricNameLabel: "mf1", "label": "value1"},
Value: -3.14,
Timestamp: 123456,
},
&model.Sample{
Metric: model.Metric{model.MetricNameLabel: "mf1", "label": "value2"},
Value: 42,
Timestamp: ts,
},
},
"mf2": model.Samples{
&model.Sample{
Metric: model.Metric{model.MetricNameLabel: "mf2"},
Value: 3,
Timestamp: ts,
},
&model.Sample{
Metric: model.Metric{model.MetricNameLabel: "mf2"},
Value: 4,
Timestamp: ts,
},
},
}
)
type testIngester struct {
results []model.Samples
}
func (i *testIngester) Ingest(s model.Samples) error {
i.results = append(i.results, s)
return nil
}
func TestTextProcessor(t *testing.T) {
var ingester testIngester
i := strings.NewReader(in)
o := &ProcessOptions{
Timestamp: ts,
}
err := Processor004.ProcessSingle(i, &ingester, o)
if err != nil {
t.Fatal(err)
}
if expected, got := len(out), len(ingester.results); expected != got {
t.Fatalf("Expected length %d, got %d", expected, got)
}
for _, r := range ingester.results {
expected, ok := out[r[0].Metric[model.MetricNameLabel]]
if !ok {
t.Fatalf(
"Unexpected metric name %q",
r[0].Metric[model.MetricNameLabel],
)
}
sort.Sort(expected)
sort.Sort(r)
if !expected.Equal(r) {
t.Errorf("expected %s, got %s", expected, r)
}
}
}

View File

@ -0,0 +1,110 @@
// Copyright 2013 The Prometheus Authors
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package model
import (
"fmt"
"strconv"
)
// Fingerprint provides a hash-capable representation of a Metric.
// For our purposes, FNV-1A 64-bit is used.
type Fingerprint uint64
func (f Fingerprint) String() string {
return fmt.Sprintf("%016x", uint64(f))
}
// Less implements sort.Interface.
func (f Fingerprint) Less(o Fingerprint) bool {
return f < o
}
// Equal implements sort.Interface.
func (f Fingerprint) Equal(o Fingerprint) bool {
return f == o
}
// LoadFromString transforms a string representation into a Fingerprint.
func (f *Fingerprint) LoadFromString(s string) error {
num, err := strconv.ParseUint(s, 16, 64)
if err != nil {
return err
}
*f = Fingerprint(num)
return nil
}
// Fingerprints represents a collection of Fingerprint subject to a given
// natural sorting scheme. It implements sort.Interface.
type Fingerprints []Fingerprint
// Len implements sort.Interface.
func (f Fingerprints) Len() int {
return len(f)
}
// Less implements sort.Interface.
func (f Fingerprints) Less(i, j int) bool {
return f[i] < f[j]
}
// Swap implements sort.Interface.
func (f Fingerprints) Swap(i, j int) {
f[i], f[j] = f[j], f[i]
}
// FingerprintSet is a set of Fingerprints.
type FingerprintSet map[Fingerprint]struct{}
// Equal returns true if both sets contain the same elements (and not more).
func (s FingerprintSet) Equal(o FingerprintSet) bool {
if len(s) != len(o) {
return false
}
for k := range s {
if _, ok := o[k]; !ok {
return false
}
}
return true
}
// Intersection returns the elements contained in both sets.
func (s FingerprintSet) Intersection(o FingerprintSet) FingerprintSet {
myLength, otherLength := len(s), len(o)
if myLength == 0 || otherLength == 0 {
return FingerprintSet{}
}
subSet := s
superSet := o
if otherLength < myLength {
subSet = o
superSet = s
}
out := FingerprintSet{}
for k := range subSet {
if _, ok := superSet[k]; ok {
out[k] = struct{}{}
}
}
return out
}

View File

@ -0,0 +1,63 @@
// Copyright 2013 The Prometheus Authors
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package model
import (
"strings"
)
const (
// ExporterLabelPrefix is the label name prefix to prepend if a
// synthetic label is already present in the exported metrics.
ExporterLabelPrefix LabelName = "exporter_"
// MetricNameLabel is the label name indicating the metric name of a
// timeseries.
MetricNameLabel LabelName = "__name__"
// ReservedLabelPrefix is a prefix which is not legal in user-supplied
// label names.
ReservedLabelPrefix = "__"
// JobLabel is the label name indicating the job from which a timeseries
// was scraped.
JobLabel LabelName = "job"
)
// A LabelName is a key for a LabelSet or Metric. It has a value associated
// therewith.
type LabelName string
// LabelNames is a sortable LabelName slice. In implements sort.Interface.
type LabelNames []LabelName
func (l LabelNames) Len() int {
return len(l)
}
func (l LabelNames) Less(i, j int) bool {
return l[i] < l[j]
}
func (l LabelNames) Swap(i, j int) {
l[i], l[j] = l[j], l[i]
}
func (l LabelNames) String() string {
labelStrings := make([]string, 0, len(l))
for _, label := range l {
labelStrings = append(labelStrings, string(label))
}
return strings.Join(labelStrings, ", ")
}

View File

@ -0,0 +1,55 @@
// Copyright 2013 The Prometheus Authors
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package model
import (
"sort"
"testing"
)
func testLabelNames(t testing.TB) {
var scenarios = []struct {
in LabelNames
out LabelNames
}{
{
in: LabelNames{"ZZZ", "zzz"},
out: LabelNames{"ZZZ", "zzz"},
},
{
in: LabelNames{"aaa", "AAA"},
out: LabelNames{"AAA", "aaa"},
},
}
for i, scenario := range scenarios {
sort.Sort(scenario.in)
for j, expected := range scenario.out {
if expected != scenario.in[j] {
t.Errorf("%d.%d expected %s, got %s", i, j, expected, scenario.in[j])
}
}
}
}
func TestLabelNames(t *testing.T) {
testLabelNames(t)
}
func BenchmarkLabelNames(b *testing.B) {
for i := 0; i < b.N; i++ {
testLabelNames(b)
}
}

View File

@ -0,0 +1,64 @@
// Copyright 2013 The Prometheus Authors
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package model
import (
"fmt"
"sort"
"strings"
)
// A LabelSet is a collection of LabelName and LabelValue pairs. The LabelSet
// may be fully-qualified down to the point where it may resolve to a single
// Metric in the data store or not. All operations that occur within the realm
// of a LabelSet can emit a vector of Metric entities to which the LabelSet may
// match.
type LabelSet map[LabelName]LabelValue
// Merge is a helper function to non-destructively merge two label sets.
func (l LabelSet) Merge(other LabelSet) LabelSet {
result := make(LabelSet, len(l))
for k, v := range l {
result[k] = v
}
for k, v := range other {
result[k] = v
}
return result
}
func (l LabelSet) String() string {
labelStrings := make([]string, 0, len(l))
for label, value := range l {
labelStrings = append(labelStrings, fmt.Sprintf("%s=%q", label, value))
}
switch len(labelStrings) {
case 0:
return ""
default:
sort.Strings(labelStrings)
return fmt.Sprintf("{%s}", strings.Join(labelStrings, ", "))
}
}
// MergeFromMetric merges Metric into this LabelSet.
func (l LabelSet) MergeFromMetric(m Metric) {
for k, v := range m {
l[k] = v
}
}

View File

@ -0,0 +1,36 @@
// Copyright 2013 The Prometheus Authors
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package model
import (
"sort"
)
// A LabelValue is an associated value for a LabelName.
type LabelValue string
// LabelValues is a sortable LabelValue slice. It implements sort.Interface.
type LabelValues []LabelValue
func (l LabelValues) Len() int {
return len(l)
}
func (l LabelValues) Less(i, j int) bool {
return sort.StringsAreSorted([]string{string(l[i]), string(l[j])})
}
func (l LabelValues) Swap(i, j int) {
l[i], l[j] = l[j], l[i]
}

View File

@ -0,0 +1,55 @@
// Copyright 2013 The Prometheus Authors
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package model
import (
"sort"
"testing"
)
func testLabelValues(t testing.TB) {
var scenarios = []struct {
in LabelValues
out LabelValues
}{
{
in: LabelValues{"ZZZ", "zzz"},
out: LabelValues{"ZZZ", "zzz"},
},
{
in: LabelValues{"aaa", "AAA"},
out: LabelValues{"AAA", "aaa"},
},
}
for i, scenario := range scenarios {
sort.Sort(scenario.in)
for j, expected := range scenario.out {
if expected != scenario.in[j] {
t.Errorf("%d.%d expected %s, got %s", i, j, expected, scenario.in[j])
}
}
}
}
func TestLabelValues(t *testing.T) {
testLabelValues(t)
}
func BenchmarkLabelValues(b *testing.B) {
for i := 0; i < b.N; i++ {
testLabelValues(b)
}
}

Some files were not shown because too many files have changed in this diff Show More