mirror of https://github.com/k3s-io/k3s
210 lines
6.5 KiB
Go
210 lines
6.5 KiB
Go
|
package runtime
|
||
|
|
||
|
import (
|
||
|
"errors"
|
||
|
"fmt"
|
||
|
"io"
|
||
|
"net/http"
|
||
|
"net/textproto"
|
||
|
|
||
|
"context"
|
||
|
"github.com/golang/protobuf/proto"
|
||
|
"github.com/grpc-ecosystem/grpc-gateway/internal"
|
||
|
"google.golang.org/grpc/grpclog"
|
||
|
)
|
||
|
|
||
|
var errEmptyResponse = errors.New("empty response")
|
||
|
|
||
|
// ForwardResponseStream forwards the stream from gRPC server to REST client.
|
||
|
func ForwardResponseStream(ctx context.Context, mux *ServeMux, marshaler Marshaler, w http.ResponseWriter, req *http.Request, recv func() (proto.Message, error), opts ...func(context.Context, http.ResponseWriter, proto.Message) error) {
|
||
|
f, ok := w.(http.Flusher)
|
||
|
if !ok {
|
||
|
grpclog.Infof("Flush not supported in %T", w)
|
||
|
http.Error(w, "unexpected type of web server", http.StatusInternalServerError)
|
||
|
return
|
||
|
}
|
||
|
|
||
|
md, ok := ServerMetadataFromContext(ctx)
|
||
|
if !ok {
|
||
|
grpclog.Infof("Failed to extract ServerMetadata from context")
|
||
|
http.Error(w, "unexpected error", http.StatusInternalServerError)
|
||
|
return
|
||
|
}
|
||
|
handleForwardResponseServerMetadata(w, mux, md)
|
||
|
|
||
|
w.Header().Set("Transfer-Encoding", "chunked")
|
||
|
w.Header().Set("Content-Type", marshaler.ContentType())
|
||
|
if err := handleForwardResponseOptions(ctx, w, nil, opts); err != nil {
|
||
|
HTTPError(ctx, mux, marshaler, w, req, err)
|
||
|
return
|
||
|
}
|
||
|
|
||
|
var delimiter []byte
|
||
|
if d, ok := marshaler.(Delimited); ok {
|
||
|
delimiter = d.Delimiter()
|
||
|
} else {
|
||
|
delimiter = []byte("\n")
|
||
|
}
|
||
|
|
||
|
var wroteHeader bool
|
||
|
for {
|
||
|
resp, err := recv()
|
||
|
if err == io.EOF {
|
||
|
return
|
||
|
}
|
||
|
if err != nil {
|
||
|
handleForwardResponseStreamError(ctx, wroteHeader, marshaler, w, req, mux, err)
|
||
|
return
|
||
|
}
|
||
|
if err := handleForwardResponseOptions(ctx, w, resp, opts); err != nil {
|
||
|
handleForwardResponseStreamError(ctx, wroteHeader, marshaler, w, req, mux, err)
|
||
|
return
|
||
|
}
|
||
|
|
||
|
buf, err := marshaler.Marshal(streamChunk(ctx, resp, mux.streamErrorHandler))
|
||
|
if err != nil {
|
||
|
grpclog.Infof("Failed to marshal response chunk: %v", err)
|
||
|
handleForwardResponseStreamError(ctx, wroteHeader, marshaler, w, req, mux, err)
|
||
|
return
|
||
|
}
|
||
|
if _, err = w.Write(buf); err != nil {
|
||
|
grpclog.Infof("Failed to send response chunk: %v", err)
|
||
|
return
|
||
|
}
|
||
|
wroteHeader = true
|
||
|
if _, err = w.Write(delimiter); err != nil {
|
||
|
grpclog.Infof("Failed to send delimiter chunk: %v", err)
|
||
|
return
|
||
|
}
|
||
|
f.Flush()
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func handleForwardResponseServerMetadata(w http.ResponseWriter, mux *ServeMux, md ServerMetadata) {
|
||
|
for k, vs := range md.HeaderMD {
|
||
|
if h, ok := mux.outgoingHeaderMatcher(k); ok {
|
||
|
for _, v := range vs {
|
||
|
w.Header().Add(h, v)
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func handleForwardResponseTrailerHeader(w http.ResponseWriter, md ServerMetadata) {
|
||
|
for k := range md.TrailerMD {
|
||
|
tKey := textproto.CanonicalMIMEHeaderKey(fmt.Sprintf("%s%s", MetadataTrailerPrefix, k))
|
||
|
w.Header().Add("Trailer", tKey)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func handleForwardResponseTrailer(w http.ResponseWriter, md ServerMetadata) {
|
||
|
for k, vs := range md.TrailerMD {
|
||
|
tKey := fmt.Sprintf("%s%s", MetadataTrailerPrefix, k)
|
||
|
for _, v := range vs {
|
||
|
w.Header().Add(tKey, v)
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// responseBody interface contains method for getting field for marshaling to the response body
|
||
|
// this method is generated for response struct from the value of `response_body` in the `google.api.HttpRule`
|
||
|
type responseBody interface {
|
||
|
XXX_ResponseBody() interface{}
|
||
|
}
|
||
|
|
||
|
// ForwardResponseMessage forwards the message "resp" from gRPC server to REST client.
|
||
|
func ForwardResponseMessage(ctx context.Context, mux *ServeMux, marshaler Marshaler, w http.ResponseWriter, req *http.Request, resp proto.Message, opts ...func(context.Context, http.ResponseWriter, proto.Message) error) {
|
||
|
md, ok := ServerMetadataFromContext(ctx)
|
||
|
if !ok {
|
||
|
grpclog.Infof("Failed to extract ServerMetadata from context")
|
||
|
}
|
||
|
|
||
|
handleForwardResponseServerMetadata(w, mux, md)
|
||
|
handleForwardResponseTrailerHeader(w, md)
|
||
|
|
||
|
contentType := marshaler.ContentType()
|
||
|
// Check marshaler on run time in order to keep backwards compatability
|
||
|
// An interface param needs to be added to the ContentType() function on
|
||
|
// the Marshal interface to be able to remove this check
|
||
|
if httpBodyMarshaler, ok := marshaler.(*HTTPBodyMarshaler); ok {
|
||
|
contentType = httpBodyMarshaler.ContentTypeFromMessage(resp)
|
||
|
}
|
||
|
w.Header().Set("Content-Type", contentType)
|
||
|
|
||
|
if err := handleForwardResponseOptions(ctx, w, resp, opts); err != nil {
|
||
|
HTTPError(ctx, mux, marshaler, w, req, err)
|
||
|
return
|
||
|
}
|
||
|
var buf []byte
|
||
|
var err error
|
||
|
if rb, ok := resp.(responseBody); ok {
|
||
|
buf, err = marshaler.Marshal(rb.XXX_ResponseBody())
|
||
|
} else {
|
||
|
buf, err = marshaler.Marshal(resp)
|
||
|
}
|
||
|
if err != nil {
|
||
|
grpclog.Infof("Marshal error: %v", err)
|
||
|
HTTPError(ctx, mux, marshaler, w, req, err)
|
||
|
return
|
||
|
}
|
||
|
|
||
|
if _, err = w.Write(buf); err != nil {
|
||
|
grpclog.Infof("Failed to write response: %v", err)
|
||
|
}
|
||
|
|
||
|
handleForwardResponseTrailer(w, md)
|
||
|
}
|
||
|
|
||
|
func handleForwardResponseOptions(ctx context.Context, w http.ResponseWriter, resp proto.Message, opts []func(context.Context, http.ResponseWriter, proto.Message) error) error {
|
||
|
if len(opts) == 0 {
|
||
|
return nil
|
||
|
}
|
||
|
for _, opt := range opts {
|
||
|
if err := opt(ctx, w, resp); err != nil {
|
||
|
grpclog.Infof("Error handling ForwardResponseOptions: %v", err)
|
||
|
return err
|
||
|
}
|
||
|
}
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
func handleForwardResponseStreamError(ctx context.Context, wroteHeader bool, marshaler Marshaler, w http.ResponseWriter, req *http.Request, mux *ServeMux, err error) {
|
||
|
serr := streamError(ctx, mux.streamErrorHandler, err)
|
||
|
if !wroteHeader {
|
||
|
w.WriteHeader(int(serr.HttpCode))
|
||
|
}
|
||
|
buf, merr := marshaler.Marshal(errorChunk(serr))
|
||
|
if merr != nil {
|
||
|
grpclog.Infof("Failed to marshal an error: %v", merr)
|
||
|
return
|
||
|
}
|
||
|
if _, werr := w.Write(buf); werr != nil {
|
||
|
grpclog.Infof("Failed to notify error to client: %v", werr)
|
||
|
return
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// streamChunk returns a chunk in a response stream for the given result. The
|
||
|
// given errHandler is used to render an error chunk if result is nil.
|
||
|
func streamChunk(ctx context.Context, result proto.Message, errHandler StreamErrorHandlerFunc) map[string]proto.Message {
|
||
|
if result == nil {
|
||
|
return errorChunk(streamError(ctx, errHandler, errEmptyResponse))
|
||
|
}
|
||
|
return map[string]proto.Message{"result": result}
|
||
|
}
|
||
|
|
||
|
// streamError returns the payload for the final message in a response stream
|
||
|
// that represents the given err.
|
||
|
func streamError(ctx context.Context, errHandler StreamErrorHandlerFunc, err error) *StreamError {
|
||
|
serr := errHandler(ctx, err)
|
||
|
if serr != nil {
|
||
|
return serr
|
||
|
}
|
||
|
// TODO: log about misbehaving stream error handler?
|
||
|
return DefaultHTTPStreamErrorHandler(ctx, err)
|
||
|
}
|
||
|
|
||
|
func errorChunk(err *StreamError) map[string]proto.Message {
|
||
|
return map[string]proto.Message{"error": (*internal.StreamError)(err)}
|
||
|
}
|