Merge pull request #587 from jefferai/unixlisten

RPC and HTTP interfaces fully generically-sockified so Unix is supported...
pull/610/head
Ryan Uber 2015-01-15 15:37:40 -08:00
commit 36f9924e51
12 changed files with 563 additions and 88 deletions

View File

@ -5,9 +5,12 @@ import (
"encoding/json"
"fmt"
"io"
"net"
"net/http"
"net/url"
"os"
"strconv"
"strings"
"time"
)
@ -111,11 +114,17 @@ type Config struct {
// DefaultConfig returns a default configuration for the client
func DefaultConfig() *Config {
return &Config{
config := &Config{
Address: "127.0.0.1:8500",
Scheme: "http",
HttpClient: http.DefaultClient,
}
if len(os.Getenv("CONSUL_HTTP_ADDR")) > 0 {
config.Address = os.Getenv("CONSUL_HTTP_ADDR")
}
return config
}
// Client provides a client to the Consul API
@ -128,7 +137,11 @@ func NewClient(config *Config) (*Client, error) {
// bootstrap the config
defConfig := DefaultConfig()
if len(config.Address) == 0 {
switch {
case len(config.Address) != 0:
case len(os.Getenv("CONSUL_HTTP_ADDR")) > 0:
config.Address = os.Getenv("CONSUL_HTTP_ADDR")
default:
config.Address = defConfig.Address
}
@ -140,6 +153,16 @@ func NewClient(config *Config) (*Client, error) {
config.HttpClient = defConfig.HttpClient
}
if strings.HasPrefix(config.Address, "unix://") {
shortStr := strings.TrimPrefix(config.Address, "unix://")
t := &http.Transport{}
t.Dial = func(_, _ string) (net.Conn, error) {
return net.Dial("unix", shortStr)
}
config.HttpClient.Transport = t
config.Address = shortStr
}
client := &Client{
config: *config,
}
@ -206,9 +229,6 @@ func (r *request) toHTTP() (*http.Request, error) {
// Encode the query parameters
r.url.RawQuery = r.params.Encode()
// Get the url sring
urlRaw := r.url.String()
// Check if we should encode the body
if r.body == nil && r.obj != nil {
if b, err := encodeBody(r.obj); err != nil {
@ -219,14 +239,21 @@ func (r *request) toHTTP() (*http.Request, error) {
}
// Create the HTTP request
req, err := http.NewRequest(r.method, urlRaw, r.body)
req, err := http.NewRequest(r.method, r.url.RequestURI(), r.body)
if err != nil {
return nil, err
}
req.URL.Host = r.url.Host
req.URL.Scheme = r.url.Scheme
req.Host = r.url.Host
// Setup auth
if err == nil && r.config.HttpAuth != nil {
if r.config.HttpAuth != nil {
req.SetBasicAuth(r.config.HttpAuth.Username, r.config.HttpAuth.Password)
}
return req, err
return req, nil
}
// newRequest is used to create a new request

View File

@ -2,6 +2,7 @@ package api
import (
crand "crypto/rand"
"encoding/json"
"fmt"
"io/ioutil"
"net/http"
@ -13,27 +14,50 @@ import (
"github.com/hashicorp/consul/testutil"
)
var consulConfig = `{
"ports": {
"dns": 19000,
"http": 18800,
"rpc": 18600,
"serf_lan": 18200,
"serf_wan": 18400,
"server": 18000
},
"data_dir": "%s",
"bootstrap": true,
"log_level": "debug",
"server": true
}`
type testServer struct {
pid int
dataDir string
configFile string
}
type testPortConfig struct {
DNS int `json:"dns,omitempty"`
HTTP int `json:"http,omitempty"`
RPC int `json:"rpc,omitempty"`
SerfLan int `json:"serf_lan,omitempty"`
SerfWan int `json:"serf_wan,omitempty"`
Server int `json:"server,omitempty"`
}
type testAddressConfig struct {
HTTP string `json:"http,omitempty"`
}
type testServerConfig struct {
Bootstrap bool `json:"bootstrap,omitempty"`
Server bool `json:"server,omitempty"`
DataDir string `json:"data_dir,omitempty"`
LogLevel string `json:"log_level,omitempty"`
Addresses *testAddressConfig `json:"addresses,omitempty"`
Ports testPortConfig `json:"ports,omitempty"`
}
func defaultConfig() *testServerConfig {
return &testServerConfig{
Bootstrap: true,
Server: true,
LogLevel: "debug",
Ports: testPortConfig{
DNS: 19000,
HTTP: 18800,
RPC: 18600,
SerfLan: 18200,
SerfWan: 18400,
Server: 18000,
},
}
}
func (s *testServer) stop() {
defer os.RemoveAll(s.dataDir)
defer os.RemoveAll(s.configFile)
@ -45,6 +69,10 @@ func (s *testServer) stop() {
}
func newTestServer(t *testing.T) *testServer {
return newTestServerWithConfig(t, func(c *testServerConfig) {})
}
func newTestServerWithConfig(t *testing.T, cb func(c *testServerConfig)) *testServer {
if path, err := exec.LookPath("consul"); err != nil || path == "" {
t.Log("consul not found on $PATH, skipping")
t.SkipNow()
@ -66,8 +94,18 @@ func newTestServer(t *testing.T) *testServer {
if err != nil {
t.Fatalf("err: %s", err)
}
configContent := fmt.Sprintf(consulConfig, dataDir)
if _, err := configFile.WriteString(configContent); err != nil {
consulConfig := defaultConfig()
consulConfig.DataDir = dataDir
cb(consulConfig)
configContent, err := json.Marshal(consulConfig)
if err != nil {
t.Fatalf("err: %s", err)
}
if _, err := configFile.Write(configContent); err != nil {
t.Fatalf("err: %s", err)
}
configFile.Close()
@ -80,10 +118,32 @@ func newTestServer(t *testing.T) *testServer {
t.Fatalf("err: %s", err)
}
return &testServer{
pid: cmd.Process.Pid,
dataDir: dataDir,
configFile: configFile.Name(),
}
}
func makeClient(t *testing.T) (*Client, *testServer) {
return makeClientWithConfig(t, func(c *Config) {
c.Address = "127.0.0.1:18800"
}, func(c *testServerConfig) {})
}
func makeClientWithConfig(t *testing.T, clientConfig func(c *Config), serverConfig func(c *testServerConfig)) (*Client, *testServer) {
server := newTestServerWithConfig(t, serverConfig)
conf := DefaultConfig()
clientConfig(conf)
client, err := NewClient(conf)
if err != nil {
t.Fatalf("err: %v", err)
}
// Allow the server some time to start, and verify we have a leader.
client := new(http.Client)
testutil.WaitForResult(func() (bool, error) {
resp, err := client.Get("http://127.0.0.1:18800/v1/catalog/nodes")
req := client.newRequest("GET", "/v1/catalog/nodes")
_, resp, err := client.doRequest(req)
if err != nil {
return false, err
}
@ -102,21 +162,6 @@ func newTestServer(t *testing.T) *testServer {
t.Fatalf("err: %s", err)
})
return &testServer{
pid: cmd.Process.Pid,
dataDir: dataDir,
configFile: configFile.Name(),
}
}
func makeClient(t *testing.T) (*Client, *testServer) {
server := newTestServer(t)
conf := DefaultConfig()
conf.Address = "127.0.0.1:18800"
client, err := NewClient(conf)
if err != nil {
t.Fatalf("err: %v", err)
}
return client, server
}
@ -205,7 +250,7 @@ func TestRequestToHTTP(t *testing.T) {
if req.Method != "DELETE" {
t.Fatalf("bad: %v", req)
}
if req.URL.String() != "http://127.0.0.1:18800/v1/kv/foo?dc=foo" {
if req.URL.RequestURI() != "/v1/kv/foo?dc=foo" {
t.Fatalf("bad: %v", req)
}
}

View File

@ -1,10 +1,13 @@
package api
import (
"io/ioutil"
"os/user"
"runtime"
"testing"
)
func TestStatusLeader(t *testing.T) {
func TestStatusLeaderTCP(t *testing.T) {
c, s := makeClient(t)
defer s.stop()
@ -19,6 +22,48 @@ func TestStatusLeader(t *testing.T) {
}
}
func TestStatusLeaderUnix(t *testing.T) {
if runtime.GOOS == "windows" {
t.SkipNow()
}
tempdir, err := ioutil.TempDir("", "consul-test-")
if err != nil {
t.Fatal("Could not create a working directory")
}
socket := "unix://" + tempdir + "/unix-http-test.sock"
clientConfig := func(c *Config) {
c.Address = socket
}
serverConfig := func(c *testServerConfig) {
user, err := user.Current()
if err != nil {
t.Fatal("Could not get current user")
}
if c.Addresses == nil {
c.Addresses = &testAddressConfig{}
}
c.Addresses.HTTP = socket + ";" + user.Uid + ";" + user.Gid + ";640"
}
c, s := makeClientWithConfig(t, clientConfig, serverConfig)
defer s.stop()
status := c.Status()
leader, err := status.Leader()
if err != nil {
t.Fatalf("err: %v", err)
}
if leader == "" {
t.Fatalf("Expected leader")
}
}
func TestStatusPeers(t *testing.T) {
c, s := makeClient(t)
defer s.stop()

View File

@ -7,8 +7,10 @@ import (
"io"
"io/ioutil"
"os"
"os/user"
"path/filepath"
"reflect"
"runtime"
"sync/atomic"
"testing"
"time"
@ -123,7 +125,7 @@ func TestAgentStartStop(t *testing.T) {
}
}
func TestAgent_RPCPing(t *testing.T) {
func TestAgent_RPCPingTCP(t *testing.T) {
dir, agent := makeAgent(t, nextConfig())
defer os.RemoveAll(dir)
defer agent.Shutdown()
@ -134,6 +136,35 @@ func TestAgent_RPCPing(t *testing.T) {
}
}
func TestAgent_RPCPingUnix(t *testing.T) {
if runtime.GOOS == "windows" {
t.SkipNow()
}
nextConf := nextConfig()
tempdir, err := ioutil.TempDir("", "consul-test-")
if err != nil {
t.Fatal("Could not create a working directory")
}
user, err := user.Current()
if err != nil {
t.Fatal("Could not get current user")
}
nextConf.Addresses.RPC = "unix://" + tempdir + "/unix-rpc-test.sock;" + user.Uid + ";" + user.Gid + ";640"
dir, agent := makeAgent(t, nextConf)
defer os.RemoveAll(dir)
defer agent.Shutdown()
var out struct{}
if err := agent.RPC("Status.Ping", struct{}{}, &out); err != nil {
t.Fatalf("err: %v", err)
}
}
func TestAgent_AddService(t *testing.T) {
dir, agent := makeAgent(t, nextConfig())
defer os.RemoveAll(dir)

View File

@ -295,13 +295,26 @@ func (c *Command) setupAgent(config *Config, logOutput io.Writer, logWriter *log
return err
}
rpcListener, err := net.Listen("tcp", rpcAddr.String())
if _, ok := rpcAddr.(*net.UnixAddr); ok {
// Remove the socket if it exists, or we'll get a bind error
_ = os.Remove(rpcAddr.String())
}
rpcListener, err := net.Listen(rpcAddr.Network(), rpcAddr.String())
if err != nil {
agent.Shutdown()
c.Ui.Error(fmt.Sprintf("Error starting RPC listener: %s", err))
return err
}
if _, ok := rpcAddr.(*net.UnixAddr); ok {
if err := adjustUnixSocketPermissions(config.Addresses.RPC); err != nil {
agent.Shutdown()
c.Ui.Error(fmt.Sprintf("Error adjusting Unix socket permissions: %s", err))
return err
}
}
// Start the IPC layer
c.Ui.Output("Starting Consul agent RPC...")
c.rpcServer = NewAgentRPC(agent, rpcListener, logOutput, logWriter)
@ -319,6 +332,7 @@ func (c *Command) setupAgent(config *Config, logOutput io.Writer, logWriter *log
if config.Ports.DNS > 0 {
dnsAddr, err := config.ClientListener(config.Addresses.DNS, config.Ports.DNS)
if err != nil {
agent.Shutdown()
c.Ui.Error(fmt.Sprintf("Invalid DNS bind address: %s", err))
return err
}
@ -575,7 +589,7 @@ func (c *Command) Run(args []string) int {
}
// Get the new client http listener addr
httpAddr, err := config.ClientListenerAddr(config.Addresses.HTTP, config.Ports.HTTP)
httpAddr, err := config.ClientListener(config.Addresses.HTTP, config.Ports.HTTP)
if err != nil {
c.Ui.Error(fmt.Sprintf("Failed to determine HTTP address: %v", err))
}
@ -585,7 +599,7 @@ func (c *Command) Run(args []string) int {
go func(wp *watch.WatchPlan) {
wp.Handler = makeWatchHandler(logOutput, wp.Exempt["handler"])
wp.LogOutput = c.logOutput
if err := wp.Run(httpAddr); err != nil {
if err := wp.Run(httpAddr.String()); err != nil {
c.Ui.Error(fmt.Sprintf("Error running watch: %v", err))
}
}(wp)
@ -744,7 +758,7 @@ func (c *Command) handleReload(config *Config) *Config {
}
// Get the new client listener addr
httpAddr, err := newConf.ClientListenerAddr(config.Addresses.HTTP, config.Ports.HTTP)
httpAddr, err := newConf.ClientListener(config.Addresses.HTTP, config.Ports.HTTP)
if err != nil {
c.Ui.Error(fmt.Sprintf("Failed to determine HTTP address: %v", err))
}
@ -759,7 +773,7 @@ func (c *Command) handleReload(config *Config) *Config {
go func(wp *watch.WatchPlan) {
wp.Handler = makeWatchHandler(c.logOutput, wp.Exempt["handler"])
wp.LogOutput = c.logOutput
if err := wp.Run(httpAddr); err != nil {
if err := wp.Run(httpAddr.String()); err != nil {
c.Ui.Error(fmt.Sprintf("Error running watch: %v", err))
}
}(wp)

View File

@ -7,8 +7,11 @@ import (
"io"
"net"
"os"
"os/user"
"path/filepath"
"regexp"
"sort"
"strconv"
"strings"
"time"
@ -345,6 +348,91 @@ type Config struct {
WatchPlans []*watch.WatchPlan `mapstructure:"-" json:"-"`
}
// UnixSocket contains the parameters for a Unix socket interface
type UnixSocket struct {
// Path to the socket on-disk
Path string
// uid of the owner of the socket
Uid int
// gid of the group of the socket
Gid int
// Permissions for the socket file
Permissions os.FileMode
}
func populateUnixSocket(addr string) (*UnixSocket, error) {
if !strings.HasPrefix(addr, "unix://") {
return nil, fmt.Errorf("Failed to parse Unix address, format is unix://[path];[user];[group];[mode]: %v", addr)
}
splitAddr := strings.Split(strings.TrimPrefix(addr, "unix://"), ";")
if len(splitAddr) != 4 {
return nil, fmt.Errorf("Failed to parse Unix address, format is unix://[path];[user];[group];[mode]: %v", addr)
}
ret := &UnixSocket{Path: splitAddr[0]}
var userVal *user.User
var err error
regex := regexp.MustCompile("[\\d]+")
if regex.MatchString(splitAddr[1]) {
userVal, err = user.LookupId(splitAddr[1])
} else {
userVal, err = user.Lookup(splitAddr[1])
}
if err != nil {
return nil, fmt.Errorf("Invalid user given for Unix socket ownership: %v", splitAddr[1])
}
if uid64, err := strconv.ParseInt(userVal.Uid, 10, 32); err != nil {
return nil, fmt.Errorf("Failed to parse given user ID of %v into integer", userVal.Uid)
} else {
ret.Uid = int(uid64)
}
// Go doesn't currently have a way to look up gid from group name,
// so require a numeric gid; see
// https://codereview.appspot.com/101310044
if gid64, err := strconv.ParseInt(splitAddr[2], 10, 32); err != nil {
return nil, fmt.Errorf("Socket group must be given as numeric gid. Failed to parse given group ID of %v into integer", splitAddr[2])
} else {
ret.Gid = int(gid64)
}
if mode, err := strconv.ParseUint(splitAddr[3], 8, 32); err != nil {
return nil, fmt.Errorf("Failed to parse given mode of %v into integer", splitAddr[3])
} else {
if mode > 0777 {
return nil, fmt.Errorf("Given mode is invalid; must be an octal number between 0 and 777")
} else {
ret.Permissions = os.FileMode(mode)
}
}
return ret, nil
}
func adjustUnixSocketPermissions(addr string) error {
sock, err := populateUnixSocket(addr)
if err != nil {
return err
}
if err = os.Chown(sock.Path, sock.Uid, sock.Gid); err != nil {
return fmt.Errorf("Error attempting to change socket permissions to userid %v and groupid %v: %v", sock.Uid, sock.Gid, err)
}
if err = os.Chmod(sock.Path, sock.Permissions); err != nil {
return fmt.Errorf("Error attempting to change socket permissions to mode %v: %v", sock.Permissions, err)
}
return nil
}
type dirEnts []os.FileInfo
// DefaultConfig is used to return a sane default configuration
@ -389,31 +477,39 @@ func (c *Config) EncryptBytes() ([]byte, error) {
// ClientListener is used to format a listener for a
// port on a ClientAddr
func (c *Config) ClientListener(override string, port int) (*net.TCPAddr, error) {
func (c *Config) ClientListener(override string, port int) (net.Addr, error) {
var addr string
if override != "" {
addr = override
} else {
addr = c.ClientAddr
}
ip := net.ParseIP(addr)
if ip == nil {
return nil, fmt.Errorf("Failed to parse IP: %v", addr)
}
return &net.TCPAddr{IP: ip, Port: port}, nil
}
// ClientListenerAddr is used to format an address for a
// port on a ClientAddr, handling the zero IP.
func (c *Config) ClientListenerAddr(override string, port int) (string, error) {
addr, err := c.ClientListener(override, port)
if err != nil {
return "", err
switch {
case strings.HasPrefix(addr, "unix://"):
sock, err := populateUnixSocket(addr)
if err != nil {
return nil, err
}
return &net.UnixAddr{Name: sock.Path, Net: "unix"}, nil
default:
ip := net.ParseIP(addr)
if ip == nil {
return nil, fmt.Errorf("Failed to parse IP: %v", addr)
}
if ip.IsUnspecified() {
ip = net.ParseIP("127.0.0.1")
}
if ip == nil {
return nil, fmt.Errorf("Failed to parse IP 127.0.0.1")
}
return &net.TCPAddr{IP: ip, Port: port}, nil
}
if addr.IP.IsUnspecified() {
addr.IP = net.ParseIP("127.0.0.1")
}
return addr.String(), nil
}
// DecodeConfig reads the configuration from the given reader in JSON

View File

@ -4,9 +4,12 @@ import (
"bytes"
"encoding/base64"
"io/ioutil"
"net"
"os"
"os/user"
"path/filepath"
"reflect"
"runtime"
"strings"
"testing"
"time"
@ -1068,3 +1071,109 @@ func TestReadConfigPaths_dir(t *testing.T) {
t.Fatalf("bad: %#v", config)
}
}
func TestUnixSockets(t *testing.T) {
if runtime.GOOS == "windows" {
t.SkipNow()
}
usr, err := user.Current()
if err != nil {
t.Fatal("Could not get current user: ", err)
}
tempdir, err := ioutil.TempDir("", "consul-test-")
if err != nil {
t.Fatal("Could not create a working directory: ", err)
}
type SocketTestData struct {
Path string
Uid string
Gid string
Mode string
}
testUnixSocketPopulation := func(s SocketTestData) (*UnixSocket, error) {
return populateUnixSocket("unix://" + s.Path + ";" + s.Uid + ";" + s.Gid + ";" + s.Mode)
}
testUnixSocketPermissions := func(s SocketTestData) error {
return adjustUnixSocketPermissions("unix://" + s.Path + ";" + s.Uid + ";" + s.Gid + ";" + s.Mode)
}
_, err = populateUnixSocket("tcp://abc123")
if err == nil {
t.Fatal("Should have rejected invalid scheme")
}
_, err = populateUnixSocket("unix://x;y;z")
if err == nil {
t.Fatal("Should have rejected invalid number of parameters in Unix socket definition")
}
std := SocketTestData{
Path: tempdir + "/unix-config-test.sock",
Uid: usr.Uid,
Gid: usr.Gid,
Mode: "640",
}
std.Uid = "orasdfdsnfoinweroiu"
_, err = testUnixSocketPopulation(std)
if err == nil {
t.Fatal("Did not error on invalid username")
}
std.Uid = usr.Username
std.Gid = "foinfphawepofhewof"
_, err = testUnixSocketPopulation(std)
if err == nil {
t.Fatal("Did not error on invalid group (a name, must be gid)")
}
std.Gid = usr.Gid
std.Mode = "999"
_, err = testUnixSocketPopulation(std)
if err == nil {
t.Fatal("Did not error on invalid socket mode")
}
std.Uid = usr.Username
std.Mode = "640"
_, err = testUnixSocketPopulation(std)
if err != nil {
t.Fatal("Unix socket test failed (using username): ", err)
}
std.Uid = usr.Uid
sock, err := testUnixSocketPopulation(std)
if err != nil {
t.Fatal("Unix socket test failed (using uid): ", err)
}
addr := &net.UnixAddr{Name: sock.Path, Net: "unix"}
_, err = net.Listen(addr.Network(), addr.String())
if err != nil {
t.Fatal("Error creating socket for futher tests: ", err)
}
std.Uid = "-999999"
err = testUnixSocketPermissions(std)
if err == nil {
t.Fatal("Did not error on invalid uid")
}
std.Uid = usr.Uid
std.Gid = "-999999"
err = testUnixSocketPermissions(std)
if err == nil {
t.Fatal("Did not error on invalid uid")
}
std.Gid = usr.Gid
err = testUnixSocketPermissions(std)
if err != nil {
t.Fatal("Adjusting socket permissions failed: ", err)
}
}

View File

@ -9,6 +9,7 @@ import (
"net"
"net/http"
"net/http/pprof"
"os"
"strconv"
"strings"
"time"
@ -34,7 +35,7 @@ type HTTPServer struct {
func NewHTTPServers(agent *Agent, config *Config, logOutput io.Writer) ([]*HTTPServer, error) {
var tlsConfig *tls.Config
var list net.Listener
var httpAddr *net.TCPAddr
var httpAddr net.Addr
var err error
var servers []*HTTPServer
@ -58,12 +59,29 @@ func NewHTTPServers(agent *Agent, config *Config, logOutput io.Writer) ([]*HTTPS
return nil, err
}
ln, err := net.Listen("tcp", httpAddr.String())
if err != nil {
return nil, err
if _, ok := httpAddr.(*net.UnixAddr); ok {
// Remove the socket if it exists, or we'll get a bind error
_ = os.Remove(httpAddr.String())
}
list = tls.NewListener(tcpKeepAliveListener{ln.(*net.TCPListener)}, tlsConfig)
ln, err := net.Listen(httpAddr.Network(), httpAddr.String())
if err != nil {
return nil, fmt.Errorf("Failed to get Listen on %s: %v", httpAddr.String(), err)
}
switch httpAddr.(type) {
case *net.UnixAddr:
if err := adjustUnixSocketPermissions(config.Addresses.HTTPS); err != nil {
return nil, err
}
list = tls.NewListener(ln, tlsConfig)
case *net.TCPAddr:
list = tls.NewListener(tcpKeepAliveListener{ln.(*net.TCPListener)}, tlsConfig)
default:
return nil, fmt.Errorf("Error determining address type when attempting to get Listen on %s: %v", httpAddr.String(), err)
}
// Create the mux
mux := http.NewServeMux()
@ -90,13 +108,29 @@ func NewHTTPServers(agent *Agent, config *Config, logOutput io.Writer) ([]*HTTPS
return nil, fmt.Errorf("Failed to get ClientListener address:port: %v", err)
}
// Create non-TLS listener
ln, err := net.Listen("tcp", httpAddr.String())
if _, ok := httpAddr.(*net.UnixAddr); ok {
// Remove the socket if it exists, or we'll get a bind error
_ = os.Remove(httpAddr.String())
}
ln, err := net.Listen(httpAddr.Network(), httpAddr.String())
if err != nil {
return nil, fmt.Errorf("Failed to get Listen on %s: %v", httpAddr.String(), err)
}
list = tcpKeepAliveListener{ln.(*net.TCPListener)}
switch httpAddr.(type) {
case *net.UnixAddr:
if err := adjustUnixSocketPermissions(config.Addresses.HTTP); err != nil {
return nil, err
}
list = ln
case *net.TCPAddr:
list = tcpKeepAliveListener{ln.(*net.TCPListener)}
default:
return nil, fmt.Errorf("Error determining address type when attempting to get Listen on %s: %v", httpAddr.String(), err)
}
// Create the mux
mux := http.NewServeMux()

View File

@ -7,6 +7,8 @@ import (
"github.com/hashicorp/logutils"
"log"
"net"
"os"
"strings"
"sync"
"sync/atomic"
)
@ -34,7 +36,7 @@ type seqHandler interface {
type RPCClient struct {
seq uint64
conn *net.TCPConn
conn net.Conn
reader *bufio.Reader
writer *bufio.Writer
dec *codec.Decoder
@ -79,8 +81,23 @@ func (c *RPCClient) send(header *requestHeader, obj interface{}) error {
// NewRPCClient is used to create a new RPC client given the address.
// This will properly dial, handshake, and start listening
func NewRPCClient(addr string) (*RPCClient, error) {
sanedAddr := os.Getenv("CONSUL_RPC_ADDR")
if len(sanedAddr) == 0 {
sanedAddr = addr
}
mode := "tcp"
if strings.HasPrefix(sanedAddr, "unix://") {
sanedAddr = strings.TrimPrefix(sanedAddr, "unix://")
}
if strings.HasPrefix(sanedAddr, "/") {
mode = "unix"
}
// Try to dial to agent
conn, err := net.Dial("tcp", addr)
conn, err := net.Dial(mode, sanedAddr)
if err != nil {
return nil, err
}
@ -88,7 +105,7 @@ func NewRPCClient(addr string) (*RPCClient, error) {
// Create the client
client := &RPCClient{
seq: 0,
conn: conn.(*net.TCPConn),
conn: conn,
reader: bufio.NewReader(conn),
writer: bufio.NewWriter(conn),
dispatch: make(map[uint64]seqHandler),

View File

@ -6,8 +6,11 @@ import (
"github.com/hashicorp/consul/testutil"
"github.com/hashicorp/serf/serf"
"io"
"io/ioutil"
"net"
"os"
"os/user"
"runtime"
"strings"
"testing"
"time"
@ -34,17 +37,22 @@ func testRPCClient(t *testing.T) *rpcParts {
}
func testRPCClientWithConfig(t *testing.T, cb func(c *Config)) *rpcParts {
l, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("err: %s", err)
}
lw := NewLogWriter(512)
mult := io.MultiWriter(os.Stderr, lw)
conf := nextConfig()
cb(conf)
rpcAddr, err := conf.ClientListener(conf.Addresses.RPC, conf.Ports.RPC)
if err != nil {
t.Fatalf("err: %s", err)
}
l, err := net.Listen(rpcAddr.Network(), rpcAddr.String())
if err != nil {
t.Fatalf("err: %s", err)
}
dir, agent := makeAgentLog(t, conf, mult)
rpc := NewAgentRPC(agent, l, mult, lw)
@ -208,6 +216,41 @@ func TestRPCClientStats(t *testing.T) {
}
}
func TestRPCClientStatsUnix(t *testing.T) {
if runtime.GOOS == "windows" {
t.SkipNow()
}
tempdir, err := ioutil.TempDir("", "consul-test-")
if err != nil {
t.Fatal("Could not create a working directory: ", err)
}
user, err := user.Current()
if err != nil {
t.Fatal("Could not get current user: ", err)
}
cb := func(c *Config) {
c.Addresses.RPC = "unix://" + tempdir + "/unix-rpc-test.sock;" + user.Uid + ";" + user.Gid + ";640"
}
p1 := testRPCClientWithConfig(t, cb)
stats, err := p1.client.Stats()
if err != nil {
t.Fatalf("err: %s", err)
}
if _, ok := stats["agent"]; !ok {
t.Fatalf("bad: %#v", stats)
}
if _, ok := stats["consul"]; !ok {
t.Fatalf("bad: %#v", stats)
}
}
func TestRPCClientLeave(t *testing.T) {
p1 := testRPCClient(t)
defer p1.Close()

View File

@ -8,8 +8,8 @@ import (
"github.com/hashicorp/consul/command/agent"
)
// RPCAddrEnvName defines the environment variable name, which can set
// a default RPC address in case there is no -rpc-addr specified.
// RPCAddrEnvName defines an environment variable name which sets
// an RPC address if there is no -rpc-addr specified.
const RPCAddrEnvName = "CONSUL_RPC_ADDR"
// RPCAddrFlag returns a pointer to a string that will be populated
@ -43,7 +43,12 @@ func HTTPClient(addr string) (*consulapi.Client, error) {
// HTTPClientDC returns a new Consul HTTP client with the given address and datacenter
func HTTPClientDC(addr, dc string) (*consulapi.Client, error) {
conf := consulapi.DefaultConfig()
conf.Address = addr
switch {
case len(os.Getenv("CONSUL_HTTP_ADDR")) > 0:
conf.Address = os.Getenv("CONSUL_HTTP_ADDR")
default:
conf.Address = addr
}
conf.Datacenter = dc
return consulapi.NewClient(conf)
}

View File

@ -239,8 +239,17 @@ definitions support being updated during a reload.
However, because the caches are not actively invalidated, ACL policy may be stale
up to the TTL value.
* `addresses` - This is a nested object that allows setting the bind address
for the following keys:
* `addresses` - This is a nested object that allows setting bind addresses. For `rpc`
and `http`, a Unix socket can be specified in the following form:
unix://[/path/to/socket];[username|uid];[gid];[mode]. The socket will be created
in the specified location with the given username or uid, gid, and mode. The
user Consul is running as must have appropriate permissions to change the socket
ownership to the given uid or gid. When running Consul agent commands against
Unix socket interfaces, use the `-rpc-addr` or `-http-addr` arguments to specify
the path to the socket, e.g. "unix://path/to/socket". You can also place the desired
values in `CONSUL_RPC_ADDR` and `CONSUL_HTTP_ADDR` environment variables. For TCP
addresses, these should be in the form ip:port.
The following keys are valid:
* `dns` - The DNS server. Defaults to `client_addr`
* `http` - The HTTP API. Defaults to `client_addr`
* `rpc` - The RPC endpoint. Defaults to `client_addr`