diff --git a/server/webdav/buffered_response_writer.go b/server/webdav/buffered_response_writer.go new file mode 100644 index 00000000..ed653eae --- /dev/null +++ b/server/webdav/buffered_response_writer.go @@ -0,0 +1,46 @@ +package webdav + +import ( + "net/http" +) + +type bufferedResponseWriter struct { + statusCode int + data []byte + header http.Header +} + +func (w *bufferedResponseWriter) Header() http.Header { + if w.header == nil { + w.header = make(http.Header) + } + return w.header +} + +func (w *bufferedResponseWriter) Write(bytes []byte) (int, error) { + w.data = append(w.data, bytes...) + return len(bytes), nil +} + +func (w *bufferedResponseWriter) WriteHeader(statusCode int) { + if w.statusCode == 0 { + w.statusCode = statusCode + } +} + +func (w *bufferedResponseWriter) WriteToResponse(rw http.ResponseWriter) (int, error) { + h := rw.Header() + for k, vs := range w.header { + for _, v := range vs { + h.Add(k, v) + } + } + rw.WriteHeader(w.statusCode) + return rw.Write(w.data) +} + +func newBufferedResponseWriter() *bufferedResponseWriter { + return &bufferedResponseWriter{ + statusCode: 0, + } +} diff --git a/server/webdav/webdav.go b/server/webdav/webdav.go index ac9d975a..7ab8dbf5 100644 --- a/server/webdav/webdav.go +++ b/server/webdav/webdav.go @@ -45,30 +45,33 @@ func (h *Handler) stripPrefix(p string) (string, int, error) { func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { status, err := http.StatusBadRequest, errUnsupportedMethod + brw := newBufferedResponseWriter() + useBufferedWriter := true if h.LockSystem == nil { status, err = http.StatusInternalServerError, errNoLockSystem } else { switch r.Method { case "OPTIONS": - status, err = h.handleOptions(w, r) + status, err = h.handleOptions(brw, r) case "GET", "HEAD", "POST": + useBufferedWriter = false status, err = h.handleGetHeadPost(w, r) case "DELETE": - status, err = h.handleDelete(w, r) + status, err = h.handleDelete(brw, r) case "PUT": - status, err = h.handlePut(w, r) + status, err = h.handlePut(brw, r) case "MKCOL": - status, err = h.handleMkcol(w, r) + status, err = h.handleMkcol(brw, r) case "COPY", "MOVE": - status, err = h.handleCopyMove(w, r) + status, err = h.handleCopyMove(brw, r) case "LOCK": - status, err = h.handleLock(w, r) + status, err = h.handleLock(brw, r) case "UNLOCK": - status, err = h.handleUnlock(w, r) + status, err = h.handleUnlock(brw, r) case "PROPFIND": - status, err = h.handlePropfind(w, r) + status, err = h.handlePropfind(brw, r) case "PROPPATCH": - status, err = h.handleProppatch(w, r) + status, err = h.handleProppatch(brw, r) } } @@ -77,6 +80,8 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { if status != http.StatusNoContent { w.Write([]byte(StatusText(status))) } + } else if useBufferedWriter { + brw.WriteToResponse(w) } if h.Logger != nil && err != nil { h.Logger(r, err)