mirror of https://github.com/k3s-io/k3s
commit
4a6ba9dae2
|
@ -17,6 +17,7 @@ limitations under the License.
|
|||
package proxy
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"compress/gzip"
|
||||
"context"
|
||||
|
@ -538,6 +539,68 @@ func TestProxyUpgradeErrorResponse(t *testing.T) {
|
|||
assert.Contains(t, string(msg), expectedErr.Error())
|
||||
}
|
||||
|
||||
func TestProxyUpgradeErrorResponseTerminates(t *testing.T) {
|
||||
for _, intercept := range []bool{true, false} {
|
||||
for _, code := range []int{200, 400, 500} {
|
||||
t.Run(fmt.Sprintf("intercept=%v,code=%v", intercept, code), func(t *testing.T) {
|
||||
// Set up a backend server
|
||||
backend := http.NewServeMux()
|
||||
backend.Handle("/hello", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(code)
|
||||
w.Write([]byte(`some data`))
|
||||
}))
|
||||
backend.Handle("/there", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
t.Error("request to /there")
|
||||
}))
|
||||
backendServer := httptest.NewServer(backend)
|
||||
defer backendServer.Close()
|
||||
backendServerURL, _ := url.Parse(backendServer.URL)
|
||||
backendServerURL.Path = "/hello"
|
||||
|
||||
// Set up a proxy pointing to a specific path on the backend
|
||||
proxyHandler := NewUpgradeAwareHandler(backendServerURL, nil, false, false, &noErrorsAllowed{t: t})
|
||||
proxyHandler.InterceptRedirects = intercept
|
||||
proxy := httptest.NewServer(proxyHandler)
|
||||
defer proxy.Close()
|
||||
proxyURL, _ := url.Parse(proxy.URL)
|
||||
|
||||
conn, err := net.Dial("tcp", proxyURL.Host)
|
||||
require.NoError(t, err)
|
||||
bufferedReader := bufio.NewReader(conn)
|
||||
|
||||
// Send upgrade request resulting in a non-101 response from the backend
|
||||
req, _ := http.NewRequest("GET", "/", nil)
|
||||
req.Header.Set(httpstream.HeaderConnection, httpstream.HeaderUpgrade)
|
||||
require.NoError(t, req.Write(conn))
|
||||
// Verify we get the correct response and full message body content
|
||||
resp, err := http.ReadResponse(bufferedReader, nil)
|
||||
require.NoError(t, err)
|
||||
data, err := ioutil.ReadAll(resp.Body)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, resp.StatusCode, code)
|
||||
require.Equal(t, data, []byte(`some data`))
|
||||
resp.Body.Close()
|
||||
|
||||
// try to read from the connection to verify it was closed
|
||||
b := make([]byte, 1)
|
||||
conn.SetReadDeadline(time.Now().Add(time.Second))
|
||||
if _, err := conn.Read(b); err != io.EOF {
|
||||
t.Errorf("expected EOF, got %v", err)
|
||||
}
|
||||
|
||||
// Send another request to another endpoint to verify it is not received
|
||||
req, _ = http.NewRequest("GET", "/there", nil)
|
||||
req.Write(conn)
|
||||
// wait to ensure the handler does not receive the request
|
||||
time.Sleep(time.Second)
|
||||
|
||||
// clean up
|
||||
conn.Close()
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultProxyTransport(t *testing.T) {
|
||||
tests := []struct {
|
||||
name,
|
||||
|
|
Loading…
Reference in New Issue