mirror of https://github.com/hashicorp/consul
133 lines
2.7 KiB
Go
133 lines
2.7 KiB
Go
|
package jws
|
||
|
|
||
|
import (
|
||
|
"bytes"
|
||
|
"encoding/json"
|
||
|
)
|
||
|
|
||
|
// Flat serializes the JWS to its "flattened" form per
|
||
|
// https://tools.ietf.org/html/rfc7515#section-7.2.2
|
||
|
func (j *jws) Flat(key interface{}) ([]byte, error) {
|
||
|
if len(j.sb) < 1 {
|
||
|
return nil, ErrNotEnoughMethods
|
||
|
}
|
||
|
if err := j.sign(key); err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
return json.Marshal(struct {
|
||
|
Payload rawBase64 `json:"payload"`
|
||
|
sigHead
|
||
|
}{
|
||
|
Payload: j.plcache,
|
||
|
sigHead: j.sb[0],
|
||
|
})
|
||
|
}
|
||
|
|
||
|
// General serializes the JWS into its "general" form per
|
||
|
// https://tools.ietf.org/html/rfc7515#section-7.2.1
|
||
|
//
|
||
|
// If only one key is passed it's used for all the provided
|
||
|
// crypto.SigningMethods. Otherwise, len(keys) must equal the number
|
||
|
// of crypto.SigningMethods added.
|
||
|
func (j *jws) General(keys ...interface{}) ([]byte, error) {
|
||
|
if err := j.sign(keys...); err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
return json.Marshal(struct {
|
||
|
Payload rawBase64 `json:"payload"`
|
||
|
Signatures []sigHead `json:"signatures"`
|
||
|
}{
|
||
|
Payload: j.plcache,
|
||
|
Signatures: j.sb,
|
||
|
})
|
||
|
}
|
||
|
|
||
|
// Compact serializes the JWS into its "compact" form per
|
||
|
// https://tools.ietf.org/html/rfc7515#section-7.1
|
||
|
func (j *jws) Compact(key interface{}) ([]byte, error) {
|
||
|
if len(j.sb) < 1 {
|
||
|
return nil, ErrNotEnoughMethods
|
||
|
}
|
||
|
|
||
|
if err := j.sign(key); err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
|
||
|
sig, err := j.sb[0].Signature.Base64()
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
return format(
|
||
|
j.sb[0].Protected,
|
||
|
j.plcache,
|
||
|
sig,
|
||
|
), nil
|
||
|
}
|
||
|
|
||
|
// sign signs each index of j's sb member.
|
||
|
func (j *jws) sign(keys ...interface{}) error {
|
||
|
if err := j.cache(); err != nil {
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
if len(keys) < 1 ||
|
||
|
len(keys) > 1 && len(keys) != len(j.sb) {
|
||
|
return ErrNotEnoughKeys
|
||
|
}
|
||
|
|
||
|
if len(keys) == 1 {
|
||
|
k := keys[0]
|
||
|
keys = make([]interface{}, len(j.sb))
|
||
|
for i := range keys {
|
||
|
keys[i] = k
|
||
|
}
|
||
|
}
|
||
|
|
||
|
for i := range j.sb {
|
||
|
if err := j.sb[i].cache(); err != nil {
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
raw := format(j.sb[i].Protected, j.plcache)
|
||
|
sig, err := j.sb[i].method.Sign(raw, keys[i])
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
j.sb[i].Signature = sig
|
||
|
}
|
||
|
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
// cache marshals the payload, but only if it's changed since the last cache.
|
||
|
func (j *jws) cache() (err error) {
|
||
|
if !j.clean {
|
||
|
j.plcache, err = j.payload.Base64()
|
||
|
j.clean = err == nil
|
||
|
}
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
// cache marshals the protected and unprotected headers, but only if
|
||
|
// they've changed since their last cache.
|
||
|
func (s *sigHead) cache() (err error) {
|
||
|
if !s.clean {
|
||
|
s.Protected, err = s.protected.Base64()
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
s.Unprotected, err = s.unprotected.Base64()
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
}
|
||
|
s.clean = true
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
// format formats a slice of bytes in the order given, joining
|
||
|
// them with a period.
|
||
|
func format(a ...[]byte) []byte {
|
||
|
return bytes.Join(a, []byte{'.'})
|
||
|
}
|