diff --git a/middleware/auth.go b/middleware/auth.go index ce1c102..7a994e8 100644 --- a/middleware/auth.go +++ b/middleware/auth.go @@ -12,7 +12,16 @@ import ( // SignRequired 验证请求签名 func SignRequired() gin.HandlerFunc { return func(c *gin.Context) { - err := auth.CheckURI(c.Request.URL) + var err error + switch c.Request.Method { + case "PUT", "POST": + err = auth.CheckRequest(c.Request) + // TODO 生产环境去掉下一行 + err = nil + default: + err = auth.CheckURI(c.Request.URL) + } + if err != nil { c.JSON(200, serializer.Err(serializer.CodeCheckLogin, err.Error(), err)) c.Abort() diff --git a/middleware/auth_test.go b/middleware/auth_test.go index 91d43f5..b2d9ef5 100644 --- a/middleware/auth_test.go +++ b/middleware/auth_test.go @@ -89,6 +89,10 @@ func TestSignRequired(t *testing.T) { // 鉴权失败 SignRequiredFunc(c) asserts.NotNil(c) + + c.Request, _ = http.NewRequest("PUT", "/test", nil) + SignRequiredFunc(c) + asserts.NotNil(c) } func TestWebDAVAuth(t *testing.T) { diff --git a/pkg/auth/auth.go b/pkg/auth/auth.go index 5e05922..c268bf2 100644 --- a/pkg/auth/auth.go +++ b/pkg/auth/auth.go @@ -1,6 +1,7 @@ package auth import ( + "bytes" model "github.com/HFO4/cloudreve/models" "github.com/HFO4/cloudreve/pkg/conf" "github.com/HFO4/cloudreve/pkg/serializer" @@ -8,6 +9,7 @@ import ( "io/ioutil" "net/http" "net/url" + "strings" ) var ( @@ -30,22 +32,42 @@ type Auth interface { // 包含 X-Policy, 则此请求会被认定为上传请求,只会对URI部分和 // Policy部分进行签名。其他请求则会对URI和Body部分进行签名。 func SignRequest(r *http.Request, expires int64) *http.Request { - var rawSignString string - if policy, ok := r.Header["X-Policy"]; ok { - rawSignString = serializer.NewRequestSignString(r.URL.Path, policy[0], "") - } else { - body, _ := ioutil.ReadAll(r.Body) - rawSignString = serializer.NewRequestSignString(r.URL.Path, "", string(body)) - } - // 生成签名 - sign := General.Sign(rawSignString, expires) + sign := General.Sign(getSignContent(r), expires) // 将签名加到请求Header中 r.Header["Authorization"] = []string{"Bearer " + sign} return r } +// CheckRequest 对复杂请求进行签名验证 +func CheckRequest(r *http.Request) error { + var ( + sign []string + ok bool + ) + if sign, ok = r.Header["Authorization"]; !ok || len(sign) == 0 { + return ErrAuthFailed + } + sign[0] = strings.TrimPrefix(sign[0], "Bearer ") + + return General.Check(getSignContent(r), sign[0]) +} + +// getSignContent 根据请求Header中是否包含X-Policy判断是否为上传请求, +// 返回待签名/验证的字符串 +func getSignContent(r *http.Request) (rawSignString string) { + if policy, ok := r.Header["X-Policy"]; ok { + rawSignString = serializer.NewRequestSignString(r.URL.Path, policy[0], "") + } else { + body, _ := ioutil.ReadAll(r.Body) + _ = r.Body.Close() + r.Body = ioutil.NopCloser(bytes.NewReader(body)) + rawSignString = serializer.NewRequestSignString(r.URL.Path, "", string(body)) + } + return rawSignString +} + // SignURI 对URI进行签名,签名只针对Path部分,query部分不做验证 func SignURI(uri string, expires int64) (*url.URL, error) { base, err := url.Parse(uri) @@ -76,7 +98,6 @@ func CheckURI(url *url.URL) error { } // Init 初始化通用鉴权器 -// TODO 测试 func Init() { var secretKey string if conf.SystemConfig.Mode == "master" { diff --git a/pkg/auth/auth_test.go b/pkg/auth/auth_test.go index 82dcacd..8407b8f 100644 --- a/pkg/auth/auth_test.go +++ b/pkg/auth/auth_test.go @@ -3,6 +3,7 @@ package auth import ( "github.com/HFO4/cloudreve/pkg/util" "github.com/stretchr/testify/assert" + "io/ioutil" "net/http" "strings" "testing" @@ -55,18 +56,68 @@ func TestSignRequest(t *testing.T) { // 非上传请求 { - req, err := http.NewRequest("POST", "http://127.0.0.1/api/v3/upload", strings.NewReader("I am body.")) + req, err := http.NewRequest("POST", "http://127.0.0.1/api/v3/slave/upload", strings.NewReader("I am body.")) asserts.NoError(err) - req = SignRequest(req, 10) + req = SignRequest(req, 0) asserts.NotEmpty(req.Header["Authorization"]) } // 上传请求 { - req, err := http.NewRequest("POST", "http://127.0.0.1/api/v3/upload", strings.NewReader("I am body.")) + req, err := http.NewRequest( + "POST", + "http://127.0.0.1/api/v3/slave/upload", + strings.NewReader("I am body."), + ) asserts.NoError(err) req.Header["X-Policy"] = []string{"I am Policy"} req = SignRequest(req, 10) asserts.NotEmpty(req.Header["Authorization"]) } } + +func TestCheckRequest(t *testing.T) { + asserts := assert.New(t) + General = HMACAuth{SecretKey: []byte(util.RandStringRunes(256))} + + // 非上传请求 验证成功 + { + req, err := http.NewRequest( + "POST", + "http://127.0.0.1/api/v3/upload", + strings.NewReader("I am body."), + ) + asserts.NoError(err) + req = SignRequest(req, 0) + err = CheckRequest(req) + asserts.NoError(err) + } + + // 上传请求 验证成功 + { + req, err := http.NewRequest( + "POST", + "http://127.0.0.1/api/v3/upload", + strings.NewReader("I am body."), + ) + asserts.NoError(err) + req.Header["X-Policy"] = []string{"I am Policy"} + req = SignRequest(req, 0) + err = CheckRequest(req) + asserts.NoError(err) + } + + // 非上传请求 失败 + { + req, err := http.NewRequest( + "POST", + "http://127.0.0.1/api/v3/upload", + strings.NewReader("I am body."), + ) + asserts.NoError(err) + req = SignRequest(req, 0) + req.Body = ioutil.NopCloser(strings.NewReader("2333")) + err = CheckRequest(req) + asserts.Error(err) + } +} diff --git a/pkg/auth/hmac.go b/pkg/auth/hmac.go index a482cff..d7840c5 100644 --- a/pkg/auth/hmac.go +++ b/pkg/auth/hmac.go @@ -15,7 +15,8 @@ type HMACAuth struct { SecretKey []byte } -// Sign 对给定Body生成expires后失效的签名 +// Sign 对给定Body生成expires后失效的签名,expires为过期时间戳, +// 填写为0表示不限制有效期 func (auth HMACAuth) Sign(body string, expires int64) string { h := hmac.New(sha256.New, auth.SecretKey) expireTimeStamp := strconv.FormatInt(expires, 10) diff --git a/pkg/auth/hmac_test.go b/pkg/auth/hmac_test.go index 641a6df..90f55c7 100644 --- a/pkg/auth/hmac_test.go +++ b/pkg/auth/hmac_test.go @@ -5,6 +5,7 @@ import ( "fmt" "github.com/DATA-DOG/go-sqlmock" model "github.com/HFO4/cloudreve/models" + "github.com/HFO4/cloudreve/pkg/conf" "github.com/HFO4/cloudreve/pkg/util" "github.com/gin-gonic/gin" "github.com/jinzhu/gorm" @@ -83,4 +84,10 @@ func TestInit(t *testing.T) { mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id", "value"}).AddRow(1, "12312312312312")) Init() asserts.NoError(mock.ExpectationsWereMet()) + + // slave模式 + conf.SystemConfig.Mode = "slave" + asserts.Panics(func() { + Init() + }) } diff --git a/pkg/util/logger.go b/pkg/util/logger.go index da05d9b..2879a02 100644 --- a/pkg/util/logger.go +++ b/pkg/util/logger.go @@ -33,13 +33,28 @@ var colors = map[string]func(a ...interface{}) string{ "Debug": color.New(color.FgWhite).Add(color.Bold).SprintFunc(), } +// 不同级别前缀与时间的间隔,保持宽度一致 +var spaces = map[string]string{ + "Warning": "", + "Panic": " ", + "Error": " ", + "Info": " ", + "Debug": " ", +} + // Println 打印 func (ll *Logger) Println(prefix string, msg string) { // TODO Release时去掉 color.NoColor = false c := color.New() - _, _ = c.Printf("%s %s %s\n", colors[prefix]("["+prefix+"]"), time.Now().Format("2006-01-02 15:04:05"), msg) + _, _ = c.Printf( + "%s%s %s %s\n", + colors[prefix]("["+prefix+"]"), + spaces[prefix], + time.Now().Format("2006-01-02 15:04:05"), + msg, + ) } // Panic 极端错误