From 6ea355b5b957efe64d5f996f13bad0a2f4eb8e73 Mon Sep 17 00:00:00 2001 From: kun Date: Mon, 3 Sep 2018 12:55:23 +0800 Subject: [PATCH] support redirect to correct version URL --- main.go | 75 ++++++++++++++++++++++++++++++++++++++++++++++----------- 1 file changed, 61 insertions(+), 14 deletions(-) diff --git a/main.go b/main.go index 0df66cf..a33de5f 100644 --- a/main.go +++ b/main.go @@ -1,9 +1,9 @@ package main import ( - "bytes" "flag" "fmt" + "io/ioutil" "net/http" "os" "os/exec" @@ -37,13 +37,18 @@ func main() { func mainHandler(inner http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if _, err := os.Stat(filepath.Join(cacheDir, r.URL.Path)); err != nil { - if strings.HasSuffix(r.URL.Path, ".info") { + var suffix string + if strings.HasSuffix(r.URL.Path, ".info") || strings.HasSuffix(r.URL.Path, ".mod") { + suffix = ".mod" + if strings.HasSuffix(r.URL.Path, ".info") { + suffix = ".info" + } mod := strings.Split(r.URL.Path, "/@v/") if len(mod) != 2 { ReturnServerError(w, fmt.Errorf("bad module path:%s", r.URL.Path)) return } - version := strings.TrimSuffix(mod[1], ".info") + version := strings.TrimSuffix(mod[1], suffix) version, err = module.DecodeVersion(version) if err != nil { ReturnServerError(w, err) @@ -55,9 +60,9 @@ func mainHandler(inner http.Handler) http.Handler { ReturnServerError(w, err) return } - stdout, stderr, err := goGet(path + "@" + version) + err = goGet(path, version, suffix, w, r) if err != nil { - ReturnServerError(w, fmt.Errorf("stdout: %s stderr: %s", stdout, stderr)) + ReturnServerError(w, err) return } } @@ -71,13 +76,55 @@ func mainHandler(inner http.Handler) http.Handler { }) } -func goGet(path string) (string, string, error) { - fmt.Fprintf(os.Stdout, "goproxy: download %s\n", path) - cmd := exec.Command("go", "get", "-d", path) - var stdout bytes.Buffer - var stderr bytes.Buffer - cmd.Stdout = &stdout - cmd.Stderr = &stderr - err := cmd.Run() - return string(stdout.Bytes()), string(stderr.Bytes()), err +func goGet(path, version, suffix string, w http.ResponseWriter, r *http.Request) error { + cmd := exec.Command("go", "get", "-d", path+"@"+version) + stdout, err := cmd.StdoutPipe() + if err != nil { + return err + } + + stderr, err := cmd.StderrPipe() + if err != nil { + return err + } + + if err := cmd.Start(); err != nil { + return err + } + + bytesErr, err := ioutil.ReadAll(stderr) + if err != nil { + return err + } + + bytesOut, err := ioutil.ReadAll(stdout) + if err != nil { + return err + } + + if err := cmd.Wait(); err != nil { + return err + } + out := fmt.Sprintf("%s", bytesErr) + + for _, line := range strings.Split(out, "\n") { + f := strings.Fields(line) + if len(f) != 4 { + continue + } + if f[1] == "downloading" && f[2] == path && f[3] != version { + h := r.Host + mod := strings.Split(r.URL.Path, "/@v/") + p := fmt.Sprintf("%s/@v/%s%s", mod[0], f[3], suffix) + scheme := "http:" + if r.TLS != nil { + scheme = "https:" + } + url := fmt.Sprintf("%s//%s/%s", scheme, h, p) + http.Redirect(w, r, url, 302) + } + } + + fmt.Fprintf(os.Stdout, "goproxy: download %s stdout: %s stderr: %s\n", path, string(bytesOut), string(bytesErr)) + return nil }