// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
package agent
import (
"context"
"strings"
"testing"
"github.com/hashicorp/go-uuid"
"github.com/stretchr/testify/require"
"github.com/hashicorp/consul/acl"
"github.com/hashicorp/consul/agent/structs"
"github.com/hashicorp/consul/sdk/testutil/retry"
)
func TestValidateUserEventParams ( t * testing . T ) {
t . Parallel ( )
p := & UserEvent { }
err := validateUserEventParams ( p )
if err == nil || err . Error ( ) != "User event missing name" {
t . Fatalf ( "err: %v" , err )
}
p . Name = "foo"
p . NodeFilter = "("
err = validateUserEventParams ( p )
if err == nil || ! strings . Contains ( err . Error ( ) , "Invalid node filter" ) {
t . Fatalf ( "err: %v" , err )
}
p . NodeFilter = ""
p . ServiceFilter = "("
err = validateUserEventParams ( p )
if err == nil || ! strings . Contains ( err . Error ( ) , "Invalid service filter" ) {
t . Fatalf ( "err: %v" , err )
}
p . ServiceFilter = "foo"
p . TagFilter = "("
err = validateUserEventParams ( p )
if err == nil || ! strings . Contains ( err . Error ( ) , "Invalid tag filter" ) {
t . Fatalf ( "err: %v" , err )
}
p . ServiceFilter = ""
p . TagFilter = "foo"
err = validateUserEventParams ( p )
if err == nil || ! strings . Contains ( err . Error ( ) , "tag filter without service" ) {
t . Fatalf ( "err: %v" , err )
}
}
func TestShouldProcessUserEvent ( t * testing . T ) {
if testing . Short ( ) {
t . Skip ( "too slow for testing.Short" )
}
t . Parallel ( )
a := NewTestAgent ( t , "" )
defer a . Shutdown ( )
srv1 := & structs . NodeService {
ID : "mysql" ,
Service : "mysql" ,
Tags : [ ] string { "test" , "foo" , "bar" , "primary" } ,
Port : 5000 ,
}
a . State . AddServiceWithChecks ( srv1 , nil , "" , false )
p := & UserEvent { }
if ! a . shouldProcessUserEvent ( p ) {
t . Fatalf ( "bad" )
}
// Bad node name
p = & UserEvent {
NodeFilter : "foobar" ,
}
if a . shouldProcessUserEvent ( p ) {
t . Fatalf ( "bad" )
}
// Good node name
p = & UserEvent {
NodeFilter : "^Node" ,
}
if ! a . shouldProcessUserEvent ( p ) {
t . Fatalf ( "bad" )
}
// Bad service name
p = & UserEvent {
ServiceFilter : "foobar" ,
}
if a . shouldProcessUserEvent ( p ) {
t . Fatalf ( "bad" )
}
// Good service name
p = & UserEvent {
ServiceFilter : ".*sql" ,
}
if ! a . shouldProcessUserEvent ( p ) {
t . Fatalf ( "bad" )
}
// Bad tag name
p = & UserEvent {
ServiceFilter : ".*sql" ,
TagFilter : "replica" ,
}
if a . shouldProcessUserEvent ( p ) {
t . Fatalf ( "bad" )
}
// Good service name
p = & UserEvent {
ServiceFilter : ".*sql" ,
TagFilter : "primary" ,
}
if ! a . shouldProcessUserEvent ( p ) {
t . Fatalf ( "bad" )
}
}
func TestIngestUserEvent ( t * testing . T ) {
if testing . Short ( ) {
t . Skip ( "too slow for testing.Short" )
}
t . Parallel ( )
a := NewTestAgent ( t , "" )
defer a . Shutdown ( )
for i := 0 ; i < 512 ; i ++ {
msg := & UserEvent { LTime : uint64 ( i ) , Name : "test" }
a . ingestUserEvent ( msg )
if a . LastUserEvent ( ) != msg {
t . Fatalf ( "bad: %#v" , msg )
}
events := a . UserEvents ( )
expectLen := 256
if i < 256 {
expectLen = i + 1
}
if len ( events ) != expectLen {
t . Fatalf ( "bad: %d %d %d" , i , expectLen , len ( events ) )
}
counter := i
for j := len ( events ) - 1 ; j >= 0 ; j -- {
if events [ j ] . LTime != uint64 ( counter ) {
t . Fatalf ( "bad: %#v" , events )
}
counter --
}
}
}
func TestFireReceiveEvent ( t * testing . T ) {
if testing . Short ( ) {
t . Skip ( "too slow for testing.Short" )
}
t . Parallel ( )
a := NewTestAgent ( t , "" )
defer a . Shutdown ( )
srv1 := & structs . NodeService {
ID : "mysql" ,
Service : "mysql" ,
Tags : [ ] string { "test" , "foo" , "bar" , "primary" } ,
Port : 5000 ,
}
a . State . AddServiceWithChecks ( srv1 , nil , "" , false )
p1 := & UserEvent { Name : "deploy" , ServiceFilter : "web" }
err := a . UserEvent ( "dc1" , "root" , p1 )
if err != nil {
t . Fatalf ( "err: %v" , err )
}
p2 := & UserEvent { Name : "deploy" }
err = a . UserEvent ( "dc1" , "root" , p2 )
if err != nil {
t . Fatalf ( "err: %v" , err )
}
retry . Run ( t , func ( r * retry . R ) {
if got , want := len ( a . UserEvents ( ) ) , 1 ; got != want {
r . Fatalf ( "got %d events want %d" , got , want )
}
} )
last := a . LastUserEvent ( )
if last . ID != p2 . ID {
t . Fatalf ( "bad: %#v" , last )
}
}
func TestUserEventToken ( t * testing . T ) {
if testing . Short ( ) {
t . Skip ( "too slow for testing.Short" )
}
t . Parallel ( )
a := NewTestAgent ( t , TestACLConfig ( ) + `
acl_default_policy = "deny"
` )
defer a . Shutdown ( )
token := createToken ( t , a , testEventPolicy )
type tcase struct {
name string
expect bool
}
cases := [ ] tcase {
{ "foo" , false } ,
{ "bar" , false } ,
{ "baz" , true } ,
{ "zip" , false } ,
}
for _ , c := range cases {
event := & UserEvent { Name : c . name }
err := a . UserEvent ( "dc1" , token , event )
allowed := ! acl . IsErrPermissionDenied ( err )
if allowed != c . expect {
t . Fatalf ( "bad: %#v result: %v" , c , allowed )
}
}
}
type RPC interface {
RPC ( ctx context . Context , method string , args interface { } , reply interface { } ) error
}
func createToken ( t * testing . T , rpc RPC , policyRules string ) string {
t . Helper ( )
reqPolicy := structs . ACLPolicySetRequest {
Datacenter : "dc1" ,
Policy : structs . ACLPolicy {
Name : "the-policy" ,
Rules : policyRules ,
} ,
WriteRequest : structs . WriteRequest { Token : "root" } ,
}
err := rpc . RPC ( context . Background ( ) , "ACL.PolicySet" , & reqPolicy , & structs . ACLPolicy { } )
require . NoError ( t , err )
token , err := uuid . GenerateUUID ( )
require . NoError ( t , err )
reqToken := structs . ACLTokenSetRequest {
Datacenter : "dc1" ,
ACLToken : structs . ACLToken {
SecretID : token ,
Policies : [ ] structs . ACLTokenPolicyLink { { Name : "the-policy" } } ,
} ,
WriteRequest : structs . WriteRequest { Token : "root" } ,
}
err = rpc . RPC ( context . Background ( ) , "ACL.TokenSet" , & reqToken , & structs . ACLToken { } )
require . NoError ( t , err )
return token
}
const testEventPolicy = `
event "foo" {
policy = "deny"
}
event "bar" {
policy = "read"
}
event "baz" {
policy = "write"
}
`