mirror of https://github.com/hashicorp/consul
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
234 lines
5.5 KiB
234 lines
5.5 KiB
// Copyright (c) HashiCorp, Inc. |
|
// SPDX-License-Identifier: MPL-2.0 |
|
|
|
package freeport |
|
|
|
import ( |
|
"fmt" |
|
"io" |
|
"net" |
|
"testing" |
|
|
|
"github.com/hashicorp/consul/sdk/testutil/retry" |
|
) |
|
|
|
func TestTakeReturn(t *testing.T) { |
|
// NOTE: for global var reasons this cannot execute in parallel |
|
// t.Parallel() |
|
|
|
// Since this test is destructive (i.e. it leaks all ports) it means that |
|
// any other test cases in this package will not function after it runs. To |
|
// help out we reset the global state after we run this test. |
|
defer reset() |
|
|
|
// OK: do a simple take/return cycle to trigger the package initialization |
|
func() { |
|
ports, err := Take(1) |
|
if err != nil { |
|
t.Fatalf("err: %v", err) |
|
} |
|
defer Return(ports) |
|
|
|
if len(ports) != 1 { |
|
t.Fatalf("expected %d but got %d ports", 1, len(ports)) |
|
} |
|
}() |
|
|
|
waitForStatsReset := func() (numTotal int) { |
|
t.Helper() |
|
numTotal, numPending, numFree := stats() |
|
if numTotal != numFree+numPending { |
|
t.Fatalf("expected total (%d) and free+pending (%d) ports to match", numTotal, numFree+numPending) |
|
} |
|
retry.Run(t, func(r *retry.R) { |
|
numTotal, numPending, numFree = stats() |
|
if numPending != 0 { |
|
r.Fatalf("pending is still non zero: %d", numPending) |
|
} |
|
if numTotal != numFree { |
|
r.Fatalf("total (%d) does not equal free (%d)", numTotal, numFree) |
|
} |
|
}) |
|
return numTotal |
|
} |
|
|
|
// Reset |
|
numTotal := waitForStatsReset() |
|
|
|
// -------------------- |
|
// OK: take the max |
|
func() { |
|
ports, err := Take(numTotal) |
|
if err != nil { |
|
t.Fatalf("err: %v", err) |
|
} |
|
defer Return(ports) |
|
|
|
if len(ports) != numTotal { |
|
t.Fatalf("expected %d but got %d ports", numTotal, len(ports)) |
|
} |
|
}() |
|
|
|
// Reset |
|
numTotal = waitForStatsReset() |
|
|
|
expectError := func(expected string, got error) { |
|
t.Helper() |
|
if got == nil { |
|
t.Fatalf("expected error but was nil") |
|
} |
|
if got.Error() != expected { |
|
t.Fatalf("expected error %q but got %q", expected, got.Error()) |
|
} |
|
} |
|
|
|
// -------------------- |
|
// ERROR: take too many ports |
|
func() { |
|
ports, err := Take(numTotal + 1) |
|
defer Return(ports) |
|
expectError("freeport: block size too small", err) |
|
}() |
|
|
|
// -------------------- |
|
// ERROR: invalid ports request (negative) |
|
func() { |
|
_, err := Take(-1) |
|
expectError("freeport: cannot take -1 ports", err) |
|
}() |
|
|
|
// -------------------- |
|
// ERROR: invalid ports request (zero) |
|
func() { |
|
_, err := Take(0) |
|
expectError("freeport: cannot take 0 ports", err) |
|
}() |
|
|
|
// -------------------- |
|
// OK: Steal a port under the covers and let freeport detect the theft and compensate |
|
leakedPort := peekFree() |
|
func() { |
|
leakyListener, err := net.ListenTCP("tcp", tcpAddr("127.0.0.1", leakedPort)) |
|
if err != nil { |
|
t.Fatalf("err: %v", err) |
|
} |
|
defer leakyListener.Close() |
|
|
|
func() { |
|
ports, err := Take(3) |
|
if err != nil { |
|
t.Fatalf("err: %v", err) |
|
} |
|
defer Return(ports) |
|
|
|
if len(ports) != 3 { |
|
t.Fatalf("expected %d but got %d ports", 3, len(ports)) |
|
} |
|
|
|
for _, port := range ports { |
|
if port == leakedPort { |
|
t.Fatalf("did not expect for Take to return the leaked port") |
|
} |
|
} |
|
}() |
|
|
|
newNumTotal := waitForStatsReset() |
|
if newNumTotal != numTotal-1 { |
|
t.Fatalf("expected total to drop to %d but got %d", numTotal-1, newNumTotal) |
|
} |
|
numTotal = newNumTotal // update outer variable for later tests |
|
}() |
|
|
|
// -------------------- |
|
// OK: sequence it so that one Take must wait on another Take to Return. |
|
func() { |
|
mostPorts, err := Take(numTotal - 5) |
|
if err != nil { |
|
t.Fatalf("err: %v", err) |
|
} |
|
|
|
type reply struct { |
|
ports []int |
|
err error |
|
} |
|
ch := make(chan reply, 1) |
|
go func() { |
|
ports, err := Take(10) |
|
ch <- reply{ports: ports, err: err} |
|
}() |
|
|
|
Return(mostPorts) |
|
|
|
r := <-ch |
|
if r.err != nil { |
|
t.Fatalf("err: %v", r.err) |
|
} |
|
defer Return(r.ports) |
|
|
|
if len(r.ports) != 10 { |
|
t.Fatalf("expected %d ports but got %d", 10, len(r.ports)) |
|
} |
|
}() |
|
|
|
// Reset |
|
numTotal = waitForStatsReset() |
|
|
|
// -------------------- |
|
// ERROR: Now we end on the crazy "Ocean's 11" level port theft where we |
|
// orchestrate a situation where all ports are stolen and we don't find out |
|
// until Take. |
|
func() { |
|
// 1. Grab all of the ports. |
|
allPorts := peekAllFree() |
|
|
|
// 2. Leak all of the ports |
|
leaked := make([]io.Closer, 0, len(allPorts)) |
|
defer func() { |
|
for _, c := range leaked { |
|
c.Close() |
|
} |
|
}() |
|
for i, port := range allPorts { |
|
ln, err := net.ListenTCP("tcp", tcpAddr("127.0.0.1", port)) |
|
if err != nil { |
|
t.Fatalf("%d err: %v", i, err) |
|
} |
|
leaked = append(leaked, ln) |
|
} |
|
|
|
// 3. Request 1 port which will detect the leaked ports and fail. |
|
_, err := Take(1) |
|
expectError("freeport: impossible to satisfy request; there are no actual free ports in the block anymore", err) |
|
|
|
// 4. Wait for the block to zero out. |
|
newNumTotal := waitForStatsReset() |
|
if newNumTotal != 0 { |
|
t.Fatalf("expected total to drop to %d but got %d", 0, newNumTotal) |
|
} |
|
}() |
|
} |
|
|
|
func TestIntervalOverlap(t *testing.T) { |
|
cases := []struct { |
|
min1, max1, min2, max2 int |
|
overlap bool |
|
}{ |
|
{0, 0, 0, 0, true}, |
|
{1, 1, 1, 1, true}, |
|
{1, 3, 1, 3, true}, // same |
|
{1, 3, 4, 6, false}, // serial |
|
{1, 4, 3, 6, true}, // inner overlap |
|
{1, 6, 3, 4, true}, // nest |
|
} |
|
|
|
for _, tc := range cases { |
|
t.Run(fmt.Sprintf("%d:%d vs %d:%d", tc.min1, tc.max1, tc.min2, tc.max2), func(t *testing.T) { |
|
if tc.overlap != intervalOverlap(tc.min1, tc.max1, tc.min2, tc.max2) { // 1 vs 2 |
|
t.Fatalf("expected %v but got %v", tc.overlap, !tc.overlap) |
|
} |
|
if tc.overlap != intervalOverlap(tc.min2, tc.max2, tc.min1, tc.max1) { // 2 vs 1 |
|
t.Fatalf("expected %v but got %v", tc.overlap, !tc.overlap) |
|
} |
|
}) |
|
} |
|
}
|
|
|