diff --git a/main.go b/main.go index b4fccb4..6b1afab 100644 --- a/main.go +++ b/main.go @@ -72,6 +72,7 @@ func init() { os.Setenv("GOSUMDB", "off") downloadRoot = getDownloadRoot() + os.Setenv("GOMODCACHE", downloadRoot) } func main() { diff --git a/proxy/router.go b/proxy/router.go index fe24361..ca13662 100644 --- a/proxy/router.go +++ b/proxy/router.go @@ -4,6 +4,7 @@ import ( "bytes" "compress/gzip" "crypto/tls" + "fmt" "io/ioutil" "log" "net/http" @@ -96,6 +97,53 @@ func NewRouter(srv *Server, opts *RouterOptions) *Router { } } } + + // support 302 status code. + if r.StatusCode == http.StatusFound { + loc := r.Header.Get("Location") + if loc == "" { + return fmt.Errorf("%d response missing Location header", r.StatusCode) + } + + // TODO: location is relative. + _, err := url.Parse(loc) + if err != nil { + return fmt.Errorf("failed to parse Location header %q: %v", loc, err) + } + resp, err := http.Get(loc) + if err != nil { + return err + } + defer resp.Body.Close() + + var buf []byte + if strings.Contains(resp.Header.Get("Content-Encoding"), "gzip") { + gr, err := gzip.NewReader(resp.Body) + if err != nil { + return err + } + defer gr.Close() + buf, err = ioutil.ReadAll(gr) + if err != nil { + return err + } + resp.Header.Del("Content-Encoding") + } else { + buf, err = ioutil.ReadAll(resp.Body) + if err != nil { + return err + } + } + resp.Body = ioutil.NopCloser(bytes.NewReader(buf)) + if buf != nil { + file := filepath.Join(opts.DownloadRoot, r.Request.URL.Path) + os.MkdirAll(path.Dir(file), os.ModePerm) + err = renameio.WriteFile(file, buf, 0666) + if err != nil { + return err + } + } + } return nil } rt.pattern = opts.Pattern