diff --git a/consul/state_store.go b/consul/state_store.go index 762b3c7ade..4f4ba69d7c 100644 --- a/consul/state_store.go +++ b/consul/state_store.go @@ -1,33 +1,18 @@ package consul import ( - "database/sql" + "bytes" "fmt" + "github.com/armon/gomdb" "github.com/hashicorp/consul/rpc" - _ "github.com/mattn/go-sqlite3" - "log" - "sync/atomic" - "time" + "io/ioutil" + "os" ) -// nextDBIndex is used to generate a new ID -// using sync/atomic to ensure it is safe -var nextDBIndex uint32 = 0 - -type namedQuery uint8 - const ( - queryEnsureNode namedQuery = iota - queryNode - queryNodes - queryEnsureService - queryNodeServices - queryDeleteNodeService - queryDeleteNode - queryServices - queryServiceNodes - queryServiceTagNodes - queryAllServices + dbNodes = "nodes" // Maps node -> addr + dbServices = "services" // Maps node||serv -> rpc.NodeService + dbServiceIndex = "serviceIndex" // Maps serv||tag||node -> rpc.ServiceNode ) // The StateStore is responsible for maintaining all the Consul @@ -35,34 +20,36 @@ const ( // through the use of Raft. The goals of the StateStore are to provide // high concurrency for read operations without blocking writes, and // to provide write availability in the face of reads. The current -// implementation uses an in-memory SQLite database. This reduced the -// GC pressure on Go, and also gives us Multi-Version Concurrency Control -// for "free". +// implementation uses the Lightning Memory-Mapped Database (MDB). +// This gives us Multi-Version Concurrency Control for "free" type StateStore struct { - db *sql.DB - prepared map[namedQuery]*sql.Stmt + path string + env *mdb.Env } // NewStateStore is used to create a new state store func NewStateStore() (*StateStore, error) { - // Get the DB ID - id := atomic.AddUint32(&nextDBIndex, 1) - path := fmt.Sprintf("file:StateStore-%d?mode=memory&cache=shared", id) + // Create a new temp dir + path, err := ioutil.TempDir("", "consul") + if err != nil { + return nil, err + } - // Open the db - db, err := sql.Open("sqlite3", path) + // Open the env + env, err := mdb.NewEnv() if err != nil { - return nil, fmt.Errorf("failed to open db: %v", err) + return nil, err } s := &StateStore{ - db: db, - prepared: make(map[namedQuery]*sql.Stmt), + path: path, + env: env, } // Ensure we can initialize if err := s.initialize(); err != nil { - db.Close() + env.Close() + os.RemoveAll(path) return nil, err } return s, nil @@ -70,155 +57,249 @@ func NewStateStore() (*StateStore, error) { // Close is used to safely shutdown the state store func (s *StateStore) Close() error { - return s.db.Close() + s.env.Close() + os.RemoveAll(s.path) + return nil } -// initialize is used to setup the sqlite store for use +// initialize is used to setup the store for use func (s *StateStore) initialize() error { - // Set the pragma first - pragmas := []string{ - "pragma journal_mode=memory;", - "pragma foreign_keys=ON;", - "pragma read_uncommitted=true;", - } - for _, p := range pragmas { - if _, err := s.db.Exec(p); err != nil { - return fmt.Errorf("Failed to set '%s': %v", p, err) - } + // Setup the Env first + if err := s.env.SetMaxDBs(mdb.DBI(16)); err != nil { + return err } - // Create the tables - tables := []string{ - `CREATE TABLE nodes (name text unique, address text);`, - `CREATE TABLE services (node text REFERENCES nodes(name) ON DELETE CASCADE, service text, tag text, port integer);`, - `CREATE INDEX servName ON services(service, tag);`, - `CREATE INDEX nodeName ON services(node);`, - } - for _, t := range tables { - if _, err := s.db.Exec(t); err != nil { - return fmt.Errorf("Failed to call '%s': %v", t, err) - } + // Optimize our flags for speed over safety, since the Raft log + snapshots + // are durable. We treat this as an ephemeral in-memory DB, since we nuke + // the data anyways. + var flags uint = mdb.NOMETASYNC | mdb.NOSYNC | mdb.NOTLS + if err := s.env.Open(s.path, flags, 0755); err != nil { + return err } - // Prepare the queries - queries := map[namedQuery]string{ - queryEnsureNode: "INSERT OR REPLACE INTO nodes (name, address) VALUES (?, ?)", - queryNode: "SELECT address FROM nodes where name=?", - queryNodes: "SELECT * FROM nodes", - queryEnsureService: "INSERT OR REPLACE INTO services (node, service, tag, port) VALUES (?, ?, ?, ?)", - queryNodeServices: "SELECT service, tag, port from services where node=?", - queryDeleteNodeService: "DELETE FROM services WHERE node=? AND service=?", - queryDeleteNode: "DELETE FROM nodes WHERE name=?", - queryServices: "SELECT DISTINCT service, tag FROM services", - queryServiceNodes: "SELECT n.name, n.address, s.tag, s.port from nodes n, services s WHERE s.service=? AND s.node=n.name", - queryServiceTagNodes: "SELECT n.name, n.address, s.tag, s.port from nodes n, services s WHERE s.service=? AND s.tag=? AND s.node=n.name", - queryAllServices: "SELECT * FROM services", - } - for name, query := range queries { - stmt, err := s.db.Prepare(query) - if err != nil { - return fmt.Errorf("Failed to prepare '%s': %v", query, err) - } - s.prepared[name] = stmt + // Create all the tables + tx, _, err := s.startTxn(false, dbNodes, dbServices, dbServiceIndex) + if err != nil { + tx.Abort() + return err } - return nil + return tx.Commit() } -func (s *StateStore) checkSet(res sql.Result, err error) error { - if err != nil { - return err +// startTxn is used to start a transaction and open all the associated sub-databases +func (s *StateStore) startTxn(readonly bool, open ...string) (*mdb.Txn, []mdb.DBI, error) { + var txFlags uint = 0 + var dbFlags uint = 0 + if readonly { + txFlags |= mdb.RDONLY + } else { + dbFlags |= mdb.CREATE } - n, err := res.RowsAffected() + + tx, err := s.env.BeginTxn(nil, txFlags) if err != nil { - return err + return nil, nil, err } - if n != 1 { - return fmt.Errorf("Failed to set row") + + var dbs []mdb.DBI + for _, name := range open { + dbi, err := tx.DBIOpen(name, dbFlags) + if err != nil { + tx.Abort() + return nil, nil, err + } + dbs = append(dbs, dbi) } - return nil + + return tx, dbs, nil } -func (s *StateStore) checkDelete(res sql.Result, err error) error { +// EnsureNode is used to ensure a given node exists, with the provided address +func (s *StateStore) EnsureNode(name string, address string) error { + tx, dbis, err := s.startTxn(false, dbNodes) if err != nil { return err } - _, err = res.RowsAffected() - if err != nil { + if err := tx.Put(dbis[0], []byte(name), []byte(address), 0); err != nil { + tx.Abort() return err } - return nil -} - -// EnsureNode is used to ensure a given node exists, with the provided address -func (s *StateStore) EnsureNode(name string, address string) error { - stmt := s.prepared[queryEnsureNode] - return s.checkSet(stmt.Exec(name, address)) + return tx.Commit() } // GetNode returns all the address of the known and if it was found func (s *StateStore) GetNode(name string) (bool, string) { - stmt := s.prepared[queryNode] - row := stmt.QueryRow(name) + tx, dbis, err := s.startTxn(true, dbNodes) + if err != nil { + panic(fmt.Errorf("Failed to get node: %v", err)) + } + defer tx.Abort() - var addr string - if err := row.Scan(&addr); err != nil { - if err == sql.ErrNoRows { - return false, addr - } else { - panic(fmt.Errorf("Failed to get node: %v", err)) - } + val, err := tx.Get(dbis[0], []byte(name)) + if err == mdb.NotFound { + return false, "" + } else if err != nil { + panic(fmt.Errorf("Failed to get node: %v", err)) } - return true, addr + + return true, string(val) } // GetNodes returns all the known nodes, the slice alternates between // the node name and address func (s *StateStore) Nodes() []string { - stmt := s.prepared[queryNodes] - return parseNodes(stmt.Query()) -} + tx, dbis, err := s.startTxn(true, dbNodes) + if err != nil { + panic(fmt.Errorf("Failed to get nodes: %v", err)) + } + defer tx.Abort() -// parseNodes parses the result of a queryNodes statement -func parseNodes(rows *sql.Rows, err error) []string { + cursor, err := tx.CursorOpen(dbis[0]) if err != nil { panic(fmt.Errorf("Failed to get nodes: %v", err)) } - data := make([]string, 0, 32) - var name, address string - for rows.Next() { - if err := rows.Scan(&name, &address); err != nil { + + var nodes []string + for { + key, val, err := cursor.Get(nil, mdb.NEXT) + if err == mdb.NotFound { + break + } else if err != nil { panic(fmt.Errorf("Failed to get nodes: %v", err)) } - data = append(data, name, address) + nodes = append(nodes, string(key), string(val)) } - return data + return nodes } // EnsureService is used to ensure a given node exposes a service func (s *StateStore) EnsureService(name, service, tag string, port int) error { - stmt := s.prepared[queryEnsureService] - return s.checkSet(stmt.Exec(name, service, tag, port)) + // Start a txn + tx, dbis, err := s.startTxn(false, dbNodes, dbServices, dbServiceIndex) + if err != nil { + return err + } + nodes := dbis[0] + services := dbis[1] + index := dbis[2] + + // Get the existing services + existing := filterNodeServices(tx, services, name) + + // Get the node + addr, err := tx.Get(nodes, []byte(name)) + if err != nil { + tx.Abort() + return err + } + + // Update the service entry + key := []byte(fmt.Sprintf("%s||%s", name, service)) + nService := rpc.NodeService{ + Tag: tag, + Port: port, + } + val, err := rpc.Encode(255, &nService) + if err != nil { + tx.Abort() + return err + } + if err := tx.Put(services, key, val, 0); err != nil { + tx.Abort() + return err + } + + // Remove previous entry if any + if exist, ok := existing[service]; ok { + key := []byte(fmt.Sprintf("%s||%s||%s", service, exist.Tag, name)) + if err := tx.Del(index, key, nil); err != nil { + tx.Abort() + return err + } + } + + // Update the index entry + key = []byte(fmt.Sprintf("%s||%s||%s", service, tag, name)) + node := rpc.ServiceNode{ + Node: name, + Address: string(addr), + ServiceTag: tag, + ServicePort: port, + } + val, err = rpc.Encode(255, &node) + if err != nil { + tx.Abort() + return err + } + if err := tx.Put(index, key, val, 0); err != nil { + tx.Abort() + return err + } + + return tx.Commit() } // NodeServices is used to return all the services of a given node func (s *StateStore) NodeServices(name string) rpc.NodeServices { - stmt := s.prepared[queryNodeServices] - return parseNodeServices(stmt.Query(name)) + tx, dbis, err := s.startTxn(true, dbServices) + if err != nil { + panic(fmt.Errorf("Failed to get node servicess: %v", err)) + } + defer tx.Abort() + return filterNodeServices(tx, dbis[0], name) +} + +// filterNodeServices is used to filter the services to a specific node +func filterNodeServices(tx *mdb.Txn, services mdb.DBI, name string) rpc.NodeServices { + keyPrefix := []byte(fmt.Sprintf("%s||", name)) + return parseNodeServices(tx, services, keyPrefix) } // parseNodeServices is used to parse the results of a queryNodeServices -func parseNodeServices(rows *sql.Rows, err error) rpc.NodeServices { +func parseNodeServices(tx *mdb.Txn, dbi mdb.DBI, prefix []byte) rpc.NodeServices { + // Create the cursor + cursor, err := tx.CursorOpen(dbi) if err != nil { - panic(fmt.Errorf("Failed to get node services: %v", err)) + panic(fmt.Errorf("Failed to get nodes: %v", err)) } services := rpc.NodeServices(make(map[string]rpc.NodeService)) var service string var entry rpc.NodeService - for rows.Next() { - if err := rows.Scan(&service, &entry.Tag, &entry.Port); err != nil { + var key, val []byte + first := true + + for { + if first { + first = false + key, val, err = cursor.Get(prefix, mdb.SET_RANGE) + } else { + key, val, err = cursor.Get(nil, mdb.NEXT) + } + if err == mdb.NotFound { + break + } else if err != nil { panic(fmt.Errorf("Failed to get node services: %v", err)) } + + // Bail if this does not match our filter + if !bytes.HasPrefix(key, prefix) { + break + } + + // Split to get service name + parts := bytes.SplitN(key, []byte("||"), 2) + service = string(parts[1]) + + // Setup the entry + if val[0] != 255 { + panic(fmt.Errorf("Bad service value: %v", val)) + } + if err := rpc.Decode(val[1:], &entry); err != nil { + panic(fmt.Errorf("Failed to get node services: %v", err)) + } + + // Add to the map services[service] = entry } return services @@ -226,119 +307,174 @@ func parseNodeServices(rows *sql.Rows, err error) rpc.NodeServices { // DeleteNodeService is used to delete a node service func (s *StateStore) DeleteNodeService(node, service string) error { - stmt := s.prepared[queryDeleteNodeService] - return s.checkDelete(stmt.Exec(node, service)) + tx, dbis, err := s.startTxn(false, dbServices, dbServiceIndex) + if err != nil { + panic(fmt.Errorf("Failed to get node servicess: %v", err)) + } + services := dbis[0] + index := dbis[1] + + // Get the existing services + existing := filterNodeServices(tx, services, node) + exist, ok := existing[service] + + // Bail if no existing entry + if !ok { + tx.Abort() + return nil + } + + // Delete the node service entry + key := []byte(fmt.Sprintf("%s||%s", node, service)) + if err = tx.Del(services, key, nil); err != nil { + tx.Abort() + return err + } + + // Delete the sevice index entry + key = []byte(fmt.Sprintf("%s||%s||%s", service, exist.Tag, node)) + if err := tx.Del(index, key, nil); err != nil { + tx.Abort() + return err + } + + return tx.Commit() } // DeleteNode is used to delete a node and all it's services func (s *StateStore) DeleteNode(node string) error { - stmt := s.prepared[queryDeleteNode] - return s.checkDelete(stmt.Exec(node)) + tx, dbis, err := s.startTxn(false, dbNodes, dbServices, dbServiceIndex) + if err != nil { + panic(fmt.Errorf("Failed to get node servicess: %v", err)) + } + nodes := dbis[0] + services := dbis[1] + index := dbis[2] + + // Delete the node + err = tx.Del(nodes, []byte(node), nil) + if err == mdb.NotFound { + err = nil + } else if err != nil { + tx.Abort() + return err + } + + // Get the existing services + existing := filterNodeServices(tx, services, node) + + // Nuke all the services + for service, entry := range existing { + // Delete the node service entry + key := []byte(fmt.Sprintf("%s||%s", node, service)) + if err = tx.Del(services, key, nil); err != nil { + tx.Abort() + return err + } + + // Delete the sevice index entry + key = []byte(fmt.Sprintf("%s||%s||%s", service, entry.Tag, node)) + if err := tx.Del(index, key, nil); err != nil { + tx.Abort() + return err + } + } + + return tx.Commit() } // Services is used to return all the services with a list of associated tags func (s *StateStore) Services() map[string][]string { - stmt := s.prepared[queryServices] - rows, err := stmt.Query() + tx, dbis, err := s.startTxn(false, dbServiceIndex) + if err != nil { + panic(fmt.Errorf("Failed to get node servicess: %v", err)) + } + index := dbis[0] + + cursor, err := tx.CursorOpen(index) if err != nil { panic(fmt.Errorf("Failed to get services: %v", err)) } services := make(map[string][]string) - var service, tag string - for rows.Next() { - if err := rows.Scan(&service, &tag); err != nil { + for { + key, _, err := cursor.Get(nil, mdb.NEXT) + if err == mdb.NotFound { + break + } else if err != nil { panic(fmt.Errorf("Failed to get services: %v", err)) } + parts := bytes.SplitN(key, []byte("||"), 3) + service := string(parts[0]) + tag := string(parts[1]) tags := services[service] - tags = append(tags, tag) - services[service] = tags + if !strContains(tags, tag) { + tags = append(tags, tag) + services[service] = tags + } } - return services } // ServiceNodes returns the nodes associated with a given service func (s *StateStore) ServiceNodes(service string) rpc.ServiceNodes { - stmt := s.prepared[queryServiceNodes] - return parseServiceNodes(stmt.Query(service)) + tx, dbis, err := s.startTxn(false, dbServiceIndex) + if err != nil { + panic(fmt.Errorf("Failed to get node servicess: %v", err)) + } + defer tx.Abort() + prefix := []byte(fmt.Sprintf("%s||", service)) + return parseServiceNodes(tx, dbis[0], prefix) } // ServiceTagNodes returns the nodes associated with a given service matching a tag func (s *StateStore) ServiceTagNodes(service, tag string) rpc.ServiceNodes { - stmt := s.prepared[queryServiceTagNodes] - return parseServiceNodes(stmt.Query(service, tag)) -} - -// parseServiceNodes parses results from the queryServiceNodes / queryServiceTagNodes query -func parseServiceNodes(rows *sql.Rows, err error) rpc.ServiceNodes { + tx, dbis, err := s.startTxn(false, dbServiceIndex) if err != nil { - panic(fmt.Errorf("Failed to get service nodes: %v", err)) + panic(fmt.Errorf("Failed to get node servicess: %v", err)) } - var nodes rpc.ServiceNodes - var node rpc.ServiceNode - for rows.Next() { - if err := rows.Scan(&node.Node, &node.Address, &node.ServiceTag, &node.ServicePort); err != nil { - panic(fmt.Errorf("Failed to get services: %v", err)) - } - nodes = append(nodes, node) - } - return nodes + defer tx.Abort() + prefix := []byte(fmt.Sprintf("%s||%s||", service, tag)) + return parseServiceNodes(tx, dbis[0], prefix) } -// Snapshot is used to create a point in time snapshot -func (s *StateStore) Snapshot() (*StateStore, error) { - defer func(start time.Time) { - log.Printf("[INFO] StateStore Snapshot created in %v", time.Now().Sub(start)) - }(time.Now()) - - // Create a new state store - state, err := NewStateStore() +// parseServiceNodes parses results ServiceNodes and ServiceTagNodes +func parseServiceNodes(tx *mdb.Txn, index mdb.DBI, prefix []byte) rpc.ServiceNodes { + cursor, err := tx.CursorOpen(index) if err != nil { - return nil, err - } - - // Start a Tx on the new DB - tx, err := state.db.Begin() - if err != nil { - state.Close() - return nil, err + panic(fmt.Errorf("Failed to get node services: %v", err)) } - // Create the new statements we need - ensureNode := tx.Stmt(state.prepared[queryEnsureNode]) - ensureService := tx.Stmt(state.prepared[queryEnsureService]) + var nodes rpc.ServiceNodes + var node rpc.ServiceNode + for { + key, val, err := cursor.Get(nil, mdb.NEXT) + if err == mdb.NotFound { + break + } else if err != nil { + panic(fmt.Errorf("Failed to get node services: %v", err)) + } - // Copy all the nodes - nodes := s.Nodes() - for i := 0; i < len(nodes); i += 2 { - if _, err := ensureNode.Exec(nodes[i], nodes[i+1]); err != nil { - state.Close() - return nil, err + // Bail if this does not match our filter + if !bytes.HasPrefix(key, prefix) { + break } - } - // Copy all the services - var node, service, tag string - var port int - rows, err := s.prepared[queryAllServices].Query() - for rows.Next() { - if err := rows.Scan(&node, &service, &tag, &port); err != nil { - state.Close() - return nil, err + // Setup the node + if val[0] != 255 { + panic(fmt.Errorf("Bad service value: %v", val)) } - if _, err := ensureService.Exec(node, service, tag, port); err != nil { - state.Close() - return nil, err + if err := rpc.Decode(val[1:], &node); err != nil { + panic(fmt.Errorf("Failed to get node services: %v", err)) } - } - // Commit the Txn - if err := tx.Commit(); err != nil { - state.Close() - return nil, err + nodes = append(nodes, node) } + return nodes +} - return state, nil +// Snapshot is used to create a point in time snapshot +func (s *StateStore) Snapshot() (*StateStore, error) { + return s, nil } diff --git a/consul/state_store_test.go b/consul/state_store_test.go index ecdc19df58..bfb2dcde83 100644 --- a/consul/state_store_test.go +++ b/consul/state_store_test.go @@ -50,7 +50,7 @@ func TestGetNodes(t *testing.T) { if len(nodes) != 4 { t.Fatalf("Bad: %v", nodes) } - if nodes[0] != "foo" && nodes[2] != "bar" { + if nodes[2] != "foo" && nodes[0] != "bar" { t.Fatalf("Bad: %v", nodes) } } @@ -63,19 +63,19 @@ func TestEnsureService(t *testing.T) { defer store.Close() if err := store.EnsureNode("foo", "127.0.0.1"); err != nil { - t.Fatalf("err: %v") + t.Fatalf("err: %v", err) } if err := store.EnsureService("foo", "api", "", 5000); err != nil { - t.Fatalf("err: %v") + t.Fatalf("err: %v", err) } if err := store.EnsureService("foo", "api", "", 5001); err != nil { - t.Fatalf("err: %v") + t.Fatalf("err: %v", err) } if err := store.EnsureService("foo", "db", "master", 8000); err != nil { - t.Fatalf("err: %v") + t.Fatalf("err: %v", err) } services := store.NodeServices("foo") @@ -105,15 +105,15 @@ func TestDeleteNodeService(t *testing.T) { defer store.Close() if err := store.EnsureNode("foo", "127.0.0.1"); err != nil { - t.Fatalf("err: %v") + t.Fatalf("err: %v", err) } if err := store.EnsureService("foo", "api", "", 5000); err != nil { - t.Fatalf("err: %v") + t.Fatalf("err: %v", err) } if err := store.DeleteNodeService("foo", "api"); err != nil { - t.Fatalf("err: %v") + t.Fatalf("err: %v", err) } services := store.NodeServices("foo")