diff --git a/pkg/util/ipset/ipset.go b/pkg/util/ipset/ipset.go index 608bba920c..e2971e7d24 100644 --- a/pkg/util/ipset/ipset.go +++ b/pkg/util/ipset/ipset.go @@ -88,6 +88,36 @@ type IPSet struct { PortRange string } +// Validate checks if a given ipset is valid or not. +func (set *IPSet) Validate() (bool, error) { + // Check if protocol is valid for `HashIPPort`, `HashIPPortIP` and `HashIPPortNet` type set. + if set.SetType == HashIPPort || set.SetType == HashIPPortIP || set.SetType == HashIPPortNet { + if valid := validateHashFamily(set.HashFamily); !valid { + return false, fmt.Errorf("currently supported ip set hash families are: [%s, %s], %s is not supported", ProtocolFamilyIPV4, ProtocolFamilyIPV6, set.HashFamily) + } + } + // check set type + if valid := validateIPSetType(set.SetType); !valid { + return false, fmt.Errorf("currently supported ipset types are: %v, %s is not supported", ValidIPSetTypes, set.SetType) + } + // check port range for bitmap type set + if set.SetType == BitmapPort { + if valid, err := validatePortRange(set.PortRange); !valid { + return false, err + } + } + // check hash size value of ipset + if set.HashSize <= 0 { + return false, fmt.Errorf("invalid hashsize value, should be >0") + } + // check max elem value of ipset + if set.MaxElem <= 0 { + return false, fmt.Errorf("invalid maxelem value, should be >0") + } + + return true, nil +} + // Entry represents a ipset entry. type Entry struct { // IP is the entry's IP. The IP address protocol corresponds to the HashFamily of IPSet. @@ -310,17 +340,41 @@ func getIPSetVersionString(exec utilexec.Interface) (string, error) { return match[0], nil } -func validatePortRange(portRange string) bool { +// checks if port range is valid. The begin port number is not necessarily less than +// end port number - ipset util can accept it. It means both 1-100 and 100-1 are valid. +func validatePortRange(portRange string) (bool, error) { strs := strings.Split(portRange, "-") if len(strs) != 2 { - return false + return false, fmt.Errorf("port range should be in the format of `a-b`") } for i := range strs { - if _, err := strconv.Atoi(strs[i]); err != nil { - return false + num, err := strconv.Atoi(strs[i]) + if err != nil { + return false, err + } + if num < 0 { + return false, fmt.Errorf("port number %d should be >=0", num) } } - return true + return true, nil +} + +// checks if the given ipset type is valid. +func validateIPSetType(set Type) bool { + for _, valid := range ValidIPSetTypes { + if set == valid { + return true + } + } + return false +} + +// checks if given hash family is supported in ipset +func validateHashFamily(family string) bool { + if family == ProtocolFamilyIPV4 || family == ProtocolFamilyIPV6 { + return true + } + return false } // IsNotFoundError returns true if the error indicates "not found". It parses diff --git a/pkg/util/ipset/ipset_test.go b/pkg/util/ipset/ipset_test.go index e740a1ab55..59a5712185 100644 --- a/pkg/util/ipset/ipset_test.go +++ b/pkg/util/ipset/ipset_test.go @@ -485,3 +485,172 @@ baz` t.Errorf("expected sets: %v, got: %v", expected, list) } } + +func Test_validIPSetType(t *testing.T) { + testCases := []struct { + setType Type + valid bool + }{ + { // case[0] + setType: Type("foo"), + valid: false, + }, + { // case[1] + setType: HashIPPortNet, + valid: true, + }, + { // case[2] + setType: HashIPPort, + valid: true, + }, + { // case[3] + setType: HashIPPortIP, + valid: true, + }, + { // case[4] + setType: BitmapPort, + valid: true, + }, + { // case[5] + setType: Type(""), + valid: false, + }, + } + for i := range testCases { + valid := validateIPSetType(testCases[i].setType) + if valid != testCases[i].valid { + t.Errorf("case [%d]: unexpected mismatch, expect valid[%v], got valid[%v]", i, testCases[i].valid, valid) + } + } +} + +func Test_validatePortRange(t *testing.T) { + testCases := []struct { + portRange string + valid bool + desc string + }{ + { // case[0] + portRange: "a-b", + valid: false, + desc: "invalid port number", + }, + { // case[1] + portRange: "1-2", + valid: true, + desc: "valid", + }, + { // case[2] + portRange: "90-1", + valid: true, + desc: "ipset util can accept the input of begin port number can be less than end port number", + }, + { // case[3] + portRange: DefaultPortRange, + valid: true, + desc: "default port range is valid, of course", + }, + { // case[4] + portRange: "12", + valid: false, + desc: "a single number is invalid", + }, + { // case[5] + portRange: "1-", + valid: false, + desc: "should specify end port", + }, + { // case[6] + portRange: "-100", + valid: false, + desc: "should specify begin port", + }, + { // case[7] + portRange: "1:100", + valid: false, + desc: "delimiter should be -", + }, + { // case[8] + portRange: "1~100", + valid: false, + desc: "delimiter should be -", + }, + { // case[9] + portRange: "1,100", + valid: false, + desc: "delimiter should be -", + }, + { // case[10] + portRange: "100-100", + valid: true, + desc: "begin port number can be equal to end port number", + }, + { // case[11] + portRange: "", + valid: false, + desc: "empty string is invalid", + }, + { // case[12] + portRange: "-1-12", + valid: false, + desc: "port number can not be negative value", + }, + { // case[13] + portRange: "-1--8", + valid: false, + desc: "port number can not be negative value", + }, + } + for i := range testCases { + valid, _ := validatePortRange(testCases[i].portRange) + if valid != testCases[i].valid { + t.Errorf("case [%d]: unexpected mismatch, expect valid[%v], got valid[%v], desc: %s", i, testCases[i].valid, valid, testCases[i].desc) + } + } +} + +func Test_validateFamily(t *testing.T) { + testCases := []struct { + family string + valid bool + }{ + { // case[0] + family: "foo", + valid: false, + }, + { // case[1] + family: ProtocolFamilyIPV4, + valid: true, + }, + { // case[2] + family: ProtocolFamilyIPV6, + valid: true, + }, + { // case[3] + family: "ipv4", + valid: false, + }, + { // case[4] + family: "ipv6", + valid: false, + }, + { // case[5] + family: "tcp", + valid: false, + }, + { // case[6] + family: "udp", + valid: false, + }, + { // case[7] + family: "", + valid: false, + }, + } + for i := range testCases { + valid := validateHashFamily(testCases[i].family) + if valid != testCases[i].valid { + t.Errorf("case [%d]: unexpected mismatch, expect valid[%v], got valid[%v]", i, testCases[i].valid, valid) + } + } +}