From 46f809d7111eea865d8ac125d4cddc3065d02008 Mon Sep 17 00:00:00 2001 From: fatedier Date: Mon, 18 Jan 2021 21:49:44 +0800 Subject: [PATCH] vhost: set DisableKeepAlives = false and fix websocket not work --- pkg/util/util/http.go | 2 +- pkg/util/vhost/http.go | 32 +++++++++++++++++++++--------- tests/ci/health/health_test.go | 36 ++++++++++++++++++++++++---------- 3 files changed, 50 insertions(+), 20 deletions(-) diff --git a/pkg/util/util/http.go b/pkg/util/util/http.go index e48ef4a..2d6089b 100644 --- a/pkg/util/util/http.go +++ b/pkg/util/util/http.go @@ -74,4 +74,4 @@ func hasPort(host string) bool { return true } return host[0] == '[' && strings.Contains(host, "]:") -} \ No newline at end of file +} diff --git a/pkg/util/vhost/http.go b/pkg/util/vhost/http.go index de105cb..ee2ab1a 100644 --- a/pkg/util/vhost/http.go +++ b/pkg/util/vhost/http.go @@ -17,6 +17,7 @@ package vhost import ( "bytes" "context" + "encoding/base64" "errors" "fmt" "log" @@ -59,20 +60,25 @@ func NewHTTPReverseProxy(option HTTPReverseProxyOptions, vhostRouter *Routers) * req.URL.Scheme = "http" url := req.Context().Value(RouteInfoURL).(string) oldHost := util.GetHostFromAddr(req.Context().Value(RouteInfoHost).(string)) - host := rp.GetRealHost(oldHost, url) - if host != "" { - req.Host = host + rc := rp.GetRouteConfig(oldHost, url) + if rc != nil { + if rc.RewriteHost != "" { + req.Host = rc.RewriteHost + } + // Set {domain}.{location} as URL host here to let http transport reuse connections. + req.URL.Host = rc.Domain + "." + base64.StdEncoding.EncodeToString([]byte(rc.Location)) + + for k, v := range rc.Headers { + req.Header.Set(k, v) + } + } else { + req.URL.Host = req.Host } - req.URL.Host = req.Host - headers := rp.GetHeaders(oldHost, url) - for k, v := range headers { - req.Header.Set(k, v) - } }, Transport: &http.Transport{ ResponseHeaderTimeout: rp.responseHeaderTimeout, - DisableKeepAlives: true, + IdleConnTimeout: 60 * time.Second, DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { url := ctx.Value(RouteInfoURL).(string) host := util.GetHostFromAddr(ctx.Value(RouteInfoHost).(string)) @@ -107,6 +113,14 @@ func (rp *HTTPReverseProxy) UnRegister(domain string, location string) { rp.vhostRouter.Del(domain, location) } +func (rp *HTTPReverseProxy) GetRouteConfig(domain string, location string) *RouteConfig { + vr, ok := rp.getVhost(domain, location) + if ok { + return vr.payload.(*RouteConfig) + } + return nil +} + func (rp *HTTPReverseProxy) GetRealHost(domain string, location string) (host string) { vr, ok := rp.getVhost(domain, location) if ok { diff --git a/tests/ci/health/health_test.go b/tests/ci/health/health_test.go index 0f48f46..0e75ad2 100644 --- a/tests/ci/health/health_test.go +++ b/tests/ci/health/health_test.go @@ -139,6 +139,7 @@ func TestHealthCheck(t *testing.T) { } httpSvc3 := mock.NewHTTPServer(15005, func(w http.ResponseWriter, r *http.Request) { + time.Sleep(time.Second) w.Write([]byte("http3")) }) err = httpSvc3.Start() @@ -147,6 +148,7 @@ func TestHealthCheck(t *testing.T) { } httpSvc4 := mock.NewHTTPServer(15006, func(w http.ResponseWriter, r *http.Request) { + time.Sleep(time.Second) w.Write([]byte("http4")) }) err = httpSvc4.Start() @@ -277,16 +279,30 @@ func TestHealthCheck(t *testing.T) { // ****** load balancing type http ****** result = make([]string, 0) - - code, body, _, err = util.SendHTTPMsg("GET", "http://127.0.0.1:14000/xxx", "test.balancing.com", nil, "") - assert.NoError(err) - assert.Equal(200, code) - result = append(result, body) - - code, body, _, err = util.SendHTTPMsg("GET", "http://127.0.0.1:14000/xxx", "test.balancing.com", nil, "") - assert.NoError(err) - assert.Equal(200, code) - result = append(result, body) + var wait sync.WaitGroup + var mu sync.Mutex + wait.Add(2) + + go func() { + defer wait.Done() + code, body, _, err := util.SendHTTPMsg("GET", "http://127.0.0.1:14000/xxx", "test.balancing.com", nil, "") + assert.NoError(err) + assert.Equal(200, code) + mu.Lock() + result = append(result, body) + mu.Unlock() + }() + + go func() { + defer wait.Done() + code, body, _, err = util.SendHTTPMsg("GET", "http://127.0.0.1:14000/xxx", "test.balancing.com", nil, "") + assert.NoError(err) + assert.Equal(200, code) + mu.Lock() + result = append(result, body) + mu.Unlock() + }() + wait.Wait() assert.Contains(result, "http3") assert.Contains(result, "http4")