diff --git a/main.go b/main.go index 87f73b6..41e3fd7 100644 --- a/main.go +++ b/main.go @@ -26,9 +26,11 @@ import ( "net/http" "os" "os/exec" + "os/signal" "path" "path/filepath" "strings" + "syscall" "time" "github.com/goproxyio/goproxy/proxy" @@ -68,22 +70,6 @@ func init() { func main() { log.SetPrefix("goproxy.io: ") log.SetFlags(0) - // TODO flags - var env struct { - GOPATH string - } - if cacheDir != "" { - downloadRoot = filepath.Join(cacheDir, "pkg/mod/cache/download") - os.Setenv("GOPATH", cacheDir) - } - if err := goJSON(&env, "go", "env", "-json", "GOPATH"); err != nil { - log.Fatal(err) - } - list := filepath.SplitList(env.GOPATH) - if len(list) == 0 || list[0] == "" { - log.Fatalf("missing $GOPATH") - } - downloadRoot = filepath.Join(list[0], "pkg/mod/cache/download") var handle http.Handler if proxyHost != "" { @@ -94,12 +80,51 @@ func main() { handle = &logger{proxy.NewRouter(proxy.NewServer(new(ops)), &proxy.RouterOptions{ Pattern: excludeHost, Proxy: proxyHost, - DownloadRoot: downloadRoot, + DownloadRoot: getDownloadRoot(), })} } else { handle = &logger{proxy.NewServer(new(ops))} } - log.Fatal(http.ListenAndServe(listen, handle)) + + server := &http.Server{Addr: listen, Handler: handle} + go func() { + if err := server.ListenAndServe(); err != nil { + if err != http.ErrServerClosed { + log.Fatal(err) + } + } + }() + + s := make(chan os.Signal, 1) + signal.Notify(s, os.Interrupt, syscall.SIGTERM, syscall.SIGTERM) + <-s + log.Println("Making a graceful shutdown...") + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + err := server.Shutdown(ctx) + if err != nil { + log.Fatalf("Error while shutting down the server: %v", err) + } + log.Println("Successful server shutdown.") +} + +func getDownloadRoot() string { + var env struct { + GOPATH string + } + if cacheDir != "" { + downloadRoot = filepath.Join(cacheDir, "pkg/mod/cache/download") + os.Setenv("GOPATH", cacheDir) + return downloadRoot + } + if err := goJSON(&env, "go", "env", "-json", "GOPATH"); err != nil { + log.Fatal(err) + } + list := filepath.SplitList(env.GOPATH) + if len(list) == 0 || list[0] == "" { + log.Fatalf("missing $GOPATH") + } + return filepath.Join(list[0], "pkg/mod/cache/download") } // goJSON runs the go command and parses its JSON output into dst.