// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
package middleware
import (
"fmt"
"net"
"google.golang.org/grpc/credentials"
)
var _ net . Listener = ( * LabelledListener ) ( nil )
var _ net . Conn = ( * LabelledConn ) ( nil )
type Protocol int
var (
ProtocolPlaintext Protocol = 0
ProtocolTLS Protocol = 1
)
// LabelledListener wraps a listener and attaches pre-specified
// fields to each spawned connection.
type LabelledListener struct {
net . Listener
Protocol Protocol
}
func ( l LabelledListener ) Accept ( ) ( net . Conn , error ) {
conn , err := l . Listener . Accept ( )
if conn != nil {
conn = LabelledConn { conn , l . Protocol }
}
return conn , err
}
// LabelledConn wraps a connection and provides extra metadata fields.
type LabelledConn struct {
net . Conn
protocol Protocol
}
var _ credentials . TransportCredentials = ( * optionalTransportCredentials ) ( nil )
// optionalTransportCredentials provides a way to selectively perform a TLS handshake
// based on metadata extracted from the underlying connection object.
type optionalTransportCredentials struct {
credentials . TransportCredentials
logger Logger
}
func NewOptionalTransportCredentials ( creds credentials . TransportCredentials , logger Logger ) credentials . TransportCredentials {
return & optionalTransportCredentials { creds , logger }
}
// ServerHandshake will attempt to detect the underlying connection protocol (TLS or Plaintext)
// based on metadata attached to the underlying connection. If TLS is detected, then a handshake
// will be performed, and the corresponding AuthInfo will be attached to the gRPC context.
// For plaintext connections, this is effectively a no-op, since there is no TLS info to attach.
// If the underlying connection is not a LabelledConn with a valid protocol, then this method will
// panic and prevent the gRPC connection from successfully progressing further.
func ( tc * optionalTransportCredentials ) ServerHandshake ( conn net . Conn ) ( net . Conn , credentials . AuthInfo , error ) {
// This should always be a LabelledConn, so no check is necessary.
nc := conn . ( LabelledConn )
switch nc . protocol {
case ProtocolPlaintext :
// This originated from a plaintext listener, so do not use TLS auth.
return nc , nil , nil
case ProtocolTLS :
// This originated from a TLS listener, so it should have a full handshake performed.
c , ai , err := tc . TransportCredentials . ServerHandshake ( conn )
if err == nil && ai == nil {
// This should not be possible, but ensure that it's non-nil for safety.
return nil , nil , fmt . Errorf ( "missing auth info after handshake" )
}
return c , ai , err
default :
return nil , nil , fmt . Errorf ( "invalid protocol for grpc connection" )
}
}