// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
package middleware
import (
"net"
"reflect"
"strconv"
"strings"
"time"
"github.com/armon/go-metrics"
"github.com/armon/go-metrics/prometheus"
"github.com/hashicorp/go-hclog"
"github.com/hashicorp/consul-net-rpc/net/rpc"
rpcRate "github.com/hashicorp/consul/agent/consul/rate"
)
// RPCTypeInternal identifies the "RPC" request as coming from some internal
// operation that runs on the cluster leader. Technically this is not an RPC
// request, but these raft.Apply operations have the same impact on blocking
// queries, and streaming subscriptions, so need to be tracked by the same metric
// and logs.
// Really what we are measuring here is a "cluster operation". The term we have
// used for this historically is "RPC", so we continue to use that here.
const RPCTypeInternal = "internal"
const RPCTypeNetRPC = "net/rpc"
var metricRPCRequest = [ ] string { "rpc" , "server" , "call" }
var requestLogName = strings . Join ( metricRPCRequest , "_" )
var OneTwelveRPCSummary = [ ] prometheus . SummaryDefinition {
{
Name : metricRPCRequest ,
Help : "Measures the time an RPC service call takes to make in milliseconds. Labels mark which RPC method was called and metadata about the call." ,
} ,
}
type RequestRecorder struct {
Logger hclog . Logger
RecorderFunc func ( key [ ] string , val float32 , labels [ ] metrics . Label )
serverIsLeader func ( ) bool
localDC string
}
func NewRequestRecorder ( logger hclog . Logger , isLeader func ( ) bool , localDC string ) * RequestRecorder {
return & RequestRecorder {
Logger : logger ,
RecorderFunc : metrics . AddSampleWithLabels ,
serverIsLeader : isLeader ,
localDC : localDC ,
}
}
func ( r * RequestRecorder ) Record ( requestName string , rpcType string , start time . Time , request interface { } , respErrored bool ) {
elapsed := time . Since ( start ) . Microseconds ( )
elapsedMs := float32 ( elapsed ) / 1000
reqType := requestType ( request )
isLeader := r . getServerLeadership ( )
labels := [ ] metrics . Label {
{ Name : "method" , Value : requestName } ,
{ Name : "errored" , Value : strconv . FormatBool ( respErrored ) } ,
{ Name : "request_type" , Value : reqType } ,
{ Name : "rpc_type" , Value : rpcType } ,
{ Name : "leader" , Value : isLeader } ,
}
labels = r . addOptionalLabels ( request , labels )
// math.MaxInt64 < math.MaxFloat32 is true so we should be good!
r . RecorderFunc ( metricRPCRequest , elapsedMs , labels )
labelsArr := flattenLabels ( labels )
r . Logger . Trace ( requestLogName , labelsArr ... )
}
func flattenLabels ( labels [ ] metrics . Label ) [ ] interface { } {
var labelArr [ ] interface { }
for _ , label := range labels {
labelArr = append ( labelArr , label . Name , label . Value )
}
return labelArr
}
func ( r * RequestRecorder ) addOptionalLabels ( request interface { } , labels [ ] metrics . Label ) [ ] metrics . Label {
if rq , ok := request . ( readQuery ) ; ok {
labels = append ( labels ,
metrics . Label {
Name : "allow_stale" ,
Value : strconv . FormatBool ( rq . AllowStaleRead ( ) ) ,
} ,
metrics . Label {
Name : "blocking" ,
Value : strconv . FormatBool ( rq . GetMinQueryIndex ( ) > 0 ) ,
} )
}
if td , ok := request . ( targetDC ) ; ok {
requestDC := td . RequestDatacenter ( )
labels = append ( labels , metrics . Label { Name : "target_datacenter" , Value : requestDC } )
if r . localDC == requestDC {
labels = append ( labels , metrics . Label { Name : "locality" , Value : "local" } )
} else {
labels = append ( labels , metrics . Label { Name : "locality" , Value : "forwarded" } )
}
}
return labels
}
func requestType ( req interface { } ) string {
if r , ok := req . ( interface { IsRead ( ) bool } ) ; ok {
if r . IsRead ( ) {
return "read"
} else {
return "write"
}
}
// This logical branch should not happen. If it happens
// it means an underlying request is not implementing the interface.
// Rather than swallowing it up in a "read" or "write", let's be aware of it.
return "unreported"
}
func ( r * RequestRecorder ) getServerLeadership ( ) string {
if r . serverIsLeader != nil {
if r . serverIsLeader ( ) {
return "true"
} else {
return "false"
}
}
// This logical branch should not happen. If it happens
// it means that we have not plumbed down a way to verify
// whether the server handling the request was a leader or not
return "unreported"
}
type readQuery interface {
GetMinQueryIndex ( ) uint64
AllowStaleRead ( ) bool
}
type targetDC interface {
RequestDatacenter ( ) string
}
func GetNetRPCInterceptor ( recorder * RequestRecorder ) rpc . ServerServiceCallInterceptor {
return func ( reqServiceMethod string , argv , replyv reflect . Value , handler func ( ) error ) {
reqStart := time . Now ( )
err := handler ( )
recorder . Record ( reqServiceMethod , RPCTypeNetRPC , reqStart , argv . Interface ( ) , err != nil )
}
}
func GetNetRPCRateLimitingInterceptor ( requestLimitsHandler rpcRate . RequestLimitsHandler , panicHandler RecoveryHandlerFunc ) rpc . PreBodyInterceptor {
return func ( reqServiceMethod string , sourceAddr net . Addr ) ( retErr error ) {
defer func ( ) {
if r := recover ( ) ; r != nil {
retErr = panicHandler ( r )
}
} ( )
op := rpcRate . Operation {
Name : reqServiceMethod ,
SourceAddr : sourceAddr ,
Type : rpcRateLimitSpecs [ reqServiceMethod ] . Type ,
Category : rpcRateLimitSpecs [ reqServiceMethod ] . Category ,
}
// net/rpc does not provide a way to encode the nuances of the
// error response (retry or retry elsewhere) so the error string
// from the rate limiter is all that we have.
return requestLimitsHandler . Allow ( op )
}
}
func ChainedRPCPreBodyInterceptor ( chain ... rpc . PreBodyInterceptor ) rpc . PreBodyInterceptor {
if len ( chain ) == 0 {
panic ( "don't call this with zero interceptors" )
}
if len ( chain ) == 1 {
return chain [ 0 ]
}
return func ( reqServiceMethod string , sourceAddr net . Addr ) error {
for _ , interceptor := range chain {
if err := interceptor ( reqServiceMethod , sourceAddr ) ; err != nil {
return err
}
}
return nil
}
}