You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
cronsun/web/base.go

258 lines
4.9 KiB

package web
import (
"bytes"
"encoding/json"
"fmt"
"net/http"
"runtime/debug"
"strconv"
"strings"
"time"
"github.com/shunfei/cronsun"
"github.com/shunfei/cronsun/conf"
"github.com/shunfei/cronsun/log"
"github.com/shunfei/cronsun/web/session"
)
var sessManager session.SessionManager
func InitServer() (*http.Server, error) {
sessManager = session.NewEtcdStore(cronsun.DefalutClient, conf.Config.Web.Session)
var err error
if err = checkAuthBasicData(); err != nil {
return nil, err
}
return initRouters()
}
type Context struct {
Data map[interface{}]interface{}
Session *session.Session
todos []func()
R *http.Request
W http.ResponseWriter
}
func newContext(w http.ResponseWriter, r *http.Request) *Context {
return &Context{
R: r,
W: w,
todos: make([]func(), 0, 2),
Data: make(map[interface{}]interface{}, 2),
}
}
func (ctx *Context) Todo(f func()) {
ctx.todos = append(ctx.todos, f)
}
func (ctx *Context) Done() {
n := len(ctx.todos) - 1
for i := n; i >= 0; i-- {
ctx.todos[i]()
}
}
type BaseHandler struct {
Ctx map[string]interface{}
BeforeHandle func(ctx *Context) (abort bool)
Handle func(ctx *Context)
}
func NewBaseHandler(f func(ctx *Context)) BaseHandler {
return BaseHandler{
BeforeHandle: authHandler(false, 0),
Handle: f,
}
}
func NewAuthHandler(f func(ctx *Context), reqRole cronsun.Role) BaseHandler {
return BaseHandler{
BeforeHandle: authHandler(true, reqRole),
Handle: f,
}
}
func authHandler(needAuth bool, reqRole cronsun.Role) func(*Context) bool {
return func(ctx *Context) (abort bool) {
var err error
ctx.Session, err = sessManager.Get(ctx.W, ctx.R)
if ctx.Session != nil {
ctx.Todo(func() {
if err := sessManager.Store(ctx.Session); err != nil {
log.Errorf("Failed to store session: %s.", err.Error())
}
})
}
if err != nil {
outJSONWithCode(ctx.W, http.StatusInternalServerError, err.Error())
abort = true
return
}
if !conf.Config.Web.Auth.Enabled || !needAuth {
return
}
if len(ctx.Session.Email) < 1 {
outJSONWithCode(ctx.W, http.StatusUnauthorized, "please login.")
abort = true
return
}
if r, ok := ctx.Session.Data["role"]; !ok {
outJSONWithCode(ctx.W, http.StatusUnauthorized, "role unknow.")
abort = true
return
} else if role, ok := r.(cronsun.Role); !ok {
outJSONWithCode(ctx.W, http.StatusUnauthorized, "role unknow.")
abort = true
return
} else if role > reqRole { // Lower number has more power.
outJSONWithCode(ctx.W, http.StatusUnauthorized, "higher role permission is required.")
abort = true
return
}
return
}
}
func (b BaseHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
defer func() {
// handle all the error
err_ := recover()
if err_ == nil {
return
}
var stack string
var buf bytes.Buffer
buf.Write(debug.Stack())
stack = buf.String()
outJSONWithCode(w, http.StatusInternalServerError, "Internal Server Error")
log.Errorf("%v\n\n%s\n", err_, stack)
return
}()
ctx := newContext(w, r)
defer ctx.Done()
if b.BeforeHandle != nil {
abort := b.BeforeHandle(ctx)
if abort {
return
}
}
b.Handle(ctx)
}
func outJSONWithCode(w http.ResponseWriter, httpCode int, data interface{}) {
w.Header().Set("Content-Type", "application/json")
s := ""
b, err := json.Marshal(data)
if err != nil {
s = `{"error":"json.Marshal error"}`
w.WriteHeader(http.StatusInternalServerError)
} else {
s = string(b)
w.WriteHeader(httpCode)
}
fmt.Fprint(w, s)
}
func outJSON(w http.ResponseWriter, data interface{}) {
outJSONWithCode(w, http.StatusOK, data)
}
func getStringArrayFromQuery(name, sep string, r *http.Request) (arr []string) {
val := strings.TrimSpace(r.FormValue(name))
if len(val) == 0 {
return
}
return strings.Split(val, sep)
}
func getPage(page string) int {
p, err := strconv.Atoi(page)
if err != nil || p < 1 {
p = 1
}
return p
}
func getPageSize(ps string) int {
p, err := strconv.Atoi(ps)
if err != nil || p < 1 {
p = 50
} else if p > 200 {
p = 200
}
return p
}
func getTime(t string) time.Time {
t = strings.TrimSpace(t)
time, _ := time.ParseInLocation("2006-01-02", t, time.Local)
return time
}
func getStringVal(n string, r *http.Request) string {
return strings.TrimSpace(r.FormValue(n))
}
func InStringArray(k string, ss []string) bool {
for i := range ss {
if ss[i] == k {
return true
}
}
return false
}
func UniqueStringArray(a []string) []string {
al := len(a)
if al == 0 {
return a
}
ret := make([]string, al)
index := 0
loopa:
for i := 0; i < al; i++ {
for j := 0; j < index; j++ {
if a[i] == ret[j] {
continue loopa
}
}
ret[index] = a[i]
index++
}
return ret[:index]
}
// 返回存在于 a 且不存在于 b 中的元素集合
func SubtractStringArray(a, b []string) (c []string) {
c = []string{}
for _, _a := range a {
if !InStringArray(_a, b) {
c = append(c, _a)
}
}
return
}