|
|
|
@ -3,6 +3,7 @@ package http
|
|
|
|
|
//go:generate errorgen
|
|
|
|
|
|
|
|
|
|
import ( |
|
|
|
|
"bufio" |
|
|
|
|
"bytes" |
|
|
|
|
"context" |
|
|
|
|
"io" |
|
|
|
@ -28,6 +29,8 @@ const (
|
|
|
|
|
|
|
|
|
|
var ( |
|
|
|
|
ErrHeaderToLong = newError("Header too long.") |
|
|
|
|
|
|
|
|
|
ErrHeaderMisMatch = newError("Header Mismatch.") |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
type Reader interface { |
|
|
|
@ -51,12 +54,22 @@ func (NoOpWriter) Write(io.Writer) error {
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
type HeaderReader struct { |
|
|
|
|
req *http.Request |
|
|
|
|
expectedHeader *RequestConfig |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
func (h *HeaderReader) ExpectThisRequest(expectedHeader *RequestConfig) *HeaderReader { |
|
|
|
|
h.expectedHeader = expectedHeader |
|
|
|
|
return h |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
func (*HeaderReader) Read(reader io.Reader) (*buf.Buffer, error) { |
|
|
|
|
func (h *HeaderReader) Read(reader io.Reader) (*buf.Buffer, error) { |
|
|
|
|
buffer := buf.New() |
|
|
|
|
totalBytes := int32(0) |
|
|
|
|
endingDetected := false |
|
|
|
|
|
|
|
|
|
var headerBuf bytes.Buffer |
|
|
|
|
|
|
|
|
|
for totalBytes < maxHeaderLength { |
|
|
|
|
_, err := buffer.ReadFrom(reader) |
|
|
|
|
if err != nil { |
|
|
|
@ -64,6 +77,7 @@ func (*HeaderReader) Read(reader io.Reader) (*buf.Buffer, error) {
|
|
|
|
|
return nil, err |
|
|
|
|
} |
|
|
|
|
if n := bytes.Index(buffer.Bytes(), []byte(ENDING)); n != -1 { |
|
|
|
|
headerBuf.Write(buffer.BytesRange(0, int32(n+len(ENDING)))) |
|
|
|
|
buffer.Advance(int32(n + len(ENDING))) |
|
|
|
|
endingDetected = true |
|
|
|
|
break |
|
|
|
@ -71,19 +85,52 @@ func (*HeaderReader) Read(reader io.Reader) (*buf.Buffer, error) {
|
|
|
|
|
lenEnding := int32(len(ENDING)) |
|
|
|
|
if buffer.Len() >= lenEnding { |
|
|
|
|
totalBytes += buffer.Len() - lenEnding |
|
|
|
|
headerBuf.Write(buffer.BytesRange(0, buffer.Len()-lenEnding)) |
|
|
|
|
leftover := buffer.BytesFrom(-lenEnding) |
|
|
|
|
buffer.Clear() |
|
|
|
|
copy(buffer.Extend(lenEnding), leftover) |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
if buffer.IsEmpty() { |
|
|
|
|
buffer.Release() |
|
|
|
|
return nil, nil |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
if !endingDetected { |
|
|
|
|
buffer.Release() |
|
|
|
|
return nil, ErrHeaderToLong |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
if h.expectedHeader == nil { |
|
|
|
|
if buffer.IsEmpty() { |
|
|
|
|
buffer.Release() |
|
|
|
|
return nil, nil |
|
|
|
|
} |
|
|
|
|
return buffer, nil |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
//Parse the request
|
|
|
|
|
|
|
|
|
|
if req, err := readRequest(bufio.NewReader(bytes.NewReader(headerBuf.Bytes())), false); err != nil { |
|
|
|
|
return nil, err |
|
|
|
|
} else { |
|
|
|
|
h.req = req |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
//Check req
|
|
|
|
|
path := h.req.URL.Path |
|
|
|
|
hasThisUri := false |
|
|
|
|
for _, u := range h.expectedHeader.Uri { |
|
|
|
|
if u == path { |
|
|
|
|
hasThisUri = true |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
if hasThisUri == false { |
|
|
|
|
return nil, ErrHeaderMisMatch |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
if buffer.IsEmpty() { |
|
|
|
|
buffer.Release() |
|
|
|
|
return nil, nil |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
return buffer, nil |
|
|
|
|
} |
|
|
|
|
|
|
|
|
@ -110,18 +157,24 @@ func (w *HeaderWriter) Write(writer io.Writer) error {
|
|
|
|
|
type HttpConn struct { |
|
|
|
|
net.Conn |
|
|
|
|
|
|
|
|
|
readBuffer *buf.Buffer |
|
|
|
|
oneTimeReader Reader |
|
|
|
|
oneTimeWriter Writer |
|
|
|
|
errorWriter Writer |
|
|
|
|
readBuffer *buf.Buffer |
|
|
|
|
oneTimeReader Reader |
|
|
|
|
oneTimeWriter Writer |
|
|
|
|
errorWriter Writer |
|
|
|
|
errorMismatchWriter Writer |
|
|
|
|
errorTooLongWriter Writer |
|
|
|
|
|
|
|
|
|
errReason error |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
func NewHttpConn(conn net.Conn, reader Reader, writer Writer, errorWriter Writer) *HttpConn { |
|
|
|
|
func NewHttpConn(conn net.Conn, reader Reader, writer Writer, errorWriter Writer, errorMismatchWriter Writer, errorTooLongWriter Writer) *HttpConn { |
|
|
|
|
return &HttpConn{ |
|
|
|
|
Conn: conn, |
|
|
|
|
oneTimeReader: reader, |
|
|
|
|
oneTimeWriter: writer, |
|
|
|
|
errorWriter: errorWriter, |
|
|
|
|
Conn: conn, |
|
|
|
|
oneTimeReader: reader, |
|
|
|
|
oneTimeWriter: writer, |
|
|
|
|
errorWriter: errorWriter, |
|
|
|
|
errorMismatchWriter: errorMismatchWriter, |
|
|
|
|
errorTooLongWriter: errorTooLongWriter, |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
|
|
|
|
@ -129,6 +182,7 @@ func (c *HttpConn) Read(b []byte) (int, error) {
|
|
|
|
|
if c.oneTimeReader != nil { |
|
|
|
|
buffer, err := c.oneTimeReader.Read(c.Conn) |
|
|
|
|
if err != nil { |
|
|
|
|
c.errReason = err |
|
|
|
|
return 0, err |
|
|
|
|
} |
|
|
|
|
c.readBuffer = buffer |
|
|
|
@ -165,7 +219,16 @@ func (c *HttpConn) Close() error {
|
|
|
|
|
if c.oneTimeWriter != nil && c.errorWriter != nil { |
|
|
|
|
// Connection is being closed but header wasn't sent. This means the client request
|
|
|
|
|
// is probably not valid. Sending back a server error header in this case.
|
|
|
|
|
c.errorWriter.Write(c.Conn) |
|
|
|
|
|
|
|
|
|
//Write response based on error reason
|
|
|
|
|
|
|
|
|
|
if c.errReason == ErrHeaderMisMatch { |
|
|
|
|
c.errorMismatchWriter.Write(c.Conn) |
|
|
|
|
} else if c.errReason == ErrHeaderToLong { |
|
|
|
|
c.errorTooLongWriter.Write(c.Conn) |
|
|
|
|
} else { |
|
|
|
|
c.errorWriter.Write(c.Conn) |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
return c.Conn.Close() |
|
|
|
@ -230,36 +293,17 @@ func (a HttpAuthenticator) Client(conn net.Conn) net.Conn {
|
|
|
|
|
if a.config.Response != nil { |
|
|
|
|
writer = a.GetClientWriter() |
|
|
|
|
} |
|
|
|
|
return NewHttpConn(conn, reader, writer, NoOpWriter{}) |
|
|
|
|
return NewHttpConn(conn, reader, writer, NoOpWriter{}, NoOpWriter{}, NoOpWriter{}) |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
func (a HttpAuthenticator) Server(conn net.Conn) net.Conn { |
|
|
|
|
if a.config.Request == nil && a.config.Response == nil { |
|
|
|
|
return conn |
|
|
|
|
} |
|
|
|
|
return NewHttpConn(conn, new(HeaderReader), a.GetServerWriter(), formResponseHeader(&ResponseConfig{ |
|
|
|
|
Version: &Version{ |
|
|
|
|
Value: "1.1", |
|
|
|
|
}, |
|
|
|
|
Status: &Status{ |
|
|
|
|
Code: "500", |
|
|
|
|
Reason: "Internal Server Error", |
|
|
|
|
}, |
|
|
|
|
Header: []*Header{ |
|
|
|
|
{ |
|
|
|
|
Name: "Connection", |
|
|
|
|
Value: []string{"close"}, |
|
|
|
|
}, |
|
|
|
|
{ |
|
|
|
|
Name: "Cache-Control", |
|
|
|
|
Value: []string{"private"}, |
|
|
|
|
}, |
|
|
|
|
{ |
|
|
|
|
Name: "Content-Length", |
|
|
|
|
Value: []string{"0"}, |
|
|
|
|
}, |
|
|
|
|
}, |
|
|
|
|
})) |
|
|
|
|
return NewHttpConn(conn, new(HeaderReader).ExpectThisRequest(a.config.Request), a.GetServerWriter(), |
|
|
|
|
formResponseHeader(resp400), |
|
|
|
|
formResponseHeader(resp404), |
|
|
|
|
formResponseHeader(resp400)) |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
func NewHttpAuthenticator(ctx context.Context, config *Config) (HttpAuthenticator, error) { |
|
|
|
|