// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: MPL-2.0
package freeport
import (
"fmt"
"io"
"net"
"testing"
"github.com/hashicorp/consul/sdk/testutil/retry"
)
func TestTakeReturn ( t * testing . T ) {
// NOTE: for global var reasons this cannot execute in parallel
// t.Parallel()
// Since this test is destructive (i.e. it leaks all ports) it means that
// any other test cases in this package will not function after it runs. To
// help out we reset the global state after we run this test.
defer reset ( )
// OK: do a simple take/return cycle to trigger the package initialization
func ( ) {
ports , err := Take ( 1 )
if err != nil {
t . Fatalf ( "err: %v" , err )
}
defer Return ( ports )
if len ( ports ) != 1 {
t . Fatalf ( "expected %d but got %d ports" , 1 , len ( ports ) )
}
} ( )
waitForStatsReset := func ( ) ( numTotal int ) {
t . Helper ( )
numTotal , numPending , numFree := stats ( )
if numTotal != numFree + numPending {
t . Fatalf ( "expected total (%d) and free+pending (%d) ports to match" , numTotal , numFree + numPending )
}
retry . Run ( t , func ( r * retry . R ) {
numTotal , numPending , numFree = stats ( )
if numPending != 0 {
r . Fatalf ( "pending is still non zero: %d" , numPending )
}
if numTotal != numFree {
r . Fatalf ( "total (%d) does not equal free (%d)" , numTotal , numFree )
}
} )
return numTotal
}
// Reset
numTotal := waitForStatsReset ( )
// --------------------
// OK: take the max
func ( ) {
ports , err := Take ( numTotal )
if err != nil {
t . Fatalf ( "err: %v" , err )
}
defer Return ( ports )
if len ( ports ) != numTotal {
t . Fatalf ( "expected %d but got %d ports" , numTotal , len ( ports ) )
}
} ( )
// Reset
numTotal = waitForStatsReset ( )
expectError := func ( expected string , got error ) {
t . Helper ( )
if got == nil {
t . Fatalf ( "expected error but was nil" )
}
if got . Error ( ) != expected {
t . Fatalf ( "expected error %q but got %q" , expected , got . Error ( ) )
}
}
// --------------------
// ERROR: take too many ports
func ( ) {
ports , err := Take ( numTotal + 1 )
defer Return ( ports )
expectError ( "freeport: block size too small" , err )
} ( )
// --------------------
// ERROR: invalid ports request (negative)
func ( ) {
_ , err := Take ( - 1 )
expectError ( "freeport: cannot take -1 ports" , err )
} ( )
// --------------------
// ERROR: invalid ports request (zero)
func ( ) {
_ , err := Take ( 0 )
expectError ( "freeport: cannot take 0 ports" , err )
} ( )
// --------------------
// OK: Steal a port under the covers and let freeport detect the theft and compensate
leakedPort := peekFree ( )
func ( ) {
leakyListener , err := net . ListenTCP ( "tcp" , tcpAddr ( "127.0.0.1" , leakedPort ) )
if err != nil {
t . Fatalf ( "err: %v" , err )
}
defer leakyListener . Close ( )
func ( ) {
ports , err := Take ( 3 )
if err != nil {
t . Fatalf ( "err: %v" , err )
}
defer Return ( ports )
if len ( ports ) != 3 {
t . Fatalf ( "expected %d but got %d ports" , 3 , len ( ports ) )
}
for _ , port := range ports {
if port == leakedPort {
t . Fatalf ( "did not expect for Take to return the leaked port" )
}
}
} ( )
newNumTotal := waitForStatsReset ( )
if newNumTotal != numTotal - 1 {
t . Fatalf ( "expected total to drop to %d but got %d" , numTotal - 1 , newNumTotal )
}
numTotal = newNumTotal // update outer variable for later tests
} ( )
// --------------------
// OK: sequence it so that one Take must wait on another Take to Return.
func ( ) {
mostPorts , err := Take ( numTotal - 5 )
if err != nil {
t . Fatalf ( "err: %v" , err )
}
type reply struct {
ports [ ] int
err error
}
ch := make ( chan reply , 1 )
go func ( ) {
ports , err := Take ( 10 )
ch <- reply { ports : ports , err : err }
} ( )
Return ( mostPorts )
r := <- ch
if r . err != nil {
t . Fatalf ( "err: %v" , r . err )
}
defer Return ( r . ports )
if len ( r . ports ) != 10 {
t . Fatalf ( "expected %d ports but got %d" , 10 , len ( r . ports ) )
}
} ( )
// Reset
numTotal = waitForStatsReset ( )
// --------------------
// ERROR: Now we end on the crazy "Ocean's 11" level port theft where we
// orchestrate a situation where all ports are stolen and we don't find out
// until Take.
func ( ) {
// 1. Grab all of the ports.
allPorts := peekAllFree ( )
// 2. Leak all of the ports
leaked := make ( [ ] io . Closer , 0 , len ( allPorts ) )
defer func ( ) {
for _ , c := range leaked {
c . Close ( )
}
} ( )
for i , port := range allPorts {
ln , err := net . ListenTCP ( "tcp" , tcpAddr ( "127.0.0.1" , port ) )
if err != nil {
t . Fatalf ( "%d err: %v" , i , err )
}
leaked = append ( leaked , ln )
}
// 3. Request 1 port which will detect the leaked ports and fail.
_ , err := Take ( 1 )
expectError ( "freeport: impossible to satisfy request; there are no actual free ports in the block anymore" , err )
// 4. Wait for the block to zero out.
newNumTotal := waitForStatsReset ( )
if newNumTotal != 0 {
t . Fatalf ( "expected total to drop to %d but got %d" , 0 , newNumTotal )
}
} ( )
}
func TestIntervalOverlap ( t * testing . T ) {
cases := [ ] struct {
min1 , max1 , min2 , max2 int
overlap bool
} {
{ 0 , 0 , 0 , 0 , true } ,
{ 1 , 1 , 1 , 1 , true } ,
{ 1 , 3 , 1 , 3 , true } , // same
{ 1 , 3 , 4 , 6 , false } , // serial
{ 1 , 4 , 3 , 6 , true } , // inner overlap
{ 1 , 6 , 3 , 4 , true } , // nest
}
for _ , tc := range cases {
t . Run ( fmt . Sprintf ( "%d:%d vs %d:%d" , tc . min1 , tc . max1 , tc . min2 , tc . max2 ) , func ( t * testing . T ) {
if tc . overlap != intervalOverlap ( tc . min1 , tc . max1 , tc . min2 , tc . max2 ) { // 1 vs 2
t . Fatalf ( "expected %v but got %v" , tc . overlap , ! tc . overlap )
}
if tc . overlap != intervalOverlap ( tc . min2 , tc . max2 , tc . min1 , tc . max1 ) { // 2 vs 1
t . Fatalf ( "expected %v but got %v" , tc . overlap , ! tc . overlap )
}
} )
}
}