mirror of https://github.com/hashicorp/consul
170 lines
4.5 KiB
Go
170 lines
4.5 KiB
Go
// Copyright (c) HashiCorp, Inc.
|
|
// SPDX-License-Identifier: MPL-2.0
|
|
|
|
package bimapper
|
|
|
|
import (
|
|
"context"
|
|
"testing"
|
|
|
|
"github.com/stretchr/testify/require"
|
|
|
|
"github.com/hashicorp/consul/internal/controller"
|
|
rtest "github.com/hashicorp/consul/internal/resource/resourcetest"
|
|
"github.com/hashicorp/consul/proto-public/pbresource"
|
|
"github.com/hashicorp/consul/proto/private/prototest"
|
|
)
|
|
|
|
const (
|
|
fakeGroupName = "catalog"
|
|
fakeVersion = "v1"
|
|
)
|
|
|
|
var (
|
|
fakeFooType = &pbresource.Type{
|
|
Group: fakeGroupName,
|
|
GroupVersion: fakeVersion,
|
|
Kind: "Foo",
|
|
}
|
|
fakeBarType = &pbresource.Type{
|
|
Group: fakeGroupName,
|
|
GroupVersion: fakeVersion,
|
|
Kind: "Bar",
|
|
}
|
|
)
|
|
|
|
func TestMapper(t *testing.T) {
|
|
// Create an advance pointer to some services.
|
|
|
|
randoSvc := rtest.Resource(fakeBarType, "rando").Build()
|
|
apiSvc := rtest.Resource(fakeBarType, "api").Build()
|
|
fooSvc := rtest.Resource(fakeBarType, "foo").Build()
|
|
barSvc := rtest.Resource(fakeBarType, "bar").Build()
|
|
wwwSvc := rtest.Resource(fakeBarType, "www").Build()
|
|
|
|
fail1 := rtest.Resource(fakeFooType, "api").Build()
|
|
fail1_refs := []*pbresource.Reference{
|
|
newRef(fakeBarType, "api"),
|
|
newRef(fakeBarType, "foo"),
|
|
newRef(fakeBarType, "bar"),
|
|
}
|
|
|
|
fail2 := rtest.Resource(fakeFooType, "www").Build()
|
|
fail2_refs := []*pbresource.Reference{
|
|
newRef(fakeBarType, "www"),
|
|
newRef(fakeBarType, "foo"),
|
|
}
|
|
|
|
fail1_updated := rtest.Resource(fakeFooType, "api").Build()
|
|
fail1_updated_refs := []*pbresource.Reference{
|
|
newRef(fakeBarType, "api"),
|
|
newRef(fakeBarType, "bar"),
|
|
}
|
|
|
|
m := New(fakeFooType, fakeBarType)
|
|
|
|
// Nothing tracked yet so we assume nothing.
|
|
requireServicesTracked(t, m, randoSvc)
|
|
requireServicesTracked(t, m, apiSvc)
|
|
requireServicesTracked(t, m, fooSvc)
|
|
requireServicesTracked(t, m, barSvc)
|
|
requireServicesTracked(t, m, wwwSvc)
|
|
|
|
// no-ops
|
|
m.UntrackItem(fail1.Id)
|
|
|
|
// still nothing
|
|
requireServicesTracked(t, m, randoSvc)
|
|
requireServicesTracked(t, m, apiSvc)
|
|
requireServicesTracked(t, m, fooSvc)
|
|
requireServicesTracked(t, m, barSvc)
|
|
requireServicesTracked(t, m, wwwSvc)
|
|
|
|
// Actually insert some data.
|
|
m.TrackItem(fail1.Id, fail1_refs)
|
|
|
|
requireServicesTracked(t, m, randoSvc)
|
|
requireServicesTracked(t, m, apiSvc, fail1.Id)
|
|
requireServicesTracked(t, m, fooSvc, fail1.Id)
|
|
requireServicesTracked(t, m, barSvc, fail1.Id)
|
|
requireServicesTracked(t, m, wwwSvc)
|
|
|
|
// track it again, no change
|
|
m.TrackItem(fail1.Id, fail1_refs)
|
|
|
|
requireServicesTracked(t, m, randoSvc)
|
|
requireServicesTracked(t, m, apiSvc, fail1.Id)
|
|
requireServicesTracked(t, m, fooSvc, fail1.Id)
|
|
requireServicesTracked(t, m, barSvc, fail1.Id)
|
|
requireServicesTracked(t, m, wwwSvc)
|
|
|
|
// track new one that overlaps slightly
|
|
m.TrackItem(fail2.Id, fail2_refs)
|
|
|
|
requireServicesTracked(t, m, randoSvc)
|
|
requireServicesTracked(t, m, apiSvc, fail1.Id)
|
|
requireServicesTracked(t, m, fooSvc, fail1.Id, fail2.Id)
|
|
requireServicesTracked(t, m, barSvc, fail1.Id)
|
|
requireServicesTracked(t, m, wwwSvc, fail2.Id)
|
|
|
|
// update the original to change it
|
|
m.TrackItem(fail1_updated.Id, fail1_updated_refs)
|
|
|
|
requireServicesTracked(t, m, randoSvc)
|
|
requireServicesTracked(t, m, apiSvc, fail1.Id)
|
|
requireServicesTracked(t, m, fooSvc, fail2.Id)
|
|
requireServicesTracked(t, m, barSvc, fail1.Id)
|
|
requireServicesTracked(t, m, wwwSvc, fail2.Id)
|
|
|
|
// delete the original
|
|
m.UntrackItem(fail1.Id)
|
|
|
|
requireServicesTracked(t, m, randoSvc)
|
|
requireServicesTracked(t, m, apiSvc)
|
|
requireServicesTracked(t, m, fooSvc, fail2.Id)
|
|
requireServicesTracked(t, m, barSvc)
|
|
requireServicesTracked(t, m, wwwSvc, fail2.Id)
|
|
|
|
// delete the other one
|
|
m.UntrackItem(fail2.Id)
|
|
|
|
requireServicesTracked(t, m, randoSvc)
|
|
requireServicesTracked(t, m, apiSvc)
|
|
requireServicesTracked(t, m, fooSvc)
|
|
requireServicesTracked(t, m, barSvc)
|
|
requireServicesTracked(t, m, wwwSvc)
|
|
}
|
|
|
|
func requireServicesTracked(t *testing.T, mapper *Mapper, link *pbresource.Resource, items ...*pbresource.ID) {
|
|
t.Helper()
|
|
|
|
reqs, err := mapper.MapLink(
|
|
context.Background(),
|
|
controller.Runtime{},
|
|
link,
|
|
)
|
|
require.NoError(t, err)
|
|
|
|
require.Len(t, reqs, len(items))
|
|
|
|
for _, item := range items {
|
|
prototest.AssertContainsElement(t, reqs, controller.Request{ID: item})
|
|
}
|
|
}
|
|
|
|
func newRef(typ *pbresource.Type, name string) *pbresource.Reference {
|
|
return rtest.Resource(typ, name).Reference("")
|
|
}
|
|
|
|
func newID(typ *pbresource.Type, name string) *pbresource.ID {
|
|
return rtest.Resource(typ, name).ID()
|
|
}
|
|
|
|
func defaultTenancy() *pbresource.Tenancy {
|
|
return &pbresource.Tenancy{
|
|
Partition: "default",
|
|
Namespace: "default",
|
|
PeerName: "local",
|
|
}
|
|
}
|