mirror of https://github.com/hashicorp/consul
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
292 lines
8.3 KiB
292 lines
8.3 KiB
package proxy |
|
|
|
import ( |
|
"context" |
|
"crypto/tls" |
|
"errors" |
|
"net" |
|
"sync" |
|
"sync/atomic" |
|
"time" |
|
|
|
metrics "github.com/armon/go-metrics" |
|
"github.com/hashicorp/go-hclog" |
|
|
|
"github.com/hashicorp/consul/api" |
|
"github.com/hashicorp/consul/connect" |
|
"github.com/hashicorp/consul/ipaddr" |
|
) |
|
|
|
const ( |
|
publicListenerPrefix = "inbound" |
|
upstreamListenerPrefix = "upstream" |
|
) |
|
|
|
// Listener is the implementation of a specific proxy listener. It has pluggable |
|
// Listen and Dial methods to suit public mTLS vs upstream semantics. It handles |
|
// the lifecycle of the listener and all connections opened through it |
|
type Listener struct { |
|
// Service is the connect service instance to use. |
|
Service *connect.Service |
|
|
|
// listenFunc, dialFunc, and bindAddr are set by type-specific constructors. |
|
listenFunc func() (net.Listener, error) |
|
dialFunc func() (net.Conn, error) |
|
bindAddr string |
|
|
|
stopFlag int32 |
|
stopChan chan struct{} |
|
|
|
// listeningChan is closed when listener is opened successfully. It's really |
|
// only for use in tests where we need to coordinate wait for the Serve |
|
// goroutine to be running before we proceed trying to connect. On my laptop |
|
// this always works out anyway but on constrained VMs and especially docker |
|
// containers (e.g. in CI) we often see the Dial routine win the race and get |
|
// `connection refused`. Retry loops and sleeps are unpleasant workarounds and |
|
// this is cheap and correct. |
|
listeningChan chan struct{} |
|
|
|
// listenerLock guards access to the listener field |
|
listenerLock sync.Mutex |
|
listener net.Listener |
|
|
|
logger hclog.Logger |
|
|
|
// Gauge to track current open connections |
|
activeConns int32 |
|
connWG sync.WaitGroup |
|
metricPrefix string |
|
metricLabels []metrics.Label |
|
} |
|
|
|
// NewPublicListener returns a Listener setup to listen for public mTLS |
|
// connections and proxy them to the configured local application over TCP. |
|
func NewPublicListener(svc *connect.Service, cfg PublicListenerConfig, |
|
logger hclog.Logger) *Listener { |
|
bindAddr := ipaddr.FormatAddressPort(cfg.BindAddress, cfg.BindPort) |
|
return &Listener{ |
|
Service: svc, |
|
listenFunc: func() (net.Listener, error) { |
|
return tls.Listen("tcp", bindAddr, svc.ServerTLSConfig()) |
|
}, |
|
dialFunc: func() (net.Conn, error) { |
|
return net.DialTimeout("tcp", cfg.LocalServiceAddress, |
|
time.Duration(cfg.LocalConnectTimeoutMs)*time.Millisecond) |
|
}, |
|
bindAddr: bindAddr, |
|
stopChan: make(chan struct{}), |
|
listeningChan: make(chan struct{}), |
|
logger: logger.Named(publicListenerPrefix), |
|
metricPrefix: publicListenerPrefix, |
|
// For now we only label ourselves as source - we could fetch the src |
|
// service from cert on each connection and label metrics differently but it |
|
// significaly complicates the active connection tracking here and it's not |
|
// clear that it's very valuable - on aggregate looking at all _outbound_ |
|
// connections across all proxies gets you a full picture of src->dst |
|
// traffic. We might expand this later for better debugging of which clients |
|
// are abusing a particular service instance but we'll see how valuable that |
|
// seems for the extra complication of tracking many gauges here. |
|
metricLabels: []metrics.Label{{Name: "dst", Value: svc.Name()}}, |
|
} |
|
} |
|
|
|
// NewUpstreamListener returns a Listener setup to listen locally for TCP |
|
// connections that are proxied to a discovered Connect service instance. |
|
func NewUpstreamListener(svc *connect.Service, client *api.Client, |
|
cfg UpstreamConfig, logger hclog.Logger) *Listener { |
|
return newUpstreamListenerWithResolver(svc, cfg, |
|
UpstreamResolverFuncFromClient(client), logger) |
|
} |
|
|
|
func newUpstreamListenerWithResolver(svc *connect.Service, cfg UpstreamConfig, |
|
resolverFunc func(UpstreamConfig) (connect.Resolver, error), |
|
logger hclog.Logger) *Listener { |
|
bindAddr := ipaddr.FormatAddressPort(cfg.LocalBindAddress, cfg.LocalBindPort) |
|
return &Listener{ |
|
Service: svc, |
|
listenFunc: func() (net.Listener, error) { |
|
return net.Listen("tcp", bindAddr) |
|
}, |
|
dialFunc: func() (net.Conn, error) { |
|
rf, err := resolverFunc(cfg) |
|
if err != nil { |
|
return nil, err |
|
} |
|
ctx, cancel := context.WithTimeout(context.Background(), |
|
cfg.ConnectTimeout()) |
|
defer cancel() |
|
return svc.Dial(ctx, rf) |
|
}, |
|
bindAddr: bindAddr, |
|
stopChan: make(chan struct{}), |
|
listeningChan: make(chan struct{}), |
|
logger: logger.Named(upstreamListenerPrefix), |
|
metricPrefix: upstreamListenerPrefix, |
|
metricLabels: []metrics.Label{ |
|
{Name: "src", Value: svc.Name()}, |
|
// TODO(banks): namespace support |
|
{Name: "dst_type", Value: string(cfg.DestinationType)}, |
|
{Name: "dst", Value: cfg.DestinationName}, |
|
}, |
|
} |
|
} |
|
|
|
// Serve runs the listener until it is stopped. It is an error to call Serve |
|
// more than once for any given Listener instance. |
|
func (l *Listener) Serve() error { |
|
// Ensure we mark state closed if we fail before Close is called externally. |
|
defer l.Close() |
|
|
|
if atomic.LoadInt32(&l.stopFlag) != 0 { |
|
return errors.New("serve called on a closed listener") |
|
} |
|
|
|
listener, err := l.listenFunc() |
|
if err != nil { |
|
return err |
|
} |
|
|
|
l.setListener(listener) |
|
|
|
close(l.listeningChan) |
|
|
|
for { |
|
conn, err := listener.Accept() |
|
if err != nil { |
|
if atomic.LoadInt32(&l.stopFlag) == 1 { |
|
return nil |
|
} |
|
return err |
|
} |
|
|
|
l.connWG.Add(1) |
|
go l.handleConn(conn) |
|
} |
|
} |
|
|
|
// handleConn is the internal connection handler goroutine. |
|
func (l *Listener) handleConn(src net.Conn) { |
|
defer func() { |
|
// Make sure Listener.Close waits for this conn to be cleaned up. |
|
src.Close() |
|
l.connWG.Done() |
|
}() |
|
|
|
dst, err := l.dialFunc() |
|
if err != nil { |
|
l.logger.Error("failed to dial", "error", err) |
|
return |
|
} |
|
|
|
// Track active conn now (first function call) and defer un-counting it when |
|
// it closes. |
|
defer l.trackConn()() |
|
|
|
// Note no need to defer dst.Close() since conn handles that for us. |
|
conn := NewConn(src, dst) |
|
defer conn.Close() |
|
|
|
connStop := make(chan struct{}) |
|
|
|
// Run another goroutine to copy the bytes. |
|
go func() { |
|
err = conn.CopyBytes() |
|
if err != nil { |
|
l.logger.Error("connection failed", "error", err) |
|
} |
|
close(connStop) |
|
}() |
|
|
|
// Periodically copy stats from conn to metrics (to keep metrics calls out of |
|
// the path of every single packet copy). 5 seconds is probably good enough |
|
// resolution - statsd and most others tend to summarize with lower resolution |
|
// anyway and this amortizes the cost more. |
|
var tx, rx uint64 |
|
statsT := time.NewTicker(5 * time.Second) |
|
defer statsT.Stop() |
|
|
|
reportStats := func() { |
|
newTx, newRx := conn.Stats() |
|
if delta := newTx - tx; delta > 0 { |
|
metrics.IncrCounterWithLabels([]string{l.metricPrefix, "tx_bytes"}, |
|
float32(newTx-tx), l.metricLabels) |
|
} |
|
if delta := newRx - rx; delta > 0 { |
|
metrics.IncrCounterWithLabels([]string{l.metricPrefix, "rx_bytes"}, |
|
float32(newRx-rx), l.metricLabels) |
|
} |
|
tx, rx = newTx, newRx |
|
} |
|
// Always report final stats for the conn. |
|
defer reportStats() |
|
|
|
// Wait for conn to close |
|
for { |
|
select { |
|
case <-connStop: |
|
return |
|
case <-l.stopChan: |
|
return |
|
case <-statsT.C: |
|
reportStats() |
|
} |
|
} |
|
} |
|
|
|
// trackConn increments the count of active conns and returns a func() that can |
|
// be deferred on to decrement the counter again on connection close. |
|
func (l *Listener) trackConn() func() { |
|
c := atomic.AddInt32(&l.activeConns, 1) |
|
metrics.SetGaugeWithLabels([]string{l.metricPrefix, "conns"}, float32(c), |
|
l.metricLabels) |
|
|
|
return func() { |
|
c := atomic.AddInt32(&l.activeConns, -1) |
|
metrics.SetGaugeWithLabels([]string{l.metricPrefix, "conns"}, float32(c), |
|
l.metricLabels) |
|
} |
|
} |
|
|
|
// Close terminates the listener and all active connections. |
|
func (l *Listener) Close() error { |
|
// Prevent the listener from being started. |
|
oldFlag := atomic.SwapInt32(&l.stopFlag, 1) |
|
if oldFlag != 0 { |
|
return nil |
|
} |
|
|
|
// Stop the current listener and stop accepting new requests. |
|
if listener := l.getListener(); listener != nil { |
|
listener.Close() |
|
} |
|
|
|
// Stop outstanding requests. |
|
close(l.stopChan) |
|
|
|
// Wait for all conns to close |
|
l.connWG.Wait() |
|
|
|
return nil |
|
} |
|
|
|
// Wait for the listener to be ready to accept connections. |
|
func (l *Listener) Wait() { |
|
<-l.listeningChan |
|
} |
|
|
|
// BindAddr returns the address the listen is bound to. |
|
func (l *Listener) BindAddr() string { |
|
return l.bindAddr |
|
} |
|
|
|
func (l *Listener) setListener(listener net.Listener) { |
|
l.listenerLock.Lock() |
|
l.listener = listener |
|
l.listenerLock.Unlock() |
|
} |
|
|
|
func (l *Listener) getListener() net.Listener { |
|
l.listenerLock.Lock() |
|
defer l.listenerLock.Unlock() |
|
return l.listener |
|
}
|
|
|