From 1d3e660248a12153dfec0ea805b94e9be128b6d6 Mon Sep 17 00:00:00 2001 From: Kelsey Hightower Date: Sun, 3 Aug 2014 12:23:15 -0700 Subject: [PATCH] proxy: cleanup and minor refactoring This change includes minor refactoring and cleanup of the proxy package including the following items: * Rename source files with misspelling of round robin * Remove unnecessary and redundant comments * Update comments for clarity * Add locking when updating the round-robin index * Improve method receiver names * Rename the LoadBalance method to NextEndpoint to add clarity No changes in behaviour have been introduced. --- pkg/proxy/loadbalancer.go | 12 +- pkg/proxy/proxier.go | 95 +++++++------- pkg/proxy/proxier_test.go | 3 +- pkg/proxy/roundrobbin.go | 110 ---------------- pkg/proxy/roundrobin.go | 118 ++++++++++++++++++ ...roundrobbin_test.go => roundrobin_test.go} | 16 +-- 6 files changed, 175 insertions(+), 179 deletions(-) delete mode 100644 pkg/proxy/roundrobbin.go create mode 100644 pkg/proxy/roundrobin.go rename pkg/proxy/{roundrobbin_test.go => roundrobin_test.go} (93%) diff --git a/pkg/proxy/loadbalancer.go b/pkg/proxy/loadbalancer.go index 6771d8bfcc..f8c33673a4 100644 --- a/pkg/proxy/loadbalancer.go +++ b/pkg/proxy/loadbalancer.go @@ -14,19 +14,15 @@ See the License for the specific language governing permissions and limitations under the License. */ -// Loadbalancer interface. Implementations use loadbalancer_ naming. - package proxy import ( "net" ) -// LoadBalancer represents a load balancer that decides where to route -// the incoming services for a particular service to. +// A LoadBalancer distributes incoming requests to service endpoints. type LoadBalancer interface { - // LoadBalance takes an incoming request and figures out where to route it to. - // Determination is based on destination service (for example, 'mysql') as - // well as the source making the connection. - LoadBalance(service string, srcAddr net.Addr) (string, error) + // NextEndpoint returns the endpoint to handle a request for the given + // service and source address. + NextEndpoint(service string, srcAddr net.Addr) (string, error) } diff --git a/pkg/proxy/proxier.go b/pkg/proxy/proxier.go index a300fb5310..ee16edac63 100644 --- a/pkg/proxy/proxier.go +++ b/pkg/proxy/proxier.go @@ -30,40 +30,40 @@ import ( ) type serviceInfo struct { + name string port int - active bool listener net.Listener - lock sync.Mutex + mu sync.Mutex // protects active + active bool } -// Proxier is a simple proxy for tcp connections between a localhost:lport and services that provide -// the actual implementations. +// Proxier is a simple proxy for TCP connections between a localhost:lport +// and services that provide the actual implementations. type Proxier struct { loadBalancer LoadBalancer + mu sync.Mutex // protects serviceMap serviceMap map[string]*serviceInfo - // protects 'serviceMap' - serviceLock sync.Mutex } -// NewProxier returns a newly created and correctly initialized instance of Proxier. +// NewProxier returns a new Proxier given a LoadBalancer. func NewProxier(loadBalancer LoadBalancer) *Proxier { - return &Proxier{loadBalancer: loadBalancer, serviceMap: make(map[string]*serviceInfo)} + return &Proxier{ + loadBalancer: loadBalancer, + serviceMap: make(map[string]*serviceInfo), + } } func copyBytes(in, out *net.TCPConn) { glog.Infof("Copying from %v <-> %v <-> %v <-> %v", in.RemoteAddr(), in.LocalAddr(), out.LocalAddr(), out.RemoteAddr()) - _, err := io.Copy(in, out) - if err != nil { + if _, err := io.Copy(in, out); err != nil { glog.Errorf("I/O error: %v", err) } - in.CloseRead() out.CloseWrite() } -// proxyConnection creates a bidirectional byte shuffler. -// It copies bytes to/from each connection. +// proxyConnection proxies data bidirectionally between in and out. func proxyConnection(in, out *net.TCPConn) { glog.Infof("Creating proxy between %v <-> %v <-> %v <-> %v", in.RemoteAddr(), in.LocalAddr(), out.LocalAddr(), out.RemoteAddr()) @@ -71,39 +71,43 @@ func proxyConnection(in, out *net.TCPConn) { go copyBytes(out, in) } -// StopProxy stops a proxy for the named service. It stops the proxy loop and closes the socket. +// StopProxy stops the proxy for the named service. func (proxier *Proxier) StopProxy(service string) error { // TODO: delete from map here? info, found := proxier.getServiceInfo(service) if !found { return fmt.Errorf("unknown service: %s", service) } - info.lock.Lock() - defer info.lock.Unlock() return proxier.stopProxyInternal(info) } -// Requires that info.lock be held before calling. func (proxier *Proxier) stopProxyInternal(info *serviceInfo) error { + info.mu.Lock() + defer info.mu.Unlock() + if !info.active { + return nil + } + glog.Infof("Removing service: %s", info.name) info.active = false return info.listener.Close() } func (proxier *Proxier) getServiceInfo(service string) (*serviceInfo, bool) { - proxier.serviceLock.Lock() - defer proxier.serviceLock.Unlock() + proxier.mu.Lock() + defer proxier.mu.Unlock() info, ok := proxier.serviceMap[service] return info, ok } func (proxier *Proxier) setServiceInfo(service string, info *serviceInfo) { - proxier.serviceLock.Lock() - defer proxier.serviceLock.Unlock() + proxier.mu.Lock() + defer proxier.mu.Unlock() + info.name = service proxier.serviceMap[service] = info } -// AcceptHandler begins accepting incoming connections from listener and proxying the connections to the load-balanced endpoints. -// It never returns. +// AcceptHandler proxies incoming connections for the specified service +// to the load-balanced service endpoints. func (proxier *Proxier) AcceptHandler(service string, listener net.Listener) { info, found := proxier.getServiceInfo(service) if !found { @@ -111,31 +115,26 @@ func (proxier *Proxier) AcceptHandler(service string, listener net.Listener) { return } for { - info.lock.Lock() + info.mu.Lock() if !info.active { - info.lock.Unlock() + info.mu.Unlock() break } - info.lock.Unlock() + info.mu.Unlock() inConn, err := listener.Accept() if err != nil { glog.Errorf("Accept failed: %v", err) continue } glog.Infof("Accepted connection from: %v to %v", inConn.RemoteAddr(), inConn.LocalAddr()) - - // Figure out where this request should go. - endpoint, err := proxier.loadBalancer.LoadBalance(service, inConn.RemoteAddr()) + endpoint, err := proxier.loadBalancer.NextEndpoint(service, inConn.RemoteAddr()) if err != nil { glog.Errorf("Couldn't find an endpoint for %s %v", service, err) inConn.Close() continue } - glog.Infof("Mapped service %s to endpoint %s", service, endpoint) outConn, err := net.DialTimeout("tcp", endpoint, time.Duration(5)*time.Second) - // We basically need to take everything from inConn and send to outConn - // and anything coming from outConn needs to be sent to inConn. if err != nil { glog.Errorf("Dial failed: %v", err) inConn.Close() @@ -145,9 +144,10 @@ func (proxier *Proxier) AcceptHandler(service string, listener net.Listener) { } } -// addService starts listening for a new service on a given port. +// addService creates and registers a service proxy for the given service on +// the specified port. +// It returns the net.Listener of the service proxy. func (proxier *Proxier) addService(service string, port int) (net.Listener, error) { - // Make sure we can start listening on the port before saying all's well. l, err := net.Listen("tcp", fmt.Sprintf(":%d", port)) if err != nil { return nil, err @@ -156,7 +156,7 @@ func (proxier *Proxier) addService(service string, port int) (net.Listener, erro return l, nil } -// used to globally lock around unused ports. Only used in testing. +// used to globally lock around unused ports. Only used in testing. var unusedPortLock sync.Mutex // addService starts listening for a new service, returning the port it's using. @@ -164,7 +164,6 @@ var unusedPortLock sync.Mutex func (proxier *Proxier) addServiceOnUnusedPort(service string) (string, error) { unusedPortLock.Lock() defer unusedPortLock.Unlock() - // Make sure we can start listening on the port before saying all's well. l, err := net.Listen("tcp", ":0") if err != nil { return "", err @@ -188,24 +187,22 @@ func (proxier *Proxier) addServiceOnUnusedPort(service string) (string, error) { func (proxier *Proxier) addServiceCommon(service string, l net.Listener) { glog.Infof("Listening for %s on %s", service, l.Addr().String()) - // If that succeeds, start the accepting loop. go proxier.AcceptHandler(service, l) } -// OnUpdate receives update notices for the updated services and start listening newly added services. -// It implements "github.com/GoogleCloudPlatform/kubernetes/pkg/proxy/config".ServiceConfigHandler.OnUpdate. -func (proxier *Proxier) OnUpdate(services []api.Service) { +// OnUpdate manages the active set of service proxies. +// Active service proxies are reinitialized if found in the update set or +// shutdown if missing from the update set. +func (proxier Proxier) OnUpdate(services []api.Service) { glog.Infof("Received update notice: %+v", services) - serviceNames := util.StringSet{} - + activeServices := util.StringSet{} for _, service := range services { - serviceNames.Insert(service.ID) + activeServices.Insert(service.ID) info, exists := proxier.getServiceInfo(service.ID) if exists && info.port == service.Port { continue } if exists { - // Stop the old proxier. proxier.StopProxy(service.ID) } glog.Infof("Adding a new service %s on port %d", service.ID, service.Port) @@ -220,15 +217,11 @@ func (proxier *Proxier) OnUpdate(services []api.Service) { listener: listener, }) } - - proxier.serviceLock.Lock() - defer proxier.serviceLock.Unlock() + proxier.mu.Lock() + defer proxier.mu.Unlock() for name, info := range proxier.serviceMap { - info.lock.Lock() - if !serviceNames.Has(name) && info.active { - glog.Infof("Removing service: %s", name) + if !activeServices.Has(name) { proxier.stopProxyInternal(info) } - info.lock.Unlock() } } diff --git a/pkg/proxy/proxier_test.go b/pkg/proxy/proxier_test.go index 5a94b623fb..aa90097622 100644 --- a/pkg/proxy/proxier_test.go +++ b/pkg/proxy/proxier_test.go @@ -68,9 +68,8 @@ func testEchoConnection(t *testing.T, address, port string) { func TestProxy(t *testing.T) { port, err := echoServer(t, "127.0.0.1:0") if err != nil { - t.Fatal(err) + t.Fatalf("Unexpected error: %v", err) } - lb := NewLoadBalancerRR() lb.OnUpdate([]api.Endpoints{ {JSONBase: api.JSONBase{ID: "echo"}, Endpoints: []string{net.JoinHostPort("127.0.0.1", port)}}}) diff --git a/pkg/proxy/roundrobbin.go b/pkg/proxy/roundrobbin.go deleted file mode 100644 index b459318d10..0000000000 --- a/pkg/proxy/roundrobbin.go +++ /dev/null @@ -1,110 +0,0 @@ -/* -Copyright 2014 Google Inc. All rights reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -// RoundRobin Loadbalancer - -package proxy - -import ( - "errors" - "net" - "reflect" - "strconv" - "sync" - - "github.com/GoogleCloudPlatform/kubernetes/pkg/api" - "github.com/golang/glog" -) - -// LoadBalancerRR is a round-robin load balancer. It implements LoadBalancer. -type LoadBalancerRR struct { - lock sync.RWMutex - endpointsMap map[string][]string - rrIndex map[string]int -} - -// NewLoadBalancerRR returns a newly created and correctly initialized instance of LoadBalancerRR. -func NewLoadBalancerRR() *LoadBalancerRR { - return &LoadBalancerRR{endpointsMap: make(map[string][]string), rrIndex: make(map[string]int)} -} - -// LoadBalance selects an endpoint of the service by round-robin algorithm. -func (impl LoadBalancerRR) LoadBalance(service string, srcAddr net.Addr) (string, error) { - impl.lock.RLock() - endpoints, exists := impl.endpointsMap[service] - index := impl.rrIndex[service] - impl.lock.RUnlock() - if !exists { - return "", errors.New("no service entry for: " + service) - } - if len(endpoints) == 0 { - return "", errors.New("no endpoints for: " + service) - } - endpoint := endpoints[index] - impl.rrIndex[service] = (index + 1) % len(endpoints) - return endpoint, nil -} - -func (impl LoadBalancerRR) isValid(spec string) bool { - _, port, err := net.SplitHostPort(spec) - if err != nil { - return false - } - value, err := strconv.Atoi(port) - if err != nil { - return false - } - return value > 0 -} - -func (impl LoadBalancerRR) filterValidEndpoints(endpoints []string) []string { - var result []string - for _, spec := range endpoints { - if impl.isValid(spec) { - result = append(result, spec) - } - } - return result -} - -// OnUpdate updates the registered endpoints with the new -// endpoint information, removes the registered endpoints -// no longer present in the provided endpoints. -func (impl LoadBalancerRR) OnUpdate(endpoints []api.Endpoints) { - tmp := make(map[string]bool) - impl.lock.Lock() - defer impl.lock.Unlock() - // First update / add all new endpoints for services. - for _, value := range endpoints { - existingEndpoints, exists := impl.endpointsMap[value.ID] - validEndpoints := impl.filterValidEndpoints(value.Endpoints) - if !exists || !reflect.DeepEqual(existingEndpoints, validEndpoints) { - glog.Infof("LoadBalancerRR: Setting endpoints for %s to %+v", value.ID, value.Endpoints) - impl.endpointsMap[value.ID] = validEndpoints - // Start RR from the beginning if added or updated. - impl.rrIndex[value.ID] = 0 - } - tmp[value.ID] = true - } - // Then remove any endpoints no longer relevant - for key, value := range impl.endpointsMap { - _, exists := tmp[key] - if !exists { - glog.Infof("LoadBalancerRR: Removing endpoints for %s -> %+v", key, value) - delete(impl.endpointsMap, key) - } - } -} diff --git a/pkg/proxy/roundrobin.go b/pkg/proxy/roundrobin.go new file mode 100644 index 0000000000..c7a69d2e63 --- /dev/null +++ b/pkg/proxy/roundrobin.go @@ -0,0 +1,118 @@ +/* +Copyright 2014 Google Inc. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package proxy + +import ( + "errors" + "net" + "reflect" + "strconv" + "sync" + + "github.com/GoogleCloudPlatform/kubernetes/pkg/api" + "github.com/golang/glog" +) + +var ( + ErrMissingServiceEntry = errors.New("missing service entry") + ErrMissingEndpoints = errors.New("missing endpoints") +) + +// LoadBalancerRR is a round-robin load balancer. +type LoadBalancerRR struct { + lock sync.RWMutex + endpointsMap map[string][]string + rrIndex map[string]int +} + +// NewLoadBalancerRR returns a new LoadBalancerRR. +func NewLoadBalancerRR() *LoadBalancerRR { + return &LoadBalancerRR{ + endpointsMap: make(map[string][]string), + rrIndex: make(map[string]int), + } +} + +// NextEndpoint returns a service endpoint. +// The service endpoint is chosen using the round-robin algorithm. +func (lb LoadBalancerRR) NextEndpoint(service string, srcAddr net.Addr) (string, error) { + lb.lock.RLock() + endpoints, exists := lb.endpointsMap[service] + index := lb.rrIndex[service] + lb.lock.RUnlock() + if !exists { + return "", ErrMissingServiceEntry + } + if len(endpoints) == 0 { + return "", ErrMissingEndpoints + } + endpoint := endpoints[index] + lb.lock.Lock() + lb.rrIndex[service] = (index + 1) % len(endpoints) + lb.lock.Unlock() + return endpoint, nil +} + +func (lb LoadBalancerRR) isValid(spec string) bool { + _, port, err := net.SplitHostPort(spec) + if err != nil { + return false + } + value, err := strconv.Atoi(port) + if err != nil { + return false + } + return value > 0 +} + +func (lb LoadBalancerRR) filterValidEndpoints(endpoints []string) []string { + var result []string + for _, spec := range endpoints { + if lb.isValid(spec) { + result = append(result, spec) + } + } + return result +} + +// OnUpdate manages the registered service endpoints. +// Registered endpoints are updated if found in the update set or +// unregistered if missing from the update set. +func (lb LoadBalancerRR) OnUpdate(endpoints []api.Endpoints) { + registeredEndpoints := make(map[string]bool) + lb.lock.Lock() + defer lb.lock.Unlock() + // Update endpoints for services. + for _, endpoint := range endpoints { + existingEndpoints, exists := lb.endpointsMap[endpoint.ID] + validEndpoints := lb.filterValidEndpoints(endpoint.Endpoints) + if !exists || !reflect.DeepEqual(existingEndpoints, validEndpoints) { + glog.Infof("LoadBalancerRR: Setting endpoints for %s to %+v", endpoint.ID, endpoint.Endpoints) + lb.endpointsMap[endpoint.ID] = validEndpoints + // Reset the round-robin index. + lb.rrIndex[endpoint.ID] = 0 + } + registeredEndpoints[endpoint.ID] = true + } + // Remove endpoints missing from the update. + for k, v := range lb.endpointsMap { + if _, exists := registeredEndpoints[k]; !exists { + glog.Infof("LoadBalancerRR: Removing endpoints for %s -> %+v", k, v) + delete(lb.endpointsMap, k) + } + } +} diff --git a/pkg/proxy/roundrobbin_test.go b/pkg/proxy/roundrobin_test.go similarity index 93% rename from pkg/proxy/roundrobbin_test.go rename to pkg/proxy/roundrobin_test.go index 55f0760bcb..e4777b355f 100644 --- a/pkg/proxy/roundrobbin_test.go +++ b/pkg/proxy/roundrobin_test.go @@ -61,7 +61,7 @@ func TestLoadBalanceFailsWithNoEndpoints(t *testing.T) { loadBalancer := NewLoadBalancerRR() var endpoints []api.Endpoints loadBalancer.OnUpdate(endpoints) - endpoint, err := loadBalancer.LoadBalance("foo", nil) + endpoint, err := loadBalancer.NextEndpoint("foo", nil) if err == nil { t.Errorf("Didn't fail with non-existent service") } @@ -71,7 +71,7 @@ func TestLoadBalanceFailsWithNoEndpoints(t *testing.T) { } func expectEndpoint(t *testing.T, loadBalancer *LoadBalancerRR, service string, expected string) { - endpoint, err := loadBalancer.LoadBalance(service, nil) + endpoint, err := loadBalancer.NextEndpoint(service, nil) if err != nil { t.Errorf("Didn't find a service for %s, expected %s, failed with: %v", service, expected, err) } @@ -82,7 +82,7 @@ func expectEndpoint(t *testing.T, loadBalancer *LoadBalancerRR, service string, func TestLoadBalanceWorksWithSingleEndpoint(t *testing.T) { loadBalancer := NewLoadBalancerRR() - endpoint, err := loadBalancer.LoadBalance("foo", nil) + endpoint, err := loadBalancer.NextEndpoint("foo", nil) if err == nil || len(endpoint) != 0 { t.Errorf("Didn't fail with non-existent service") } @@ -100,7 +100,7 @@ func TestLoadBalanceWorksWithSingleEndpoint(t *testing.T) { func TestLoadBalanceWorksWithMultipleEndpoints(t *testing.T) { loadBalancer := NewLoadBalancerRR() - endpoint, err := loadBalancer.LoadBalance("foo", nil) + endpoint, err := loadBalancer.NextEndpoint("foo", nil) if err == nil || len(endpoint) != 0 { t.Errorf("Didn't fail with non-existent service") } @@ -118,7 +118,7 @@ func TestLoadBalanceWorksWithMultipleEndpoints(t *testing.T) { func TestLoadBalanceWorksWithMultipleEndpointsAndUpdates(t *testing.T) { loadBalancer := NewLoadBalancerRR() - endpoint, err := loadBalancer.LoadBalance("foo", nil) + endpoint, err := loadBalancer.NextEndpoint("foo", nil) if err == nil || len(endpoint) != 0 { t.Errorf("Didn't fail with non-existent service") } @@ -147,7 +147,7 @@ func TestLoadBalanceWorksWithMultipleEndpointsAndUpdates(t *testing.T) { endpoints[0] = api.Endpoints{JSONBase: api.JSONBase{ID: "foo"}, Endpoints: []string{}} loadBalancer.OnUpdate(endpoints) - endpoint, err = loadBalancer.LoadBalance("foo", nil) + endpoint, err = loadBalancer.NextEndpoint("foo", nil) if err == nil || len(endpoint) != 0 { t.Errorf("Didn't fail with non-existent service") } @@ -155,7 +155,7 @@ func TestLoadBalanceWorksWithMultipleEndpointsAndUpdates(t *testing.T) { func TestLoadBalanceWorksWithServiceRemoval(t *testing.T) { loadBalancer := NewLoadBalancerRR() - endpoint, err := loadBalancer.LoadBalance("foo", nil) + endpoint, err := loadBalancer.NextEndpoint("foo", nil) if err == nil || len(endpoint) != 0 { t.Errorf("Didn't fail with non-existent service") } @@ -183,7 +183,7 @@ func TestLoadBalanceWorksWithServiceRemoval(t *testing.T) { // Then update the configuration by removing foo loadBalancer.OnUpdate(endpoints[1:]) - endpoint, err = loadBalancer.LoadBalance("foo", nil) + endpoint, err = loadBalancer.NextEndpoint("foo", nil) if err == nil || len(endpoint) != 0 { t.Errorf("Didn't fail with non-existent service") }