diff --git a/common/buf/io.go b/common/buf/io.go index 3c47e334..e04cebae 100644 --- a/common/buf/io.go +++ b/common/buf/io.go @@ -53,6 +53,22 @@ type copyHandler struct { onWriteError func(error) error } +func (h *copyHandler) readFrom(reader Reader) (MultiBuffer, error) { + mb, err := reader.Read() + if err != nil && h.onReadError != nil { + err = h.onReadError(err) + } + return mb, err +} + +func (h *copyHandler) writeTo(writer Writer, mb MultiBuffer) error { + err := writer.Write(mb) + if err != nil && h.onWriteError != nil { + err = h.onWriteError(err) + } + return err +} + type CopyOption func(*copyHandler) func IgnoreReaderError() CopyOption { @@ -79,27 +95,25 @@ func UpdateActivity(timer signal.ActivityTimer) CopyOption { } } -func copyInternal(reader Reader, writer Writer, handler copyHandler) error { +func copyInternal(reader Reader, writer Writer, handler *copyHandler) error { for { - buffer, err := reader.Read() + buffer, err := handler.readFrom(reader) if err != nil { - if err = handler.onReadError(err); err != nil { - return err - } + return err } - handler.onData() - if buffer.IsEmpty() { buffer.Release() continue } - if err := writer.Write(buffer); err != nil { - if err = handler.onWriteError(err); err != nil { - buffer.Release() - return err - } + if handler.onData != nil { + handler.onData() + } + + if err := handler.writeTo(writer, buffer); err != nil { + buffer.Release() + return err } } } @@ -107,9 +121,9 @@ func copyInternal(reader Reader, writer Writer, handler copyHandler) error { // Copy dumps all payload from reader to writer or stops when an error occurs. // ActivityTimer gets updated as soon as there is a payload. func Copy(reader Reader, writer Writer, options ...CopyOption) error { - handler := copyHandler{} + handler := new(copyHandler) for _, option := range options { - option(&handler) + option(handler) } err := copyInternal(reader, writer, handler) if err != nil && errors.Cause(err) != io.EOF {