diff --git a/agent/acl_test.go b/agent/acl_test.go index ef2a74f193..dec79fefa5 100644 --- a/agent/acl_test.go +++ b/agent/acl_test.go @@ -17,6 +17,7 @@ import ( "github.com/hashicorp/consul/types" "github.com/hashicorp/go-hclog" "github.com/hashicorp/serf/serf" + "google.golang.org/grpc" "github.com/stretchr/testify/require" ) @@ -183,6 +184,9 @@ func (a *TestACLAgent) Stats() map[string]map[string]string { func (a *TestACLAgent) ReloadConfig(config *consul.Config) error { return fmt.Errorf("Unimplemented") } +func (a *TestACLAgent) GRPCConn() (*grpc.ClientConn, error) { + return nil, fmt.Errorf("Unimplemented") +} func TestACL_Version8(t *testing.T) { t.Parallel() diff --git a/agent/agent.go b/agent/agent.go index 2567857744..6197d9bd3f 100644 --- a/agent/agent.go +++ b/agent/agent.go @@ -139,6 +139,7 @@ type delegate interface { ResolveTokenAndDefaultMeta(secretID string, entMeta *structs.EnterpriseMeta, authzContext *acl.AuthorizerContext) (acl.Authorizer, error) ResolveIdentityFromToken(secretID string) (bool, structs.ACLIdentity, error) RPC(method string, args interface{}, reply interface{}) error + GRPCConn() (*grpc.ClientConn, error) ACLsEnabled() bool UseLegacyACLs() bool SnapshotRPC(args *structs.SnapshotRequest, in io.Reader, out io.Writer, replyFn structs.SnapshotReplyFn) error @@ -1407,6 +1408,8 @@ func (a *Agent) consulConfig() (*consul.Config, error) { base.ConfigEntryBootstrap = a.config.ConfigEntryBootstrap + base.GRPCEnabled = a.config.EnableBackendStreaming + return base, nil } diff --git a/agent/agentpb/test.pb.binary.go b/agent/agentpb/test.pb.binary.go new file mode 100644 index 0000000000..20560e13b1 --- /dev/null +++ b/agent/agentpb/test.pb.binary.go @@ -0,0 +1,28 @@ +// Code generated by protoc-gen-go-binary. DO NOT EDIT. +// source: test.proto + +package agentpb + +import ( + "github.com/golang/protobuf/proto" +) + +// MarshalBinary implements encoding.BinaryMarshaler +func (msg *TestRequest) MarshalBinary() ([]byte, error) { + return proto.Marshal(msg) +} + +// UnmarshalBinary implements encoding.BinaryUnmarshaler +func (msg *TestRequest) UnmarshalBinary(b []byte) error { + return proto.Unmarshal(b, msg) +} + +// MarshalBinary implements encoding.BinaryMarshaler +func (msg *TestResponse) MarshalBinary() ([]byte, error) { + return proto.Marshal(msg) +} + +// UnmarshalBinary implements encoding.BinaryUnmarshaler +func (msg *TestResponse) UnmarshalBinary(b []byte) error { + return proto.Unmarshal(b, msg) +} diff --git a/agent/agentpb/test.pb.go b/agent/agentpb/test.pb.go new file mode 100644 index 0000000000..9a89a22b36 --- /dev/null +++ b/agent/agentpb/test.pb.go @@ -0,0 +1,604 @@ +// Code generated by protoc-gen-gogo. DO NOT EDIT. +// source: test.proto + +package agentpb + +import ( + context "context" + fmt "fmt" + proto "github.com/golang/protobuf/proto" + grpc "google.golang.org/grpc" + io "io" + math "math" +) + +// Reference imports to suppress errors if they are not otherwise used. +var _ = proto.Marshal +var _ = fmt.Errorf +var _ = math.Inf + +// This is a compile-time assertion to ensure that this generated file +// is compatible with the proto package it is being compiled against. +// A compilation error at this line likely means your copy of the +// proto package needs to be updated. +const _ = proto.ProtoPackageIsVersion2 // please upgrade the proto package + +type TestRequest struct { + Datacenter string `protobuf:"bytes,1,opt,name=Datacenter,proto3" json:"Datacenter,omitempty"` + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` +} + +func (m *TestRequest) Reset() { *m = TestRequest{} } +func (m *TestRequest) String() string { return proto.CompactTextString(m) } +func (*TestRequest) ProtoMessage() {} +func (*TestRequest) Descriptor() ([]byte, []int) { + return fileDescriptor_c161fcfdc0c3ff1e, []int{0} +} +func (m *TestRequest) XXX_Unmarshal(b []byte) error { + return m.Unmarshal(b) +} +func (m *TestRequest) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + if deterministic { + return xxx_messageInfo_TestRequest.Marshal(b, m, deterministic) + } else { + b = b[:cap(b)] + n, err := m.MarshalTo(b) + if err != nil { + return nil, err + } + return b[:n], nil + } +} +func (m *TestRequest) XXX_Merge(src proto.Message) { + xxx_messageInfo_TestRequest.Merge(m, src) +} +func (m *TestRequest) XXX_Size() int { + return m.Size() +} +func (m *TestRequest) XXX_DiscardUnknown() { + xxx_messageInfo_TestRequest.DiscardUnknown(m) +} + +var xxx_messageInfo_TestRequest proto.InternalMessageInfo + +func (m *TestRequest) GetDatacenter() string { + if m != nil { + return m.Datacenter + } + return "" +} + +type TestResponse struct { + ServerName string `protobuf:"bytes,1,opt,name=ServerName,proto3" json:"ServerName,omitempty"` + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` +} + +func (m *TestResponse) Reset() { *m = TestResponse{} } +func (m *TestResponse) String() string { return proto.CompactTextString(m) } +func (*TestResponse) ProtoMessage() {} +func (*TestResponse) Descriptor() ([]byte, []int) { + return fileDescriptor_c161fcfdc0c3ff1e, []int{1} +} +func (m *TestResponse) XXX_Unmarshal(b []byte) error { + return m.Unmarshal(b) +} +func (m *TestResponse) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + if deterministic { + return xxx_messageInfo_TestResponse.Marshal(b, m, deterministic) + } else { + b = b[:cap(b)] + n, err := m.MarshalTo(b) + if err != nil { + return nil, err + } + return b[:n], nil + } +} +func (m *TestResponse) XXX_Merge(src proto.Message) { + xxx_messageInfo_TestResponse.Merge(m, src) +} +func (m *TestResponse) XXX_Size() int { + return m.Size() +} +func (m *TestResponse) XXX_DiscardUnknown() { + xxx_messageInfo_TestResponse.DiscardUnknown(m) +} + +var xxx_messageInfo_TestResponse proto.InternalMessageInfo + +func (m *TestResponse) GetServerName() string { + if m != nil { + return m.ServerName + } + return "" +} + +func init() { + proto.RegisterType((*TestRequest)(nil), "agentpb.TestRequest") + proto.RegisterType((*TestResponse)(nil), "agentpb.TestResponse") +} + +func init() { proto.RegisterFile("test.proto", fileDescriptor_c161fcfdc0c3ff1e) } + +var fileDescriptor_c161fcfdc0c3ff1e = []byte{ + // 160 bytes of a gzipped FileDescriptorProto + 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0xe2, 0xe2, 0x2a, 0x49, 0x2d, 0x2e, + 0xd1, 0x2b, 0x28, 0xca, 0x2f, 0xc9, 0x17, 0x62, 0x4f, 0x4c, 0x4f, 0xcd, 0x2b, 0x29, 0x48, 0x52, + 0xd2, 0xe5, 0xe2, 0x0e, 0x49, 0x2d, 0x2e, 0x09, 0x4a, 0x2d, 0x2c, 0x4d, 0x2d, 0x2e, 0x11, 0x92, + 0xe3, 0xe2, 0x72, 0x49, 0x2c, 0x49, 0x4c, 0x4e, 0xcd, 0x2b, 0x49, 0x2d, 0x92, 0x60, 0x54, 0x60, + 0xd4, 0xe0, 0x0c, 0x42, 0x12, 0x51, 0xd2, 0xe3, 0xe2, 0x81, 0x28, 0x2f, 0x2e, 0xc8, 0xcf, 0x2b, + 0x4e, 0x05, 0xa9, 0x0f, 0x4e, 0x2d, 0x2a, 0x4b, 0x2d, 0xf2, 0x4b, 0xcc, 0x4d, 0x85, 0xa9, 0x47, + 0x88, 0x18, 0xd9, 0x72, 0xb1, 0x80, 0xd4, 0x0b, 0x99, 0x42, 0x69, 0x11, 0x3d, 0xa8, 0xc5, 0x7a, + 0x48, 0xb6, 0x4a, 0x89, 0xa2, 0x89, 0x42, 0x0c, 0x57, 0x62, 0x70, 0x12, 0x38, 0xf1, 0x48, 0x8e, + 0xf1, 0xc2, 0x23, 0x39, 0xc6, 0x07, 0x8f, 0xe4, 0x18, 0x67, 0x3c, 0x96, 0x63, 0x48, 0x62, 0x03, + 0xbb, 0xdf, 0x18, 0x10, 0x00, 0x00, 0xff, 0xff, 0x94, 0xb6, 0xef, 0xa2, 0xcd, 0x00, 0x00, 0x00, +} + +// Reference imports to suppress errors if they are not otherwise used. +var _ context.Context +var _ grpc.ClientConn + +// This is a compile-time assertion to ensure that this generated file +// is compatible with the grpc package it is being compiled against. +const _ = grpc.SupportPackageIsVersion4 + +// TestClient is the client API for Test service. +// +// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://godoc.org/google.golang.org/grpc#ClientConn.NewStream. +type TestClient interface { + // Test is only used internally for testing connectivity/balancing logic. + Test(ctx context.Context, in *TestRequest, opts ...grpc.CallOption) (*TestResponse, error) +} + +type testClient struct { + cc *grpc.ClientConn +} + +func NewTestClient(cc *grpc.ClientConn) TestClient { + return &testClient{cc} +} + +func (c *testClient) Test(ctx context.Context, in *TestRequest, opts ...grpc.CallOption) (*TestResponse, error) { + out := new(TestResponse) + err := c.cc.Invoke(ctx, "/agentpb.Test/Test", in, out, opts...) + if err != nil { + return nil, err + } + return out, nil +} + +// TestServer is the server API for Test service. +type TestServer interface { + // Test is only used internally for testing connectivity/balancing logic. + Test(context.Context, *TestRequest) (*TestResponse, error) +} + +func RegisterTestServer(s *grpc.Server, srv TestServer) { + s.RegisterService(&_Test_serviceDesc, srv) +} + +func _Test_Test_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(TestRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(TestServer).Test(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: "/agentpb.Test/Test", + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(TestServer).Test(ctx, req.(*TestRequest)) + } + return interceptor(ctx, in, info, handler) +} + +var _Test_serviceDesc = grpc.ServiceDesc{ + ServiceName: "agentpb.Test", + HandlerType: (*TestServer)(nil), + Methods: []grpc.MethodDesc{ + { + MethodName: "Test", + Handler: _Test_Test_Handler, + }, + }, + Streams: []grpc.StreamDesc{}, + Metadata: "test.proto", +} + +func (m *TestRequest) Marshal() (dAtA []byte, err error) { + size := m.Size() + dAtA = make([]byte, size) + n, err := m.MarshalTo(dAtA) + if err != nil { + return nil, err + } + return dAtA[:n], nil +} + +func (m *TestRequest) MarshalTo(dAtA []byte) (int, error) { + var i int + _ = i + var l int + _ = l + if len(m.Datacenter) > 0 { + dAtA[i] = 0xa + i++ + i = encodeVarintTest(dAtA, i, uint64(len(m.Datacenter))) + i += copy(dAtA[i:], m.Datacenter) + } + if m.XXX_unrecognized != nil { + i += copy(dAtA[i:], m.XXX_unrecognized) + } + return i, nil +} + +func (m *TestResponse) Marshal() (dAtA []byte, err error) { + size := m.Size() + dAtA = make([]byte, size) + n, err := m.MarshalTo(dAtA) + if err != nil { + return nil, err + } + return dAtA[:n], nil +} + +func (m *TestResponse) MarshalTo(dAtA []byte) (int, error) { + var i int + _ = i + var l int + _ = l + if len(m.ServerName) > 0 { + dAtA[i] = 0xa + i++ + i = encodeVarintTest(dAtA, i, uint64(len(m.ServerName))) + i += copy(dAtA[i:], m.ServerName) + } + if m.XXX_unrecognized != nil { + i += copy(dAtA[i:], m.XXX_unrecognized) + } + return i, nil +} + +func encodeVarintTest(dAtA []byte, offset int, v uint64) int { + for v >= 1<<7 { + dAtA[offset] = uint8(v&0x7f | 0x80) + v >>= 7 + offset++ + } + dAtA[offset] = uint8(v) + return offset + 1 +} +func (m *TestRequest) Size() (n int) { + if m == nil { + return 0 + } + var l int + _ = l + l = len(m.Datacenter) + if l > 0 { + n += 1 + l + sovTest(uint64(l)) + } + if m.XXX_unrecognized != nil { + n += len(m.XXX_unrecognized) + } + return n +} + +func (m *TestResponse) Size() (n int) { + if m == nil { + return 0 + } + var l int + _ = l + l = len(m.ServerName) + if l > 0 { + n += 1 + l + sovTest(uint64(l)) + } + if m.XXX_unrecognized != nil { + n += len(m.XXX_unrecognized) + } + return n +} + +func sovTest(x uint64) (n int) { + for { + n++ + x >>= 7 + if x == 0 { + break + } + } + return n +} +func sozTest(x uint64) (n int) { + return sovTest(uint64((x << 1) ^ uint64((int64(x) >> 63)))) +} +func (m *TestRequest) Unmarshal(dAtA []byte) error { + l := len(dAtA) + iNdEx := 0 + for iNdEx < l { + preIndex := iNdEx + var wire uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowTest + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + wire |= uint64(b&0x7F) << shift + if b < 0x80 { + break + } + } + fieldNum := int32(wire >> 3) + wireType := int(wire & 0x7) + if wireType == 4 { + return fmt.Errorf("proto: TestRequest: wiretype end group for non-group") + } + if fieldNum <= 0 { + return fmt.Errorf("proto: TestRequest: illegal tag %d (wire type %d)", fieldNum, wire) + } + switch fieldNum { + case 1: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field Datacenter", wireType) + } + var stringLen uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowTest + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + stringLen |= uint64(b&0x7F) << shift + if b < 0x80 { + break + } + } + intStringLen := int(stringLen) + if intStringLen < 0 { + return ErrInvalidLengthTest + } + postIndex := iNdEx + intStringLen + if postIndex < 0 { + return ErrInvalidLengthTest + } + if postIndex > l { + return io.ErrUnexpectedEOF + } + m.Datacenter = string(dAtA[iNdEx:postIndex]) + iNdEx = postIndex + default: + iNdEx = preIndex + skippy, err := skipTest(dAtA[iNdEx:]) + if err != nil { + return err + } + if skippy < 0 { + return ErrInvalidLengthTest + } + if (iNdEx + skippy) < 0 { + return ErrInvalidLengthTest + } + if (iNdEx + skippy) > l { + return io.ErrUnexpectedEOF + } + m.XXX_unrecognized = append(m.XXX_unrecognized, dAtA[iNdEx:iNdEx+skippy]...) + iNdEx += skippy + } + } + + if iNdEx > l { + return io.ErrUnexpectedEOF + } + return nil +} +func (m *TestResponse) Unmarshal(dAtA []byte) error { + l := len(dAtA) + iNdEx := 0 + for iNdEx < l { + preIndex := iNdEx + var wire uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowTest + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + wire |= uint64(b&0x7F) << shift + if b < 0x80 { + break + } + } + fieldNum := int32(wire >> 3) + wireType := int(wire & 0x7) + if wireType == 4 { + return fmt.Errorf("proto: TestResponse: wiretype end group for non-group") + } + if fieldNum <= 0 { + return fmt.Errorf("proto: TestResponse: illegal tag %d (wire type %d)", fieldNum, wire) + } + switch fieldNum { + case 1: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field ServerName", wireType) + } + var stringLen uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowTest + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + stringLen |= uint64(b&0x7F) << shift + if b < 0x80 { + break + } + } + intStringLen := int(stringLen) + if intStringLen < 0 { + return ErrInvalidLengthTest + } + postIndex := iNdEx + intStringLen + if postIndex < 0 { + return ErrInvalidLengthTest + } + if postIndex > l { + return io.ErrUnexpectedEOF + } + m.ServerName = string(dAtA[iNdEx:postIndex]) + iNdEx = postIndex + default: + iNdEx = preIndex + skippy, err := skipTest(dAtA[iNdEx:]) + if err != nil { + return err + } + if skippy < 0 { + return ErrInvalidLengthTest + } + if (iNdEx + skippy) < 0 { + return ErrInvalidLengthTest + } + if (iNdEx + skippy) > l { + return io.ErrUnexpectedEOF + } + m.XXX_unrecognized = append(m.XXX_unrecognized, dAtA[iNdEx:iNdEx+skippy]...) + iNdEx += skippy + } + } + + if iNdEx > l { + return io.ErrUnexpectedEOF + } + return nil +} +func skipTest(dAtA []byte) (n int, err error) { + l := len(dAtA) + iNdEx := 0 + for iNdEx < l { + var wire uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return 0, ErrIntOverflowTest + } + if iNdEx >= l { + return 0, io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + wire |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + wireType := int(wire & 0x7) + switch wireType { + case 0: + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return 0, ErrIntOverflowTest + } + if iNdEx >= l { + return 0, io.ErrUnexpectedEOF + } + iNdEx++ + if dAtA[iNdEx-1] < 0x80 { + break + } + } + return iNdEx, nil + case 1: + iNdEx += 8 + return iNdEx, nil + case 2: + var length int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return 0, ErrIntOverflowTest + } + if iNdEx >= l { + return 0, io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + length |= (int(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + if length < 0 { + return 0, ErrInvalidLengthTest + } + iNdEx += length + if iNdEx < 0 { + return 0, ErrInvalidLengthTest + } + return iNdEx, nil + case 3: + for { + var innerWire uint64 + var start int = iNdEx + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return 0, ErrIntOverflowTest + } + if iNdEx >= l { + return 0, io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + innerWire |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + innerWireType := int(innerWire & 0x7) + if innerWireType == 4 { + break + } + next, err := skipTest(dAtA[start:]) + if err != nil { + return 0, err + } + iNdEx = start + next + if iNdEx < 0 { + return 0, ErrInvalidLengthTest + } + } + return iNdEx, nil + case 4: + return iNdEx, nil + case 5: + iNdEx += 4 + return iNdEx, nil + default: + return 0, fmt.Errorf("proto: illegal wireType %d", wireType) + } + } + panic("unreachable") +} + +var ( + ErrInvalidLengthTest = fmt.Errorf("proto: negative length found during unmarshaling") + ErrIntOverflowTest = fmt.Errorf("proto: integer overflow") +) diff --git a/agent/agentpb/test.proto b/agent/agentpb/test.proto new file mode 100644 index 0000000000..4ae196c17f --- /dev/null +++ b/agent/agentpb/test.proto @@ -0,0 +1,18 @@ +syntax = "proto3"; + +package agentpb; + +// Test service is used internally for testing gRPC plumbing and is not exposed +// on a production server. +service Test { + // Test is only used internally for testing connectivity/balancing logic. + rpc Test(TestRequest) returns (TestResponse) {} +} + +message TestRequest { + string Datacenter = 1; +} + +message TestResponse { + string ServerName = 1; +} \ No newline at end of file diff --git a/agent/config/builder.go b/agent/config/builder.go index 48d18b2724..fed8c8aba8 100644 --- a/agent/config/builder.go +++ b/agent/config/builder.go @@ -900,6 +900,7 @@ func (b *Builder) Build() (rt RuntimeConfig, err error) { DiscardCheckOutput: b.boolVal(c.DiscardCheckOutput), DiscoveryMaxStale: b.durationVal("discovery_max_stale", c.DiscoveryMaxStale), EnableAgentTLSForChecks: b.boolVal(c.EnableAgentTLSForChecks), + EnableBackendStreaming: b.boolVal(c.EnableBackendStreaming), EnableCentralServiceConfig: b.boolVal(c.EnableCentralServiceConfig), EnableDebug: b.boolVal(c.EnableDebug), EnableRemoteScriptChecks: enableRemoteScriptChecks, diff --git a/agent/config/config.go b/agent/config/config.go index 4de4b313a0..c1898c4ec8 100644 --- a/agent/config/config.go +++ b/agent/config/config.go @@ -221,6 +221,7 @@ type Config struct { DiscoveryMaxStale *string `json:"discovery_max_stale" hcl:"discovery_max_stale" mapstructure:"discovery_max_stale"` EnableACLReplication *bool `json:"enable_acl_replication,omitempty" hcl:"enable_acl_replication" mapstructure:"enable_acl_replication"` EnableAgentTLSForChecks *bool `json:"enable_agent_tls_for_checks,omitempty" hcl:"enable_agent_tls_for_checks" mapstructure:"enable_agent_tls_for_checks"` + EnableBackendStreaming *bool `json:"enable_backend_streaming,omitempty" hcl:"enable_backend_streaming" mapstructure:"enable_backend_streaming"` EnableCentralServiceConfig *bool `json:"enable_central_service_config,omitempty" hcl:"enable_central_service_config" mapstructure:"enable_central_service_config"` EnableDebug *bool `json:"enable_debug,omitempty" hcl:"enable_debug" mapstructure:"enable_debug"` EnableScriptChecks *bool `json:"enable_script_checks,omitempty" hcl:"enable_script_checks" mapstructure:"enable_script_checks"` diff --git a/agent/config/runtime.go b/agent/config/runtime.go index f80d28de79..60fa7782c6 100644 --- a/agent/config/runtime.go +++ b/agent/config/runtime.go @@ -689,6 +689,13 @@ type RuntimeConfig struct { // and key). EnableAgentTLSForChecks bool + // EnableBackendStreaming is used to enable the new backend streaming interface when + // making blocking queries to the HTTP API. This greatly reduces bandwidth and server + // CPU load in large clusters with lots of activity. + // + // hcl: enable_backend_streaming = (true|false) + EnableBackendStreaming bool + // EnableCentralServiceConfig controls whether the agent should incorporate // centralized config such as service-defaults into local service registrations. // diff --git a/agent/config/runtime_test.go b/agent/config/runtime_test.go index fd8ab9c3dd..5c0119e99a 100644 --- a/agent/config/runtime_test.go +++ b/agent/config/runtime_test.go @@ -4053,10 +4053,11 @@ func TestFullConfig(t *testing.T) { }, "udp_answer_limit": 29909, "use_cache": true, - "cache_max_age": "5m"` + entFullDNSJSONConfig + ` + "cache_max_age": "5m" }, "enable_acl_replication": true, "enable_agent_tls_for_checks": true, + "enable_backend_streaming": true, "enable_central_service_config": true, "enable_debug": true, "enable_script_checks": true, @@ -4690,6 +4691,7 @@ func TestFullConfig(t *testing.T) { } enable_acl_replication = true enable_agent_tls_for_checks = true + enable_backend_streaming = true enable_central_service_config = true enable_debug = true enable_script_checks = true @@ -5399,6 +5401,7 @@ func TestFullConfig(t *testing.T) { DiscardCheckOutput: true, DiscoveryMaxStale: 5 * time.Second, EnableAgentTLSForChecks: true, + EnableBackendStreaming: true, EnableCentralServiceConfig: true, EnableDebug: true, EnableRemoteScriptChecks: true, @@ -6279,6 +6282,7 @@ func TestSanitize(t *testing.T) { "DiscardCheckOutput": false, "DiscoveryMaxStale": "0s", "EnableAgentTLSForChecks": false, + "EnableBackendStreaming": false, "EnableDebug": false, "EnableCentralServiceConfig": false, "EnableLocalScriptChecks": false, diff --git a/agent/consul/client.go b/agent/consul/client.go index d46df2fb1a..0f611a41d5 100644 --- a/agent/consul/client.go +++ b/agent/consul/client.go @@ -19,6 +19,7 @@ import ( "github.com/hashicorp/go-hclog" "github.com/hashicorp/serf/serf" "golang.org/x/time/rate" + "google.golang.org/grpc" ) const ( @@ -58,7 +59,9 @@ type Client struct { useNewACLs int32 // Connection pool to consul servers - connPool *pool.ConnPool + connPool *pool.ConnPool + grpcClient *GRPCClient + grpcResolverBuilder *ServerResolverBuilder // routers is responsible for the selection and maintenance of // Consul servers this agent uses for RPC requests @@ -185,10 +188,26 @@ func NewClientLogger(config *Config, logger hclog.InterceptLogger, tlsConfigurat go c.monitorACLMode() } + var tracker router.ServerTracker + + if c.config.GRPCEnabled { + // Register the gRPC resolver used for connection balancing. + c.grpcResolverBuilder = registerResolverBuilder(config.GRPCResolverScheme, config.Datacenter, c.shutdownCh) + tracker = c.grpcResolverBuilder + go c.grpcResolverBuilder.periodicServerRebalance(c.serf) + } else { + tracker = &router.NoOpServerTracker{} + } + // Start maintenance task for servers - c.routers = router.New(c.logger, c.shutdownCh, c.serf, c.connPool) + c.routers = router.New(c.logger, c.shutdownCh, c.serf, c.connPool, tracker) go c.routers.Start() + // Start the GRPC client. + if c.config.GRPCEnabled { + c.grpcClient = NewGRPCClient(logger, c.routers, tlsConfigurator, config.GRPCResolverScheme) + } + // Start LAN event handlers after the router is complete since the event // handlers depend on the router and the router depends on Serf. go c.lanEventHandler() @@ -383,6 +402,14 @@ func (c *Client) SnapshotRPC(args *structs.SnapshotRequest, in io.Reader, out io return nil } +// GRPCConn returns a gRPC connection to a server. +func (c *Client) GRPCConn() (*grpc.ClientConn, error) { + if !c.config.GRPCEnabled { + return nil, fmt.Errorf("GRPC is not enabled on this client") + } + return c.grpcClient.GRPCConn(c.config.Datacenter) +} + // Stats is used to return statistics for debugging and insight // for various sub-systems func (c *Client) Stats() map[string]map[string]string { diff --git a/agent/consul/client_test.go b/agent/consul/client_test.go index 5a218e247f..37c05d5c91 100644 --- a/agent/consul/client_test.go +++ b/agent/consul/client_test.go @@ -2,6 +2,7 @@ package consul import ( "bytes" + "fmt" "net" "os" "sync" @@ -46,6 +47,7 @@ func testClientConfig(t *testing.T) (string, *Config) { config.SerfLANConfig.MemberlistConfig.ProbeTimeout = 200 * time.Millisecond config.SerfLANConfig.MemberlistConfig.ProbeInterval = time.Second config.SerfLANConfig.MemberlistConfig.GossipInterval = 100 * time.Millisecond + config.GRPCResolverScheme = fmt.Sprintf("consul-%s", config.NodeName) return dir, config } diff --git a/agent/consul/config.go b/agent/consul/config.go index e7df308d21..f66c21e4c6 100644 --- a/agent/consul/config.go +++ b/agent/consul/config.go @@ -470,6 +470,18 @@ type Config struct { // AutoEncrypt.Sign requests. AutoEncryptAllowTLS bool + // GRPCEnabled controls whether servers will listen for gRPC streams or RPC + // calls and whether clients will start gRPC clients. + GRPCEnabled bool + + // GRPCResolverScheme is the gRPC resolver scheme to use. This is only used for + // tests running in parallel to avoid overwriting each other. + GRPCResolverScheme string + + // GRPCTestServerEnabled causes the Test grpc service to be registered and + // served. This is only intended for use in internal testing. + GRPCTestServerEnabled bool + // Embedded Consul Enterprise specific configuration *EnterpriseConfig } @@ -598,6 +610,7 @@ func DefaultConfig() *Config { DefaultQueryTime: 300 * time.Second, MaxQueryTime: 600 * time.Second, EnterpriseConfig: DefaultEnterpriseConfig(), + GRPCResolverScheme: "consul", } // Increase our reap interval to 3 days instead of 24h. diff --git a/agent/consul/grpc_client.go b/agent/consul/grpc_client.go new file mode 100644 index 0000000000..7e59193fd2 --- /dev/null +++ b/agent/consul/grpc_client.go @@ -0,0 +1,122 @@ +package consul + +import ( + "context" + "fmt" + "net" + "sync" + + "google.golang.org/grpc" + + "github.com/hashicorp/consul/agent/metadata" + "github.com/hashicorp/consul/agent/pool" + "github.com/hashicorp/go-hclog" + + "github.com/hashicorp/consul/tlsutil" +) + +const ( + grpcBasePath = "/consul" +) + +type ServerProvider interface { + Servers() []*metadata.Server +} + +type GRPCClient struct { + scheme string + serverProvider ServerProvider + tlsConfigurator *tlsutil.Configurator + grpcConns map[string]*grpc.ClientConn + grpcConnLock sync.Mutex +} + +func NewGRPCClient(logger hclog.Logger, serverProvider ServerProvider, tlsConfigurator *tlsutil.Configurator, scheme string) *GRPCClient { + // Note we don't actually use the logger anywhere yet but I guess it was added + // for future compatibility... + return &GRPCClient{ + scheme: scheme, + serverProvider: serverProvider, + tlsConfigurator: tlsConfigurator, + grpcConns: make(map[string]*grpc.ClientConn), + } +} + +func (c *GRPCClient) GRPCConn(datacenter string) (*grpc.ClientConn, error) { + c.grpcConnLock.Lock() + defer c.grpcConnLock.Unlock() + + // If there's an existing ClientConn for the given DC, return it. + if conn, ok := c.grpcConns[datacenter]; ok { + return conn, nil + } + + dialer := newDialer(c.serverProvider, c.tlsConfigurator.OutgoingRPCWrapper()) + conn, err := grpc.Dial(fmt.Sprintf("%s:///server.%s", c.scheme, datacenter), + // use WithInsecure mode here because we handle the TLS wrapping in the + // custom dialer based on logic around whether the server has TLS enabled. + grpc.WithInsecure(), + grpc.WithContextDialer(dialer), + grpc.WithDisableRetry(), + grpc.WithStatsHandler(grpcStatsHandler), + grpc.WithBalancerName("pick_first")) + if err != nil { + return nil, err + } + + c.grpcConns[datacenter] = conn + + return conn, nil +} + +// newDialer returns a gRPC dialer function that conditionally wraps the connection +// with TLS depending on the given useTLS value. +func newDialer(serverProvider ServerProvider, wrapper tlsutil.DCWrapper) func(context.Context, string) (net.Conn, error) { + return func(ctx context.Context, addr string) (net.Conn, error) { + d := net.Dialer{} + conn, err := d.DialContext(ctx, "tcp", addr) + if err != nil { + return nil, err + } + + // Check if TLS is enabled for the server. + var found bool + var server *metadata.Server + for _, s := range serverProvider.Servers() { + if s.Addr.String() == addr { + found = true + server = s + } + } + if !found { + return nil, fmt.Errorf("could not find Consul server for address %q", addr) + } + + if server.UseTLS { + if wrapper == nil { + return nil, fmt.Errorf("TLS enabled but got nil TLS wrapper") + } + + // Switch the connection into TLS mode + if _, err := conn.Write([]byte{byte(pool.RPCTLS)}); err != nil { + conn.Close() + return nil, err + } + + // Wrap the connection in a TLS client + tlsConn, err := wrapper(server.Datacenter, conn) + if err != nil { + conn.Close() + return nil, err + } + conn = tlsConn + } + + _, err = conn.Write([]byte{pool.RPCGRPC}) + if err != nil { + return nil, err + } + + return conn, nil + } +} diff --git a/agent/consul/grpc_resolver.go b/agent/consul/grpc_resolver.go new file mode 100644 index 0000000000..54836bb11b --- /dev/null +++ b/agent/consul/grpc_resolver.go @@ -0,0 +1,234 @@ +package consul + +import ( + "math/rand" + "strings" + "sync" + "time" + + "github.com/hashicorp/consul/agent/metadata" + "github.com/hashicorp/consul/agent/router" + "github.com/hashicorp/serf/serf" + "google.golang.org/grpc/resolver" +) + +var registerLock sync.Mutex + +// registerResolverBuilder registers our custom grpc resolver with the given scheme. +func registerResolverBuilder(scheme, datacenter string, shutdownCh <-chan struct{}) *ServerResolverBuilder { + registerLock.Lock() + defer registerLock.Unlock() + grpcResolverBuilder := NewServerResolverBuilder(scheme, datacenter, shutdownCh) + resolver.Register(grpcResolverBuilder) + return grpcResolverBuilder +} + +// ServerResolverBuilder tracks the current server list and keeps any +// ServerResolvers updated when changes occur. +type ServerResolverBuilder struct { + // Allow overriding the scheme to support parallel tests, since + // the resolver builder is registered globally. + scheme string + datacenter string + servers map[string]*metadata.Server + resolvers map[resolver.ClientConn]*ServerResolver + shutdownCh <-chan struct{} + lock sync.Mutex +} + +func NewServerResolverBuilder(scheme, datacenter string, shutdownCh <-chan struct{}) *ServerResolverBuilder { + return &ServerResolverBuilder{ + scheme: scheme, + datacenter: datacenter, + servers: make(map[string]*metadata.Server), + resolvers: make(map[resolver.ClientConn]*ServerResolver), + } +} + +// periodicServerRebalance periodically reshuffles the order of server addresses +// within the resolvers to ensure the load is balanced across servers. +func (s *ServerResolverBuilder) periodicServerRebalance(serf *serf.Serf) { + // Compute the rebalance timer based on the number of local servers and nodes. + rebalanceDuration := router.ComputeRebalanceTimer(s.serversInDC(s.datacenter), serf.NumNodes()) + timer := time.NewTimer(rebalanceDuration) + + for { + select { + case <-timer.C: + s.rebalanceResolvers() + + // Re-compute the wait duration. + newTimerDuration := router.ComputeRebalanceTimer(s.serversInDC(s.datacenter), serf.NumNodes()) + timer.Reset(newTimerDuration) + case <-s.shutdownCh: + timer.Stop() + return + } + } +} + +// rebalanceResolvers shuffles the server list for resolvers in all datacenters. +func (s *ServerResolverBuilder) rebalanceResolvers() { + s.lock.Lock() + defer s.lock.Unlock() + + for _, resolver := range s.resolvers { + // Shuffle the list of addresses using the last list given to the resolver. + resolver.addrLock.Lock() + addrs := resolver.lastAddrs + rand.Shuffle(len(addrs), func(i, j int) { + addrs[i], addrs[j] = addrs[j], addrs[i] + }) + resolver.addrLock.Unlock() + + // Pass the shuffled list to the resolver. + resolver.updateAddrs(addrs) + } +} + +// serversInDC returns the number of servers in the given datacenter. +func (s *ServerResolverBuilder) serversInDC(dc string) int { + s.lock.Lock() + defer s.lock.Unlock() + + var serverCount int + for _, server := range s.servers { + if server.Datacenter == dc { + serverCount++ + } + } + + return serverCount +} + +// Servers returns metadata for all currently known servers. This is used +// by grpc.ClientConn through our custom dialer. +func (s *ServerResolverBuilder) Servers() []*metadata.Server { + s.lock.Lock() + defer s.lock.Unlock() + + servers := make([]*metadata.Server, 0, len(s.servers)) + for _, server := range s.servers { + servers = append(servers, server) + } + return servers +} + +// Build returns a new ServerResolver for the given ClientConn. The resolver +// will keep the ClientConn's state updated based on updates from Serf. +func (s *ServerResolverBuilder) Build(target resolver.Target, cc resolver.ClientConn, opts resolver.BuildOption) (resolver.Resolver, error) { + s.lock.Lock() + defer s.lock.Unlock() + + // If there's already a resolver for this datacenter, return it. + datacenter := strings.TrimPrefix(target.Endpoint, "server.") + if resolver, ok := s.resolvers[cc]; ok { + return resolver, nil + } + + // Make a new resolver for the dc and add it to the list of active ones. + resolver := &ServerResolver{ + datacenter: datacenter, + clientConn: cc, + } + resolver.updateAddrs(s.getDCAddrs(datacenter)) + resolver.closeCallback = func() { + s.lock.Lock() + defer s.lock.Unlock() + delete(s.resolvers, cc) + } + + s.resolvers[cc] = resolver + + return resolver, nil +} + +func (s *ServerResolverBuilder) Scheme() string { return s.scheme } + +// AddServer updates the resolvers' states to include the new server's address. +func (s *ServerResolverBuilder) AddServer(server *metadata.Server) { + s.lock.Lock() + defer s.lock.Unlock() + + s.servers[server.ID] = server + + addrs := s.getDCAddrs(server.Datacenter) + for _, resolver := range s.resolvers { + if resolver.datacenter == server.Datacenter { + resolver.updateAddrs(addrs) + } + } +} + +// RemoveServer updates the resolvers' states with the given server removed. +func (s *ServerResolverBuilder) RemoveServer(server *metadata.Server) { + s.lock.Lock() + defer s.lock.Unlock() + + delete(s.servers, server.ID) + + addrs := s.getDCAddrs(server.Datacenter) + for _, resolver := range s.resolvers { + if resolver.datacenter == server.Datacenter { + resolver.updateAddrs(addrs) + } + } +} + +// getDCAddrs returns a list of the server addresses for the given datacenter. +// This method assumes the lock is held. +func (s *ServerResolverBuilder) getDCAddrs(dc string) []resolver.Address { + var addrs []resolver.Address + for _, server := range s.servers { + if server.Datacenter != dc { + continue + } + + addrs = append(addrs, resolver.Address{ + Addr: server.Addr.String(), + Type: resolver.Backend, + ServerName: server.Name, + }) + } + return addrs +} + +// ServerResolver is a grpc Resolver that will keep a grpc.ClientConn up to date +// on the list of server addresses to use. +type ServerResolver struct { + datacenter string + clientConn resolver.ClientConn + closeCallback func() + + lastAddrs []resolver.Address + addrLock sync.Mutex +} + +// updateAddrs updates this ServerResolver's ClientConn to use the given set of addrs. +func (r *ServerResolver) updateAddrs(addrs []resolver.Address) { + // Only pass the first address initially, which will cause the + // balancer to spin down the connection for its previous first address + // if it is different. If we don't do this, it will keep using the old + // first address as long as it is still in the list, making it impossible to + // rebalance until that address is removed. + var firstAddr []resolver.Address + if len(addrs) > 0 { + firstAddr = []resolver.Address{addrs[0]} + } + r.clientConn.UpdateState(resolver.State{Addresses: firstAddr}) + + // Call UpdateState again with the entire list of addrs in case we need them + // for failover. + r.clientConn.UpdateState(resolver.State{Addresses: addrs}) + + r.addrLock.Lock() + defer r.addrLock.Unlock() + r.lastAddrs = addrs +} + +func (s *ServerResolver) Close() { + s.closeCallback() +} + +// Unneeded since we only update the ClientConn when our server list changes. +func (*ServerResolver) ResolveNow(o resolver.ResolveNowOption) {} diff --git a/agent/consul/grpc_resolver_test.go b/agent/consul/grpc_resolver_test.go new file mode 100644 index 0000000000..7dd043492a --- /dev/null +++ b/agent/consul/grpc_resolver_test.go @@ -0,0 +1,277 @@ +package consul + +import ( + "context" + "os" + "testing" + "time" + + "github.com/hashicorp/consul/agent/agentpb" + "github.com/hashicorp/consul/sdk/testutil/retry" + "github.com/hashicorp/consul/testrpc" + "github.com/stretchr/testify/require" +) + +func TestGRPCResolver_Rebalance(t *testing.T) { + t.Parallel() + + require := require.New(t) + dir1, server1 := testServerWithConfig(t, func(c *Config) { + c.Datacenter = "dc1" + c.Bootstrap = true + c.GRPCEnabled = true + c.GRPCTestServerEnabled = true + }) + defer os.RemoveAll(dir1) + defer server1.Shutdown() + + dir2, server2 := testServerWithConfig(t, func(c *Config) { + c.Datacenter = "dc1" + c.Bootstrap = false + c.GRPCEnabled = true + c.GRPCTestServerEnabled = true + }) + defer os.RemoveAll(dir2) + defer server2.Shutdown() + + dir3, server3 := testServerWithConfig(t, func(c *Config) { + c.Datacenter = "dc1" + c.Bootstrap = false + c.GRPCEnabled = true + c.GRPCTestServerEnabled = true + }) + defer os.RemoveAll(dir3) + defer server3.Shutdown() + + dir4, client := testClientWithConfig(t, func(c *Config) { + c.Datacenter = "dc1" + c.NodeName = uniqueNodeName(t.Name()) + c.GRPCEnabled = true + c.GRPCTestServerEnabled = true + }) + defer os.RemoveAll(dir4) + defer client.Shutdown() + + // Try to join + joinLAN(t, server2, server1) + joinLAN(t, server3, server1) + testrpc.WaitForLeader(t, server1.RPC, "dc1") + joinLAN(t, client, server2) + testrpc.WaitForTestAgent(t, client.RPC, "dc1") + + // Make a call to our test endpoint. + conn, err := client.GRPCConn() + require.NoError(err) + + grpcClient := agentpb.NewTestClient(conn) + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + response1, err := grpcClient.Test(ctx, &agentpb.TestRequest{}) + require.NoError(err) + + // Rebalance a few times to hit a different server. + for { + select { + case <-ctx.Done(): + t.Fatal("could not get a response from a different server") + default: + } + + // Force a shuffle and wait for the connection to be rebalanced. + client.grpcResolverBuilder.rebalanceResolvers() + time.Sleep(100 * time.Millisecond) + + response2, err := grpcClient.Test(ctx, &agentpb.TestRequest{}) + require.NoError(err) + if response1.ServerName == response2.ServerName { + break + } + } +} + +func TestGRPCResolver_Failover_LocalDC(t *testing.T) { + t.Parallel() + + require := require.New(t) + dir1, server1 := testServerWithConfig(t, func(c *Config) { + c.Datacenter = "dc1" + c.Bootstrap = true + c.GRPCEnabled = true + c.GRPCTestServerEnabled = true + }) + defer os.RemoveAll(dir1) + defer server1.Shutdown() + + dir2, server2 := testServerWithConfig(t, func(c *Config) { + c.Datacenter = "dc1" + c.GRPCEnabled = true + c.GRPCTestServerEnabled = true + }) + defer os.RemoveAll(dir2) + defer server2.Shutdown() + + dir3, server3 := testServerWithConfig(t, func(c *Config) { + c.Datacenter = "dc1" + c.GRPCEnabled = true + c.GRPCTestServerEnabled = true + }) + defer os.RemoveAll(dir3) + defer server3.Shutdown() + + dir4, client := testClientWithConfig(t, func(c *Config) { + c.Datacenter = "dc1" + c.NodeName = uniqueNodeName(t.Name()) + c.GRPCEnabled = true + c.GRPCTestServerEnabled = true + }) + defer os.RemoveAll(dir4) + defer client.Shutdown() + + // Try to join + joinLAN(t, server2, server1) + joinLAN(t, server3, server1) + testrpc.WaitForLeader(t, server1.RPC, "dc1") + joinLAN(t, client, server2) + testrpc.WaitForTestAgent(t, client.RPC, "dc1") + + // Make a call to our test endpoint. + conn, err := client.GRPCConn() + require.NoError(err) + + grpcClient := agentpb.NewTestClient(conn) + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + response1, err := grpcClient.Test(ctx, &agentpb.TestRequest{}) + require.NoError(err) + + // Shutdown the server that answered the request. + var shutdown *Server + for _, s := range []*Server{server1, server2, server3} { + if s.config.NodeName == response1.ServerName { + shutdown = s + break + } + } + require.NotNil(shutdown) + require.NoError(shutdown.Shutdown()) + + // Wait for the balancer to switch over to another server so we get a different response. + retry.Run(t, func(r *retry.R) { + response2, err := grpcClient.Test(ctx, &agentpb.TestRequest{}) + r.Check(err) + if response1.ServerName == response2.ServerName { + r.Fatal("responses should be from different servers") + } + }) +} + +func TestGRPCResolver_Failover_MultiDC(t *testing.T) { + t.Parallel() + + // Create a single server in dc1. + require := require.New(t) + dir1, server1 := testServerWithConfig(t, func(c *Config) { + c.Datacenter = "dc1" + c.Bootstrap = true + c.GRPCEnabled = true + c.GRPCTestServerEnabled = true + }) + defer os.RemoveAll(dir1) + defer server1.Shutdown() + + // Create a client in dc1. + cDir, client := testClientWithConfig(t, func(c *Config) { + c.Datacenter = "dc1" + c.NodeName = uniqueNodeName(t.Name()) + c.GRPCEnabled = true + c.GRPCTestServerEnabled = true + }) + defer os.RemoveAll(cDir) + defer client.Shutdown() + + // Create 3 servers in dc2. + dir2, server2 := testServerWithConfig(t, func(c *Config) { + c.Datacenter = "dc2" + c.Bootstrap = false + c.BootstrapExpect = 3 + c.GRPCEnabled = true + c.GRPCTestServerEnabled = true + }) + defer os.RemoveAll(dir2) + defer server2.Shutdown() + + dir3, server3 := testServerWithConfig(t, func(c *Config) { + c.Datacenter = "dc2" + c.Bootstrap = false + c.BootstrapExpect = 3 + c.GRPCEnabled = true + c.GRPCTestServerEnabled = true + }) + defer os.RemoveAll(dir3) + defer server3.Shutdown() + + dir4, server4 := testServerWithConfig(t, func(c *Config) { + c.Datacenter = "dc2" + c.Bootstrap = false + c.BootstrapExpect = 3 + c.GRPCEnabled = true + c.GRPCTestServerEnabled = true + c.GRPCTestServerEnabled = true + }) + defer os.RemoveAll(dir4) + defer server4.Shutdown() + + // Try to join + joinLAN(t, server3, server2) + joinLAN(t, server4, server2) + testrpc.WaitForLeader(t, server1.RPC, "dc1") + testrpc.WaitForLeader(t, server2.RPC, "dc2") + + joinWAN(t, server1, server2) + joinWAN(t, server3, server2) + joinWAN(t, server4, server2) + joinLAN(t, client, server1) + testrpc.WaitForTestAgent(t, client.RPC, "dc1") + + // Make a call to our test endpoint on the client in dc1. + conn, err := client.GRPCConn() + require.NoError(err) + + grpcClient := agentpb.NewTestClient(conn) + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + response1, err := grpcClient.Test(ctx, &agentpb.TestRequest{Datacenter: "dc2"}) + require.NoError(err) + // Make sure the response didn't come from dc1. + require.Contains([]string{ + server2.config.NodeName, + server3.config.NodeName, + server4.config.NodeName, + }, response1.ServerName) + + // Shutdown the server that answered the request. + var shutdown *Server + for _, s := range []*Server{server2, server3, server4} { + if s.config.NodeName == response1.ServerName { + shutdown = s + break + } + } + require.NotNil(shutdown) + require.NoError(shutdown.Leave()) + require.NoError(shutdown.Shutdown()) + + // Wait for the balancer to switch over to another server so we get a different response. + retry.Run(t, func(r *retry.R) { + ctx, cancel = context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + response2, err := grpcClient.Test(ctx, &agentpb.TestRequest{Datacenter: "dc2"}) + r.Check(err) + if response1.ServerName == response2.ServerName { + r.Fatal("responses should be from different servers") + } + }) +} diff --git a/agent/consul/grpc_stats.go b/agent/consul/grpc_stats.go new file mode 100644 index 0000000000..3db3718d65 --- /dev/null +++ b/agent/consul/grpc_stats.go @@ -0,0 +1,90 @@ +package consul + +import ( + "context" + "sync/atomic" + + metrics "github.com/armon/go-metrics" + "google.golang.org/grpc" + grpcStats "google.golang.org/grpc/stats" +) + +var ( + // grpcStatsHandler is the global stats handler instance. Yes I know global is + // horrible but go-metrics started it. Now we need to be global to make + // connection count gauge useful. + grpcStatsHandler *GRPCStatsHandler + + // grpcActiveStreams is used to keep track of the number of open streaming + // RPCs on a server. It is accessed atomically, See notes above on global + // sadness. + grpcActiveStreams uint64 +) + +func init() { + grpcStatsHandler = &GRPCStatsHandler{} +} + +// GRPCStatsHandler is a grpc/stats.StatsHandler which emits stats to +// go-metrics. +type GRPCStatsHandler struct { + activeConns uint64 // must be 8-byte aligned for atomic access +} + +// TagRPC implements grpcStats.StatsHandler +func (c *GRPCStatsHandler) TagRPC(ctx context.Context, i *grpcStats.RPCTagInfo) context.Context { + // No-op + return ctx +} + +// HandleRPC implements grpcStats.StatsHandler +func (c *GRPCStatsHandler) HandleRPC(ctx context.Context, s grpcStats.RPCStats) { + label := "server" + if s.IsClient() { + label = "client" + } + switch s.(type) { + case *grpcStats.InHeader: + metrics.IncrCounter([]string{"grpc", label, "request"}, 1) + } +} + +// TagConn implements grpcStats.StatsHandler +func (c *GRPCStatsHandler) TagConn(ctx context.Context, i *grpcStats.ConnTagInfo) context.Context { + // No-op + return ctx +} + +// HandleConn implements grpcStats.StatsHandler +func (c *GRPCStatsHandler) HandleConn(ctx context.Context, s grpcStats.ConnStats) { + label := "server" + if s.IsClient() { + label = "client" + } + var new uint64 + switch s.(type) { + case *grpcStats.ConnBegin: + new = atomic.AddUint64(&c.activeConns, 1) + case *grpcStats.ConnEnd: + // Decrement! + new = atomic.AddUint64(&c.activeConns, ^uint64(0)) + } + metrics.SetGauge([]string{"grpc", label, "active_conns"}, float32(new)) +} + +// GRPCCountingStreamInterceptor is a grpc.ServerStreamInterceptor that just +// tracks open streaming calls to the server and emits metrics on how many are +// open. +func GRPCCountingStreamInterceptor(srv interface{}, ss grpc.ServerStream, + info *grpc.StreamServerInfo, handler grpc.StreamHandler) error { + + // Count the stream + new := atomic.AddUint64(&grpcActiveStreams, 1) + metrics.SetGauge([]string{"grpc", "server", "active_streams"}, float32(new)) + defer func() { + new := atomic.AddUint64(&grpcActiveStreams, ^uint64(0)) + metrics.SetGauge([]string{"grpc", "server", "active_streams"}, float32(new)) + }() + + return handler(srv, ss) +} diff --git a/agent/consul/rpc.go b/agent/consul/rpc.go index a7a40442e5..c3cad76f98 100644 --- a/agent/consul/rpc.go +++ b/agent/consul/rpc.go @@ -56,6 +56,27 @@ func (s *Server) rpcLogger() hclog.Logger { return s.loggers.Named(logging.RPC) } +type grpcListener struct { + conns chan net.Conn + addr net.Addr +} + +func (l *grpcListener) Handle(conn net.Conn) { + l.conns <- conn +} + +func (l *grpcListener) Accept() (net.Conn, error) { + return <-l.conns, nil +} + +func (l *grpcListener) Addr() net.Addr { + return l.addr +} + +func (l *grpcListener) Close() error { + return nil +} + // listen is used to listen for incoming RPC connections func (s *Server) listen(listener net.Listener) { for { @@ -187,6 +208,15 @@ func (s *Server) handleConn(conn net.Conn, isTLS bool) { conn = tls.Server(conn, s.tlsConfigurator.IncomingInsecureRPCConfig()) s.handleInsecureConn(conn) + case pool.RPCGRPC: + if !s.config.GRPCEnabled { + s.rpcLogger().Error("GRPC conn opened but GRPC is not enabled, closing", + "conn", logConn(conn), + ) + conn.Close() + } else { + s.handleGRPCConn(conn) + } default: if !s.handleEnterpriseRPCConn(typ, conn, isTLS) { s.rpcLogger().Error("unrecognized RPC byte", @@ -253,6 +283,16 @@ func (s *Server) handleNativeTLS(conn net.Conn) { case pool.ALPN_RPCSnapshot: s.handleSnapshotConn(tlsConn) + case pool.ALPN_RPCGRPC: + if !s.config.GRPCEnabled { + s.rpcLogger().Error("GRPC conn opened but GRPC is not enabled, closing", + "conn", logConn(conn), + ) + conn.Close() + } else { + s.handleGRPCConn(tlsConn) + } + case pool.ALPN_WANGossipPacket: if err := s.handleALPN_WANGossipPacketStream(tlsConn); err != nil && err != io.EOF { s.rpcLogger().Error( @@ -439,6 +479,11 @@ func (c *limitedConn) Read(b []byte) (n int, err error) { return c.lr.Read(b) } +// HandleGRPCConn is used to dispatch connections to the built in gRPC server +func (s *Server) handleGRPCConn(conn net.Conn) { + s.GRPCListener.Handle(conn) +} + // canRetry returns true if the given situation is safe for a retry. func canRetry(args interface{}, err error) bool { // No leader errors are always safe to retry since no state could have diff --git a/agent/consul/server.go b/agent/consul/server.go index 071c224eed..c541b853f4 100644 --- a/agent/consul/server.go +++ b/agent/consul/server.go @@ -18,6 +18,7 @@ import ( metrics "github.com/armon/go-metrics" "github.com/hashicorp/consul/acl" + "github.com/hashicorp/consul/agent/agentpb" ca "github.com/hashicorp/consul/agent/connect/ca" "github.com/hashicorp/consul/agent/consul/authmethod" "github.com/hashicorp/consul/agent/consul/autopilot" @@ -40,6 +41,7 @@ import ( raftboltdb "github.com/hashicorp/raft-boltdb" "github.com/hashicorp/serf/serf" "golang.org/x/time/rate" + "google.golang.org/grpc" ) // These are the protocol versions that Consul can _understand_. These are @@ -168,6 +170,7 @@ type Server struct { // Connection pool to other consul servers connPool *pool.ConnPool + grpcConn *grpc.ClientConn // eventChLAN is used to receive events from the // serf cluster in the datacenter @@ -225,8 +228,12 @@ type Server struct { rpcConnLimiter connlimit.Limiter // Listener is used to listen for incoming connections - Listener net.Listener - rpcServer *rpc.Server + Listener net.Listener + GRPCListener *grpcListener + rpcServer *rpc.Server + + // grpcClient is used for gRPC calls to remote datacenters. + grpcClient *GRPCClient // insecureRPCServer is a RPC server that is configure with // IncomingInsecureRPCConfig to allow clients to call AutoEncrypt.Sign @@ -380,6 +387,18 @@ func NewServerLogger(config *Config, logger hclog.InterceptLogger, tokens *token serverLogger := logger.NamedIntercept(logging.ConsulServer) loggers := newLoggerStore(serverLogger) + + var tracker router.ServerTracker + var grpcResolverBuilder *ServerResolverBuilder + + if config.GRPCEnabled { + // Register the gRPC resolver used for connection balancing. + grpcResolverBuilder = registerResolverBuilder(config.GRPCResolverScheme, config.Datacenter, shutdownCh) + tracker = grpcResolverBuilder + } else { + tracker = &router.NoOpServerTracker{} + } + // Create server. s := &Server{ config: config, @@ -391,7 +410,7 @@ func NewServerLogger(config *Config, logger hclog.InterceptLogger, tokens *token loggers: loggers, leaveCh: make(chan struct{}), reconcileCh: make(chan serf.Member, reconcileChSize), - router: router.NewRouter(serverLogger, config.Datacenter), + router: router.NewRouter(serverLogger, config.Datacenter, tracker), rpcServer: rpc.NewServer(), insecureRPCServer: rpc.NewServer(), tlsConfigurator: tlsConfigurator, @@ -546,6 +565,20 @@ func NewServerLogger(config *Config, logger hclog.InterceptLogger, tokens *token } go s.lanEventHandler() + if s.config.GRPCEnabled { + // Start the gRPC server shuffling using the LAN serf for node count. + go grpcResolverBuilder.periodicServerRebalance(s.serfLAN) + + // Initialize the GRPC listener. + if err := s.setupGRPC(); err != nil { + s.Shutdown() + return nil, fmt.Errorf("Failed to start GRPC layer: %v", err) + } + + // Start the GRPC client. + s.grpcClient = NewGRPCClient(s.logger, grpcResolverBuilder, tlsConfigurator, config.GRPCResolverScheme) + } + // Start the flooders after the LAN event handler is wired up. s.floodSegments(config) @@ -861,6 +894,44 @@ func (s *Server) setupRPC() error { return nil } +// setupGRPC initializes the built in gRPC server components +func (s *Server) setupGRPC() error { + lis := &grpcListener{ + addr: s.Listener.Addr(), + conns: make(chan net.Conn), + } + + // We don't need to pass tls.Config to the server since it's multiplexed + // behind the RPC listener, which already has TLS configured. + srv := grpc.NewServer( + grpc.StatsHandler(grpcStatsHandler), + grpc.StreamInterceptor(GRPCCountingStreamInterceptor), + ) + //stream.RegisterConsulServer(srv, &ConsulGRPCAdapter{Health{s}}) + if s.config.GRPCTestServerEnabled { + agentpb.RegisterTestServer(srv, &GRPCTest{srv: s}) + } + + go srv.Serve(lis) + s.GRPCListener = lis + + // Set up a gRPC client connection to the above listener. + dialer := newDialer(s.serverLookup, s.tlsConfigurator.OutgoingRPCWrapper()) + conn, err := grpc.Dial(lis.Addr().String(), + grpc.WithInsecure(), + grpc.WithContextDialer(dialer), + grpc.WithDisableRetry(), + grpc.WithStatsHandler(grpcStatsHandler), + grpc.WithBalancerName("pick_first")) + if err != nil { + return err + } + + s.grpcConn = conn + + return nil +} + // Shutdown is used to shutdown the server func (s *Server) Shutdown() error { s.logger.Info("shutting down server") @@ -907,6 +978,10 @@ func (s *Server) Shutdown() error { s.Listener.Close() } + if s.GRPCListener != nil { + s.GRPCListener.Close() + } + // Close the connection pool s.connPool.Shutdown() @@ -1277,6 +1352,14 @@ func (s *Server) SnapshotRPC(args *structs.SnapshotRequest, in io.Reader, out io return nil } +// GRPCConn returns a gRPC connection to a server. +func (s *Server) GRPCConn() (*grpc.ClientConn, error) { + if !s.config.GRPCEnabled { + return nil, fmt.Errorf("GRPC not enabled") + } + return s.grpcConn, nil +} + // RegisterEndpoint is used to substitute an endpoint for testing. func (s *Server) RegisterEndpoint(name string, handler interface{}) error { s.logger.Warn("endpoint injected; this should only be used for testing") diff --git a/agent/consul/server_test.go b/agent/consul/server_test.go index c897607c88..02616e99e9 100644 --- a/agent/consul/server_test.go +++ b/agent/consul/server_test.go @@ -168,6 +168,7 @@ func testServerConfig(t *testing.T) (string, *Config) { } config.NotifyShutdown = returnPortsFn + config.GRPCResolverScheme = fmt.Sprintf("consul-%s", config.NodeName) return dir, config } diff --git a/agent/consul/test_endpoint.go b/agent/consul/test_endpoint.go new file mode 100644 index 0000000000..a62749c6a3 --- /dev/null +++ b/agent/consul/test_endpoint.go @@ -0,0 +1,32 @@ +package consul + +import ( + "context" + + "github.com/hashicorp/consul/agent/agentpb" +) + +// GRPCTest is a gRPC handler object for the agentpb.Test service. It's only +// used for testing gRPC plumbing details and never exposed in a running Consul +// server. +type GRPCTest struct { + srv *Server +} + +// Test is the gRPC Test.Test endpoint handler. +func (t *GRPCTest) Test(ctx context.Context, req *agentpb.TestRequest) (*agentpb.TestResponse, error) { + if req.Datacenter != "" && req.Datacenter != t.srv.config.Datacenter { + conn, err := t.srv.grpcClient.GRPCConn(req.Datacenter) + if err != nil { + return nil, err + } + + t.srv.logger.Debug("GRPC test server conn state %s", conn.GetState()) + + // Open a Test call to the remote DC. + client := agentpb.NewTestClient(conn) + return client.Test(ctx, req) + } + + return &agentpb.TestResponse{ServerName: t.srv.config.NodeName}, nil +} diff --git a/agent/pool/conn.go b/agent/pool/conn.go index 8a046fa4ca..79731953b4 100644 --- a/agent/pool/conn.go +++ b/agent/pool/conn.go @@ -40,23 +40,24 @@ const ( // that is supported and it might be the only one there // ever is. RPCTLSInsecure = 7 + RPCGRPC = 8 - // RPCMaxTypeValue is the maximum rpc type byte value currently used for - // the various protocols riding over our "rpc" port. + // RPCMaxTypeValue is the maximum rpc type byte value currently used for the + // various protocols riding over our "rpc" port. // - // Currently our 0-7 values are mutually exclusive with any valid first - // byte of a TLS header. The first TLS header byte will begin with a TLS - // content type and the values 0-19 are all explicitly unassigned and - // marked as requiring coordination. RFC 7983 does the marking and goes - // into some details about multiplexing connections and identifying TLS. + // Currently our 0-8 values are mutually exclusive with any valid first byte + // of a TLS header. The first TLS header byte will begin with a TLS content + // type and the values 0-19 are all explicitly unassigned and marked as + // requiring coordination. RFC 7983 does the marking and goes into some + // details about multiplexing connections and identifying TLS. // // We use this value to determine if the incoming request is actual real - // native TLS (where we can demultiplex based on ALPN protocol) or our - // older type-byte system when new connections are established. + // native TLS (where we can de-multiplex based on ALPN protocol) or our older + // type-byte system when new connections are established. // // NOTE: if you add new RPCTypes beyond this value, you must similarly bump // this value. - RPCMaxTypeValue = 7 + RPCMaxTypeValue = 8 ) const ( @@ -66,6 +67,7 @@ const ( ALPN_RPCMultiplexV2 = "consul/rpc-multi" // RPCMultiplexV2 ALPN_RPCSnapshot = "consul/rpc-snapshot" // RPCSnapshot ALPN_RPCGossip = "consul/rpc-gossip" // RPCGossip + ALPN_RPCGRPC = "consul/rpc-grpc" // RPCGRPC // wan federation additions ALPN_WANGossipPacket = "consul/wan-gossip/packet" ALPN_WANGossipStream = "consul/wan-gossip/stream" diff --git a/agent/pool/pool.go b/agent/pool/pool.go index a2e4a4ea17..bb08e97278 100644 --- a/agent/pool/pool.go +++ b/agent/pool/pool.go @@ -389,7 +389,11 @@ func DialTimeoutWithRPCTypeDirectly( } // Check if TLS is enabled - if (useTLS) && wrapper != nil { + if useTLS { + if wrapper == nil { + return nil, nil, fmt.Errorf("TLS enabled but got nil TLS wrapper") + } + // Switch the connection into TLS mode if _, err := conn.Write([]byte{byte(tlsRPCType)}); err != nil { conn.Close() diff --git a/agent/router/manager.go b/agent/router/manager.go index a02392ae36..1320704ffa 100644 --- a/agent/router/manager.go +++ b/agent/router/manager.go @@ -64,6 +64,23 @@ type Pinger interface { Ping(dc, nodeName string, addr net.Addr, version int, useTLS bool) (bool, error) } +// ServerTracker is a wrapper around consul.ServerResolverBuilder to prevent a +// cyclic import dependency. +type ServerTracker interface { + AddServer(*metadata.Server) + RemoveServer(*metadata.Server) +} + +// NoOpServerTracker is a ServerTracker that does nothing. Used when gRPC is not +// enabled. +type NoOpServerTracker struct{} + +// AddServer implements ServerTracker +func (t *NoOpServerTracker) AddServer(*metadata.Server) {} + +// RemoveServer implements ServerTracker +func (t *NoOpServerTracker) RemoveServer(*metadata.Server) {} + // serverList is a local copy of the struct used to maintain the list of // Consul servers used by Manager. // @@ -98,6 +115,10 @@ type Manager struct { // client.ConnPool. connPoolPinger Pinger + // grpcServerTracker is used to balance grpc connections across servers, + // and has callbacks for adding or removing a server. + grpcServerTracker ServerTracker + // notifyFailedBarrier is acts as a barrier to prevent queuing behind // serverListLog and acts as a TryLock(). notifyFailedBarrier int32 @@ -115,6 +136,7 @@ type Manager struct { func (m *Manager) AddServer(s *metadata.Server) { m.listLock.Lock() defer m.listLock.Unlock() + m.grpcServerTracker.AddServer(s) l := m.getServerList() // Check if this server is known @@ -243,6 +265,11 @@ func (m *Manager) CheckServers(fn func(srv *metadata.Server) bool) { _ = m.checkServers(fn) } +// Servers returns the current list of servers. +func (m *Manager) Servers() []*metadata.Server { + return m.getServerList().servers +} + // getServerList is a convenience method which hides the locking semantics // of atomic.Value from the caller. func (m *Manager) getServerList() serverList { @@ -256,7 +283,7 @@ func (m *Manager) saveServerList(l serverList) { } // New is the only way to safely create a new Manager struct. -func New(logger hclog.Logger, shutdownCh chan struct{}, clusterInfo ManagerSerfCluster, connPoolPinger Pinger) (m *Manager) { +func New(logger hclog.Logger, shutdownCh chan struct{}, clusterInfo ManagerSerfCluster, connPoolPinger Pinger, tracker ServerTracker) (m *Manager) { if logger == nil { logger = hclog.New(&hclog.LoggerOptions{}) } @@ -265,6 +292,7 @@ func New(logger hclog.Logger, shutdownCh chan struct{}, clusterInfo ManagerSerfC m.logger = logger.Named(logging.Manager) m.clusterInfo = clusterInfo // can't pass *consul.Client: import cycle m.connPoolPinger = connPoolPinger // can't pass *consul.ConnPool: import cycle + m.grpcServerTracker = tracker // can't pass *consul.ServerResolverBuilder: import cycle m.rebalanceTimer = time.NewTimer(clientRPCMinReuseDuration) m.shutdownCh = shutdownCh atomic.StoreInt32(&m.offline, 1) @@ -453,6 +481,7 @@ func (m *Manager) reconcileServerList(l *serverList) bool { func (m *Manager) RemoveServer(s *metadata.Server) { m.listLock.Lock() defer m.listLock.Unlock() + m.grpcServerTracker.RemoveServer(s) l := m.getServerList() // Remove the server if known @@ -473,17 +502,22 @@ func (m *Manager) RemoveServer(s *metadata.Server) { func (m *Manager) refreshServerRebalanceTimer() time.Duration { l := m.getServerList() numServers := len(l.servers) + connRebalanceTimeout := ComputeRebalanceTimer(numServers, m.clusterInfo.NumNodes()) + + m.rebalanceTimer.Reset(connRebalanceTimeout) + return connRebalanceTimeout +} + +// ComputeRebalanceTimer returns a time to wait before rebalancing connections given +// a number of servers and LAN nodes. +func ComputeRebalanceTimer(numServers, numLANMembers int) time.Duration { // Limit this connection's life based on the size (and health) of the // cluster. Never rebalance a connection more frequently than // connReuseLowWatermarkDuration, and make sure we never exceed // clusterWideRebalanceConnsPerSec operations/s across numLANMembers. clusterWideRebalanceConnsPerSec := float64(numServers * newRebalanceConnsPerSecPerServer) connReuseLowWatermarkDuration := clientRPCMinReuseDuration + lib.RandomStagger(clientRPCMinReuseDuration/clientRPCJitterFraction) - numLANMembers := m.clusterInfo.NumNodes() - connRebalanceTimeout := lib.RateScaledInterval(clusterWideRebalanceConnsPerSec, connReuseLowWatermarkDuration, numLANMembers) - - m.rebalanceTimer.Reset(connRebalanceTimeout) - return connRebalanceTimeout + return lib.RateScaledInterval(clusterWideRebalanceConnsPerSec, connReuseLowWatermarkDuration, numLANMembers) } // ResetRebalanceTimer resets the rebalance timer. This method exists for diff --git a/agent/router/manager_internal_test.go b/agent/router/manager_internal_test.go index 9b58abd9c9..1def663559 100644 --- a/agent/router/manager_internal_test.go +++ b/agent/router/manager_internal_test.go @@ -53,14 +53,15 @@ func (s *fauxSerf) NumNodes() int { func testManager() (m *Manager) { logger := GetBufferedLogger() shutdownCh := make(chan struct{}) - m = New(logger, shutdownCh, &fauxSerf{numNodes: 16384}, &fauxConnPool{}) + m = New(logger, shutdownCh, &fauxSerf{numNodes: 16384}, &fauxConnPool{}, &mockTracker{}) return m } func testManagerFailProb(failPct float64) (m *Manager) { logger := GetBufferedLogger() + logger = hclog.NewNullLogger() shutdownCh := make(chan struct{}) - m = New(logger, shutdownCh, &fauxSerf{}, &fauxConnPool{failPct: failPct}) + m = New(logger, shutdownCh, &fauxSerf{}, &fauxConnPool{failPct: failPct}, &mockTracker{}) return m } @@ -299,7 +300,7 @@ func TestManagerInternal_refreshServerRebalanceTimer(t *testing.T) { shutdownCh := make(chan struct{}) for _, s := range clusters { - m := New(logger, shutdownCh, &fauxSerf{numNodes: s.numNodes}, &fauxConnPool{}) + m := New(logger, shutdownCh, &fauxSerf{numNodes: s.numNodes}, &fauxConnPool{}, &mockTracker{}) for i := 0; i < s.numServers; i++ { nodeName := fmt.Sprintf("s%02d", i) m.AddServer(&metadata.Server{Name: nodeName}) diff --git a/agent/router/manager_test.go b/agent/router/manager_test.go index 676afd016c..50ee07c161 100644 --- a/agent/router/manager_test.go +++ b/agent/router/manager_test.go @@ -54,24 +54,29 @@ func (s *fauxSerf) NumNodes() int { return 16384 } +type fauxTracker struct{} + +func (m *fauxTracker) AddServer(s *metadata.Server) {} +func (m *fauxTracker) RemoveServer(s *metadata.Server) {} + func testManager(t testing.TB) (m *router.Manager) { logger := testutil.Logger(t) shutdownCh := make(chan struct{}) - m = router.New(logger, shutdownCh, &fauxSerf{}, &fauxConnPool{}) + m = router.New(logger, shutdownCh, &fauxSerf{}, &fauxConnPool{}, &fauxTracker{}) return m } func testManagerFailProb(t testing.TB, failPct float64) (m *router.Manager) { logger := testutil.Logger(t) shutdownCh := make(chan struct{}) - m = router.New(logger, shutdownCh, &fauxSerf{}, &fauxConnPool{failPct: failPct}) + m = router.New(logger, shutdownCh, &fauxSerf{}, &fauxConnPool{failPct: failPct}, &fauxTracker{}) return m } func testManagerFailAddr(t testing.TB, failAddr net.Addr) (m *router.Manager) { logger := testutil.Logger(t) shutdownCh := make(chan struct{}) - m = router.New(logger, shutdownCh, &fauxSerf{}, &fauxConnPool{failAddr: failAddr}) + m = router.New(logger, shutdownCh, &fauxSerf{}, &fauxConnPool{failAddr: failAddr}, &fauxTracker{}) return m } @@ -195,7 +200,7 @@ func TestServers_FindServer(t *testing.T) { func TestServers_New(t *testing.T) { logger := testutil.Logger(t) shutdownCh := make(chan struct{}) - m := router.New(logger, shutdownCh, &fauxSerf{}, &fauxConnPool{}) + m := router.New(logger, shutdownCh, &fauxSerf{}, &fauxConnPool{}, &fauxTracker{}) if m == nil { t.Fatalf("Manager nil") } diff --git a/agent/router/router.go b/agent/router/router.go index 4cdc864b06..33ea2b43dc 100644 --- a/agent/router/router.go +++ b/agent/router/router.go @@ -37,6 +37,10 @@ type Router struct { // routeFn is a hook to actually do the routing. routeFn func(datacenter string) (*Manager, *metadata.Server, bool) + // grpcServerTracker is used to balance grpc connections across servers, + // and has callbacks for adding or removing a server. + grpcServerTracker ServerTracker + // isShutdown prevents adding new routes to a router after it is shut // down. isShutdown bool @@ -83,16 +87,17 @@ type areaInfo struct { } // NewRouter returns a new Router with the given configuration. -func NewRouter(logger hclog.Logger, localDatacenter string) *Router { +func NewRouter(logger hclog.Logger, localDatacenter string, tracker ServerTracker) *Router { if logger == nil { logger = hclog.New(&hclog.LoggerOptions{}) } - + router := &Router{ - logger: logger.Named(logging.Router), - localDatacenter: localDatacenter, - areas: make(map[types.AreaID]*areaInfo), - managers: make(map[string][]*Manager), + logger: logger.Named(logging.Router), + localDatacenter: localDatacenter, + areas: make(map[types.AreaID]*areaInfo), + managers: make(map[string][]*Manager), + grpcServerTracker: tracker, } // Hook the direct route lookup by default. @@ -219,7 +224,7 @@ func (r *Router) addServer(area *areaInfo, s *metadata.Server) error { info, ok := area.managers[s.Datacenter] if !ok { shutdownCh := make(chan struct{}) - manager := New(r.logger, shutdownCh, area.cluster, area.pinger) + manager := New(r.logger, shutdownCh, area.cluster, area.pinger, r.grpcServerTracker) info = &managerInfo{ manager: manager, shutdownCh: shutdownCh, diff --git a/agent/router/router_test.go b/agent/router/router_test.go index 18d01236f1..40b9bdd1b4 100644 --- a/agent/router/router_test.go +++ b/agent/router/router_test.go @@ -9,6 +9,7 @@ import ( "testing" "time" + "github.com/hashicorp/consul/agent/metadata" "github.com/hashicorp/consul/agent/structs" "github.com/hashicorp/consul/lib" "github.com/hashicorp/consul/sdk/testutil" @@ -69,6 +70,11 @@ func (m *mockCluster) AddMember(dc string, name string, coord *coordinate.Coordi m.addr++ } +type mockTracker struct{} + +func (m *mockTracker) AddServer(s *metadata.Server) {} +func (m *mockTracker) RemoveServer(s *metadata.Server) {} + // testCluster is used to generate a single WAN-like area with a known set of // member and RTT topology. // @@ -95,7 +101,7 @@ func testCluster(self string) *mockCluster { func testRouter(t testing.TB, dc string) *Router { logger := testutil.Logger(t) - return NewRouter(logger, dc) + return NewRouter(logger, dc, &mockTracker{}) } func TestRouter_Shutdown(t *testing.T) {