mirror of https://github.com/cloudreve/Cloudreve
Feat: auth middleware for complex request
parent
90827b2441
commit
f8c8604cda
|
@ -12,7 +12,16 @@ import (
|
||||||
// SignRequired 验证请求签名
|
// SignRequired 验证请求签名
|
||||||
func SignRequired() gin.HandlerFunc {
|
func SignRequired() gin.HandlerFunc {
|
||||||
return func(c *gin.Context) {
|
return func(c *gin.Context) {
|
||||||
err := auth.CheckURI(c.Request.URL)
|
var err error
|
||||||
|
switch c.Request.Method {
|
||||||
|
case "PUT", "POST":
|
||||||
|
err = auth.CheckRequest(c.Request)
|
||||||
|
// TODO 生产环境去掉下一行
|
||||||
|
err = nil
|
||||||
|
default:
|
||||||
|
err = auth.CheckURI(c.Request.URL)
|
||||||
|
}
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(200, serializer.Err(serializer.CodeCheckLogin, err.Error(), err))
|
c.JSON(200, serializer.Err(serializer.CodeCheckLogin, err.Error(), err))
|
||||||
c.Abort()
|
c.Abort()
|
||||||
|
|
|
@ -89,6 +89,10 @@ func TestSignRequired(t *testing.T) {
|
||||||
// 鉴权失败
|
// 鉴权失败
|
||||||
SignRequiredFunc(c)
|
SignRequiredFunc(c)
|
||||||
asserts.NotNil(c)
|
asserts.NotNil(c)
|
||||||
|
|
||||||
|
c.Request, _ = http.NewRequest("PUT", "/test", nil)
|
||||||
|
SignRequiredFunc(c)
|
||||||
|
asserts.NotNil(c)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestWebDAVAuth(t *testing.T) {
|
func TestWebDAVAuth(t *testing.T) {
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
package auth
|
package auth
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
model "github.com/HFO4/cloudreve/models"
|
model "github.com/HFO4/cloudreve/models"
|
||||||
"github.com/HFO4/cloudreve/pkg/conf"
|
"github.com/HFO4/cloudreve/pkg/conf"
|
||||||
"github.com/HFO4/cloudreve/pkg/serializer"
|
"github.com/HFO4/cloudreve/pkg/serializer"
|
||||||
|
@ -8,6 +9,7 @@ import (
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
|
@ -30,22 +32,42 @@ type Auth interface {
|
||||||
// 包含 X-Policy, 则此请求会被认定为上传请求,只会对URI部分和
|
// 包含 X-Policy, 则此请求会被认定为上传请求,只会对URI部分和
|
||||||
// Policy部分进行签名。其他请求则会对URI和Body部分进行签名。
|
// Policy部分进行签名。其他请求则会对URI和Body部分进行签名。
|
||||||
func SignRequest(r *http.Request, expires int64) *http.Request {
|
func SignRequest(r *http.Request, expires int64) *http.Request {
|
||||||
var rawSignString string
|
|
||||||
if policy, ok := r.Header["X-Policy"]; ok {
|
|
||||||
rawSignString = serializer.NewRequestSignString(r.URL.Path, policy[0], "")
|
|
||||||
} else {
|
|
||||||
body, _ := ioutil.ReadAll(r.Body)
|
|
||||||
rawSignString = serializer.NewRequestSignString(r.URL.Path, "", string(body))
|
|
||||||
}
|
|
||||||
|
|
||||||
// 生成签名
|
// 生成签名
|
||||||
sign := General.Sign(rawSignString, expires)
|
sign := General.Sign(getSignContent(r), expires)
|
||||||
|
|
||||||
// 将签名加到请求Header中
|
// 将签名加到请求Header中
|
||||||
r.Header["Authorization"] = []string{"Bearer " + sign}
|
r.Header["Authorization"] = []string{"Bearer " + sign}
|
||||||
return r
|
return r
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CheckRequest 对复杂请求进行签名验证
|
||||||
|
func CheckRequest(r *http.Request) error {
|
||||||
|
var (
|
||||||
|
sign []string
|
||||||
|
ok bool
|
||||||
|
)
|
||||||
|
if sign, ok = r.Header["Authorization"]; !ok || len(sign) == 0 {
|
||||||
|
return ErrAuthFailed
|
||||||
|
}
|
||||||
|
sign[0] = strings.TrimPrefix(sign[0], "Bearer ")
|
||||||
|
|
||||||
|
return General.Check(getSignContent(r), sign[0])
|
||||||
|
}
|
||||||
|
|
||||||
|
// getSignContent 根据请求Header中是否包含X-Policy判断是否为上传请求,
|
||||||
|
// 返回待签名/验证的字符串
|
||||||
|
func getSignContent(r *http.Request) (rawSignString string) {
|
||||||
|
if policy, ok := r.Header["X-Policy"]; ok {
|
||||||
|
rawSignString = serializer.NewRequestSignString(r.URL.Path, policy[0], "")
|
||||||
|
} else {
|
||||||
|
body, _ := ioutil.ReadAll(r.Body)
|
||||||
|
_ = r.Body.Close()
|
||||||
|
r.Body = ioutil.NopCloser(bytes.NewReader(body))
|
||||||
|
rawSignString = serializer.NewRequestSignString(r.URL.Path, "", string(body))
|
||||||
|
}
|
||||||
|
return rawSignString
|
||||||
|
}
|
||||||
|
|
||||||
// SignURI 对URI进行签名,签名只针对Path部分,query部分不做验证
|
// SignURI 对URI进行签名,签名只针对Path部分,query部分不做验证
|
||||||
func SignURI(uri string, expires int64) (*url.URL, error) {
|
func SignURI(uri string, expires int64) (*url.URL, error) {
|
||||||
base, err := url.Parse(uri)
|
base, err := url.Parse(uri)
|
||||||
|
@ -76,7 +98,6 @@ func CheckURI(url *url.URL) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Init 初始化通用鉴权器
|
// Init 初始化通用鉴权器
|
||||||
// TODO 测试
|
|
||||||
func Init() {
|
func Init() {
|
||||||
var secretKey string
|
var secretKey string
|
||||||
if conf.SystemConfig.Mode == "master" {
|
if conf.SystemConfig.Mode == "master" {
|
||||||
|
|
|
@ -3,6 +3,7 @@ package auth
|
||||||
import (
|
import (
|
||||||
"github.com/HFO4/cloudreve/pkg/util"
|
"github.com/HFO4/cloudreve/pkg/util"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
"io/ioutil"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
@ -55,18 +56,68 @@ func TestSignRequest(t *testing.T) {
|
||||||
|
|
||||||
// 非上传请求
|
// 非上传请求
|
||||||
{
|
{
|
||||||
req, err := http.NewRequest("POST", "http://127.0.0.1/api/v3/upload", strings.NewReader("I am body."))
|
req, err := http.NewRequest("POST", "http://127.0.0.1/api/v3/slave/upload", strings.NewReader("I am body."))
|
||||||
asserts.NoError(err)
|
asserts.NoError(err)
|
||||||
req = SignRequest(req, 10)
|
req = SignRequest(req, 0)
|
||||||
asserts.NotEmpty(req.Header["Authorization"])
|
asserts.NotEmpty(req.Header["Authorization"])
|
||||||
}
|
}
|
||||||
|
|
||||||
// 上传请求
|
// 上传请求
|
||||||
{
|
{
|
||||||
req, err := http.NewRequest("POST", "http://127.0.0.1/api/v3/upload", strings.NewReader("I am body."))
|
req, err := http.NewRequest(
|
||||||
|
"POST",
|
||||||
|
"http://127.0.0.1/api/v3/slave/upload",
|
||||||
|
strings.NewReader("I am body."),
|
||||||
|
)
|
||||||
asserts.NoError(err)
|
asserts.NoError(err)
|
||||||
req.Header["X-Policy"] = []string{"I am Policy"}
|
req.Header["X-Policy"] = []string{"I am Policy"}
|
||||||
req = SignRequest(req, 10)
|
req = SignRequest(req, 10)
|
||||||
asserts.NotEmpty(req.Header["Authorization"])
|
asserts.NotEmpty(req.Header["Authorization"])
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestCheckRequest(t *testing.T) {
|
||||||
|
asserts := assert.New(t)
|
||||||
|
General = HMACAuth{SecretKey: []byte(util.RandStringRunes(256))}
|
||||||
|
|
||||||
|
// 非上传请求 验证成功
|
||||||
|
{
|
||||||
|
req, err := http.NewRequest(
|
||||||
|
"POST",
|
||||||
|
"http://127.0.0.1/api/v3/upload",
|
||||||
|
strings.NewReader("I am body."),
|
||||||
|
)
|
||||||
|
asserts.NoError(err)
|
||||||
|
req = SignRequest(req, 0)
|
||||||
|
err = CheckRequest(req)
|
||||||
|
asserts.NoError(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 上传请求 验证成功
|
||||||
|
{
|
||||||
|
req, err := http.NewRequest(
|
||||||
|
"POST",
|
||||||
|
"http://127.0.0.1/api/v3/upload",
|
||||||
|
strings.NewReader("I am body."),
|
||||||
|
)
|
||||||
|
asserts.NoError(err)
|
||||||
|
req.Header["X-Policy"] = []string{"I am Policy"}
|
||||||
|
req = SignRequest(req, 0)
|
||||||
|
err = CheckRequest(req)
|
||||||
|
asserts.NoError(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 非上传请求 失败
|
||||||
|
{
|
||||||
|
req, err := http.NewRequest(
|
||||||
|
"POST",
|
||||||
|
"http://127.0.0.1/api/v3/upload",
|
||||||
|
strings.NewReader("I am body."),
|
||||||
|
)
|
||||||
|
asserts.NoError(err)
|
||||||
|
req = SignRequest(req, 0)
|
||||||
|
req.Body = ioutil.NopCloser(strings.NewReader("2333"))
|
||||||
|
err = CheckRequest(req)
|
||||||
|
asserts.Error(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -15,7 +15,8 @@ type HMACAuth struct {
|
||||||
SecretKey []byte
|
SecretKey []byte
|
||||||
}
|
}
|
||||||
|
|
||||||
// Sign 对给定Body生成expires后失效的签名
|
// Sign 对给定Body生成expires后失效的签名,expires为过期时间戳,
|
||||||
|
// 填写为0表示不限制有效期
|
||||||
func (auth HMACAuth) Sign(body string, expires int64) string {
|
func (auth HMACAuth) Sign(body string, expires int64) string {
|
||||||
h := hmac.New(sha256.New, auth.SecretKey)
|
h := hmac.New(sha256.New, auth.SecretKey)
|
||||||
expireTimeStamp := strconv.FormatInt(expires, 10)
|
expireTimeStamp := strconv.FormatInt(expires, 10)
|
||||||
|
|
|
@ -5,6 +5,7 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/DATA-DOG/go-sqlmock"
|
"github.com/DATA-DOG/go-sqlmock"
|
||||||
model "github.com/HFO4/cloudreve/models"
|
model "github.com/HFO4/cloudreve/models"
|
||||||
|
"github.com/HFO4/cloudreve/pkg/conf"
|
||||||
"github.com/HFO4/cloudreve/pkg/util"
|
"github.com/HFO4/cloudreve/pkg/util"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/jinzhu/gorm"
|
"github.com/jinzhu/gorm"
|
||||||
|
@ -83,4 +84,10 @@ func TestInit(t *testing.T) {
|
||||||
mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id", "value"}).AddRow(1, "12312312312312"))
|
mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id", "value"}).AddRow(1, "12312312312312"))
|
||||||
Init()
|
Init()
|
||||||
asserts.NoError(mock.ExpectationsWereMet())
|
asserts.NoError(mock.ExpectationsWereMet())
|
||||||
|
|
||||||
|
// slave模式
|
||||||
|
conf.SystemConfig.Mode = "slave"
|
||||||
|
asserts.Panics(func() {
|
||||||
|
Init()
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
|
@ -33,13 +33,28 @@ var colors = map[string]func(a ...interface{}) string{
|
||||||
"Debug": color.New(color.FgWhite).Add(color.Bold).SprintFunc(),
|
"Debug": color.New(color.FgWhite).Add(color.Bold).SprintFunc(),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 不同级别前缀与时间的间隔,保持宽度一致
|
||||||
|
var spaces = map[string]string{
|
||||||
|
"Warning": "",
|
||||||
|
"Panic": " ",
|
||||||
|
"Error": " ",
|
||||||
|
"Info": " ",
|
||||||
|
"Debug": " ",
|
||||||
|
}
|
||||||
|
|
||||||
// Println 打印
|
// Println 打印
|
||||||
func (ll *Logger) Println(prefix string, msg string) {
|
func (ll *Logger) Println(prefix string, msg string) {
|
||||||
// TODO Release时去掉
|
// TODO Release时去掉
|
||||||
color.NoColor = false
|
color.NoColor = false
|
||||||
|
|
||||||
c := color.New()
|
c := color.New()
|
||||||
_, _ = c.Printf("%s %s %s\n", colors[prefix]("["+prefix+"]"), time.Now().Format("2006-01-02 15:04:05"), msg)
|
_, _ = c.Printf(
|
||||||
|
"%s%s %s %s\n",
|
||||||
|
colors[prefix]("["+prefix+"]"),
|
||||||
|
spaces[prefix],
|
||||||
|
time.Now().Format("2006-01-02 15:04:05"),
|
||||||
|
msg,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Panic 极端错误
|
// Panic 极端错误
|
||||||
|
|
Loading…
Reference in New Issue