mirror of https://github.com/hashicorp/consul
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
249 lines
6.0 KiB
249 lines
6.0 KiB
package private |
|
|
|
import ( |
|
"context" |
|
"crypto/tls" |
|
"fmt" |
|
"io" |
|
"net" |
|
"sync/atomic" |
|
"testing" |
|
"time" |
|
|
|
"github.com/hashicorp/go-hclog" |
|
"github.com/stretchr/testify/require" |
|
"golang.org/x/sync/errgroup" |
|
"google.golang.org/grpc" |
|
|
|
"github.com/hashicorp/consul/agent/grpc/private/internal/testservice" |
|
"github.com/hashicorp/consul/agent/metadata" |
|
"github.com/hashicorp/consul/agent/pool" |
|
"github.com/hashicorp/consul/tlsutil" |
|
) |
|
|
|
type testServer struct { |
|
addr net.Addr |
|
name string |
|
dc string |
|
shutdown func() |
|
rpc *fakeRPCListener |
|
} |
|
|
|
func (s testServer) Metadata() *metadata.Server { |
|
return &metadata.Server{ |
|
ID: s.name, |
|
Name: s.name + "." + s.dc, |
|
ShortName: s.name, |
|
Datacenter: s.dc, |
|
Addr: s.addr, |
|
UseTLS: s.rpc.tlsConf != nil, |
|
} |
|
} |
|
|
|
func newSimpleTestServer(t *testing.T, name, dc string, tlsConf *tlsutil.Configurator) testServer { |
|
return newTestServer(t, hclog.Default(), name, dc, tlsConf, func(server *grpc.Server) { |
|
testservice.RegisterSimpleServer(server, &simple{name: name, dc: dc}) |
|
}) |
|
} |
|
|
|
// newPanicTestServer sets up a simple server with handlers that panic. |
|
func newPanicTestServer(t *testing.T, logger hclog.Logger, name, dc string, tlsConf *tlsutil.Configurator) testServer { |
|
return newTestServer(t, logger, name, dc, tlsConf, func(server *grpc.Server) { |
|
testservice.RegisterSimpleServer(server, &simplePanic{name: name, dc: dc}) |
|
}) |
|
} |
|
|
|
func newTestServer(t *testing.T, logger hclog.Logger, name, dc string, tlsConf *tlsutil.Configurator, register func(server *grpc.Server)) testServer { |
|
addr := &net.IPAddr{IP: net.ParseIP("127.0.0.1")} |
|
handler := NewHandler(logger, addr, register) |
|
|
|
lis, err := net.Listen("tcp", "127.0.0.1:0") |
|
require.NoError(t, err) |
|
|
|
rpc := &fakeRPCListener{t: t, handler: handler, tlsConf: tlsConf} |
|
|
|
g := errgroup.Group{} |
|
g.Go(func() error { |
|
if err := rpc.listen(lis); err != nil { |
|
return fmt.Errorf("fake rpc listen error: %w", err) |
|
} |
|
return nil |
|
}) |
|
g.Go(func() error { |
|
if err := handler.Run(); err != nil { |
|
return fmt.Errorf("grpc server error: %w", err) |
|
} |
|
return nil |
|
}) |
|
return testServer{ |
|
addr: lis.Addr(), |
|
name: name, |
|
dc: dc, |
|
rpc: rpc, |
|
shutdown: func() { |
|
rpc.shutdown = true |
|
if err := lis.Close(); err != nil { |
|
t.Logf("listener closed with error: %v", err) |
|
} |
|
if err := handler.Shutdown(); err != nil { |
|
t.Logf("grpc server shutdown: %v", err) |
|
} |
|
if err := g.Wait(); err != nil { |
|
t.Log(err) |
|
} |
|
}, |
|
} |
|
} |
|
|
|
type simple struct { |
|
name string |
|
dc string |
|
} |
|
|
|
func (s *simple) Flow(_ *testservice.Req, flow testservice.Simple_FlowServer) error { |
|
for flow.Context().Err() == nil { |
|
resp := &testservice.Resp{ServerName: "one", Datacenter: s.dc} |
|
if err := flow.Send(resp); err != nil { |
|
return err |
|
} |
|
time.Sleep(time.Millisecond) |
|
} |
|
return nil |
|
} |
|
|
|
func (s *simple) Something(_ context.Context, _ *testservice.Req) (*testservice.Resp, error) { |
|
return &testservice.Resp{ServerName: s.name, Datacenter: s.dc}, nil |
|
} |
|
|
|
type simplePanic struct { |
|
name, dc string |
|
} |
|
|
|
func (s *simplePanic) Flow(_ *testservice.Req, flow testservice.Simple_FlowServer) error { |
|
for flow.Context().Err() == nil { |
|
time.Sleep(time.Millisecond) |
|
panic("panic from Flow") |
|
} |
|
return nil |
|
} |
|
|
|
func (s *simplePanic) Something(_ context.Context, _ *testservice.Req) (*testservice.Resp, error) { |
|
time.Sleep(time.Millisecond) |
|
panic("panic from Something") |
|
} |
|
|
|
// fakeRPCListener mimics agent/consul.Server.listen to handle the RPCType byte. |
|
// In the future we should be able to refactor Server and extract this RPC |
|
// handling logic so that we don't need to use a fake. |
|
// For now, since this logic is in agent/consul, we can't easily use Server.listen |
|
// so we fake it. |
|
type fakeRPCListener struct { |
|
t *testing.T |
|
handler *Handler |
|
shutdown bool |
|
tlsConf *tlsutil.Configurator |
|
tlsConnEstablished int32 |
|
alpnConnEstablished int32 |
|
} |
|
|
|
func (f *fakeRPCListener) listen(listener net.Listener) error { |
|
for { |
|
conn, err := listener.Accept() |
|
if err != nil { |
|
if f.shutdown { |
|
return nil |
|
} |
|
return err |
|
} |
|
|
|
go f.handleConn(conn) |
|
} |
|
} |
|
|
|
func (f *fakeRPCListener) handleConn(conn net.Conn) { |
|
if f.tlsConf != nil && f.tlsConf.MutualTLSCapable() { |
|
// See if actually this is native TLS multiplexed onto the old |
|
// "type-byte" system. |
|
|
|
peekedConn, nativeTLS, err := pool.PeekForTLS(conn) |
|
if err != nil { |
|
if err != io.EOF { |
|
fmt.Printf("ERROR: failed to read first byte: %v\n", err) |
|
} |
|
conn.Close() |
|
return |
|
} |
|
|
|
if nativeTLS { |
|
f.handleNativeTLSConn(peekedConn) |
|
return |
|
} |
|
conn = peekedConn |
|
} |
|
|
|
buf := make([]byte, 1) |
|
|
|
if _, err := conn.Read(buf); err != nil { |
|
if err != io.EOF { |
|
fmt.Println("ERROR", err.Error()) |
|
} |
|
conn.Close() |
|
return |
|
} |
|
typ := pool.RPCType(buf[0]) |
|
|
|
switch typ { |
|
|
|
case pool.RPCGRPC: |
|
f.handler.Handle(conn) |
|
return |
|
|
|
case pool.RPCTLS: |
|
// occasionally we see a test client connecting to an rpc listener that |
|
// was created as part of another test, despite none of the tests running |
|
// in parallel. |
|
// Maybe some strange grpc behaviour? I'm not sure. |
|
if f.tlsConf == nil { |
|
fmt.Println("ERROR: tls is not configured") |
|
conn.Close() |
|
return |
|
} |
|
|
|
atomic.AddInt32(&f.tlsConnEstablished, 1) |
|
conn = tls.Server(conn, f.tlsConf.IncomingRPCConfig()) |
|
f.handleConn(conn) |
|
|
|
default: |
|
fmt.Println("ERROR: unexpected byte", typ) |
|
conn.Close() |
|
} |
|
} |
|
|
|
func (f *fakeRPCListener) handleNativeTLSConn(conn net.Conn) { |
|
tlscfg := f.tlsConf.IncomingALPNRPCConfig(pool.RPCNextProtos) |
|
tlsConn := tls.Server(conn, tlscfg) |
|
|
|
// Force the handshake to conclude. |
|
if err := tlsConn.Handshake(); err != nil { |
|
fmt.Printf("ERROR: TLS handshake failed: %v", err) |
|
conn.Close() |
|
return |
|
} |
|
|
|
conn.SetReadDeadline(time.Time{}) |
|
|
|
var ( |
|
cs = tlsConn.ConnectionState() |
|
nextProto = cs.NegotiatedProtocol |
|
) |
|
|
|
switch nextProto { |
|
case pool.ALPN_RPCGRPC: |
|
atomic.AddInt32(&f.alpnConnEstablished, 1) |
|
f.handler.Handle(tlsConn) |
|
|
|
default: |
|
fmt.Printf("ERROR: discarding RPC for unknown negotiated protocol %q\n", nextProto) |
|
conn.Close() |
|
} |
|
}
|
|
|