diff --git a/common/buf/multi_buffer.go b/common/buf/multi_buffer.go index a245123c..48cc185f 100644 --- a/common/buf/multi_buffer.go +++ b/common/buf/multi_buffer.go @@ -30,6 +30,26 @@ func ReadAllToMultiBuffer(reader io.Reader) (MultiBuffer, error) { } } +// ReadSizeToMultiBuffer reads specific number of bytes from reader into a MultiBuffer. +func ReadSizeToMultiBuffer(reader io.Reader, size int32) (MultiBuffer, error) { + mb := NewMultiBufferCap(32) + + for size > 0 { + bSize := size + if bSize > Size { + bSize = Size + } + b := NewSize(uint32(bSize)) + if err := b.Reset(ReadFullFrom(reader, int(bSize))); err != nil { + mb.Release() + return nil, err + } + size -= bSize + mb.Append(b) + } + return mb, nil +} + // ReadAllToBytes reads all content from the reader into a byte array, until EOF. func ReadAllToBytes(reader io.Reader) ([]byte, error) { mb, err := ReadAllToMultiBuffer(reader) diff --git a/proxy/http/server.go b/proxy/http/server.go index 2e939803..ec30bd11 100644 --- a/proxy/http/server.go +++ b/proxy/http/server.go @@ -173,11 +173,11 @@ func (s *Server) handleConnect(ctx context.Context, request *http.Request, reade } if reader.Buffered() > 0 { - payload := buf.NewSize(uint32(reader.Buffered())) - common.Must(payload.Reset(func(b []byte) (int, error) { - return reader.Read(b[:reader.Buffered()]) - })) - if err := ray.InboundInput().WriteMultiBuffer(buf.NewMultiBufferValue(payload)); err != nil { + payload, err := buf.ReadSizeToMultiBuffer(reader, int32(reader.Buffered())) + if err != nil { + return err + } + if err := ray.InboundInput().WriteMultiBuffer(payload); err != nil { return err } reader = nil