mirror of https://github.com/hashicorp/consul
143 lines
3.7 KiB
Go
143 lines
3.7 KiB
Go
package sasl
|
|
|
|
// #include "sasl_windows.h"
|
|
import "C"
|
|
|
|
import (
|
|
"fmt"
|
|
"strings"
|
|
"sync"
|
|
"unsafe"
|
|
)
|
|
|
|
type saslStepper interface {
|
|
Step(serverData []byte) (clientData []byte, done bool, err error)
|
|
Close()
|
|
}
|
|
|
|
type saslSession struct {
|
|
// Credentials
|
|
mech string
|
|
service string
|
|
host string
|
|
userPlusRealm string
|
|
target string
|
|
domain string
|
|
|
|
// Internal state
|
|
authComplete bool
|
|
errored bool
|
|
step int
|
|
|
|
// C internal state
|
|
credHandle C.CredHandle
|
|
context C.CtxtHandle
|
|
hasContext C.int
|
|
|
|
// Keep track of pointers we need to explicitly free
|
|
stringsToFree []*C.char
|
|
}
|
|
|
|
var initError error
|
|
var initOnce sync.Once
|
|
|
|
func initSSPI() {
|
|
rc := C.load_secur32_dll()
|
|
if rc != 0 {
|
|
initError = fmt.Errorf("Error loading libraries: %v", rc)
|
|
}
|
|
}
|
|
|
|
func New(username, password, mechanism, service, host string) (saslStepper, error) {
|
|
initOnce.Do(initSSPI)
|
|
ss := &saslSession{mech: mechanism, hasContext: 0, userPlusRealm: username}
|
|
if service == "" {
|
|
service = "mongodb"
|
|
}
|
|
if i := strings.Index(host, ":"); i >= 0 {
|
|
host = host[:i]
|
|
}
|
|
ss.service = service
|
|
ss.host = host
|
|
|
|
usernameComponents := strings.Split(username, "@")
|
|
if len(usernameComponents) < 2 {
|
|
return nil, fmt.Errorf("Username '%v' doesn't contain a realm!", username)
|
|
}
|
|
user := usernameComponents[0]
|
|
ss.domain = usernameComponents[1]
|
|
ss.target = fmt.Sprintf("%s/%s", ss.service, ss.host)
|
|
|
|
var status C.SECURITY_STATUS
|
|
// Step 0: call AcquireCredentialsHandle to get a nice SSPI CredHandle
|
|
if len(password) > 0 {
|
|
status = C.sspi_acquire_credentials_handle(&ss.credHandle, ss.cstr(user), ss.cstr(password), ss.cstr(ss.domain))
|
|
} else {
|
|
status = C.sspi_acquire_credentials_handle(&ss.credHandle, ss.cstr(user), nil, ss.cstr(ss.domain))
|
|
}
|
|
if status != C.SEC_E_OK {
|
|
ss.errored = true
|
|
return nil, fmt.Errorf("Couldn't create new SSPI client, error code %v", status)
|
|
}
|
|
return ss, nil
|
|
}
|
|
|
|
func (ss *saslSession) cstr(s string) *C.char {
|
|
cstr := C.CString(s)
|
|
ss.stringsToFree = append(ss.stringsToFree, cstr)
|
|
return cstr
|
|
}
|
|
|
|
func (ss *saslSession) Close() {
|
|
for _, cstr := range ss.stringsToFree {
|
|
C.free(unsafe.Pointer(cstr))
|
|
}
|
|
}
|
|
|
|
func (ss *saslSession) Step(serverData []byte) (clientData []byte, done bool, err error) {
|
|
ss.step++
|
|
if ss.step > 10 {
|
|
return nil, false, fmt.Errorf("too many SSPI steps without authentication")
|
|
}
|
|
var buffer C.PVOID
|
|
var bufferLength C.ULONG
|
|
var outBuffer C.PVOID
|
|
var outBufferLength C.ULONG
|
|
if len(serverData) > 0 {
|
|
buffer = (C.PVOID)(unsafe.Pointer(&serverData[0]))
|
|
bufferLength = C.ULONG(len(serverData))
|
|
}
|
|
var status C.int
|
|
if ss.authComplete {
|
|
// Step 3: last bit of magic to use the correct server credentials
|
|
status = C.sspi_send_client_authz_id(&ss.context, &outBuffer, &outBufferLength, ss.cstr(ss.userPlusRealm))
|
|
} else {
|
|
// Step 1 + Step 2: set up security context with the server and TGT
|
|
status = C.sspi_step(&ss.credHandle, ss.hasContext, &ss.context, buffer, bufferLength, &outBuffer, &outBufferLength, ss.cstr(ss.target))
|
|
}
|
|
if outBuffer != C.PVOID(nil) {
|
|
defer C.free(unsafe.Pointer(outBuffer))
|
|
}
|
|
if status != C.SEC_E_OK && status != C.SEC_I_CONTINUE_NEEDED {
|
|
ss.errored = true
|
|
return nil, false, ss.handleSSPIErrorCode(status)
|
|
}
|
|
|
|
clientData = C.GoBytes(unsafe.Pointer(outBuffer), C.int(outBufferLength))
|
|
if status == C.SEC_E_OK {
|
|
ss.authComplete = true
|
|
return clientData, true, nil
|
|
} else {
|
|
ss.hasContext = 1
|
|
return clientData, false, nil
|
|
}
|
|
}
|
|
|
|
func (ss *saslSession) handleSSPIErrorCode(code C.int) error {
|
|
switch {
|
|
case code == C.SEC_E_TARGET_UNKNOWN:
|
|
return fmt.Errorf("Target %v@%v not found", ss.target, ss.domain)
|
|
}
|
|
return fmt.Errorf("Unknown error doing step %v, error code %v", ss.step, code)
|
|
}
|