package agent import ( "bytes" "encoding/json" "fmt" "io" "io/ioutil" "net" "net/http" "net/http/httptest" "os" "path/filepath" "runtime" "strconv" "testing" "time" "github.com/hashicorp/consul/consul/structs" "github.com/hashicorp/consul/testutil" ) func makeHTTPServer(t *testing.T) (string, *HTTPServer) { return makeHTTPServerWithConfig(t, nil) } func makeHTTPServerWithConfig(t *testing.T, cb func(c *Config)) (string, *HTTPServer) { conf := nextConfig() if cb != nil { cb(conf) } dir, agent := makeAgent(t, conf) uiDir := filepath.Join(dir, "ui") if err := os.Mkdir(uiDir, 755); err != nil { t.Fatalf("err: %v", err) } conf.UiDir = uiDir servers, err := NewHTTPServers(agent, conf, agent.logOutput) if err != nil { t.Fatalf("err: %v", err) } if len(servers) == 0 { t.Fatalf(fmt.Sprintf("Failed to make HTTP server")) } return dir, servers[0] } func encodeReq(obj interface{}) io.ReadCloser { buf := bytes.NewBuffer(nil) enc := json.NewEncoder(buf) enc.Encode(obj) return ioutil.NopCloser(buf) } func TestHTTPServer_UnixSocket(t *testing.T) { if runtime.GOOS == "windows" { t.SkipNow() } tempDir, err := ioutil.TempDir("", "consul") if err != nil { t.Fatalf("err: %s", err) } defer os.RemoveAll(tempDir) socket := filepath.Join(tempDir, "test.sock") dir, srv := makeHTTPServerWithConfig(t, func(c *Config) { c.Addresses.HTTP = "unix://" + socket }) defer os.RemoveAll(dir) defer srv.Shutdown() defer srv.agent.Shutdown() // Ensure the socket was created if _, err := os.Stat(socket); err != nil { t.Fatalf("err: %s", err) } // Ensure we can get a response from the socket. path, _ := unixSocketAddr(srv.agent.config.Addresses.HTTP) client := &http.Client{ Transport: &http.Transport{ Dial: func(_, _ string) (net.Conn, error) { return net.Dial("unix", path) }, }, } // This URL doesn't look like it makes sense, but the scheme (http://) and // the host (127.0.0.1) are required by the HTTP client library. In reality // this will just use the custom dialer and talk to the socket. resp, err := client.Get("http://127.0.0.1/v1/agent/self") if err != nil { t.Fatalf("err: %s", err) } defer resp.Body.Close() if body, err := ioutil.ReadAll(resp.Body); err != nil || len(body) == 0 { t.Fatalf("bad: %s %v", body, err) } } func TestHTTPServer_UnixSocket_FileExists(t *testing.T) { if runtime.GOOS == "windows" { t.SkipNow() } tempDir, err := ioutil.TempDir("", "consul") if err != nil { t.Fatalf("err: %s", err) } defer os.RemoveAll(tempDir) socket := filepath.Join(tempDir, "test.sock") // Create a regular file at the socket path if err := ioutil.WriteFile(socket, []byte("hello world"), 0644); err != nil { t.Fatalf("err: %s", err) } fi, err := os.Stat(socket) if err != nil { t.Fatalf("err: %s", err) } if !fi.Mode().IsRegular() { t.Fatalf("not a regular file: %s", socket) } conf := nextConfig() conf.Addresses.HTTP = "unix://" + socket dir, agent := makeAgent(t, conf) defer os.RemoveAll(dir) // Try to start the server with the same path anyways. if servers, err := NewHTTPServers(agent, conf, agent.logOutput); err == nil { for _, server := range servers { server.Shutdown() } t.Fatalf("expected socket binding error") } } func TestSetIndex(t *testing.T) { resp := httptest.NewRecorder() setIndex(resp, 1000) header := resp.Header().Get("X-Consul-Index") if header != "1000" { t.Fatalf("Bad: %v", header) } setIndex(resp, 2000) if v := resp.Header()["X-Consul-Index"]; len(v) != 1 { t.Fatalf("bad: %#v", v) } } func TestSetKnownLeader(t *testing.T) { resp := httptest.NewRecorder() setKnownLeader(resp, true) header := resp.Header().Get("X-Consul-KnownLeader") if header != "true" { t.Fatalf("Bad: %v", header) } resp = httptest.NewRecorder() setKnownLeader(resp, false) header = resp.Header().Get("X-Consul-KnownLeader") if header != "false" { t.Fatalf("Bad: %v", header) } } func TestSetLastContact(t *testing.T) { resp := httptest.NewRecorder() setLastContact(resp, 123456*time.Microsecond) header := resp.Header().Get("X-Consul-LastContact") if header != "123" { t.Fatalf("Bad: %v", header) } } func TestSetMeta(t *testing.T) { meta := structs.QueryMeta{ Index: 1000, KnownLeader: true, LastContact: 123456 * time.Microsecond, } resp := httptest.NewRecorder() setMeta(resp, &meta) header := resp.Header().Get("X-Consul-Index") if header != "1000" { t.Fatalf("Bad: %v", header) } header = resp.Header().Get("X-Consul-KnownLeader") if header != "true" { t.Fatalf("Bad: %v", header) } header = resp.Header().Get("X-Consul-LastContact") if header != "123" { t.Fatalf("Bad: %v", header) } } func TestHTTPAPIResponseHeaders(t *testing.T) { dir, srv := makeHTTPServer(t) srv.agent.config.HTTPAPIResponseHeaders = map[string]string{ "Access-Control-Allow-Origin": "*", "X-XSS-Protection": "1; mode=block", } defer os.RemoveAll(dir) defer srv.Shutdown() defer srv.agent.Shutdown() resp := httptest.NewRecorder() handler := func(resp http.ResponseWriter, req *http.Request) (interface{}, error) { return nil, nil } req, _ := http.NewRequest("GET", "/v1/agent/self", nil) srv.wrap(handler)(resp, req) origin := resp.Header().Get("Access-Control-Allow-Origin") if origin != "*" { t.Fatalf("bad Access-Control-Allow-Origin: expected %q, got %q", "*", origin) } xss := resp.Header().Get("X-XSS-Protection") if xss != "1; mode=block" { t.Fatalf("bad X-XSS-Protection header: expected %q, got %q", "1; mode=block", xss) } } func TestContentTypeIsJSON(t *testing.T) { dir, srv := makeHTTPServer(t) defer os.RemoveAll(dir) defer srv.Shutdown() defer srv.agent.Shutdown() resp := httptest.NewRecorder() handler := func(resp http.ResponseWriter, req *http.Request) (interface{}, error) { // stub out a DirEntry so that it will be encoded as JSON return &structs.DirEntry{Key: "key"}, nil } req, _ := http.NewRequest("GET", "/v1/kv/key", nil) srv.wrap(handler)(resp, req) contentType := resp.Header().Get("Content-Type") if contentType != "application/json" { t.Fatalf("Content-Type header was not 'application/json'") } } func TestPrettyPrint(t *testing.T) { testPrettyPrint("pretty=1", t) } func TestPrettyPrintBare(t *testing.T) { testPrettyPrint("pretty", t) } func testPrettyPrint(pretty string, t *testing.T) { dir, srv := makeHTTPServer(t) defer os.RemoveAll(dir) defer srv.Shutdown() defer srv.agent.Shutdown() r := &structs.DirEntry{Key: "key"} resp := httptest.NewRecorder() handler := func(resp http.ResponseWriter, req *http.Request) (interface{}, error) { return r, nil } urlStr := "/v1/kv/key?" + pretty req, _ := http.NewRequest("GET", urlStr, nil) srv.wrap(handler)(resp, req) expected, _ := json.MarshalIndent(r, "", " ") actual, err := ioutil.ReadAll(resp.Body) if err != nil { t.Fatalf("err: %s", err) } if !bytes.Equal(expected, actual) { t.Fatalf("bad: %q", string(actual)) } } func TestParseWait(t *testing.T) { resp := httptest.NewRecorder() var b structs.QueryOptions req, err := http.NewRequest("GET", "/v1/catalog/nodes?wait=60s&index=1000", nil) if err != nil { t.Fatalf("err: %v", err) } if d := parseWait(resp, req, &b); d { t.Fatalf("unexpected done") } if b.MinQueryIndex != 1000 { t.Fatalf("Bad: %v", b) } if b.MaxQueryTime != 60*time.Second { t.Fatalf("Bad: %v", b) } } func TestParseWait_InvalidTime(t *testing.T) { resp := httptest.NewRecorder() var b structs.QueryOptions req, err := http.NewRequest("GET", "/v1/catalog/nodes?wait=60foo&index=1000", nil) if err != nil { t.Fatalf("err: %v", err) } if d := parseWait(resp, req, &b); !d { t.Fatalf("expected done") } if resp.Code != 400 { t.Fatalf("bad code: %v", resp.Code) } } func TestParseWait_InvalidIndex(t *testing.T) { resp := httptest.NewRecorder() var b structs.QueryOptions req, err := http.NewRequest("GET", "/v1/catalog/nodes?wait=60s&index=foo", nil) if err != nil { t.Fatalf("err: %v", err) } if d := parseWait(resp, req, &b); !d { t.Fatalf("expected done") } if resp.Code != 400 { t.Fatalf("bad code: %v", resp.Code) } } func TestParseConsistency(t *testing.T) { resp := httptest.NewRecorder() var b structs.QueryOptions req, err := http.NewRequest("GET", "/v1/catalog/nodes?stale", nil) if err != nil { t.Fatalf("err: %v", err) } if d := parseConsistency(resp, req, &b); d { t.Fatalf("unexpected done") } if !b.AllowStale { t.Fatalf("Bad: %v", b) } if b.RequireConsistent { t.Fatalf("Bad: %v", b) } b = structs.QueryOptions{} req, err = http.NewRequest("GET", "/v1/catalog/nodes?consistent", nil) if err != nil { t.Fatalf("err: %v", err) } if d := parseConsistency(resp, req, &b); d { t.Fatalf("unexpected done") } if b.AllowStale { t.Fatalf("Bad: %v", b) } if !b.RequireConsistent { t.Fatalf("Bad: %v", b) } } func TestParseConsistency_Invalid(t *testing.T) { resp := httptest.NewRecorder() var b structs.QueryOptions req, err := http.NewRequest("GET", "/v1/catalog/nodes?stale&consistent", nil) if err != nil { t.Fatalf("err: %v", err) } if d := parseConsistency(resp, req, &b); !d { t.Fatalf("expected done") } if resp.Code != 400 { t.Fatalf("bad code: %v", resp.Code) } } // assertIndex tests that X-Consul-Index is set and non-zero func assertIndex(t *testing.T, resp *httptest.ResponseRecorder) { header := resp.Header().Get("X-Consul-Index") if header == "" || header == "0" { t.Fatalf("Bad: %v", header) } } // checkIndex is like assertIndex but returns an error func checkIndex(resp *httptest.ResponseRecorder) error { header := resp.Header().Get("X-Consul-Index") if header == "" || header == "0" { return fmt.Errorf("Bad: %v", header) } return nil } // getIndex parses X-Consul-Index func getIndex(t *testing.T, resp *httptest.ResponseRecorder) uint64 { header := resp.Header().Get("X-Consul-Index") if header == "" { t.Fatalf("Bad: %v", header) } val, err := strconv.Atoi(header) if err != nil { t.Fatalf("Bad: %v", header) } return uint64(val) } func httpTest(t *testing.T, f func(srv *HTTPServer)) { dir, srv := makeHTTPServer(t) defer os.RemoveAll(dir) defer srv.Shutdown() defer srv.agent.Shutdown() testutil.WaitForLeader(t, srv.agent.RPC, "dc1") f(srv) }