From 7e661012185d16d27d9da510345fc4b5288e1927 Mon Sep 17 00:00:00 2001 From: hangaoshuai Date: Thu, 26 Jul 2018 10:53:55 +0800 Subject: [PATCH 1/2] refactor some hard code in pkg/util/ipset/ipset.go --- pkg/util/ipset/ipset.go | 108 +++++++++++++++++++--------------------- 1 file changed, 52 insertions(+), 56 deletions(-) diff --git a/pkg/util/ipset/ipset.go b/pkg/util/ipset/ipset.go index e7c7991ea9..3f0c7dcb0f 100644 --- a/pkg/util/ipset/ipset.go +++ b/pkg/util/ipset/ipset.go @@ -52,7 +52,7 @@ type Interface interface { GetVersion() (string, error) } -// IPSetCmd represents the ipset util. We use ipset command for ipset execute. +// IPSetCmd represents the ipset util. We use ipset command for ipset execute. const IPSetCmd = "ipset" // EntryMemberPattern is the regular expression pattern of ipset member list. @@ -72,7 +72,7 @@ var EntryMemberPattern = "(?m)^(.*\n)*Members:\n" // ipset version output is similar to "v6.10". var VersionPattern = "v[0-9]+\\.[0-9]+" -// IPSet implements an Interface to an set. +// IPSet implements an Interface to a set. type IPSet struct { // Name is the set name. Name string @@ -123,6 +123,28 @@ func (set *IPSet) Validate() bool { return true } +//setIPSetDefaults sets some IPSet fields if not present to their default values. +func (set *IPSet) setIPSetDefaults() { + // Setting default values if not present + if set.HashSize == 0 { + set.HashSize = 1024 + } + if set.MaxElem == 0 { + set.MaxElem = 65536 + } + // Default protocol is IPv4 + if set.HashFamily == "" { + set.HashFamily = ProtocolFamilyIPV4 + } + // Default ipset type is "hash:ip,port" + if len(set.SetType) == 0 { + set.SetType = HashIPPort + } + if len(set.PortRange) == 0 { + set.PortRange = DefaultPortRange + } +} + // Entry represents a ipset entry. type Entry struct { // IP is the entry's IP. The IP address protocol corresponds to the HashFamily of IPSet. @@ -150,31 +172,13 @@ func (e *Entry) Validate(set *IPSet) bool { } switch e.SetType { case HashIPPort: - // set default protocol to tcp if empty - if len(e.Protocol) == 0 { - e.Protocol = ProtocolTCP - } - - if net.ParseIP(e.IP) == nil { - glog.Errorf("Error parsing entry %v ip address %v for ipset %v", e, e.IP, set) - return false - } - - if valid := validateProtocol(e.Protocol); !valid { + //check if IP and Protocol of Entry is valid. + if valid := e.checkIPandProtocol(set); !valid { return false } case HashIPPortIP: - // set default protocol to tcp if empty - if len(e.Protocol) == 0 { - e.Protocol = ProtocolTCP - } - - if net.ParseIP(e.IP) == nil { - glog.Errorf("Error parsing entry %v ip address %v for ipset %v", e, e.IP, set) - return false - } - - if valid := validateProtocol(e.Protocol); !valid { + //check if IP and Protocol of Entry is valid. + if valid := e.checkIPandProtocol(set); !valid { return false } @@ -184,23 +188,14 @@ func (e *Entry) Validate(set *IPSet) bool { return false } case HashIPPortNet: - // set default protocol to tcp if empty - if len(e.Protocol) == 0 { - e.Protocol = ProtocolTCP - } - - if net.ParseIP(e.IP) == nil { - glog.Errorf("Error parsing entry %v ip address %v for ipset %v", e, e.IP, set) - return false - } - - if valid := validateProtocol(e.Protocol); !valid { + //check if IP and Protocol of Entry is valid. + if valid := e.checkIPandProtocol(set); !valid { return false } // Net can not be empty for `hash:ip,port,net` type ip set - if _, ipNet, _ := net.ParseCIDR(e.Net); ipNet == nil { - glog.Errorf("Error parsing entry %v ip net %v for ipset %v", e, e.Net, set) + if _, ipNet, err := net.ParseCIDR(e.Net); ipNet == nil { + glog.Errorf("Error parsing entry %v ip net %v for ipset %v, error: %v", e, e.Net, set, err) return false } case BitmapPort: @@ -246,6 +241,23 @@ func (e *Entry) String() string { return "" } +// checkIPandProtocol checks if IP and Protocol of Entry is valid. +func (e *Entry) checkIPandProtocol(set *IPSet) bool { + // set default protocol to tcp if empty + if len(e.Protocol) == 0 { + e.Protocol = ProtocolTCP + } else if !validateProtocol(e.Protocol) { + return false + } + + if net.ParseIP(e.IP) == nil { + glog.Errorf("Error parsing entry %v ip address %v for ipset %v", e, e.IP, set) + return false + } + + return true +} + type runner struct { exec utilexec.Interface } @@ -257,26 +269,10 @@ func New(exec utilexec.Interface) Interface { } } -// CreateSet creates a new set, it will ignore error when the set already exists if ignoreExistErr=true. +// CreateSet creates a new set, it will ignore error when the set already exists if ignoreExistErr=true. func (runner *runner) CreateSet(set *IPSet, ignoreExistErr bool) error { - // Setting default values if not present - if set.HashSize == 0 { - set.HashSize = 1024 - } - if set.MaxElem == 0 { - set.MaxElem = 65536 - } - // Default protocol is IPv4 - if set.HashFamily == "" { - set.HashFamily = ProtocolFamilyIPV4 - } - // Default ipset type is "hash:ip,port" - if len(set.SetType) == 0 { - set.SetType = HashIPPort - } - if len(set.PortRange) == 0 { - set.PortRange = DefaultPortRange - } + // sets some IPSet fields if not present to their default values. + set.setIPSetDefaults() // Validate ipset before creating valid := set.Validate() From 5dfb0a2d6095b9c783a013f6464041d129c61b93 Mon Sep 17 00:00:00 2001 From: hangaoshuai Date: Thu, 26 Jul 2018 10:56:22 +0800 Subject: [PATCH 2/2] add unit tests for checkIPandProtocol and setIPSetDefaults --- pkg/util/ipset/ipset_test.go | 144 +++++++++++++++++++++++++++++++++++ 1 file changed, 144 insertions(+) diff --git a/pkg/util/ipset/ipset_test.go b/pkg/util/ipset/ipset_test.go index 192b9beb57..c3d34bb98a 100644 --- a/pkg/util/ipset/ipset_test.go +++ b/pkg/util/ipset/ipset_test.go @@ -904,6 +904,150 @@ func TestValidateIPSet(t *testing.T) { } } +func Test_setIPSetDefaults(t *testing.T) { + testCases := []struct { + name string + set *IPSet + expect *IPSet + }{ + { + name: "test all the IPSet fields not present", + set: &IPSet{ + Name: "test1", + }, + expect: &IPSet{ + Name: "test1", + SetType: HashIPPort, + HashFamily: ProtocolFamilyIPV4, + HashSize: 1024, + MaxElem: 65536, + PortRange: DefaultPortRange, + }, + }, + { + name: "test all the IPSet fields present", + set: &IPSet{ + Name: "test2", + SetType: BitmapPort, + HashFamily: ProtocolFamilyIPV6, + HashSize: 65535, + MaxElem: 2048, + PortRange: DefaultPortRange, + }, + expect: &IPSet{ + Name: "test2", + SetType: BitmapPort, + HashFamily: ProtocolFamilyIPV6, + HashSize: 65535, + MaxElem: 2048, + PortRange: DefaultPortRange, + }, + }, + { + name: "test part of the IPSet fields present", + set: &IPSet{ + Name: "test3", + SetType: BitmapPort, + HashFamily: ProtocolFamilyIPV6, + HashSize: 65535, + }, + expect: &IPSet{ + Name: "test3", + SetType: BitmapPort, + HashFamily: ProtocolFamilyIPV6, + HashSize: 65535, + MaxElem: 65536, + PortRange: DefaultPortRange, + }, + }, + } + + for _, test := range testCases { + t.Run(test.name, func(t *testing.T) { + test.set.setIPSetDefaults() + if !reflect.DeepEqual(test.set, test.expect) { + t.Errorf("expected ipset struct: %v, got ipset struct: %v", test.expect, test.set) + } + }) + } +} + +func Test_checkIPandProtocol(t *testing.T) { + testset := &IPSet{ + Name: "test1", + SetType: HashIPPort, + HashFamily: ProtocolFamilyIPV4, + HashSize: 1024, + MaxElem: 65536, + PortRange: DefaultPortRange, + } + + testCases := []struct { + name string + entry *Entry + valid bool + }{ + { + name: "valid IP with ProtocolTCP", + entry: &Entry{ + SetType: HashIPPort, + IP: "1.2.3.4", + Protocol: ProtocolTCP, + Port: 8080, + }, + valid: true, + }, + { + name: "valid IP with ProtocolUDP", + entry: &Entry{ + SetType: HashIPPort, + IP: "1.2.3.4", + Protocol: ProtocolUDP, + Port: 8080, + }, + valid: true, + }, + { + name: "valid IP with nil Protocol", + entry: &Entry{ + SetType: HashIPPort, + IP: "1.2.3.4", + Port: 8080, + }, + valid: true, + }, + { + name: "valid IP with invalid Protocol", + entry: &Entry{ + SetType: HashIPPort, + IP: "1.2.3.4", + Protocol: "invalidProtocol", + Port: 8080, + }, + valid: false, + }, + { + name: "invalid IP with ProtocolTCP", + entry: &Entry{ + SetType: HashIPPort, + IP: "1.2.3.423", + Protocol: ProtocolTCP, + Port: 8080, + }, + valid: false, + }, + } + + for _, test := range testCases { + t.Run(test.name, func(t *testing.T) { + result := test.entry.checkIPandProtocol(testset) + if result != test.valid { + t.Errorf("expected valid: %v, got valid: %v", test.valid, result) + } + }) + } +} + func Test_parsePortRange(t *testing.T) { testCases := []struct { portRange string