diff --git a/connect/proxy/listener.go b/connect/proxy/listener.go index 33f1f5292b..a6b9a75e85 100644 --- a/connect/proxy/listener.go +++ b/connect/proxy/listener.go @@ -7,12 +7,19 @@ import ( "fmt" "log" "net" + "sync" "sync/atomic" "time" + metrics "github.com/armon/go-metrics" "github.com/hashicorp/consul/connect" ) +const ( + publicListenerMetricPrefix = "inbound" + upstreamMetricPrefix = "upstream" +) + // Listener is the implementation of a specific proxy listener. It has pluggable // Listen and Dial methods to suit public mTLS vs upstream semantics. It handles // the lifecycle of the listener and all connections opened through it @@ -38,6 +45,12 @@ type Listener struct { listeningChan chan struct{} logger *log.Logger + + // Gauge to track current open connections + activeConns int64 + connWG sync.WaitGroup + metricPrefix string + metricLabels []metrics.Label } // NewPublicListener returns a Listener setup to listen for public mTLS @@ -58,6 +71,16 @@ func NewPublicListener(svc *connect.Service, cfg PublicListenerConfig, stopChan: make(chan struct{}), listeningChan: make(chan struct{}), logger: logger, + metricPrefix: publicListenerMetricPrefix, + // For now we only label ourselves as source - we could fetch the src + // service from cert on each connection and label metrics differently but it + // significaly complicates the active connection tracking here and it's not + // clear that it's very valuable - on aggregate looking at all _outbound_ + // connections across all proxies gets you a full picture of src->dst + // traffic. We might expand this later for better debugging of which clients + // are abusing a particular service instance but we'll see how valuable that + // seems for the extra complication of tracking many gauges here. + metricLabels: []metrics.Label{{Name: "dst", Value: svc.Name()}}, } } @@ -84,6 +107,13 @@ func NewUpstreamListener(svc *connect.Service, cfg UpstreamConfig, stopChan: make(chan struct{}), listeningChan: make(chan struct{}), logger: logger, + metricPrefix: upstreamMetricPrefix, + metricLabels: []metrics.Label{ + {Name: "src", Value: svc.Name()}, + // TODO(banks): namespace support + {Name: "dst_type", Value: cfg.DestinationType}, + {Name: "dst", Value: cfg.DestinationName}, + }, } } @@ -125,19 +155,87 @@ func (l *Listener) handleConn(src net.Conn) { l.logger.Printf("[ERR] failed to dial: %s", err) return } + + // Track active conn now (first function call) and defer un-counting it when + // it closes. + defer l.trackConn()() + // Note no need to defer dst.Close() since conn handles that for us. conn := NewConn(src, dst) defer conn.Close() - err = conn.CopyBytes() - if err != nil { - l.logger.Printf("[ERR] connection failed: %s", err) - return + connStop := make(chan struct{}) + + // Run another goroutine to copy the bytes. + go func() { + err = conn.CopyBytes() + if err != nil { + l.logger.Printf("[ERR] connection failed: %s", err) + return + } + close(connStop) + }() + + // Periodically copy stats from conn to metrics (to keep metrics calls out of + // the path of every single packet copy). 5 seconds is probably good enough + // resolution - statsd and most others tend to summarize with lower resolution + // anyway and this amortizes the cost more. + var tx, rx uint64 + statsT := time.NewTicker(5 * time.Second) + defer statsT.Stop() + + reportStats := func() { + newTx, newRx := conn.Stats() + if delta := newTx - tx; delta > 0 { + metrics.IncrCounterWithLabels([]string{l.metricPrefix, "tx_bytes"}, + float32(newTx-tx), l.metricLabels) + } + if delta := newRx - rx; delta > 0 { + metrics.IncrCounterWithLabels([]string{l.metricPrefix, "rx_bytes"}, + float32(newRx-rx), l.metricLabels) + } + tx, rx = newTx, newRx + } + // Always report final stats for the conn. + defer reportStats() + + // Wait for conn to close + for { + select { + case <-connStop: + return + case <-l.stopChan: + return + case <-statsT.C: + reportStats() + } + } +} + +// trackConn increments the count of active conns and returns a func() that can +// be deferred on to decrement the counter again on connection close. +func (l *Listener) trackConn() func() { + l.connWG.Add(1) + c := atomic.AddInt64(&l.activeConns, 1) + metrics.SetGaugeWithLabels([]string{l.metricPrefix, "conns"}, float32(c), + l.metricLabels) + + return func() { + l.connWG.Done() + c := atomic.AddInt64(&l.activeConns, -1) + metrics.SetGaugeWithLabels([]string{l.metricPrefix, "conns"}, float32(c), + l.metricLabels) } } // Close terminates the listener and all active connections. func (l *Listener) Close() error { + oldFlag := atomic.SwapInt32(&l.stopFlag, 1) + if oldFlag == 0 { + close(l.stopChan) + // Wait for all conns to close + l.connWG.Wait() + } return nil } diff --git a/connect/proxy/listener_test.go b/connect/proxy/listener_test.go index d63f5818bb..773e3a2db2 100644 --- a/connect/proxy/listener_test.go +++ b/connect/proxy/listener_test.go @@ -1,19 +1,92 @@ package proxy import ( + "bytes" "context" "fmt" "log" "net" "os" "testing" + "time" + + metrics "github.com/armon/go-metrics" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" agConnect "github.com/hashicorp/consul/agent/connect" "github.com/hashicorp/consul/connect" "github.com/hashicorp/consul/lib/freeport" - "github.com/stretchr/testify/require" ) +func testSetupMetrics(t *testing.T) *metrics.InmemSink { + // Record for ages so we can be confident that our assertions won't fail on + // silly long test runs due to dropped data. + s := metrics.NewInmemSink(10*time.Second, 300*time.Second) + cfg := metrics.DefaultConfig("consul.proxy.test") + cfg.EnableHostname = false + metrics.NewGlobal(cfg, s) + return s +} + +func assertCurrentGaugeValue(t *testing.T, sink *metrics.InmemSink, + name string, value float32) { + t.Helper() + + data := sink.Data() + + // Current interval is the last one + currentInterval := data[len(data)-1] + currentInterval.RLock() + defer currentInterval.RUnlock() + + assert.Equalf(t, value, currentInterval.Gauges[name].Value, + "gauge value mismatch. Current Interval:\n%v", currentInterval) +} + +func assertAllTimeCounterValue(t *testing.T, sink *metrics.InmemSink, + name string, value float64) { + t.Helper() + + data := sink.Data() + + var got float64 + for _, intv := range data { + intv.RLock() + // Note that InMemSink uses SampledValue and treats the _Sum_ not the Count + // as the entire value. + if sample, ok := intv.Counters[name]; ok { + got += sample.Sum + } + intv.RUnlock() + } + + if !assert.Equal(t, value, got) { + // no nice way to dump this - this is copied from private method in + // InMemSink used for dumping to stdout on SIGUSR1. + buf := bytes.NewBuffer(nil) + for _, intv := range data { + intv.RLock() + for _, val := range intv.Gauges { + fmt.Fprintf(buf, "[%v][G] '%s': %0.3f\n", intv.Interval, name, val.Value) + } + for name, vals := range intv.Points { + for _, val := range vals { + fmt.Fprintf(buf, "[%v][P] '%s': %0.3f\n", intv.Interval, name, val) + } + } + for _, agg := range intv.Counters { + fmt.Fprintf(buf, "[%v][C] '%s': %s\n", intv.Interval, name, agg.AggregateSample) + } + for _, agg := range intv.Samples { + fmt.Fprintf(buf, "[%v][S] '%s': %s\n", intv.Interval, name, agg.AggregateSample) + } + intv.RUnlock() + } + t.Log(buf.String()) + } +} + func TestPublicListener(t *testing.T) { t.Parallel() @@ -31,6 +104,9 @@ func TestPublicListener(t *testing.T) { LocalConnectTimeoutMs: 100, } + // Setup metrics to test they are recorded + sink := testSetupMetrics(t) + svc := connect.TestService(t, "db", ca) l := NewPublicListener(svc, cfg, log.New(os.Stderr, "", log.LstdFlags)) @@ -49,7 +125,19 @@ func TestPublicListener(t *testing.T) { CertURI: agConnect.TestSpiffeIDService(t, "db"), }) require.NoError(t, err) + TestEchoConn(t, conn, "") + + // Check active conn is tracked in gauges + assertCurrentGaugeValue(t, sink, "consul.proxy.test.inbound.conns;dst=db", 1) + + // Close listener to ensure all conns are closed and have reported their + // metrics + l.Close() + + // Check all the tx/rx counters got added + assertAllTimeCounterValue(t, sink, "consul.proxy.test.inbound.tx_bytes;dst=db", 11) + assertAllTimeCounterValue(t, sink, "consul.proxy.test.inbound.rx_bytes;dst=db", 11) } func TestUpstreamListener(t *testing.T) { @@ -80,6 +168,9 @@ func TestUpstreamListener(t *testing.T) { }, } + // Setup metrics to test they are recorded + sink := testSetupMetrics(t) + svc := connect.TestService(t, "web", ca) l := NewUpstreamListener(svc, cfg, log.New(os.Stderr, "", log.LstdFlags)) @@ -98,4 +189,15 @@ func TestUpstreamListener(t *testing.T) { fmt.Sprintf("%s:%d", cfg.LocalBindAddress, cfg.LocalBindPort)) require.NoError(t, err) TestEchoConn(t, conn, "") + + // Check active conn is tracked in gauges + assertCurrentGaugeValue(t, sink, "consul.proxy.test.upstream.conns;src=web;dst_type=service;dst=db", 1) + + // Close listener to ensure all conns are closed and have reported their + // metrics + l.Close() + + // Check all the tx/rx counters got added + assertAllTimeCounterValue(t, sink, "consul.proxy.test.upstream.tx_bytes;src=web;dst_type=service;dst=db", 11) + assertAllTimeCounterValue(t, sink, "consul.proxy.test.upstream.rx_bytes;src=web;dst_type=service;dst=db", 11) } diff --git a/connect/proxy/proxy.go b/connect/proxy/proxy.go index 0430528cd0..65a929b4a9 100644 --- a/connect/proxy/proxy.go +++ b/connect/proxy/proxy.go @@ -67,7 +67,8 @@ func (p *Proxy) Serve() error { tcfg := service.ServerTLSConfig() cert, _ := tcfg.GetCertificate(nil) leaf, _ := x509.ParseCertificate(cert.Certificate[0]) - p.logger.Printf("[DEBUG] leaf: %s roots: %s", leaf.URIs[0], bytes.Join(tcfg.RootCAs.Subjects(), []byte(","))) + p.logger.Printf("[DEBUG] leaf: %s roots: %s", leaf.URIs[0], + bytes.Join(tcfg.RootCAs.Subjects(), []byte(","))) }() // Only start a listener if we have a port set. This allows