Don't include user data in CRI streaming redirect URLs

pull/6/head
Tim St. Clair 2016-12-13 18:27:05 -08:00
parent 330c922706
commit c17f3ee367
No known key found for this signature in database
GPG Key ID: 434D16BCEF479EAB
8 changed files with 650 additions and 188 deletions

View File

@ -606,7 +606,7 @@ func (s *Server) getAttach(request *restful.Request, response *restful.Response)
podFullName := kubecontainer.GetPodFullName(pod)
redirect, err := s.host.GetAttach(podFullName, params.podUID, params.containerName, *streamOpts)
if err != nil {
response.WriteError(streaming.HTTPStatus(err), err)
streaming.WriteError(err, response.ResponseWriter)
return
}
if redirect != nil {
@ -644,7 +644,7 @@ func (s *Server) getExec(request *restful.Request, response *restful.Response) {
podFullName := kubecontainer.GetPodFullName(pod)
redirect, err := s.host.GetExec(podFullName, params.podUID, params.containerName, params.cmd, *streamOpts)
if err != nil {
response.WriteError(streaming.HTTPStatus(err), err)
streaming.WriteError(err, response.ResponseWriter)
return
}
if redirect != nil {
@ -714,7 +714,7 @@ func (s *Server) getPortForward(request *restful.Request, response *restful.Resp
redirect, err := s.host.GetPortForward(pod.Name, pod.Namespace, pod.UID)
if err != nil {
response.WriteError(streaming.HTTPStatus(err), err)
streaming.WriteError(err, response.ResponseWriter)
return
}
if redirect != nil {

View File

@ -12,14 +12,15 @@ go_library(
name = "go_default_library",
srcs = [
"errors.go",
"request_cache.go",
"server.go",
],
tags = ["automanaged"],
deps = [
"//pkg/api:go_default_library",
"//pkg/kubelet/api/v1alpha1/runtime:go_default_library",
"//pkg/kubelet/server/portforward:go_default_library",
"//pkg/kubelet/server/remotecommand:go_default_library",
"//pkg/util/clock:go_default_library",
"//pkg/util/term:go_default_library",
"//vendor:github.com/emicklei/go-restful",
"//vendor:google.golang.org/grpc",
@ -30,7 +31,10 @@ go_library(
go_test(
name = "go_default_test",
srcs = ["server_test.go"],
srcs = [
"request_cache_test.go",
"server_test.go",
],
library = ":go_default_library",
tags = ["automanaged"],
deps = [
@ -43,6 +47,7 @@ go_test(
"//vendor:github.com/stretchr/testify/assert",
"//vendor:github.com/stretchr/testify/require",
"//vendor:k8s.io/client-go/pkg/api",
"//vendor:k8s.io/client-go/pkg/util/clock",
],
)

View File

@ -19,6 +19,7 @@ package streaming
import (
"fmt"
"net/http"
"strconv"
"time"
"google.golang.org/grpc"
@ -33,12 +34,27 @@ func ErrorTimeout(op string, timeout time.Duration) error {
return grpc.Errorf(codes.DeadlineExceeded, fmt.Sprintf("%s timed out after %s", op, timeout.String()))
}
// Translates a CRI streaming error into an HTTP status code.
func HTTPStatus(err error) int {
// The error returned when the maximum number of in-flight requests is exceeded.
func ErrorTooManyInFlight() error {
return grpc.Errorf(codes.ResourceExhausted, "maximum number of in-flight requests exceeded")
}
// Translates a CRI streaming error into an appropriate HTTP response.
func WriteError(err error, w http.ResponseWriter) error {
var status int
switch grpc.Code(err) {
case codes.NotFound:
return http.StatusNotFound
status = http.StatusNotFound
case codes.ResourceExhausted:
// We only expect to hit this if there is a DoS, so we just wait the full TTL.
// If this is ever hit in steady-state operations, consider increasing the MaxInFlight requests,
// or plumbing through the time to next expiration.
w.Header().Set("Retry-After", strconv.Itoa(int(CacheTTL.Seconds())))
status = http.StatusTooManyRequests
default:
return http.StatusInternalServerError
status = http.StatusInternalServerError
}
w.WriteHeader(status)
_, writeErr := w.Write([]byte(err.Error()))
return writeErr
}

View File

@ -0,0 +1,146 @@
/*
Copyright 2016 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package streaming
import (
"container/list"
"crypto/rand"
"encoding/base64"
"fmt"
"math"
"sync"
"time"
"k8s.io/kubernetes/pkg/util/clock"
)
var (
// Timeout after which tokens become invalid.
CacheTTL = 1 * time.Minute
// The maximum number of in-flight requests to allow.
MaxInFlight = 1000
// Length of the random base64 encoded token identifying the request.
TokenLen = 8
)
// requestCache caches streaming (exec/attach/port-forward) requests and generates a single-use
// random token for their retrieval. The requestCache is used for building streaming URLs without
// the need to encode every request parameter in the URL.
type requestCache struct {
// clock is used to obtain the current time
clock clock.Clock
// tokens maps the generate token to the request for fast retrieval.
tokens map[string]*list.Element
// ll maintains an age-ordered request list for faster garbage collection of expired requests.
ll *list.List
lock sync.Mutex
}
// Type representing an *ExecRequest, *AttachRequest, or *PortForwardRequest.
type request interface{}
type cacheEntry struct {
token string
req request
expireTime time.Time
}
func newRequestCache() *requestCache {
return &requestCache{
clock: clock.RealClock{},
ll: list.New(),
tokens: make(map[string]*list.Element),
}
}
// Insert the given request into the cache and returns the token used for fetching it out.
func (c *requestCache) Insert(req request) (token string, err error) {
c.lock.Lock()
defer c.lock.Unlock()
// Remove expired entries.
c.gc()
// If the cache is full, reject the request.
if c.ll.Len() == MaxInFlight {
return "", ErrorTooManyInFlight()
}
token, err = c.uniqueToken()
if err != nil {
return "", err
}
ele := c.ll.PushFront(&cacheEntry{token, req, c.clock.Now().Add(CacheTTL)})
c.tokens[token] = ele
return token, nil
}
// Consume the token (remove it from the cache) and return the cached request, if found.
func (c *requestCache) Consume(token string) (req request, found bool) {
c.lock.Lock()
defer c.lock.Unlock()
ele, ok := c.tokens[token]
if !ok {
return nil, false
}
c.ll.Remove(ele)
delete(c.tokens, token)
entry := ele.Value.(*cacheEntry)
if c.clock.Now().After(entry.expireTime) {
// Entry already expired.
return nil, false
}
return entry.req, true
}
// uniqueToken generates a random URL-safe token and ensures uniqueness.
func (c *requestCache) uniqueToken() (string, error) {
const maxTries = 10
// Number of bytes to be TokenLen when base64 encoded.
tokenSize := math.Ceil(float64(TokenLen) * 6 / 8)
rawToken := make([]byte, int(tokenSize))
for i := 0; i < maxTries; i++ {
if _, err := rand.Read(rawToken); err != nil {
return "", err
}
encoded := base64.RawURLEncoding.EncodeToString(rawToken)
token := encoded[:TokenLen]
// If it's unique, return it. Otherwise retry.
if _, exists := c.tokens[encoded]; !exists {
return token, nil
}
}
return "", fmt.Errorf("failed to generate unique token")
}
// Must be write-locked prior to calling.
func (c *requestCache) gc() {
now := c.clock.Now()
for c.ll.Len() > 0 {
oldest := c.ll.Back()
entry := oldest.Value.(*cacheEntry)
if !now.After(entry.expireTime) {
return
}
// Oldest value is expired; remove it.
c.ll.Remove(oldest)
delete(c.tokens, entry.token)
}
}

View File

@ -0,0 +1,221 @@
/*
Copyright 2016 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package streaming
import (
"net/http"
"net/http/httptest"
"strconv"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"k8s.io/client-go/pkg/util/clock"
)
func TestInsert(t *testing.T) {
c, _ := newTestCache()
// Insert normal
oldestTok, err := c.Insert(nextRequest())
require.NoError(t, err)
assert.Len(t, oldestTok, TokenLen)
assertCacheSize(t, c, 1)
// Insert until full
for i := 0; i < MaxInFlight-2; i++ {
tok, err := c.Insert(nextRequest())
require.NoError(t, err)
assert.Len(t, tok, TokenLen)
}
assertCacheSize(t, c, MaxInFlight-1)
newestReq := nextRequest()
newestTok, err := c.Insert(newestReq)
require.NoError(t, err)
assert.Len(t, newestTok, TokenLen)
assertCacheSize(t, c, MaxInFlight)
require.Contains(t, c.tokens, oldestTok, "oldest request should still be cached")
// Consume newest token.
req, ok := c.Consume(newestTok)
assert.True(t, ok, "newest request should still be cached")
assert.Equal(t, newestReq, req)
require.Contains(t, c.tokens, oldestTok, "oldest request should still be cached")
// Insert again (still full)
tok, err := c.Insert(nextRequest())
require.NoError(t, err)
assert.Len(t, tok, TokenLen)
assertCacheSize(t, c, MaxInFlight)
// Insert again (should evict)
_, err = c.Insert(nextRequest())
assert.Error(t, err, "should reject further requests")
errResponse := httptest.NewRecorder()
require.NoError(t, WriteError(err, errResponse))
assert.Equal(t, errResponse.Code, http.StatusTooManyRequests)
assert.Equal(t, strconv.Itoa(int(CacheTTL.Seconds())), errResponse.HeaderMap.Get("Retry-After"))
assertCacheSize(t, c, MaxInFlight)
_, ok = c.Consume(oldestTok)
assert.True(t, ok, "oldest request should be valid")
}
func TestConsume(t *testing.T) {
c, clock := newTestCache()
{ // Insert & consume.
req := nextRequest()
tok, err := c.Insert(req)
require.NoError(t, err)
assertCacheSize(t, c, 1)
cachedReq, ok := c.Consume(tok)
assert.True(t, ok)
assert.Equal(t, req, cachedReq)
assertCacheSize(t, c, 0)
}
{ // Insert & consume out of order
req1 := nextRequest()
tok1, err := c.Insert(req1)
require.NoError(t, err)
assertCacheSize(t, c, 1)
req2 := nextRequest()
tok2, err := c.Insert(req2)
require.NoError(t, err)
assertCacheSize(t, c, 2)
cachedReq2, ok := c.Consume(tok2)
assert.True(t, ok)
assert.Equal(t, req2, cachedReq2)
assertCacheSize(t, c, 1)
cachedReq1, ok := c.Consume(tok1)
assert.True(t, ok)
assert.Equal(t, req1, cachedReq1)
assertCacheSize(t, c, 0)
}
{ // Consume a second time
req := nextRequest()
tok, err := c.Insert(req)
require.NoError(t, err)
assertCacheSize(t, c, 1)
cachedReq, ok := c.Consume(tok)
assert.True(t, ok)
assert.Equal(t, req, cachedReq)
assertCacheSize(t, c, 0)
_, ok = c.Consume(tok)
assert.False(t, ok)
assertCacheSize(t, c, 0)
}
{ // Consume without insert
_, ok := c.Consume("fooBAR")
assert.False(t, ok)
assertCacheSize(t, c, 0)
}
{ // Consume expired
tok, err := c.Insert(nextRequest())
require.NoError(t, err)
assertCacheSize(t, c, 1)
clock.Step(2 * CacheTTL)
_, ok := c.Consume(tok)
assert.False(t, ok)
assertCacheSize(t, c, 0)
}
}
func TestGC(t *testing.T) {
c, clock := newTestCache()
// When empty
c.gc()
assertCacheSize(t, c, 0)
tok1, err := c.Insert(nextRequest())
require.NoError(t, err)
assertCacheSize(t, c, 1)
clock.Step(10 * time.Second)
tok2, err := c.Insert(nextRequest())
require.NoError(t, err)
assertCacheSize(t, c, 2)
// expired: tok1, tok2
// non-expired: tok3, tok4
clock.Step(2 * CacheTTL)
tok3, err := c.Insert(nextRequest())
require.NoError(t, err)
assertCacheSize(t, c, 1)
clock.Step(10 * time.Second)
tok4, err := c.Insert(nextRequest())
require.NoError(t, err)
assertCacheSize(t, c, 2)
_, ok := c.Consume(tok1)
assert.False(t, ok)
_, ok = c.Consume(tok2)
assert.False(t, ok)
_, ok = c.Consume(tok3)
assert.True(t, ok)
_, ok = c.Consume(tok4)
assert.True(t, ok)
// When full, nothing is expired.
for i := 0; i < MaxInFlight; i++ {
_, err := c.Insert(nextRequest())
require.NoError(t, err)
}
assertCacheSize(t, c, MaxInFlight)
// When everything is expired
clock.Step(2 * CacheTTL)
_, err = c.Insert(nextRequest())
require.NoError(t, err)
assertCacheSize(t, c, 1)
}
func newTestCache() (*requestCache, *clock.FakeClock) {
c := newRequestCache()
fakeClock := clock.NewFakeClock(time.Now())
c.clock = fakeClock
return c, fakeClock
}
func assertCacheSize(t *testing.T, cache *requestCache, expectedSize int) {
tokenLen := len(cache.tokens)
llLen := cache.ll.Len()
assert.Equal(t, tokenLen, llLen, "inconsistent cache size! len(tokens)=%d; len(ll)=%d", tokenLen, llLen)
assert.Equal(t, expectedSize, tokenLen, "unexpected cache size!")
}
var requestUID = 0
func nextRequest() interface{} {
requestUID++
return requestUID
}

View File

@ -25,10 +25,12 @@ import (
"path"
"time"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
restful "github.com/emicklei/go-restful"
"k8s.io/apimachinery/pkg/types"
"k8s.io/kubernetes/pkg/api"
runtimeapi "k8s.io/kubernetes/pkg/kubelet/api/v1alpha1/runtime"
"k8s.io/kubernetes/pkg/kubelet/server/portforward"
"k8s.io/kubernetes/pkg/kubelet/server/remotecommand"
@ -97,6 +99,7 @@ func NewServer(config Config, runtime Runtime) (Server, error) {
s := &server{
config: config,
runtime: &criAdapter{runtime},
cache: newRequestCache(),
}
if s.config.BaseURL == nil {
@ -114,9 +117,9 @@ func NewServer(config Config, runtime Runtime) (Server, error) {
path string
handler restful.RouteFunction
}{
{"/exec/{containerID}", s.serveExec},
{"/attach/{containerID}", s.serveAttach},
{"/portforward/{podSandboxID}", s.servePortForward},
{"/exec/{token}", s.serveExec},
{"/attach/{token}", s.serveAttach},
{"/portforward/{token}", s.servePortForward},
}
// If serving relative to a base path, set that here.
pathPrefix := path.Dir(s.config.BaseURL.Path)
@ -139,37 +142,45 @@ type server struct {
config Config
runtime *criAdapter
handler http.Handler
cache *requestCache
}
func (s *server) GetExec(req *runtimeapi.ExecRequest) (*runtimeapi.ExecResponse, error) {
url := s.buildURL("exec", req.GetContainerId(), streamOpts{
stdin: req.GetStdin(),
stdout: true,
stderr: !req.GetTty(), // For TTY connections, both stderr is combined with stdout.
tty: req.GetTty(),
command: req.GetCmd(),
})
if req.GetContainerId() == "" {
return nil, grpc.Errorf(codes.InvalidArgument, "missing required container_id")
}
token, err := s.cache.Insert(req)
if err != nil {
return nil, err
}
return &runtimeapi.ExecResponse{
Url: &url,
Url: s.buildURL("exec", token),
}, nil
}
func (s *server) GetAttach(req *runtimeapi.AttachRequest) (*runtimeapi.AttachResponse, error) {
url := s.buildURL("attach", req.GetContainerId(), streamOpts{
stdin: req.GetStdin(),
stdout: true,
stderr: !req.GetTty(), // For TTY connections, both stderr is combined with stdout.
tty: req.GetTty(),
})
if req.GetContainerId() == "" {
return nil, grpc.Errorf(codes.InvalidArgument, "missing required container_id")
}
token, err := s.cache.Insert(req)
if err != nil {
return nil, err
}
return &runtimeapi.AttachResponse{
Url: &url,
Url: s.buildURL("attach", token),
}, nil
}
func (s *server) GetPortForward(req *runtimeapi.PortForwardRequest) (*runtimeapi.PortForwardResponse, error) {
url := s.buildURL("portforward", req.GetPodSandboxId(), streamOpts{})
if req.GetPodSandboxId() == "" {
return nil, grpc.Errorf(codes.InvalidArgument, "missing required pod_sandbox_id")
}
token, err := s.cache.Insert(req)
if err != nil {
return nil, err
}
return &runtimeapi.PortForwardResponse{
Url: &url,
Url: s.buildURL("portforward", token),
}, nil
}
@ -200,63 +211,32 @@ func (s *server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
s.handler.ServeHTTP(w, r)
}
type streamOpts struct {
stdin bool
stdout bool
stderr bool
tty bool
command []string
port []int32
}
const (
urlParamStdin = api.ExecStdinParam
urlParamStdout = api.ExecStdoutParam
urlParamStderr = api.ExecStderrParam
urlParamTTY = api.ExecTTYParam
urlParamCommand = api.ExecCommandParamm
)
func (s *server) buildURL(method, id string, opts streamOpts) string {
loc := &url.URL{
Path: path.Join(method, id),
}
query := url.Values{}
if opts.stdin {
query.Add(urlParamStdin, "1")
}
if opts.stdout {
query.Add(urlParamStdout, "1")
}
if opts.stderr {
query.Add(urlParamStderr, "1")
}
if opts.tty {
query.Add(urlParamTTY, "1")
}
for _, c := range opts.command {
query.Add(urlParamCommand, c)
}
loc.RawQuery = query.Encode()
return s.config.BaseURL.ResolveReference(loc).String()
func (s *server) buildURL(method, token string) *string {
loc := s.config.BaseURL.ResolveReference(&url.URL{
Path: path.Join(method, token),
}).String()
return &loc
}
func (s *server) serveExec(req *restful.Request, resp *restful.Response) {
containerID := req.PathParameter("containerID")
if containerID == "" {
resp.WriteError(http.StatusBadRequest, errors.New("missing required containerID path parameter"))
token := req.PathParameter("token")
cachedRequest, ok := s.cache.Consume(token)
if !ok {
http.NotFound(resp.ResponseWriter, req.Request)
return
}
exec, ok := cachedRequest.(*runtimeapi.ExecRequest)
if !ok {
http.NotFound(resp.ResponseWriter, req.Request)
return
}
streamOpts, err := remotecommand.NewOptions(req.Request)
if err != nil {
resp.WriteError(http.StatusBadRequest, err)
return
streamOpts := &remotecommand.Options{
Stdin: exec.GetStdin(),
Stdout: true,
Stderr: !exec.GetTty(),
TTY: exec.GetTty(),
}
cmd := req.Request.URL.Query()[api.ExecCommandParamm]
remotecommand.ServeExec(
resp.ResponseWriter,
@ -264,8 +244,8 @@ func (s *server) serveExec(req *restful.Request, resp *restful.Response) {
s.runtime,
"", // unused: podName
"", // unusued: podUID
containerID,
cmd,
exec.GetContainerId(),
exec.GetCmd(),
streamOpts,
s.config.StreamIdleTimeout,
s.config.StreamCreationTimeout,
@ -273,25 +253,31 @@ func (s *server) serveExec(req *restful.Request, resp *restful.Response) {
}
func (s *server) serveAttach(req *restful.Request, resp *restful.Response) {
containerID := req.PathParameter("containerID")
if containerID == "" {
resp.WriteError(http.StatusBadRequest, errors.New("missing required containerID path parameter"))
token := req.PathParameter("token")
cachedRequest, ok := s.cache.Consume(token)
if !ok {
http.NotFound(resp.ResponseWriter, req.Request)
return
}
attach, ok := cachedRequest.(*runtimeapi.AttachRequest)
if !ok {
http.NotFound(resp.ResponseWriter, req.Request)
return
}
streamOpts, err := remotecommand.NewOptions(req.Request)
if err != nil {
resp.WriteError(http.StatusBadRequest, err)
return
streamOpts := &remotecommand.Options{
Stdin: attach.GetStdin(),
Stdout: true,
Stderr: !attach.GetTty(),
TTY: attach.GetTty(),
}
remotecommand.ServeAttach(
resp.ResponseWriter,
req.Request,
s.runtime,
"", // unused: podName
"", // unusued: podUID
containerID,
attach.GetContainerId(),
streamOpts,
s.config.StreamIdleTimeout,
s.config.StreamCreationTimeout,
@ -299,9 +285,15 @@ func (s *server) serveAttach(req *restful.Request, resp *restful.Response) {
}
func (s *server) servePortForward(req *restful.Request, resp *restful.Response) {
podSandboxID := req.PathParameter("podSandboxID")
if podSandboxID == "" {
resp.WriteError(http.StatusBadRequest, errors.New("missing required podSandboxID path parameter"))
token := req.PathParameter("token")
cachedRequest, ok := s.cache.Consume(token)
if !ok {
http.NotFound(resp.ResponseWriter, req.Request)
return
}
pf, ok := cachedRequest.(*runtimeapi.PortForwardRequest)
if !ok {
http.NotFound(resp.ResponseWriter, req.Request)
return
}
@ -309,7 +301,7 @@ func (s *server) servePortForward(req *restful.Request, resp *restful.Response)
resp.ResponseWriter,
req.Request,
s.runtime,
podSandboxID,
pf.GetPodSandboxId(),
"", // unused: podUID
s.config.StreamIdleTimeout,
s.config.StreamCreationTimeout)

View File

@ -18,12 +18,12 @@ package streaming
import (
"crypto/tls"
"fmt"
"io"
"net/http"
"net/http/httptest"
"net/url"
"strconv"
"strings"
"sync"
"testing"
@ -46,18 +46,18 @@ const (
)
func TestGetExec(t *testing.T) {
testcases := []struct {
type testcase struct {
cmd []string
tty bool
stdin bool
expectedQuery string
}{
{[]string{"echo", "foo"}, false, false, "?command=echo&command=foo&error=1&output=1"},
{[]string{"date"}, true, false, "?command=date&output=1&tty=1"},
{[]string{"date"}, false, true, "?command=date&error=1&input=1&output=1"},
{[]string{"date"}, true, true, "?command=date&input=1&output=1&tty=1"},
}
server, err := NewServer(Config{
testcases := []testcase{
{[]string{"echo", "foo"}, false, false},
{[]string{"date"}, true, false},
{[]string{"date"}, false, true},
{[]string{"date"}, true, true},
}
serv, err := NewServer(Config{
Addr: testAddr,
}, nil)
assert.NoError(t, err)
@ -79,6 +79,14 @@ func TestGetExec(t *testing.T) {
}, nil)
assert.NoError(t, err)
assertRequestToken := func(test testcase, cache *requestCache, token string) {
req, ok := cache.Consume(token)
require.True(t, ok, "token %s not found! testcase=%+v", token, test)
assert.Equal(t, testContainerID, req.(*runtimeapi.ExecRequest).GetContainerId(), "testcase=%+v", test)
assert.Equal(t, test.cmd, req.(*runtimeapi.ExecRequest).GetCmd(), "testcase=%+v", test)
assert.Equal(t, test.tty, req.(*runtimeapi.ExecRequest).GetTty(), "testcase=%+v", test)
assert.Equal(t, test.stdin, req.(*runtimeapi.ExecRequest).GetStdin(), "testcase=%+v", test)
}
containerID := testContainerID
for _, test := range testcases {
request := &runtimeapi.ExecRequest{
@ -87,38 +95,47 @@ func TestGetExec(t *testing.T) {
Tty: &test.tty,
Stdin: &test.stdin,
}
// Non-TLS
resp, err := server.GetExec(request)
{ // Non-TLS
resp, err := serv.GetExec(request)
assert.NoError(t, err, "testcase=%+v", test)
expectedURL := "http://" + testAddr + "/exec/" + testContainerID + test.expectedQuery
assert.Equal(t, expectedURL, resp.GetUrl(), "testcase=%+v", test)
expectedURL := "http://" + testAddr + "/exec/"
assert.Contains(t, resp.GetUrl(), expectedURL, "testcase=%+v", test)
token := strings.TrimPrefix(resp.GetUrl(), expectedURL)
assertRequestToken(test, serv.(*server).cache, token)
}
// TLS
resp, err = tlsServer.GetExec(request)
{ // TLS
resp, err := tlsServer.GetExec(request)
assert.NoError(t, err, "testcase=%+v", test)
expectedURL = "https://" + testAddr + "/exec/" + testContainerID + test.expectedQuery
assert.Equal(t, expectedURL, resp.GetUrl(), "testcase=%+v", test)
expectedURL := "https://" + testAddr + "/exec/"
assert.Contains(t, resp.GetUrl(), expectedURL, "testcase=%+v", test)
token := strings.TrimPrefix(resp.GetUrl(), expectedURL)
assertRequestToken(test, tlsServer.(*server).cache, token)
}
// Path prefix
resp, err = prefixServer.GetExec(request)
{ // Path prefix
resp, err := prefixServer.GetExec(request)
assert.NoError(t, err, "testcase=%+v", test)
expectedURL = "http://" + testAddr + "/" + pathPrefix + "/exec/" + testContainerID + test.expectedQuery
assert.Equal(t, expectedURL, resp.GetUrl(), "testcase=%+v", test)
expectedURL := "http://" + testAddr + "/" + pathPrefix + "/exec/"
assert.Contains(t, resp.GetUrl(), expectedURL, "testcase=%+v", test)
token := strings.TrimPrefix(resp.GetUrl(), expectedURL)
assertRequestToken(test, prefixServer.(*server).cache, token)
}
}
}
func TestGetAttach(t *testing.T) {
testcases := []struct {
type testcase struct {
tty bool
stdin bool
expectedQuery string
}{
{false, false, "?error=1&output=1"},
{true, false, "?output=1&tty=1"},
{false, true, "?error=1&input=1&output=1"},
{true, true, "?input=1&output=1&tty=1"},
}
server, err := NewServer(Config{
testcases := []testcase{
{false, false},
{true, false},
{false, true},
{true, true},
}
serv, err := NewServer(Config{
Addr: testAddr,
}, nil)
assert.NoError(t, err)
@ -129,6 +146,13 @@ func TestGetAttach(t *testing.T) {
}, nil)
assert.NoError(t, err)
assertRequestToken := func(test testcase, cache *requestCache, token string) {
req, ok := cache.Consume(token)
require.True(t, ok, "token %s not found! testcase=%+v", token, test)
assert.Equal(t, testContainerID, req.(*runtimeapi.AttachRequest).GetContainerId(), "testcase=%+v", test)
assert.Equal(t, test.tty, req.(*runtimeapi.AttachRequest).GetTty(), "testcase=%+v", test)
assert.Equal(t, test.stdin, req.(*runtimeapi.AttachRequest).GetStdin(), "testcase=%+v", test)
}
containerID := testContainerID
for _, test := range testcases {
request := &runtimeapi.AttachRequest{
@ -136,17 +160,23 @@ func TestGetAttach(t *testing.T) {
Stdin: &test.stdin,
Tty: &test.tty,
}
// Non-TLS
resp, err := server.GetAttach(request)
{ // Non-TLS
resp, err := serv.GetAttach(request)
assert.NoError(t, err, "testcase=%+v", test)
expectedURL := "http://" + testAddr + "/attach/" + testContainerID + test.expectedQuery
assert.Equal(t, expectedURL, resp.GetUrl(), "testcase=%+v", test)
expectedURL := "http://" + testAddr + "/attach/"
assert.Contains(t, resp.GetUrl(), expectedURL, "testcase=%+v", test)
token := strings.TrimPrefix(resp.GetUrl(), expectedURL)
assertRequestToken(test, serv.(*server).cache, token)
}
// TLS
resp, err = tlsServer.GetAttach(request)
{ // TLS
resp, err := tlsServer.GetAttach(request)
assert.NoError(t, err, "testcase=%+v", test)
expectedURL = "https://" + testAddr + "/attach/" + testContainerID + test.expectedQuery
assert.Equal(t, expectedURL, resp.GetUrl(), "testcase=%+v", test)
expectedURL := "https://" + testAddr + "/attach/"
assert.Contains(t, resp.GetUrl(), expectedURL, "testcase=%+v", test)
token := strings.TrimPrefix(resp.GetUrl(), expectedURL)
assertRequestToken(test, tlsServer.(*server).cache, token)
}
}
}
@ -157,26 +187,36 @@ func TestGetPortForward(t *testing.T) {
Port: []int32{1, 2, 3, 4},
}
// Non-TLS
server, err := NewServer(Config{
{ // Non-TLS
serv, err := NewServer(Config{
Addr: testAddr,
}, nil)
assert.NoError(t, err)
resp, err := server.GetPortForward(request)
resp, err := serv.GetPortForward(request)
assert.NoError(t, err)
expectedURL := "http://" + testAddr + "/portforward/" + testPodSandboxID
assert.Equal(t, expectedURL, resp.GetUrl())
expectedURL := "http://" + testAddr + "/portforward/"
assert.True(t, strings.HasPrefix(resp.GetUrl(), expectedURL))
token := strings.TrimPrefix(resp.GetUrl(), expectedURL)
req, ok := serv.(*server).cache.Consume(token)
require.True(t, ok, "token %s not found!", token)
assert.Equal(t, testPodSandboxID, req.(*runtimeapi.PortForwardRequest).GetPodSandboxId())
}
// TLS
{ // TLS
tlsServer, err := NewServer(Config{
Addr: testAddr,
TLSConfig: &tls.Config{},
}, nil)
assert.NoError(t, err)
resp, err = tlsServer.GetPortForward(request)
resp, err := tlsServer.GetPortForward(request)
assert.NoError(t, err)
expectedURL = "https://" + testAddr + "/portforward/" + testPodSandboxID
assert.Equal(t, expectedURL, resp.GetUrl())
expectedURL := "https://" + testAddr + "/portforward/"
assert.True(t, strings.HasPrefix(resp.GetUrl(), expectedURL))
token := strings.TrimPrefix(resp.GetUrl(), expectedURL)
req, ok := tlsServer.(*server).cache.Consume(token)
require.True(t, ok, "token %s not found!", token)
assert.Equal(t, testPodSandboxID, req.(*runtimeapi.PortForwardRequest).GetPodSandboxId())
}
}
func TestServeExec(t *testing.T) {
@ -188,21 +228,18 @@ func TestServeAttach(t *testing.T) {
}
func TestServePortForward(t *testing.T) {
rt := newFakeRuntime(t)
s, err := NewServer(DefaultConfig, rt)
require.NoError(t, err)
testServer := httptest.NewServer(s)
s, testServer := startTestServer(t)
defer testServer.Close()
testURL, err := url.Parse(testServer.URL)
podSandboxID := testPodSandboxID
resp, err := s.GetPortForward(&runtimeapi.PortForwardRequest{
PodSandboxId: &podSandboxID,
})
require.NoError(t, err)
reqURL, err := url.Parse(resp.GetUrl())
require.NoError(t, err)
loc := &url.URL{
Scheme: testURL.Scheme,
Host: testURL.Host,
}
loc.Path = fmt.Sprintf("/%s/%s", "portforward", testPodSandboxID)
exec, err := remotecommand.NewExecutor(&restclient.Config{}, "POST", loc)
exec, err := remotecommand.NewExecutor(&restclient.Config{}, "POST", reqURL)
require.NoError(t, err)
streamConn, _, err := exec.Dial(kubeletportforward.PortForwardProtocolV1Name)
require.NoError(t, err)
@ -227,22 +264,30 @@ func TestServePortForward(t *testing.T) {
// Run the remote command test.
// commandType is either "exec" or "attach".
func runRemoteCommandTest(t *testing.T, commandType string) {
rt := newFakeRuntime(t)
s, err := NewServer(DefaultConfig, rt)
require.NoError(t, err)
testServer := httptest.NewServer(s)
s, testServer := startTestServer(t)
defer testServer.Close()
testURL, err := url.Parse(testServer.URL)
var reqURL *url.URL
stdin := true
containerID := testContainerID
switch commandType {
case "exec":
resp, err := s.GetExec(&runtimeapi.ExecRequest{
ContainerId: &containerID,
Cmd: []string{"echo"},
Stdin: &stdin,
})
require.NoError(t, err)
reqURL, err = url.Parse(resp.GetUrl())
require.NoError(t, err)
case "attach":
resp, err := s.GetAttach(&runtimeapi.AttachRequest{
ContainerId: &containerID,
Stdin: &stdin,
})
require.NoError(t, err)
reqURL, err = url.Parse(resp.GetUrl())
require.NoError(t, err)
query := url.Values{}
query.Add(urlParamStdin, "1")
query.Add(urlParamStdout, "1")
query.Add(urlParamStderr, "1")
loc := &url.URL{
Scheme: testURL.Scheme,
Host: testURL.Host,
RawQuery: query.Encode(),
}
wg := sync.WaitGroup{}
@ -254,8 +299,7 @@ func runRemoteCommandTest(t *testing.T, commandType string) {
go func() {
defer wg.Done()
loc.Path = fmt.Sprintf("/%s/%s", commandType, testContainerID)
exec, err := remotecommand.NewExecutor(&restclient.Config{}, "POST", loc)
exec, err := remotecommand.NewExecutor(&restclient.Config{}, "POST", reqURL)
require.NoError(t, err)
opts := remotecommand.StreamOptions{
@ -275,6 +319,36 @@ func runRemoteCommandTest(t *testing.T, commandType string) {
}()
wg.Wait()
// Repeat request with the same URL should be a 404.
resp, err := http.Get(reqURL.String())
require.NoError(t, err)
assert.Equal(t, http.StatusNotFound, resp.StatusCode)
}
func startTestServer(t *testing.T) (Server, *httptest.Server) {
var s Server
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
s.ServeHTTP(w, r)
}))
cleanup := true
defer func() {
if cleanup {
testServer.Close()
}
}()
testURL, err := url.Parse(testServer.URL)
require.NoError(t, err)
rt := newFakeRuntime(t)
config := DefaultConfig
config.BaseURL = testURL
s, err = NewServer(config, rt)
require.NoError(t, err)
cleanup = false // Caller must close the test server.
return s, testServer
}
const (

View File

@ -391,6 +391,14 @@ var _ = framework.KubeDescribe("Kubectl client", func() {
framework.Failf("Unexpected kubectl exec output. Wanted %q, got %q", e, a)
}
By("executing a very long command in the container")
veryLongData := make([]rune, 20000)
for i := 0; i < len(veryLongData); i++ {
veryLongData[i] = 'a'
}
execOutput = framework.RunKubectlOrDie("exec", fmt.Sprintf("--namespace=%v", ns), simplePodName, "echo", string(veryLongData))
Expect(string(veryLongData)).To(Equal(strings.TrimSpace(execOutput)), "Unexpected kubectl exec output")
By("executing a command in the container with noninteractive stdin")
execOutput = framework.NewKubectlCommand("exec", fmt.Sprintf("--namespace=%v", ns), "-i", simplePodName, "cat").
WithStdinData("abcd1234").